目录

一、决策树原理

二、实验过程

2.1信息增益

2.2增益率

2.3基尼指数

2.4数据集

2.5创建决策树

2.6保存和读取决策树

2.7绘制决策树

2.8使用决策树进行分类

2.9主函数


一、决策树原理

决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。

决策树学习的关键在于如何选择最优划分属性。一般而言,随着划分过程不断进行,我们希望决策树的分支结点 所包含的样本尽可能属于同一类别,即结点的“纯度 ”(purity)越来越高。经典的属性划分方法有三种:信息增益;增益率;基尼指数。

二、实验过程

2.1信息增益

离散属性a有V个可能的取值{a1, a2, ..., aV},用a来进行划分,则会产 生V个分支结点,其中第v个分支结点包含了D中所有在属性a上取值为 av的样本,记为Dv。则可计算出用属性a对样本集D进行划分所获得的 “信息增益”:

一般而言,信息增益越大,则意味着使用属性a来进行划分所获得的 “纯度提升”越大。

实现代码:
def chooseBestFeatureToSplit(dataSet):#特征数量,-1是因为最后一列是类别标签numFeatures = len(dataSet[0]) - 1#计算数据集的原始香农熵baseEntropy = calcShannonEnt1(dataSet)bestInfoGain = 0.0                       #信息增益赋初值0bestFeature = -1                         #最优特征的索引值for i in range(numFeatures):#获取dataSet的第i个所有特征存到featList中featList = [example[i] for example in dataSet]#print(featList)  #每个特征的15项特征值列表#创建set集合{},元素不可重复uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:#subDataSet划分后的子集subDataSet = splitDataSet(dataSet,i,value)#计算子集的概率=子集个数/整个训练样本数prob = len(subDataSet)/float(len(dataSet))#计算香农熵newEntropy += prob * calcShannonEnt(subDataSet)#计算信息增益infoGain = baseEntropy - newEntropy#print("第%d个特征的增益为%.3f" %(i,infoGain))#C4.5算法:计算增益比(信息增益率)#infoGain2 = (baseEntropy - newEntropy)/baseEntropyif (infoGain >bestInfoGain):bestInfoGain = infoGain      #更新信息增益,找到最大的信息增益bestFeature = i              #记录信息增益最大的特征的索引值return bestFeature                   #返回信息增益最大的特征的索引值

2.2增益率

称为属性a的“固有值” [Quinlan, 1993],属性a的可能取值数 目越多(即V越大),则IV(a)的值通常就越大。

增益率准则对可取值数目较少的属性有所偏好。

C4.5 [Quinlan, 1993]采用了一个启发式方法:先从候选划分属性中找出信 息增益高于平均水平的属性,再从中选取增益率最高的

实现代码:

def calcShannonEnt1(dataSet, method = 'none'):numEntries = len(dataSet)labelCount = {}for feature in dataSet:if method =='prob': #当参数为prob时转而计算增益率label =  featureelse:label = feature[-1]if label not in labelCount.keys():labelCount[label]=1else:labelCount[label]+=1shannonEnt = 0.0for key in labelCount:numLabels = labelCount[key]prob = numLabels/numEntriesshannonEnt -= prob*(log(prob,2))return shannonEnt
#增益率
def chooseBestFeatureToSplit2(dataSet): #使用增益率进行划分数据集numFeatures = len(dataSet[0]) -1 #最后一个位置的特征不算baseEntropy = calcShannonEnt(dataSet) #计算数据集的总信息熵bestInfoGain = 0.0bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]newEntropyProb = calcShannonEnt1(featList, method='prob') #计算内部增益率uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:# 通过不同的特征值划分数据子集subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob *calcGini(subDataSet)newEntropy  = newEntropy*newEntropyProbinfoGain = baseEntropy - newEntropy #计算每个信息值的信息增益if(infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature #返回信息增益的最佳索引

2.3基尼指数

定义:分类问题中,假设D有K个类,样本点属于第k类的概率为Pk, 则概率 分布的基尼值定义为:

Gini(D)越小,数据集D的纯度越高; 给定数据集D,属性a的基尼指数定义为:

在候选属性集合A中,选择那个使得划分后基尼指数最小的属性作为最有划分属性。

实现代码:

#基尼指数
def calcGini(dataset):feature = [example[-1] for example in dataset]uniqueFeat = set(feature)sumProb =0.0for feat in uniqueFeat:prob = feature.count(feat)/len(uniqueFeat)sumProb += prob*probsumProb = 1-sumProbreturn sumProb
def chooseBestFeatureToSplit3(dataSet): #使用基尼系数进行划分数据集numFeatures = len(dataSet[0]) -1 #最后一个位置的特征不算bestInfoGain = np.InfbestFeature = 0.0for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:# 通过不同的特征值划分数据子集subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob *calcGini(subDataSet)infoGain = newEntropyif(infoGain < bestInfoGain): # 选择最小的基尼系数作为划分依据bestInfoGain = infoGainbestFeature = ireturn bestFeature #返回决策属性的最佳索引

2.4数据集

关于购买手机的144条数据,六个属性分别为:颜色, 运行内存, 内存, 价格, 品牌,指纹解锁位置

用于验证的十条数据:

读入数据的代码:

import pandas as pdmyData = pd.read_excel('H:\Python project\shuju.xlsx',header = None)print(myData.head())myData = np.array(myData).tolist()for d in myData:for i in range(len(d)):d[i] = d[i].strip()Labels = ['颜色', '运行内存', '内存', '价格', '品牌','指纹解锁位置']

2.5创建决策树

#创建决策树
def createTree(dataSet,labels):#取分类标签classList = [example[-1] for example in dataSet]#特征可能存在多个属性,需要判断一下,如果类别完全相同则停止继续划分if classList.count(classList[0]) == len(classList):return classList[0]if len(dataSet[0]) == 1:return majorityCnt(classList)                 #遍历完所有特征时返回出现次数最多的类标签bestFeat = chooseBestFeatureToSplit3(dataSet)      #选择最优特征bestFeatLabel = labels[bestFeat]                  #最优特征的类标签myTree = {bestFeatLabel:{}}                       #根据最有特征的标签生成树#得到训练集中所有最优特征的属性值featValues = [example[bestFeat] for example in dataSet]#去掉重复的属性值uniqueVals = set(featValues)#遍历特征,创建决策树for value in uniqueVals:subLabels = labels[:]myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)return myTree

2.6保存和读取决策树

#存储决策树
def storeTree(inputTree, filename):import picklefw = open(filename, 'wb')pickle.dump(inputTree, fw)fw.close()
#读取决策树
def grabTree(filename):import picklefr = open(filename)return pickle.load(fr)  #决策树字典

2.7绘制决策树


#绘制决策树
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei']decisionNode = dict(boxstyle="sawtooth", fc="0.8")   #设置结点格式
leafNode = dict(boxstyle="round4", fc="0.8")         #设置叶节点格式
arrow_args = dict(arrowstyle="<-")                   #设置箭头格式#获取决策树叶子结点数目
def getNumLeafs(myTree):numLeafs = 0                                     #初始化叶子结点#firstStr = myTree.keys()[0]firstStr = next(iter(myTree))                    #获取结点属性secondDict = myTree[firstStr]                    #获取下一组字典for key in secondDict.keys():#测试该结点是否为字典,如果不是字典,代表此结点为叶子结点if type(secondDict[key]).__name__=='dict':numLeafs += getNumLeafs(secondDict[key])else:   numLeafs +=1return numLeafs#获取决策树的深度
def getTreeDepth(myTree):maxDepth = 0                              #初始化决策树层数firstStr = next(iter(myTree))secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':thisDepth = 1 + getTreeDepth(secondDict[key])else:   thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepth              #更新树的层数return maxDepth#绘制结点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)#标注有向边属性值
def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):#获取决策树叶结点数目,决定了树的宽度numLeafs = getNumLeafs(myTree)#获取决策树层数depth = getTreeDepth(myTree)#获取下一个字典firstStr = next(iter(myTree))#中心位置cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)#标注有向边属性值plotMidText(cntrPt, parentPt, nodeTxt)#绘制结点plotNode(firstStr, cntrPt, parentPt, decisionNode)#下一个字典,继续绘制子结点secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalDfor key in secondDict.keys():# 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点if type(secondDict[key]).__name__=='dict':#不是叶子结点,递归调用继续绘制plotTree(secondDict[key],cntrPt,str(key))else:      #是叶子结点则进行绘制,标注有向边plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#创建绘制面板
def createPlot(inTree):fig = plt.figure(1, facecolor='white')            #创建figfig.clf()                                         #清空figaxprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)     #去掉x,y轴plotTree.totalW = float(getNumLeafs(inTree))      #决策树叶子结点数目plotTree.totalD = float(getTreeDepth(inTree))     #决策树层数plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;   #x偏移plotTree(inTree, (0.5,1.0), '')                   #绘制决策树plt.show()

2.8使用决策树进行分类

#使用决策树的分类函数
def classify(inputTree,featLabels,testVec):#firstStr = next(iter(inputTree))            #获取决策树结点firstStr = list(inputTree.keys())[0]#print(firstStr)secondDict = inputTree[firstStr]            #下一个字典featIndex = featLabels.index(firstStr)      #获取存储选择的最优特征标签的索引classLabel = -1for key in secondDict.keys():if testVec[featIndex] == key:# 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点if type(secondDict[key]).__name__=='dict':classLabel = classify(secondDict[key],featLabels,testVec)else:classLabel = secondDict[key]# 标记classLabel为-1当循环过后若仍然为-1,表示未找到该数据对应的节点则我们返回他兄弟节点出现次数最多的类别if classLabel == -1:return (getLeafBestCls(inputTree))else:return classLabel#求该节点下所有叶子节点的列表
def getLeafscls(myTree, clsList):numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':clsList =getLeafscls(secondDict[key],clsList)else:clsList.append(secondDict[key])return clsList#返回出现次数最多的类别
def getLeafBestCls(myTree):clsList = []resultList = getLeafscls(myTree,clsList)return max(resultList,key = resultList.count)

2.9主函数

if __name__ == '__main__':import pandas as pdmyData = pd.read_excel('H:\Python project\shuju.xlsx',header = None)print(myData.head())myData = np.array(myData).tolist()for d in myData:for i in range(len(d)):d[i] = d[i].strip()Labels = ['颜色', '运行内存', '内存', '价格', '品牌','指纹解锁位置']myTree = createTree(myData, Labels)storeTree(storeTree, 'saveTree.txt')  #将决策树存储进入txt文件treePlotter.createPlot(myTree)print("预测结果:")print(classify(myTree, Labels, ['白色', '8G', '256G', '3000以下', '华为','正面解锁']))dataSet2 = [['白色', '8G', '256G', '3000以上', '小米', '侧面解锁'],['白色', '16G', '128G', '3000以下', '华为', '正面解锁'],['黑色', '8G', '128G', '3000以上', '小米', '反面解锁'],['蓝色', '8G', '256G', '3000以下', '华为', '反面解锁'],['黑色', '16G', '256G', '3000以上', '小米', '反面解锁'],['黑色', '8G', '256G', '3000以下', '小米', '侧面解锁'],['黑色', '16G', '128G', '3000以下', '小米', '侧面解锁'],['蓝色', '16G', '256G', '3000以上', '华为', '反面解锁'],['白色', '8G', '256G', '3000以下', '华为', '反面解锁']]for dataVet in dataSet2:print(classify(myTree,Labels,dataSet2))

使用信息增益方法生成的决策树:

对验证数据的验证结果:

正确率为60%

使用增益率方法生成的决策树:

对验证数据的验证结果:

正确率为60%

使用基尼指数方法生成的决策树:

对验证数据的验证结果:

正确率为60%

可能是代码原因或者是数据出现问题,导致三种方法的正确率都一样

信息增益的缺点:通过信息增益来选择特征会偏好取值较多的特征,一个极端情况的例子就是如果对于某个特征所有样本在此特征上的取值都不相同,这样通过这个特征计算得到的信息增益式最大的。

增益率的缺点:信息增益率朝着信息增益相反的方向发展,偏好于取值较少的特征。

决策树---使用三种方法对数据建立决策树相关推荐

  1. 三种方法实现数据离散化-python实现

    #-*- coding: utf-8 -*- #数据规范化 import pandas as pddatafile = '../data/discretization_data.xls' #参数初始化 ...

  2. MATLAB笔记:打开数据文件的三种方法+读取数据文件的两种方法+保存数据文件的两种方法

    1.打开数据文件 1.1 直接打开文件 PATHNAME = 'C:\Users\s55\Desktop\dat'; FILENAME = '\data_1.dat'; str0=strcat(PAT ...

  3. oracle数据迁移过程中,把表中数据导出为txt文件的三种方法

    在数据迁过程中需求需要将oracle数据库数据导出程txt格式然后再导入db2库中,经查询实验汇总三种方法: 1.plsqldev 里面有一个选项可以把表以excel格式到时 2.使用spool sq ...

  4. 处理数据中分类变量的三种方法(附代码实现)

    本文是将kaggle Courses中 Categorical Variables | Kaggle进行了翻译并且加入自己的理解,如有地方不清楚,可以查阅原文 文章目录 介绍 三种方法 1)删除分类变 ...

  5. 前端js调用后端API获取数据的三种方法(2022.7.25)

    前端js调用后台API获取数据的三种方法(2022.7.25) 前言 需求分析 一个Get实例 浏览器请求 SoapUI软件请求 一个Post实例 浏览器请求 SoapUI软件请求 1.Http简介( ...

  6. java数据输入的步骤_Java学习日志1.4 Scanner 数据输入的三种方法

    Scanner sc = new Scanner(System.in); /注意in 是InputStream的缩写,是字节输入流的意思. 整句话的含义就是: new 一个对象,接受从键盘输入的数据, ...

  7. mysql如何防止插入重复数据_防止MySQL重复插入数据的三种方法

    新建表格 CREATE TABLE `person` ( `id` int NOT NULL COMMENT '主键', `name` varchar(64) CHARACTER SET utf8 C ...

  8. python csv库,Python 中导入csv数据的三种方法

    Python 中导入csv数据的三种方法,具体内容如下所示: 1.通过标准的Python库导入CSV文件: Python提供了一个标准的类库CSV文件.这个类库中的reader()函数用来导入CSV文 ...

  9. mysql改存储引擎教程_MySQL中修改数据表存储引擎的三种方法

    第一种方法:ALTER TABLE 将表从一个引擎修改为另一个引擎最简单的办法是使用ALTER TABLE语句,转换表的存储引擎会导致失去原引擎相关的所有特性. 例:将mytable的引擎修改为Inn ...

  10. 查看Mat对象的数据的三种方法

    我们有了Mat的对象之后,就可以开始对图像进行处理. 在图像的处理过程中,对数据的查看并且对其进行修改,这应当是比较频繁的操作了. 这里讲讲官方手册当中给出的三种方法. 第一种方法:使用指向Mat数据 ...

最新文章

  1. Android:Activity(页面)的生存周期
  2. Java中需要全部小写的是,下列哪些是Java中常量的命名约定?A、全部为小写字母B、全部为大写字母C、大小写混合D、字与字之间...
  3. 查找算法之四 斐波那契查找(C++版本)
  4. easyui常用控件样式收藏
  5. 一、Java语言基础(5)_数组高级——方法参数的值传递机制
  6. 带485接口伺服电机使用MODBUS协议控制
  7. matlab 力学,力学专业程序实践:用MATLAB解决力学问题的方法与实例
  8. matlab 生成hasse图,Hasse图详解
  9. vmware linux dns,vmware环境下在linux中创建dns服务器
  10. 男孩子读博士的十大好处
  11. win10共享文件夹无法访问问题
  12. 微信扫码登录自定义二维码显示信息
  13. 联想电脑thinkPad开机黑屏
  14. 桌面多出一个IE图标无法删除的解决办法
  15. pyltp学习笔记——中文语言处理工具
  16. JavaScript - 判断浏览器内的页面是在手机端还是电脑(PC)上打开的(判断用户访问设备是什么)
  17. 2011级《软件设计模式》考试试题(开卷)
  18. AppStore发布app问题汇总(一)
  19. arm9 中断向量 重定位_ARM的启动和中断向量表
  20. 一起学习Hive基础(多知识点)

热门文章

  1. 平方和误差函数--代价函数(机器学习)
  2. 诺诺开放平台(电子发票、智能编码、发票查验接口调用)
  3. 中国大陆身份证号码有效性验证
  4. 【PS】证件照修改尺寸
  5. 思科网络模拟器7.3.1版本的下载和安装
  6. 大数据与云计算的关系
  7. 关于magic-api的使用入门
  8. Transformer 权重共享
  9. 高数__已知2个平面方程, 求这2个平面的夹角
  10. 安科瑞【节能学院】电气火灾监控系统在太焦铁路博爱站房项目的应用