现在我们碰到这样一个问题,一个人去医院想配一副隐形眼镜。我们需要通过问他4个问题,决定他需要带眼镜的类型。那么如何解决这个问题呢?我们决定用决策树。首先我们去下载一个隐形眼镜数据集,数据来源于UCI数据库。下载了lenses.data文件,如下:

1  1  1  1  1  3
2  1  1  1  2  2
3  1  1  2  1  3
4  1  1  2  2  1
5  1  2  1  1  3
6  1  2  1  2  2
7  1  2  2  1  3
8  1  2  2  2  1
9  2  1  1  1  3
10  2  1  1  2  2
11  2  1  2  1  3
12  2  1  2  2  1
13  2  2  1  1  3
14  2  2  1  2  2
15  2  2  2  1  3
16  2  2  2  2  3
17  3  1  1  1  3
18  3  1  1  2  3
19  3  1  2  1  3
20  3  1  2  2  1
21  3  2  1  1  3
22  3  2  1  2  2
23  3  2  2  1  3
24  3  2  2  2  3

我们可以看到,第一列的1到24,对应数据的ID

第二列的1到3,分别对应病人的年龄(age of patient),分别是青年(young),中年(pre-presbyopic),老年(presbyopic)

第三列的1和2,分别对应近视情况(spectacle prescription),近视(myope),远视(hypermetrope)

第四列的1和2,分别对应眼睛是否散光(astigmatic),不散光(no),散光(yes)

第五列的1和2,分别对应分泌眼泪的频率(tear production rate),很少(reduce),普通(normal)

第六列的1到3,则是最终根据以上数据得到的分类,分别是硬性的隐形眼镜(hard),软性的隐形眼镜(soft),不需要带眼镜(no lenses)

数据我们获取到了,那么我们写一个函数去打开文件设定好数据集,以下是代码:

from numpy import *
import operator
from math import logdef createLensesDataSet():#创建隐形眼镜数据集fr = open('lenses.data')allLinesArr = fr.readlines()linesNum = len(allLinesArr)returnMat = zeros((linesNum, 4))statusLabels = ['age of the patient', 'spectacle prescription', 'astigmatic', 'tear production rate']classLabelVector = []classLabels = ['hard', 'soft', 'no lenses']index = 0for line in allLinesArr:line = line.strip()lineList = line.split('  ')returnMat[index, :] = lineList[1:5]classIndex = int(lineList[5]) - 1classLabelVector.append(classLabels[classIndex])  # 索引-1代表列表最后一个元素index += 1return ndarray.tolist(returnMat), statusLabels, classLabelVectordef createLensesAttributeInfo():parentAgeList = ['young', 'pre', 'presbyopic']spectacleList = ['myope', 'hyper']astigmaticList = ['no', 'yes']tearRateList = ['reduced', 'normal']return parentAgeList, spectacleList, astigmaticList, tearRateList

那么接下来我们应该设定决策树的分支,如何确定以上哪一个特征是第一个分支呢,我们要提到一个概念,香农熵(Shannon entropy)。熵这个概念代表信息的不确定性的大小,在划分数据集中经常会运用到。

它的公式是:

那么我们先写一个计算香农熵的函数:

def calcShannonEnt(dataSet):#计算香农熵numEntries = len(dataSet)labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEnt -= prob * log(prob, 2)return shannonEnt

经过计算,我们可以得到我们当前使用的数据集,熵为:1.32608752536

然后,我们写一个划分数据集的函数,可以根据数据集,特征索引和特征值来划分数据集:

def splitDataSet(dataSet, axis, value):#按照特征值划分数据集,参数为数据集,特征索引,特征值retDataSet = []for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec)return retDataSet

说到取最佳特征值,我们就要提到一个概念信息增益(information divergence)

他的公式是:

即将单独一个特征值提取出来,计算该特征值每个分支划分出数据集的熵的求和,然后用总数据集的熵减去它

计算四个特征值的信息增益我们得到以下数据:

0:0.0393965036461
1:0.0395108354236
2:0.377005230011
3:0.548794940695

以下是计算信息增益的代码:

def chooseBestFeatureToSplit(dataSet):#选择最佳分割特征值numFeatures =  len(dataSet[0]) - 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet) / float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyprint(str(i)+':'+str(infoGain))if (infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature

通过计算我们可以得出特征值的优先级,tear production rate>astigmatic>spectacle prescription>age of patient

接下来,有了以上的计算函数,我们就可以开始创建决策树了,创建决策树,我们使用字典类型去存储,用键代表分支节点,值代表下一个节点或者叶子节点,代码如下:

def createTree(dataSet, labels):#创建决策树classList = [example[-1] for example in dataSet]if classList.count(classList[0]) == len(classList):print(classList[0])return classList[0]if len(dataSet[0]) == 1:return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel:{}}del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)return myTreedef majorityCnt(classList):#对于单个特征值的列表,按出现次数进行排序classCount = {}for vote in classList:if vote not in classCount.keys(): classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)return sortedClassCount[0][0]

主要函数写完以后,我们写一段测试代码,打印我们创建出的决策树:

import trees
import treePlotter
from numpy import *lensesData, labels, vector = trees.createLensesDataSet()
parentAgeList, spectacleList, astigmaticList, tearRateList = trees.createLensesAttributeInfo()
lensesAttributeList = [parentAgeList, spectacleList, astigmaticList, tearRateList]for i in range(len(lensesData)):for j in range(len(lensesData[i])):index = int(lensesData[i][j]) - 1lensesData[i][j] = lensesAttributeList[j][index]lensesData[i].append(str(vector[i]))myTree = trees.createTree(lensesData, labels)
print(myTree)

我们看一下输出:

{'tear production rate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'spectacle prescription': {'hyper': {'age of the patient': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age of the patient': {'pre': 'soft', 'presbyopic': {'spectacle prescription': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}

可以看出这是一个比较长的字典嵌套结构,但是这样看上去很不直观,为了让这个决策树能直观的显示出来,我们要导入图形化模块matplotlib,用来把决策树画出来。

我们新写一个treePlotter脚本,脚本中添加计算决策树叶节点数量及深度的函数,用以计算画布的高宽布局。通过计算两个节点中点坐标的函数,确定分支属性的位置,最终画出决策树。以下是脚本代码:

import matplotlib.pyplot as plt
import matplotlibfrom pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']# 定义文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlotPlus.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction', xytext = centerPt, textcoords = 'axes fraction', \va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)def getNumLeafs(myTree):#获取叶节点的总数量numLeafs = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for k in secondDict.keys():if type(secondDict[k]).__name__ == 'dict':#判断节点数据类型是否为字典numLeafs += getNumLeafs(secondDict[k])else:numLeafs += 1return numLeafsdef getTreeDepth(myTree):#判断决策树的深度maxDepth = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for k in secondDict.keys():if type(secondDict[k]).__name__ == 'dict':  # 判断节点数据类型是否为字典thisDepth = 1 + getTreeDepth(secondDict[k])else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepthdef plotMidText(cntrPt, parentPt, txtString):#计算给定两个坐标的中点坐标xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlotPlus.ax1.text(xMid-0.05, yMid, txtString, rotation = 30)def plotTree(myTree, parentPt, nodeTxt):#根据树,父节点,节点文本,绘制一个分支节点numLeafs = getNumLeafs(myTree)firstStr = myTree.keys()[0]cntrPt = (plotTree.xOff +(1.0 + float(numLeafs)) / 2.0 /plotTree.totalW, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalDfor k in secondDict.keys():if type(secondDict[k]).__name__ =='dict':plotTree(secondDict[k], cntrPt, str(k))else:plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalWplotNode(secondDict[k], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(k))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalDdef createPlotPlus(inTree):#根据给定决策树创建图像fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks = [], yticks = [])createPlotPlus.ax1 = plt.subplot(111, frameon = False, **axprops)plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalWplotTree.yOff = 1.0plotTree(inTree, (0.5, 1.0), '')plt.show()

经过这个脚本的处理,我们在测试代码上调用创建决策树图像的函数:

treePlotter.createPlotPlus(myTree)

得到最终图像:

以上,完成。

参考书籍:《机器学习实战》

Python创建决策树—解决隐形眼镜选择问题相关推荐

  1. 【数据挖掘】决策树算法简介 ( 决策树模型 | 模型示例 | 决策树算法性能要求 | 递归创建决策树 | 树根属性选择 )

    文章目录 I . 决策树模型 II . 决策树模型 示例 III . 决策树算法列举 IV . 决策树算法 示例 V . 决策树算法性能要求 VI . 决策树模型创建 ( 递归创建决策树 ) VII ...

  2. 03-0004 CART决策树解决银行贷款问题(Python)

    CART决策树解决银行贷款问题-Python 1.问题描述 2.术语解释 2.1.CART 2.2.基尼系数 2.3.基尼指数 2.4.基尼系数增益 3.算法步骤 3.1语言描述 3.2举例说明 4. ...

  3. python创建文件夹 覆盖_Python 创建新文件时避免覆盖已有的同名文件的解决方法...

    思路:创建文件时,先检查是否有同名文件(使用os.path.isfile),如果有,则在文件名后加上编号n来创建. 关键点: 1. 使用os.path.isfile判断文件是否存在 2. 使用递归函数 ...

  4. 决策树(四):使用决策树预测隐形眼镜类型

    使用决策树预测隐形眼镜类型 介绍 代码部分 总结 介绍 本节我们将通过一个例子讲解决策树如何预测患者需要佩戴的隐形眼镜类型.使用小数据集 ,我们就可以利用决策树学到很多知识:眼科医生是如何判断患者需要 ...

  5. 数据分享|PYTHON用决策树分类预测糖尿病和可视化实例

    全文下载链接:http://tecdat.cn/?p=23848 在本文中,决策树是对例子进行分类的一种简单表示.它是一种有监督的机器学习技术,数据根据某个参数被连续分割.决策树分析可以帮助解决分类和 ...

  6. 《机器学习实战》学习笔记:绘制树形图使用决策树预测隐形眼镜类型

    上一节实现了决策树,但只是使用包含树结构信息的嵌套字典来实现,其表示形式较难理解,显然,绘制直观的二叉树图是十分必要的.Python没有提供自带的绘制树工具,需要自己编写函数,结合Matplotlib ...

  7. python初学者代码-Python-为什么Python是初学者的完美选择?

    原标题:Python-为什么Python是初学者的完美选择? 在过去的两年中,Python部落已经教会了成百上千个开发人员学会了他们的第一个编程语言.我们一直关注的一件事就是为一个初学程序员找到最好的 ...

  8. python实现决策树数据直接赋值导入_Python3.0 实现决策树算法的流程

    决策树的一般流程 检测数据集中的每个子项是否属于同一个分类 if so return 类标签 Else 寻找划分数据集的最好特征 划分数据集 创建分支 节点 from math import log ...

  9. 徒手写代码之《机器学习实战》-----决策树算法(2)(使用决策树预测隐形眼镜类型)

    使用决策树预测隐形眼镜类型 说明: 将数据集文件 'lenses.txt' 放在当前文件夹 from math import log import operator 熵的定义 "" ...

最新文章

  1. error while loading shared libraries: libopencv_core.so.3.4: cannot open 报错解决方法
  2. c char转int_C指针精华知识大汇总
  3. SpringBoot笔记整理(四)
  4. 5G 比 4G 快,不只是因为......
  5. Ubuntu/Mac彻底解决手机ADB识别问题
  6. ICLR'17 | 在特征空间增强数据集
  7. java网页内容不能复制_win7系统禁用Java小程序脚本网页内容复制不了的解决方法...
  8. 高通开发-烧写及异常启动分析
  9. 【题解】codeforces 1023G. Pisces 最长反链 数据结构维护差分 启发式合并
  10. 论劳动生产力进步的原因,兼论劳动产品在不同阶级人民之间自然分配顺序(读后感)
  11. centos7 安装7z压缩命令
  12. python删除列表元素delete_Python3 tkinter基础 Listbox delete 删除单个、所有元素
  13. php发布编辑删除功能,php实现添加修改删除
  14. 学编程用什么系统好?
  15. 春运赶火车小心这些骗术
  16. 大家好,我是新人,请多多关照,(*  ̄3)(ε ̄ *)么么
  17. 什么是站群服务器?站群服务器与普通服务器的区别,为什么站长都爱用站群服务器
  18. ​SIGIR 2021 | 多样性推荐:增强领域级别和用户级别的自适应性
  19. 关于智能家居布线 这些你都知道吗
  20. markdown设置字体颜色大小和背景色

热门文章

  1. springboot整合jett实现模板excel数据导出
  2. MarkDown 的常用高阶教程
  3. 网站证书过期导致WordPress后台无法登录问题解决方法
  4. SSL证书会不会过期?域名SSL证书过期了怎么办?
  5. ubuntu Aria2 AriaNg安装
  6. 同步光网络(SONET,Synchronous Optical Networking)简介
  7. Python 程序员需要知道的 30 个技巧
  8. 优美库图片小程序 Version1.0
  9. Android 9.0 一定要适配htpps请求?
  10. 异国他乡-写在即将离开芬兰的早晨