1 决策树的构造

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据
缺点:可能产生过度匹配问题
适用数据类型:数值型和标称型

解决的首要问题:当前数据集上哪个特征在划分数据分类时起决定性作用
创建分支的伪代码函数createBranch():

检测数据集中的每个子项是否属于同一分类;
if so return 类标签;
Else寻找划分数据集的最好特征划分数据集创建分支节点for 每个划分的子集调用函数createBranch并增加返回结果到分支节点中return 分支节点

决策树的一般流程

  1. 收集数据
  2. 准备数据:只适用于标称型数据,数值型数据必须离散化
  3. 分析数据:检查图形是否符合预期
  4. 训练算法:构造树的数据结构
  5. 测试算法:使用经验树计算错误率
  6. 使用算法:可以适用于任何监督学习算法,使用决策树可以更好地理解数据的内在含义

1.1 信息增益

信息增益:划分数据集之后信息发生的变化
通过计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择
符号xi的信息定义:
I(xi)=-log2p(xi)
计算熵:
H=- ∑ i = 1 n \sum_{i=1}^{n} ∑i=1n​p(xi)log2p(xi)
定义calcShannonEnt函数计算给定数据集的香农熵:

def calcShannonEnt(dataSet):numEntries = len(dataSet)  # 计算数据集中实例的总数labelCounts = {}  # 新建字典,记录每个分类下的数据个数for featVec in dataSet:  # 为所有可能的分类创建字典currentLabel = featVec[-1]  # 将dataSet每一个元素的最后一个元素选择出来if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0  # 当没有该键时,使用字典的自动添加添加值为0的项labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key]) / numEntries  # 取概率shannonEnt -= prob * log(prob, 2)  # log(x,2)表示以2为底求x的对数return shannonEnt




熵越高,则混合的数据也越多

得到熵后,就可以按照获得最大信息增益的方法划分数据集

注:另一个度量集合无序程度的方法是基尼不纯度:从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率

1.2 划分数据集

划分数据集:

def splitDataSet(dataSet, axis, value):  # 按照给定特征划分数据集。dataSet是待划分的数据集,axis是划分数据集的特征,value是需要返回的特征值retDataSet = []  # 创建新的list对象(为了不修改原始数据集,数据集这个列表的各个元素也是列表)for featVec in dataSet:  # 将符合特征的数据抽取出来if featVec[axis] == value:reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis + 1:])  # extend用于在列表末尾一次性追加另一个序列的多个值retDataSet.append(reducedFeatVec)  # append用于在列表末尾添加新的对象return retDataSet



选择最好的数据集划分方式:

def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1  # 计算所有特征数baseEntropy = calcShannonEnt(dataSet)  # 计算原始香农熵,保存最初的无序度量值,用于与划分完之后的数据集计算的熵值进行比较bestInfoGain = 0.0bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]  # 遍历所有样本的第i个特征的取值情况(使用列表推导创建新的列表)uniqueVals = set(featList)  # 第i条特征的取值(去重)   set函数用于创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可计算交集、差集、并集等newEntropy = 0.0for value in uniqueVals:  # 计算每种划分方式的信息熵。对每个特征划分一次数据集,然后计算数据集的新熵值,并对所有唯一特征值得到的熵求和subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet) / float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif (infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature

trees.chooseBestFeatureToSplit(myDat)
输出:0

1.3 递归构建决策树

需要注意的点:由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分
递归结束的条件:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类,则得到一个叶子结点或终止块
需要考虑的特殊情况:当数据集已经处理了所有的属性,但类标签依然不是唯一的,此时需要决定如何定义该叶子节点。通常采用多数表决的方法决定该叶子结点的分类。

def majorityCnt(classList):classCount = {}#创建键值为classList中唯一值的数据字典,存储classList中每个类标签出现的频率,然后利用operator操作键值排序字典,返回出现次数最多的分类名称for vote in classList:if vote not in classCount.keys(): classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]
def createTree(dataSet, labels):  # 参数为数据集和标签列表 标签列表包含了数据集中所有特征的标签,为了给出数据明确的含义将其作为输入参数提供classList = [example[-1] for example in dataSet]  # 取标签值# 第一个停止条件:所有的类标签完全相同,则直接返回该类标签if classList.count(classList[0]) == len(classList):  # count函数用于统计某个元素在列表中出现的次数return classList[0]  # 列表完全相同则停止继续划分# 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,因此挑选出现次数最多的类别作为返回值if len(dataSet[0]) == 1:return majorityCnt(classList)  # 遍历完所有特征时返回出现次数最多的bestFeat = chooseBestFeatureToSplit(dataSet)  # 当前数据集选取的最好特征bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel: {}}  # 用字典存储树的结构del (labels[bestFeat])  # 删除取出过的标签,避免重复计算featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)  # 得到列表包含的所有属性值,利用set去重for value in uniqueVals:subLabels = labels[:]  # 复制所有的子标签,因为是引用类型,以避免改变原始标签数据myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)  # 递归构建树return myTree

myDat, labels = trees.createDataSet()
myTree = trees.createTree(myDat, labels)
print(myTree)


可以看出myTree包含了很多树结构信息的嵌套字典
第一个关键字是第一个划分数据集的特征名称,该关键字的值也是另一个数据字典;第二个关键字是no surfacing特征划分的数据集,这些关键字的值是no surfacing节点的子节点:值是类标签时,说明该子节点是叶子节点;值是另一个数据字典时,则子节点是一个判断节点

2 使用matplotlib注解绘制数图形

matplotlib提供的注解工具:在数据图形上添加文本注释

2.1 使用文本注解绘制树节点

treePlotter.py:

import matplotlib.pyplot as plt#定义树节点格式的常量
decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 决策节点的属性。boxstyle为文本框的类型,sawtooth为锯齿形,fc为边框线粗细
leafNode = dict(boxstyle="round4", fc="0.8")  # 决策树叶子结点的属性
arrow_args = dict(arrowstyle="<-")  # 剪头的属性def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 执行绘图功能。绘图区域由全局变量createPlot.ax1定义createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)# plt.annotate(str, xy=data_point_position, xytext=annotate_position,
#va="center",  ha="center", xycoords="axes fraction",#              textcoords="axes fraction", bbox=annotate_box_type, arrowprops=arrow_style)# str是给数据点添加注释的内容,支持输入一个字符串# xy=是要添加注释的数据点的位置# xytext=是注释内容的位置# bbox=是注释框的风格和颜色深度,fc越小,注释框的颜色越深,支持输入一个字典# va="center",  ha="center"表示注释的坐标以注释框的正中心为准,而不是注释框的左下角(v代表垂直方向,h代表水平方向)# xycoords和textcoords可以指定数据点的坐标系和注释内容的坐标系,通常只需指定xycoords即可,textcoords默认和xycoords相同# arrowprops可以指定箭头的风格支持,输入一个字典# plt.annotate()的详细参数可用__doc__查看,如:print(plt.annotate.__doc__)def createPlot():  # 代码核心。首先创建一个新图形并清空绘图区,然后在绘图区上绘制两个代表不同类型的树节点,后面用这两个结点绘制树图形fig = plt.figure(1, facecolor='white')  # 1表示图形编好/名称fig.clf()  # 表示清除所有轴createPlot.ax1 = plt.subplot(111, frameon=False)  # 为对象添加属性  frameon=true时图示被绘制在一个patch实体上;=false则图示直接被绘制在图形上plotNode('a dicision node', (0.5, 0.1), (0.1, 0.5), decisionNode)plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)plt.show()

2.2 构造注解树

def getNumLeafs(myTree):  # 获取叶节点的数目numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():  # 测试节点的数据类型是否为字典if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])else:numLeafs += 1return numLeafsdef getTreeDepth(myTree):  # 获取树的层数maxDepth = 0firstStr = list(myTree.keys())[0]  # keys()返回一个字典的所有键secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1 + getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepthdef retrieveTree(i):  # 输出预先存储的树信息,避免每次测试代码时都要从数据中创建树的麻烦listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]return listOfTrees[i]

Terminal:

print(treePlotter.retrieveTree(1))
myTree = treePlotter.retrieveTree(0)
print(treePlotter.getNumLeafs(myTree))
print(treePlotter.getTreeDepth(myTree))

def plotMidText(cntrPt, parentPt, txtString):  # 用于计算父节点和子节点的中间位置,并在此添加简单的文本标签信息xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):  # 完成绘制树形的大多数工作numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)# 按照图形比例绘图,根据叶子结点的数目划分图形的宽度,从而计算得到当前结点的中心位置plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # 按比例减少全局变量plotTree.yOff# 由于是自顶向下绘制图形,因此需要依次递减y坐标值for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':  # 当节点不是叶子节点时递归调用plotTreeplotTree(secondDict[key], cntrPt, str(key))else:  # 当节点时叶子节点时在图形上画出叶子节点plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD  # 在绘制了所有子节点之后,增加全局变量Y的偏移def createPlot(inTree):  # 创建绘图区,计算树形图的全局尺寸,并递归调用函数plotTree()fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # .ax1相当于对函数对象添加属性plotTree.totalW = float(getNumLeafs(inTree))  # 全局变量plotTree.totalW存储树的宽度plotTree.totalD = float(getTreeDepth(inTree))  # 全局变量plotTree.totalD存储树的深度plotTree.xOff = -0.5 / plotTree.totalW  # 全局变量plotTree.xOff和plotTree.yOff用于追踪已经绘制的结点位置,以及放置下一个节点的恰当位置plotTree.yOff = 1.0;plotTree(inTree, (0.5, 1.0), '')plt.show()

myTree = treePlotter.retrieveTree(0)
treePlotter.createPlot(myTree)

myTree = treePlotter.retrieveTree(0)
myTree[‘no surfacing’][3] = ‘maybe’
treePlotter.createPlot(myTree)

3 测试和存储分类器

3.1 使用决策树执行分类

def classify(inputTree, featLabels, testVec):# 在存储带有特征的数据时,程序无法确定特征在数据集中的位置,因此使用特征标签列表解决该问题。使用index方法查找当前列表中第一个匹配firstStr变量的元素,然后代码递归遍历整棵树,比较testVec变量中的值域树节点的值,如果到达叶子节点则返回节点的分类标签firstStr = list(inputTree.keys())[0]secondDist = inputTree[firstStr]featIndex = featLabels.index(firstStr)  # 找到根特征在featLabels的位置,将标签字符串转换为索引for key in secondDist.keys():if testVec[featIndex] == key:if type(secondDist[key]).__name__ == 'dict':classLabel = classify(secondDist[key], featLabels, testVec)else:classLabel = secondDist[key]return classLabel

myDat, labels = trees.createDataSet()
print(labels)
myTree = treePlotter.retrieveTree(0)
print(myTree)
print(trees.classify(myTree, labels, [1, 0]))
print(trees.classify(myTree, labels, [1, 1]))

3.2 使用算法:决策树的存储

每次使用分类器时重新构造决策树是很耗时的任务。为了解决该问题需要使用python模块pickle序列化对象,可以在磁盘上保存对象,并在需要时读取出来。任何对象都可以执行序列化操作。

def storeTree(inputTree, filename):import picklefw = open(filename, 'wb+')pickle.dump(inputTree, fw)  # pickle.dump(obj, file, [,protocol])将对象obj保存到file中.proctol为序列化使用的协议版本fw.close()def grabTree(filename):import picklefr = open(filename,'rb+')return pickle.load(fr)  # 用于反序列化对象,将文件中的数据解析为一个python对象

myDat, labels = trees.createDataSet()
print(labels)
myTree = treePlotter.retrieveTree(0)
trees.storeTree(myTree, ‘classifierStorage.txt’)
trees.grabTree(‘classifierStorage.txt’)

4 示例:使用决策树预测隐形眼镜类型

fr = open('lenses.txt')lenses = [inst.strip().split('\t') for inst in fr.readlines()]lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']lensesTree = trees.createTree(lenses, lensesLabels)print(lensesTree)treePlotter.createPlot(lensesTree)


沿着决策树的不同分支即可得到不同患者需要佩戴的隐形眼镜类型

附:全部代码

treePlotter.py:

import matplotlib.pyplot as plt# 定义树节点格式的常量
decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 决策节点的属性。boxstyle为文本框的类型,sawtooth为锯齿形,fc为边框线粗细
leafNode = dict(boxstyle="round4", fc="0.8")  # 决策树叶子结点的属性
arrow_args = dict(arrowstyle="<-")  # 剪头的属性def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 执行绘图功能。绘图区域由全局变量createPlot.ax1定义createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)# plt.annotate(str, xy=data_point_position, xytext=annotate_position,
#              va="center",  ha="center", xycoords="axes fraction",
#              textcoords="axes fraction", bbox=annotate_box_type, arrowprops=arrow_style)
# str是给数据点添加注释的内容,支持输入一个字符串
# xy=是要添加注释的数据点的位置
# xytext=是注释内容的位置
# bbox=是注释框的风格和颜色深度,fc越小,注释框的颜色越深,支持输入一个字典
# va="center",  ha="center"表示注释的坐标以注释框的正中心为准,而不是注释框的左下角(v代表垂直方向,h代表水平方向)
# xycoords和textcoords可以指定数据点的坐标系和注释内容的坐标系,通常只需指定xycoords即可,textcoords默认和xycoords相同
# arrowprops可以指定箭头的风格支持,输入一个字典
# plt.annotate()的详细参数可用__doc__查看,如:print(plt.annotate.__doc__)def createPlot():  # 代码核心。首先创建一个新图形并清空绘图区,然后在绘图区上绘制两个代表不同类型的树节点,后面用这两个结点绘制树图形fig = plt.figure(1, facecolor='white')  # 1表示图形编好/名称fig.clf()  # 表示清除所有轴createPlot.ax1 = plt.subplot(111, frameon=False)  # 为对象添加属性  frameon=true时图示被绘制在一个patch实体上;=false则图示直接被绘制在图形上plotNode('a dicision node', (0.5, 0.1), (0.1, 0.5), decisionNode)plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)plt.show()def getNumLeafs(myTree):  # 获取叶节点的数目numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():  # 测试节点的数据类型是否为字典if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])else:numLeafs += 1return numLeafsdef getTreeDepth(myTree):  # 获取树的层数maxDepth = 0firstStr = list(myTree.keys())[0]  # keys()返回一个字典的所有键secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1 + getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepthdef retrieveTree(i):  # 输出预先存储的树信息,避免每次测试代码时都要从数据中创建树的麻烦listOfTrees = [{'no surfacing': {0: 'no', 1: {'flipppers': {0: 'no', 1: 'yes'}}}},{'no surfacing': {0: 'no', 1: {'flipppers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]return listOfTrees[i]def plotMidText(cntrPt, parentPt, txtString):  # 用于计算父节点和子节点的中间位置,并在此添加简单的文本标签信息xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):  # 完成绘制树形的大多数工作numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)# 按照图形比例绘图,根据叶子结点的数目划分图形的宽度,从而计算得到当前结点的中心位置plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # 按比例减少全局变量plotTree.yOff# 由于是自顶向下绘制图形,因此需要依次递减y坐标值for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':  # 当节点不是叶子节点时递归调用plotTreeplotTree(secondDict[key], cntrPt, str(key))else:  # 当节点时叶子节点时在图形上画出叶子节点plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD  # 在绘制了所有子节点之后,增加全局变量Y的偏移def createPlot(inTree):  # 创建绘图区,计算树形图的全局尺寸,并递归调用函数plotTree()fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # .ax1相当于对函数对象添加属性plotTree.totalW = float(getNumLeafs(inTree))  # 全局变量plotTree.totalW存储树的宽度plotTree.totalD = float(getTreeDepth(inTree))  # 全局变量plotTree.totalD存储树的深度plotTree.xOff = -0.5 / plotTree.totalW  # 全局变量plotTree.xOff和plotTree.yOff用于追踪已经绘制的结点位置,以及放置下一个节点的恰当位置plotTree.yOff = 1.0;plotTree(inTree, (0.5, 1.0), '')plt.show()

trees.py:

from math import log
import operatordef createDataSet():dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]labels = ['no surfacing', 'flipppers']return dataSet, labelsdef calcShannonEnt(dataSet):numEntries = len(dataSet)  # 计算数据集中实例的总数labelCounts = {}  # 新建字典,记录每个分类下的数据个数for featVec in dataSet:  # 为所有可能的分类创建字典currentLabel = featVec[-1]  # 将dataSet每一个元素的最后一个元素选择出来if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0  # 当没有该键时,使用字典的自动添加添加值为0的项labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key]) / numEntries  # 取概率shannonEnt -= prob * log(prob, 2)  # log(x,2)表示以2为底求x的对数return shannonEntdef splitDataSet(dataSet, axis, value):  # 按照给定特征划分数据集。dataSet是待划分的数据集,axis是划分数据集的特征,value是需要返回的特征值retDataSet = []  # 创建新的list对象(为了不修改原始数据集,数据集这个列表的各个元素也是列表)for featVec in dataSet:  # 将符合特征的数据抽取出来if featVec[axis] == value:reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis + 1:])  # extend用于在列表末尾一次性追加另一个序列的多个值retDataSet.append(reducedFeatVec)  # append用于在列表末尾添加新的对象return retDataSetdef chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1  # 计算所有特征数baseEntropy = calcShannonEnt(dataSet)  # 计算原始香农熵,保存最初的无序度量值,用于与划分完之后的数据集计算的熵值进行比较bestInfoGain = 0.0bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]  # 遍历所有样本的第i个特征的取值情况(使用列表推导创建新的列表)uniqueVals = set(featList)  # 第i条特征的取值(去重)   set函数用于创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可计算交集、差集、并集等newEntropy = 0.0for value in uniqueVals:  # 计算每种划分方式的信息熵。对每个特征划分一次数据集,然后计算数据集的新熵值,并对所有唯一特征值得到的熵求和subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet) / float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif (infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeaturedef majorityCnt(classList):classCount = {}  # 创建键值为classList中唯一值的数据字典,存储classList中每个类标签出现的频率,然后利用operator操作键值排序字典,返回出现次数最多的分类名称for vote in classList:if vote not in classCount.keys(): classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]def createTree(dataSet, labels):  # 参数为数据集和标签列表 标签列表包含了数据集中所有特征的标签,为了给出数据明确的含义将其作为输入参数提供classList = [example[-1] for example in dataSet]  # 取标签值# 第一个停止条件:所有的类标签完全相同,则直接返回该类标签if classList.count(classList[0]) == len(classList):  # count函数用于统计某个元素在列表中出现的次数return classList[0]  # 列表完全相同则停止继续划分# 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,因此挑选出现次数最多的类别作为返回值if len(dataSet[0]) == 1:return majorityCnt(classList)  # 遍历完所有特征时返回出现次数最多的bestFeat = chooseBestFeatureToSplit(dataSet)  # 当前数据集选取的最好特征bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel: {}}  # 用字典存储树的结构del (labels[bestFeat])  # 删除取出过的标签,避免重复计算featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)  # 得到列表包含的所有属性值,利用set去重for value in uniqueVals:subLabels = labels[:]  # 复制所有的子标签,因为是引用类型,以避免改变原始标签数据# 在python中函数参数是列表类型时,参数是按照引用方式传递。为了保证每次调用函数createTree()时不改变原始列表的内容,使用新变量subLabels代替原始列表myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)  # 递归构建树return myTreedef classify(inputTree, featLabels, testVec):# 在存储带有特征的数据时,程序无法确定特征在数据集中的位置,因此使用特征标签列表解决该问题。使用index方法查找当前列表中第一个匹配firstStr变量的元素,然后代码递归遍历整棵树,比较testVec变量中的值域树节点的值,如果到达叶子节点则返回节点的分类标签firstStr = list(inputTree.keys())[0]secondDist = inputTree[firstStr]featIndex = featLabels.index(firstStr)  # 找到根特征在featLabels的位置,将标签字符串转换为索引for key in secondDist.keys():if testVec[featIndex] == key:if type(secondDist[key]).__name__ == 'dict':classLabel = classify(secondDist[key], featLabels, testVec)else:classLabel = secondDist[key]return classLabeldef storeTree(inputTree, filename):import picklefw = open(filename, 'wb+')pickle.dump(inputTree, fw)  # pickle.dump(obj, file, [,protocol])将对象obj保存到file中.proctol为序列化使用的协议版本fw.close()def grabTree(filename):import picklefr = open(filename,'rb+')return pickle.load(fr)  # 用于反序列化对象,将文件中的数据解析为一个python对象

main.py:

 fr = open('lenses.txt')lenses = [inst.strip().split('\t') for inst in fr.readlines()]lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']lensesTree = trees.createTree(lenses, lensesLabels)print(lensesTree)treePlotter.createPlot(lensesTree)

机器学习实战|决策树相关推荐

  1. 机器学习实战-决策树-22

    机器学习实战-决策树-叶子分类 import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplo ...

  2. 机器学习实战 —— 决策树(完整代码)

    声明: 此笔记是学习<机器学习实战> -- Peter Harrington 上的实例并结合西瓜书上的理论知识来完成,使用Python3 ,会与书上一些地方不一样. 机器学习实战-- 决策 ...

  3. [机器学习数据挖掘]机器学习实战决策树plotTree函数完全解析

    [机器学习&数据挖掘]机器学习实战决策树plotTree函数完全解析 http://www.cnblogs.com/fantasy01/p/4595902.html点击打开链接 import ...

  4. 机器学习实战--决策树ID3的构建、画图与实例:预测隐形眼镜类型

    声明 本文参考了<机器学习实战>书中代码,结合该书讲解,并加之自己的理解和阐述 机器学习实战系列博文 机器学习实战--k近邻算法改进约会网站的配对效果 机器学习实战--决策树的构建.画图与 ...

  5. python神经网络算法pdf_Python与机器学习实战 决策树、集成学习、支持向量机与神经网络算法详解及编程实现.pdf...

    作 者 :何宇健 出版发行 : 北京:电子工业出版社 , 2017.06 ISBN号 :978-7-121-31720-0 页 数 : 315 原书定价 : 69.00 主题词 : 软件工具-程序设计 ...

  6. [机器学习实战]决策树

    1. 简介 决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种 ...

  7. 机器学习实战——决策树:matplotlib绘图

    书籍:<机器学习实战>中文版 IDE:PyCharm Edu 4.02 环境:Adaconda3  python3.6 第一个例子: import matplotlib.pyplot as ...

  8. 机器学习实战决策树画图理解

    机器学习实战第二章决策树难点 第二章决策树用matplotlib画图的理解 决策树matplotlib画图代码 第二章决策树用matplotlib画图的理解 作为一个小白呢,确实对于我们来说第二章画图 ...

  9. 机器学习实战——决策树(代码)

    最近在学习Peter Harrington的<机器学习实战>,代码与书中的略有不同,但可以顺利运行. from math import log import operator# 计算熵 d ...

  10. 机器学习实战-决策树(二)Python实现

    转载请注明作者和出处: http://blog.csdn.net/c406495762 运行平台: Windows Python版本: Python3.x IDE: Sublime text3 一 前 ...

最新文章

  1. 2021年大数据Spark(三十五):SparkStreaming数据抽象 DStream
  2. 看图说OpenGL之三:是什么在改变物体的颜色
  3. SD-WAN为企业业务出海提供网络保障
  4. 机器学习-数据科学库(第五天)
  5. 《精彩绝伦的CSS》读书笔记(二)
  6. 谈跨平台C++动态连接库的实现
  7. Multi_thread--Linux下多线程编程互斥锁和条件变量的简单使用
  8. 大数据分块_谷歌卫星影像金字塔分块下载原理说明
  9. linux大一实验报告,linux实验报告
  10. Python | 实现pdf文件分页
  11. 服务器装系统就蓝屏,重装了下系统老是蓝屏
  12. Keil 5安装教程,搭建单片机环境
  13. 小米官网竖直导航栏html,手把手教你做小米商城顶部导航栏
  14. 不需要写代码,快速批量修改文件夹中图片的格式
  15. 前端基础入门课程推荐
  16. ps怎么加底部阴影_ps物体底部阴影怎么做阴影有立体感
  17. Win10强制更新关闭方法
  18. JVM调优-GC基本原理和调优关键分析
  19. x264 i_thead
  20. 倚天鸿蒙系统,倚天屠龙记强化系统详解

热门文章

  1. Linux九阴真经之大伏魔拳残卷4 nginx(模型,安装配置,模块)
  2. 【RNN】基于RNN的动态系统参数辨识matlab仿真
  3. 南京柳树湾与云南汉族人
  4. PAT_乙级_1008_筱筱
  5. 部署搭建DNS服务器
  6. Java集成PayPal支付
  7. iphone/ipad网站开发技巧整理
  8. 请收下这 72 个炫酷的 CSS 技巧
  9. vim 基本够用的操作命令
  10. 有助于睡眠的产品,失眠一定要知道的几样东西