【机器学习实战】决策树算法:预测隐形眼镜类型

0.收集数据

这里采用的数据集是《机器学习实战》提供的 lenses.txt 文件,该文件内容如下:

young    myope   no  reduced no lenses
young   myope   no  normal  soft
young   myope   yes reduced no lenses
young   myope   yes normal  hard
young   hyper   no  reduced no lenses
young   hyper   no  normal  soft
young   hyper   yes reduced no lenses
young   hyper   yes normal  hard
pre myope   no  reduced no lenses
pre myope   no  normal  soft
pre myope   yes reduced no lenses
pre myope   yes normal  hard
pre hyper   no  reduced no lenses
pre hyper   no  normal  soft
pre hyper   yes reduced no lenses
pre hyper   yes normal  no lenses
presbyopic  myope   no  reduced no lenses
presbyopic  myope   no  normal  no lenses
presbyopic  myope   yes reduced no lenses
presbyopic  myope   yes normal  hard
presbyopic  hyper   no  reduced no lenses
presbyopic  hyper   no  normal  soft
presbyopic  hyper   yes reduced no lenses
presbyopic  hyper   yes normal  no lenses

每列数据类型分别是 age、prescript、astigmatic、tearRateage、prescript、astigmatic、tearRateage、prescript、astigmatic、tearRate ,而最后一列的类型是隐形眼镜的类型。

1.准备数据:解析tab键分隔的数据行

首先由于我们的数据文件是以 TabTabTab 分割开各列之间的数据的,所以我们首先需要获取被分隔的数据行。

代码如下,其中 strip()strip()strip() 表示删除掉数据中的换行符,则 split('\t') 是数据中遇到 '\t' (既 TabTabTab) 就隔开。

fr = open('lenses.txt') # 打开数据集文件
lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 解析tab键分割的数据行

由于 lenses.txtlenses.txtlenses.txt 文件中并没有对每列数据进行命名,这里我将每列数据的名称准备在 lensesLabelslensesLabelslensesLabels 变量中。

lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']

数据都准备好了,接下来就可以开始我们的决策树构造了。

2.决策树的构造

决策树算法(DecisionTreeDecision TreeDecisionTree):决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。

缺点:可能会产生过度匹配问题。

适用数据类型:数值型和标称型。

2.1 信息增益

划分数据集的大原则是:将无序的数据变得更加有序。在划分数据集之前之后信息发生的变化称为信息增益,这里我们采用 香农熵 来计算信息的增益。

如果待分类的事务可能划分在多个分类中,则符号 xix_ixi​ 的信息定义为:l(xi)=−log2p(xi)l(x_i)=-log_2p(x_i)l(xi​)=−log2​p(xi​)

其中 p(xi)p(x_i)p(xi​) 是选择该分类的概率。

为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到(其中 nnn 是分类的数目):
H=−∑i=1np(xi)log2p(xi)H=-\sum^{n}_{i=1}p(x_i)log_2p(x_i)H=−∑i=1n​p(xi​)log2​p(xi​)

from math import log#计算给定数据集的香农熵
def calcShannonEnt(dataSet):numEntries = len(dataSet) # 获取数据集中实例的总数labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1] # featVec[-1]是指获取最后一个数值if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0 # 新添加的值,所以计数为 0labelCounts[currentLabel] += 1shannonEnt = 0.0 # shannonEnt用于记录计算的香农熵for key in labelCounts:prob = float(labelCounts[key])/numEntries # 计算P(xi)的概率shannonEnt -= prob * log(prob, 2) # 计算香农熵return shannonEnt

由于熵越高,则混合的数据也越多,因此我们可以通过计算香农熵来划分数据集。

2.2 划分数据集

首先先把当作特征值的属性进行抽取。

# 输入参数分别是:待划分的数据集、划分数据集的特征,需要返回的特征的值
def splitDataSet(dataSet, axis, value):retDataSet = [] # 创建新的list对象for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[:axis] # 获取关键特征前面的属性reducedFeatVec.extend(featVec[axis + 1 :]) # 填加关键特征后面的属性retDataSet.append(reducedFeatVec) # 以上步骤相当于对特征值进行抽取return retDataSet # 返回抽取特征后的数据集

然后再依次计算以不同属性值为特征值时的香农熵,判断以何种属性为特征值时是最优的数据划分。

# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1 #获取每个数据集拥有几个特征(排除最后一个)beseEntropy = calcShannonEnt(dataSet) # 计算以最后一个数值为特征的香农熵bestInfoGain = 0.0;bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]# 将dataSet中的数据先按行依次放入example中,然后取得example中的example[i]元素,放入列表featList中uniqueVals = set(featList) # set() 函数创建一个无序不重复元素集newEntropy = 0.0for value in uniqueVals: # 计算每种划分方式的信息熵subDataSet = splitDataSet(dataSet, i, value) # 按照给定特征划分数据集prob = len(subDataSet) / float(len(dataSet)) # 计算当前结果的可能性newEntropy += prob * calcShannonEnt(subDataSet) # 不同可能性的香农熵的和infoGain = beseEntropy - newEntropyif(infoGain > bestInfoGain): # 判断是否是当前最小香农熵,计算出最好的信息增益bestInfoGain = infoGainbestFeature = ireturn bestFeature

到这里,我们已经可以计算当前数据的最好划分方式了,但决策树不是只划分一次就好了,而是层层递进的划分下去,因此接下来就开始实现递归构建决策树。

2.3 递归构建决策树

工作原理:得到原始数据,然后基于最好的属性值划分数据集,由于特征值可能多余两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,再这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。

递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。

首先使用分类名称的列表,然后创建值为 classListclassListclassList 中唯一值的数据字典,字典对象存储了 classListclassListclassList 中每个类标签出现的频率,最后利用 operatoroperatoroperator 操作键值排序字典,并返回出现次数最多的分类名称。

import operatordef majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys(): classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)return sortedClassCount # 返回出现次数最多的分类名称

接着就可以创建树了,其中变量 myTreemyTreemyTree 包含了很多代表树结构信息的嵌套字典,至此我们已经正确的构建好了树。

# 创建树的函数代码,两个输入参数:数据集和标签列表
def creatTree(dataSet, labels):classList = [example[-1] for example in dataSet]# 将dataSet中的数据先按行依次放入example中,然后取得example中的example[-1]元素,放入列表classList中if classList.count(classList[0]) == len(classList): # 类别完全相同则停止继续划分return classList[0]if len(dataSet[0]) == 1: # 遍历完所有特征时返回出现次数最多的类别return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet) # 选择最好的数据集划分方式bestFeatLabel = labels[bestFeat] # 获取属性文字标签myTree = {bestFeatLabel : {}}# 得到列表包含的所有属性值del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataSet, bestFeat, value), subLabels)return myTree

3.在Python中使用Matplotlib注解绘制树形图

由于这里使用的主要是 MatplotlibMatplotlibMatplotlib 绘图的知识,与机器学习关系不大,故这里不对代码进行详细讲解。

import matplotlib.pyplot as plt
import matplotlib# 定义文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',xytext = centerPt, textcoords = 'axes fraction',va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)# 获取叶节点的数目和树的层数
def getNumLeafs(myTree):numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])else: numLeafs += 1return numLeafsdef getTreeDepth(myTree):maxDepth = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1 + getTreeDepth(secondDict[key])else: thisDepth = 1if thisDepth > maxDepth: maxDepth = thisDepthreturn maxDepth# plotTree函数
# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)# 计算宽与高
def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)#标记子节点属性值plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodesplotTree(secondDict[key], cntrPt, str(key))  # recursionelse:  # it's a leaf node print the leaf nodeplotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# 这个是真正的绘制,上边是逻辑的绘制
def createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # no ticksplotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalW;plotTree.yOff = 1.0;plotTree(inTree, (0.5, 1.0), '')plt.axis('off') # 去掉坐标轴plt.show()

4.使用算法

主函数代码:

if __name__ == "__main__":fr = open('lenses.txt') # 打开数据集文件lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 解析tab键分割的数据行lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']lensesTree = creatTree(lenses, lensesLabels)createPlot(lensesTree)

运行过后就可以得到我们的结果,如下图片:

5.总结

这个算法的思想本质其实并不复杂,但我在阅读代码的过程中却是困难重重

【机器学习实战】决策树算法:预测隐形眼镜类型相关推荐

  1. 机器学习——决策树实践(预测隐形眼镜类型)

    前言 之前把<机器学习实战>这本书的分类部分学完了,想自己动手实践一下,所以从前面的章节开始,慢慢熟悉代码. 今天在学习决策树的时候,发现书中并没有直接给出预测隐形眼镜类型的代码,于是想借 ...

  2. 机器学习实战--决策树算法

    决策树 决策树(decision tree)是一种基本的分类与回归方法.从判断模块引出的左右箭头称为分支,它可以达到另外一个判断模块或者终止模块.分类决策树模型是一种描述对实例进行分类的树形结构.决策 ...

  3. 机器学习实战-决策树算法

    本篇决策树算法是依据ID3算法来的,所以在看之间建议先了解ID3算法:https://blog.csdn.net/qq_27396861/article/details/88226296 文章目录 一 ...

  4. 机器学习实战——KNN算法预测电影类型

    预测电影类型 现有爱情片和动作片(不是爱情动作片,雾)的打斗场面和接吻场面的次数统计,然后给出一个电影打斗场面和接吻场面出现的次数,预测其类型. 那么如何预测呢?当然用KNN了. KNN算法的原理就是 ...

  5. Python3:《机器学习实战》之决策树算法(3)预测隐形眼镜类型

    Python3:<机器学习实战>之决策树算法(3)预测隐形眼镜类型 转载请注明作者和出处:http://blog.csdn.net/u011475210 代码地址:https://gith ...

  6. python3.5《机器学习实战》学习笔记(五):决策树算法实战之预测隐形眼镜类型

    转载请注明作者和出处:http://blog.csdn.net/u013829973 系统版本:window 7 (64bit) 文章出现的所有代码在我的GitHub:https://github.c ...

  7. 徒手写代码之《机器学习实战》-----决策树算法(2)(使用决策树预测隐形眼镜类型)

    使用决策树预测隐形眼镜类型 说明: 将数据集文件 'lenses.txt' 放在当前文件夹 from math import log import operator 熵的定义 "" ...

  8. Python3:《机器学习实战》之决策算法(3)预测隐形眼镜类型

    Python3:<机器学习实战>之决策树算法(3)预测隐形眼镜类型 转载请注明作者和出处:http://blog.csdn.net/u011475210 代码地址:https://gith ...

  9. 机器学习实战--决策树ID3的构建、画图与实例:预测隐形眼镜类型

    声明 本文参考了<机器学习实战>书中代码,结合该书讲解,并加之自己的理解和阐述 机器学习实战系列博文 机器学习实战--k近邻算法改进约会网站的配对效果 机器学习实战--决策树的构建.画图与 ...

  10. 《机器学习实战》学习笔记:绘制树形图使用决策树预测隐形眼镜类型

    上一节实现了决策树,但只是使用包含树结构信息的嵌套字典来实现,其表示形式较难理解,显然,绘制直观的二叉树图是十分必要的.Python没有提供自带的绘制树工具,需要自己编写函数,结合Matplotlib ...

最新文章

  1. django链接数据库报错Error loading MySQLdb module:No module named MySQLdb
  2. [NodeJs] 如何获取项目的根路径?
  3. 复制中文到linux系统,怎么是的window文字复制到linux下
  4. android.content.res.Resources$NotFoundException: String resource ID #0x0
  5. 初步认识图像的直方图
  6. 微信小程序DEMO初体验
  7. 开放源代码是如何吞噬软件的
  8. 蓝桥杯2015年第六届C/C++省赛A组第三题-奇妙的数字
  9. office2010安装提示需要MSXML版本6.10.1129.0
  10. 数据库1_五大主流数据库模型
  11. 可用等式为:html+java=jsp表示jsp[8]._在 JSP 中 , 对 jsp:setProperty 标记描述正确的是 ()_学小易找答案...
  12. 企业招聘黑名单的9类人,你中枪了么?
  13. C语言二叉排序树的中序遍历,C语言实现二叉树的中序线索化及遍历中序线索二叉树...
  14. 如何修改PDF中图片的大小尺寸
  15. 【爆品】馥兰朵想买找谁?代理能月入过万吗?
  16. nexus上传jar总是读条而上传不成功的问题
  17. IT行业招聘技巧--渠道篇
  18. Android文字广告(Textview上下滚动),使用TextSwitcher控件实现
  19. 一篇写给迷茫的你或者想入门java的同学们
  20. stm32中断源有哪些_STM32常见问题汇总

热门文章

  1. typecho插件,typechoSEO插件,typecho程序插件
  2. 一起学libcef--libcef的基本类和方法介绍(如何产生一个你自己的浏览器)
  3. 《MFC 控件透明处理》
  4. python生成PDF报告
  5. msvcp71.dll、msvcr71.dll丢失解决方法
  6. linux玩安卓游戏下载,在Deepin 20.2系统下可用Xdroid on Linux来玩王者荣耀游戏
  7. 中职学校计算机课程标准,中等职业学校课程标准发布
  8. 使用油猴插件,屏蔽网页上的禁止右键操作
  9. 全自动高清录播服务器,高清高清录播服务器 高清全自动录播系统 方便携带 搭建快捷...
  10. 个人计算机组装主板,电脑主板安装详细图解 可以自己组装电脑了