一、介绍

决策树(Decision Tree)是有监督学习中的一种算法,并且是一种基本的分类与回归的方法。有分类树和回归树两种。

决策树的算法本质是树形结构,我们可以把决策树看成是一个if-then规则的集合。将决策树转换成if-then规则的过程是这样的:

  • 由决策树的根节点到叶节点的每一条路径构建一条规则
  • 路径上中间节点的特征对应着规则的条件,叶节点的类标签对应着规则的结论

决策树的路径或者其对应的if-then规则集合有一个重要的性质:互斥且完备。即每一个实例都被有且仅有一条路径或者规则所覆盖。这里的覆盖指实例的特征与路径上的特征一致,或实例满足规则的条件。

二、香农熵和信息增益

香农熵及计算函数:

l(xi)l(x_i)l(xi​) = −log2p(xi)-log_2p(x_i)−log2​p(xi​)

Ent(D)Ent(D)Ent(D) = −∑i=1np(xi)log2p(xi)-\sum_{i=1}^np(x_i)log_2p(x_i)−∑i=1n​p(xi​)log2​p(xi​)

Ent(D)Ent(D)Ent(D)的值越小,则D的不纯度就越低。

信息增益

Gain(D,a)Gain(D,a)Gain(D,a) = Ent(D)Ent(D)Ent(D)−∑v=1V∣Dv∣∣D∣Ent(Dv)-\sum_{v=1}^V \frac{|D^v|}{|D|}Ent(D^v)−∑v=1V​∣D∣∣Dv∣​Ent(Dv)

Python实现:

#创建数据集,书中海洋生物为例
import numpy as np
import pandas as pddef createDataSet():row_data = {'no surfacing':[1,1,1,0,0],'flippers':[1,1,0,1,1],'fish':['yes','yes','no','no','no']}dataSet = pd.DataFrame(row_data)return dataSet# 计算香农熵
def calEnt(dataSet):n = dataSet.shape[0] iset = dataSet.iloc[:,-1].value_counts() p = iset/n ent = (-p*np.log2(p)).sum()return ent# 根据信息增益选择出最佳数据集切分的列
def bestSplit(dataSet):baseEnt = calEnt(dataSet) #计算原始熵bestGain = 0 #初始化信息增益axis = -1 #初始化最佳切分列,标签列for i in range(dataSet.shape[1]-1): #对特征的每一列进行循环,-1是不需要对标签列循环levels= dataSet.iloc[:,i].value_counts().index #提取出当前列的所有取值ents = 0 #初始化子节点的信息熵for j in levels: #对当前列的每一个取值进行循环childSet = dataSet[dataSet.iloc[:,i]==j] #某一个子节点的dataframeent = calEnt(childSet) #计算某一个子节点的信息熵ents += (childSet.shape[0]/dataSet.shape[0])*ent #计算当前列的信息熵#print(f'第{i}列的信息熵为{ents}')infoGain = baseEnt-ents #计算当前列的信息增益#print(f'第{i}列的信息增益为{infoGain}')if (infoGain > bestGain):bestGain = infoGain #选择最大信息增益axis = i #最大信息增益所在列的索引return axis# 按照给定的列划分数据集
def mySplit(dataSet,axis,value):col = dataSet.columns[axis]redataSet = dataSet.loc[dataSet[col]==value,:].drop(col,axis=1)return redataSet

三、递归构建决策树

构建决策树的算法有很多,例如ID3、C4.5、CART。在此处选择ID3。

ID3算法的核心是在决策树各个节点上对应信息增益准则选择特征,递归地构建决策树。
具体做法:从根节点开始,对节点计算所有可能的特征的信息增益,选择信息增益最大的特征作为节点的特征,由该特征的不同取值建立子节点;再对子节点递归地调用以上方法,构建决策树;直到所有特征信息增益均很小或没有特征可以选择为止。最后得到一个决策树。

递归结束的条件:程序遍历完所有的特征列,或者每个分支下的所有实例都具有相同的分类。如果所有实例均具有相同分类,则得到一个叶节点。任何到达叶节点的数据必然属于叶节点的分类,即叶节点里面必须是标签。

def createTree(dataSet):featlist = list(dataSet.columns) classlist = dataSet.iloc[:,-1].value_counts() #判断最多标签数目是否等于数据集行数,或者数据集是否只有一列if classlist[0]==dataSet.shape[0] or dataSet.shape[1] == 1:return classlist.index[0] axis = bestSplit(dataSet) bestfeat = featlist[axis] myTree = {bestfeat:{}} del featlist[axis] valuelist = set(dataSet.iloc[:,axis]) for value in valuelist: myTree[bestfeat][value] = createTree(mySplit(dataSet,axis,value))return myTreemyTree = createTree(dataSet)
myTree#树的存储
np.save('myTree.npy',myTree)#树的读取
read_myTree = np.load('myTree.npy').item()
read_myTree# 对一个测试实例进行分类
def classify(inputTree,labels, testVec):firstStr = next(iter(inputTree)) secondDict = inputTree[firstStr] featIndex = labels.index(firstStr)for key in secondDict.keys():if testVec[featIndex] == key:if type(secondDict[key]) == dict :classLabel = classify(secondDict[key], labels, testVec)else:classLabel = secondDict[key]return classLabel# 对测试集进行预测,并返回预测后的结果
def acc_classify(train,test):inputTree = createTree(train)labels = list(train.columns)result = []for i in range(test.shape[0]): testVec = test.iloc[i,:-1] classLabel = classify(inputTree,labels,testVec) result.append(classLabel)test['predict']=result acc = (test.iloc[:,-1]==test.iloc[:,-2]).mean() print(f'模型预测准确率为{acc}')return test

四、使用SKlearn中graphviz绘制决策树

#导入相应的包
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
import graphviz#特征
Xtrain = dataSet.iloc[:,:-1]
#标签
Ytrain = dataSet.iloc[:,-1]
labels = Ytrain.unique().tolist()
Ytrain = Ytrain.apply(lambda x: labels.index(x))#绘制树模型
clf = DecisionTreeClassifier()
clf = clf.fit(Xtrain, Ytrain)
tree.export_graphviz(clf)
dot_data = tree.export_graphviz(clf, out_file=None)
graphviz.Source(dot_data)#给图形增加标签和颜色
dot_data = tree.export_graphviz(clf, out_file=None,feature_names=['no surfacing', 'flippers'],class_names=['fish', 'not fish'],filled=True, rounded=True,special_characters=True)
graphviz.Source(dot_data)#利用render方法生成图形
graph = graphviz.Source(dot_data)
graph.render("fish")

五、决策树可视化

# 递归计算叶子节点的数目
def getNumLeafs(myTree):numLeafs = 0 firstStr = next(iter(myTree))secondDict = myTree[firstStr] for key in secondDict.keys():if type(secondDict[key]) == dict: numLeafs += getNumLeafs(secondDict[key]) else:numLeafs +=1 #不是字典,代表此结点为叶子结点return numLeafs# 递归计算树的深度
def getTreeDepth(myTree):maxDepth = 0firstStr = next(iter(myTree))secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]) == dict:thisDepth = 1+getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth>maxDepth:maxDepth = thisDepthreturn maxDepth# 绘制节点
def plotNode(nodeTxt, cntrPt, parentPt, nodeType):arrow_args = dict(arrowstyle="<-") createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',  # axes fractionxytext=cntrPt, textcoords='axes fraction',va="center", ha="center",bbox=nodeType,arrowprops=arrow_args)# 标注有向边属性值
def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=0)# 绘制决策树
def plotTree(myTree, parentPt, nodeTxt):decisionNode = dict(boxstyle="sawtooth", fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8")numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = next(iter(myTree))cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys():if type(secondDict[key])== dict: plotTree(secondDict[key],cntrPt,str(key)) else:plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD# 创建绘制面板
def createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf() axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5/plotTree.totalWplotTree.yOff = 1.0 plotTree(inTree, (0.5,1.0), '') plt.show()

六、使用决策树预测隐形眼镜

# 导入数据集
lenses = pd.read_table('lenses.txt',header = None)
lenses.columns =['age','prescript','astigmatic','tearRate','class']# 划分训练集和测试集
import random
def randSplit(dataSet, rate):l = list(dataSet.index) #提取出索引random.shuffle(l) #随机打乱索引dataSet.index = l #将打乱后的索引重新赋值给原数据集n = dataSet.shape[0] #总行数m = int(n * rate) #训练集的数量train = dataSet.loc[range(m), :] #提取前m个记录作为训练集test = dataSet.loc[range(m, n), :] #剩下的作为测试集dataSet.index = range(dataSet.shape[0]) #更新原数据集的索引test.index = range(test.shape[0]) #更新测试集的索引return train, test#利用训练集生成决策树
lensesTree = createTree(train1)
lensesTree#构造注解树
createPlot(lensesTree)#用决策树进行分类并计算有预测准确率
acc_classify(train1,test1)# 使用SKlearn中graphviz绘制决策树
#特征列
Xtrain1 = train1.iloc[:,:-1]
for i in Xtrain1.columns:labels = Xtrain1[i].unique().tolist()Xtrain1[i]= Xtrain1[i].apply(lambda x: labels.index(x))#标签列
Ytrain1 = train1.iloc[:,-1]
labels = Ytrain1.unique().tolist()
Ytrain1= Ytrain1.apply(lambda x: labels.index(x))#绘制树形图
clf = DecisionTreeClassifier()
clf = clf.fit(Xtrain1, Ytrain1)
tree.export_graphviz(clf)
dot_data = tree.export_graphviz(clf, out_file=None)
graphviz.Source(dot_data)#添加标签和颜色
dot_data = tree.export_graphviz(clf, out_file=None,feature_names=['age', 'prescript', 'astigmatic','tearRate'],class_names=['soft','hard','no lenses'],filled=True, rounded=True,special_characters=True)
graphviz.Source(dot_data)#使用render存储树形图
graph = graphviz.Source(dot_data)
graph.render("lense")

七、算法优缺点

优点:

(1)决策树可以可视化,易于理解和解释;

(2)数据准备工作很少。其他很多算法通常都需要数据规范化,需要创建虚拟变量并删除空值等;

(3)能够同时处理数值和分类数据,既可以做回归又可以做分类。其他技术通常专门用于分析仅具有一种变类型的数据集;

(4)效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度;

(5)能够处理多输出问题,即含有多个标签的问题,注意与一个标签中含有 多种标签分类的问题区别开;

(6)是一个白盒模型,结果很容易能够被解释。如果在模型中可以观察到给定的情况,则可以通过布尔逻辑轻松解释条件。相反,在黑盒模型中(例如,在人工神经网络中),结果可能更难以解释。

缺点:

(1)递归生成树的方法很容易出现过拟合。

(2)决策树可能是不稳定的,因为即使非常小的变异,可能会产生一颗完全不同的树

(3)如果某些分类占优势,决策树将会创建一棵有偏差的树。因此,建议在拟合决策树之前平衡数据集。

决策树案例学习(Python实现)相关推荐

  1. python数值运算实例_“每天进步一点点”案例学习python数值操作

    这是树哥讲python系列的第四篇文章. 本质上计算机熟悉的是二进制,也就是我们常说的"0,1"代码,所以无论是执行的命令还是数据本身,都必须转化为0和1他们才会认知.而我们熟悉的 ...

  2. 案例学习|Python实现某医院药品销售分析

    数据分析的基本过程一般分为以下几个部分: 提出问题 获取并理解数据 数据清洗 构建模型 数据可视化 1.提出问题 在数据分析之前,我们先要明确分析目标,可以帮助我们更高效的选取数据,进行分析研究. 本 ...

  3. 在我的新书里,尝试着用股票案例讲述Python爬虫大数据可视化等知识

    我的新书,<基于股票大数据分析的Python入门实战>,预计将于2019年底在清华出版社出版. 如果大家对大数据分析有兴趣,又想学习Python,这本书是一本不错的选择.从知识体系上来看, ...

  4. python全案例学习_Python全案例学习与实践

    第一篇 营造环境 第1章 Python及其安装 1.1 我为什么要学习Python 1.1.1 它的广泛性 1.1.2 它的新颖性 1.1.3 它的生态性 1.1.4 Python的应用领域举例 1. ...

  5. python中continue用法案例_记录今天学习python中for与while循环针对break和continue的用法...

    python中有两个主要的循环for与while,其中针对这两个循环有两种不同的中断用法break与continue. 首先先看下面的循环代码: 1: for i in range(10):#变量i带 ...

  6. 最详细的python案例学习与实践(含详细教程)

    为了学习Python,使用Python开发应用程序,手头必须有得力的工具.在笔者看来,Python和PyCharm是最基本的学习和开发工具.如果要成为教育家或作家型的Python程序员,那么还必须准备 ...

  7. python大神-国内某Python大神自创完整版,系统性学习Python

    很多小伙伴纠结于这个一百天的时间,我觉得完全没有必要,也违背了我最初放这个大纲上来的初衷,我是觉得这个学习大纲还不错,自学按照这个来也能相对系统的学习知识,而不是零散细碎的知识最后无法整合,每个人的基 ...

  8. python网络编程要学吗_总算发现如何学习python网络编程

    为了提高模块加载的速度,每个模块都会在__pycache__文件夹中放置该模块的预编译模块,命名为module.version.pyc,version是模块的预编译版本编码,一般都包含Python的版 ...

  9. python100天-如何系统地学习 Python,100天从新手到大师

    如果你还是迷茫到底如何系统的学习Python,没关系! 为大家整理了Python100天从新手到大师的系统学习教程,让你不用再纠结,一定能帮助到你的问题. 包括从基础的Python脚本到web开发.爬 ...

  10. 用通俗易懂的方式讲解:主成分分析(PCA)算法及案例(Python 代码)

    文章目录 知识汇总 加入方式 一.引入问题 二.数据降维 三.PCA基本数学原理 3.1 内积与投影 3.2 基 3.3 基变换的矩阵表示 3.4 协方差矩阵及优化目标 3.5 方差 3.6 协方差 ...

最新文章

  1. 程序员的求生欲有所强?用 Python 花式哄女友
  2. css通配选择符用什么表示,css的选择符
  3. java 执行linux命令行_10个高效Linux技巧及Vim命令对比
  4. DVWA学习(二)SQL Injection(Blind)
  5. leetCode刷题--两数相加
  6. recv函数_第五十二章、send和recv函数
  7. SpringBoot 迭代输出
  8. jzoj6305-最小值【线段树,dp,双端链表】
  9. python基本语法1.2--数的移位及与或抑或相关计算
  10. 字符串循环右移的一道题目
  11. springboot请求处理
  12. atoi()函数和stoi()函数
  13. Java Spring Security 安全框架:(四)PasswordEncoder 密码解析器详解
  14. linux 内核rps,Linux内核中RPS/RFS代码分析
  15. K-Means聚类算法原理及实现
  16. web前端期末大作业:旅游网页主题网站设计——桂林旅游网站的设计 (7页)HTML+CSS+JavaScript web网页设计与开发 静态网页的制作 web期末作业设计网页 web结课作业的源
  17. flink sql 执行源码走读全流程
  18. 初学opencv 2
  19. 国家计算机二级在线模拟试题,全国计算机二级机试题模拟试题10套(一).pdf
  20. 【UE4】视角制作相关知识点(蓝图)

热门文章

  1. 64位计算机装32位系统,32位装64位系统教程
  2. When Color Constancy Goes Wrong:Correcting Improperly White-Balanced Images阅读札记
  3. sublimetext的文件编码理解reopen/reload with encoding,set encoding,save with encoding, set file encoding to
  4. Photoshop-选区的应用
  5. 统计局解读1月制造业采购经理指数:服务业回升明显
  6. LWIP应用开发|DNS域名解析
  7. Verilog仿真器
  8. 易语言删除全部空白字符
  9. 此处纸薄不经墨,待入章中再续貂
  10. VB中数组的大小排序解析