机器学习之构造决策树
前言:
本节使用数据依旧是之前生成的三种球类数据,刚进入这篇文章的小伙伴可以回头看下。链接如下:
机器学习入门之k近邻算法_俺从头开始的博客-CSDN博客
信息营地:
决策树:
百度百科讲决策树:“决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。
决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。”
本质上来讲,决策树还是一个分类模型,所以它的中心工作是利用一种度量手段来区分各个类别的数据。
信息熵:
一位二十世纪的天才——克劳德·香农,提出了一种名叫“熵”的度量标准,至今依旧被广泛应用于信息领域。
熵公式:
其中,为分类的概率。
信息增益:
与信息熵一起诞生的产物。
计算公式:
一个系统越是有序,信息熵就越低,一个系统越是混乱,信息熵就越高,所以信息熵被认为是一个系统有序程度的度量。
动手实践:
主要就是敲代码实现的过程了,展示如下:
from math import log
import operator
import matplotlib.pyplot as plt
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):numEntries = len(dataSet) # 计算实例总数labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1] # 键值是最后一列数值if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1 # 将不存在的类别加入字典shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEnt -= prob * log(prob, 2) # 香农公式return shannonEnt# 划分数据集
def splitDataSet(dataSet, axis, value): # dataSet:待划分的数据集,axis:划分数据集特征,value:需要返回的特征值retDataSet = [] # 防止破坏原始数据for featVec in dataSet:if featVec[axis] == value: # 符合要求reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec) # 抽取return retDataSet# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0; bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList) # 创建唯一的分类标签列表newEntropy = 0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropy # 计算每种方式的信息熵if(infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = i # 计算最好的信息增益return bestFeature# 取最大值标签
def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys():classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.items(), \key=operator.itemgetter(1), reverse=True) # 根据字典的值降序排列return sortedClassCount[0][0]# 创建树
def createTree(dataSet, labels):classList = [example[-1] for example in dataSet]if classList.count(classList[0]) == len(classList):return classList[0] # classList只剩下一种值if len(dataSet[0]) == 1: # dataSet中属性使用完毕,但没有分配完毕return majorityCnt(classList) # 取数量最多作为分类bestFeat = chooseBestFeatureToSplit(dataSet)labels2 = labels.copy()bestFeatLabel = labels2[bestFeat]myTree = {bestFeatLabel: {}}del(labels2[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels2[:] # 剩余属性列表myTree[bestFeatLabel][value] = createTree(splitDataSet \(dataSet, bestFeat, value), subLabels)return myTree#导入数据
def sentDataSet(filename):with open(filename, 'r', encoding='utf-8') as file:arrayOLines = file.readlines() #列表型numberOfLines = len(arrayOLines)dataSet = numpy.zeros((numberOfLines, 4))index = 0for line in arrayOLines:line = line.strip()listFromLine = line.split('\t\t')dataSet[index, :] = listFromLine[0: 5]labels = ['圆周长', '重量', '材料', '花纹']return dataSet, labelsmyDat, labels = sentDataSet('./data1')
myTree = createTree(myDat, labels)
print(myTree)
结果如下:
数据集采用了上一节中已生成的数据 ,所以还是分类球类的问题。
主要遇到的问题是,连续的数据没有进行处理,以至于每一组数据都成了一个单独的类别。连续数组的处理将在下一节提到。
导入数据主要就是将之前生成的数据集导入进来,方便后续操作。这一块代码可能有些简陋,毕竟博主水平还有限,不像之前的代码可以照着书敲【手动狗头】。不过嘛,代码毕竟是调通了,皆大欢喜,放心食用。
当然,这个问题也不能一直放着。于是乎,博主决定去掉数据集里的连续型变量,只用后两列数据进行分类,结果如下:
至于代码嘛,倒也不用大改,切片取出数据时修改一下参数即可。
如下:
# 导入数据
def sentDataSet(filename):with open(filename, 'r', encoding='utf-8') as file:arrayOLines = file.readlines() # 列表型dataSet = []for line in arrayOLines:line = line.strip()listFromLine = line.split('\t\t')dataSet.append(listFromLine[2:5]) # 修改切片labels = ['材料', '花纹'] # 修改这里return dataSet, labels
修改部分已注释。
可视化树:
这一块是希望能过够将分类好的树可视化的输出,也就是直观的看到这棵树。
代码如下:
# 用Matplotlib注解绘制树形图# 定义文本框和箭头格式
decisionNode = dict(boxstyle="square", fc="0.8") # boxstyle文本框样式、fc=”0.8” 是颜色深度
leafNode = dict(boxstyle="round4", fc="0.8") # 叶子节点
arrow_args = dict(arrowstyle="<-") # 定义箭头# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 此函数执行绘制功能# createPlot.ax1是表示: ax1是函数createPlot的一个属性createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt,textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)# 获取叶节点的数目和树的层数
def getNumLeafs(myTree):numLeafs = 0 # 初始化firstStr = list(myTree.keys())[0] # 获得第一个key值(根节点)secondDict = myTree[firstStr] # 获得value值for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict': # 测试节点的数据类型是否为字典numLeafs += getNumLeafs(secondDict[key]) # 递归调用函数else:numLeafs += 1return numLeafs# 获取树的深度
def getTreeDepth(myTree):maxDepth = 0 # 初始化firstStr = list(myTree.keys())[0] # 获得第一个key值(根节点)secondDict = myTree[firstStr] # 获得value值for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict': # 测试节点的数据类型是否为字典thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepth# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)# 画树
def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree) # 获取树高depth = getTreeDepth(myTree) # 获取树深度firstStr = list(myTree.keys())[0] # 这个节点的文本标签cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,plotTree.yOff) # plotTree.totalW, plotTree.yOff全局变量,追踪已经绘制的节点,以及放置下一个节点的恰当位置plotMidText(cntrPt, parentPt, nodeTxt) # 标记子节点属性plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD # 减少y偏移for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':plotTree(secondDict[key], cntrPt, str(key))else:plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# 绘制决策树
def createPlot(inTree):fig = plt.figure(1, facecolor='white') # 创建一个新图形fig.clf() # 清空绘图区font = {'family': 'MicroSoft YaHei'}plt.rc("font", **font)axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalW;plotTree.yOff = 1.0;plotTree(inTree, (0.5, 1.0), '')plt.show()Createplot = createPlot(myTree)
运行结果:
材料这个分类分的稀烂,但这主要是因为这个类别本身就不适合分类。但由于去掉了前两列的数据,防止生成树过于单调,我最后还是决定将它加上。
代码问题上注意下面这一列:
本身就是字体的选择,但需要你的电脑上有这个可以编译的字体,这里推荐使用
font = {'family': 'SimHei'},一般情况下没问题。
下一节将会对树进行优化,也有连续性变量的处理方式,欢迎追订哦!
机器学习之构造决策树相关推荐
- 机器学习之使用sklearn构造决策树模型
一.任务基础 导入所需要的库 import matplotlib.pyplot as plt import pandas as pd%matplotlib inline 加载sklearn内置数据集 ...
- 机器学习实战之决策树(一)构造决策树
决策树(一)构造决策树 1.简介 1.1 优缺点 1.2 流程 1.3 决策树的构造 1.4 海洋生物数据 2.信息增益 2.1 几个概念 2.2 计算给定数据集的熵 3 划分数据集 3.1 按照给定 ...
- 机器学习知识点(七)决策树学习算法Java实现
为理解机器学习第四章节决策树学习算法,通过网上找到的一份现成代码,主要实现了最优划分属性选择和决策树构造,其中最优划分属性选择采用信息增益准则.决策树构造采用递归实现,代码如下: package sk ...
- 秒懂机器学习---当机器学习遇上决策树....
秒懂机器学习---当机器学习遇上决策树.... 一.总结 一句话总结: 多多看图,图的直观效果很好,很多时候文字实在表达不清 1.决策树(Decision Tree)中的各个节点表示什么意思? 每一个 ...
- 机器学习系列文章-决策树
决策树 由于我们是使用sklearn对决策树代码进行实现,所以并不是很关心其原理部分.但我仍需要对其进行一定的了解.通过查询资料,去学习了下决策树的原理,这里对其原理进行简要介绍. 注:这里决策树的原 ...
- 《机器学习实战》—— 决策树
目录 一.决策树的构造 1. 信息增益 2. 划分数据集 3. 递归构建决策树 二.在 Python 中使用 Matplotlib 注解绘制树形图 1. Matplotlib 2. 构造注解树 三.测 ...
- 机器学习算法 04 —— 决策树(ID3、C4.5、CART,剪枝,特征提取,回归决策树)
文章目录 系列文章 决策树 1 决策树算法简介 2 决策树分类的原理 2.1 信息熵 2.2 决策树划分依据-信息增益(ID3) 2.3 决策树划分依据-信息增益率(C4.5) 2.4 决策树划分依据 ...
- python数据分析/机器学习 笔记之决策树(泰坦尼克号旅客生存预测)
最近在学习用python数据分析,不可避免的接触到了机器学习的一些算法,所以在这里简单整理一些学习的笔记和心得与大家分享! 首先机器学习分为:监督学习和非监督学习,前者有参照物,后者为参照物:主要分为 ...
- 决策树(decision tree)(一)——构造决策树方法
决策树(decision tree)(一)--构造决策树方法 说明:这篇博客是看周志华老师的<机器学习>(西瓜书)的笔记总结,虽然自己写了很多总结性文字包括一些算法细节,但博客中仍有部分文 ...
最新文章
- 【Python 第8课】while
- 人工智能这条小船何时才能变成航母?
- ssm 上传图片到mysql_ssm(Spring+Spring MVC+MyBatis)+Web Uploader开发图片文件上传实例,支持批量上传,拖拽上传,复制粘贴上传...
- CentOS-创建yum本地源
- zippo油和zorro油的区别_陶瓷轴承润滑油的性能质量最指标油的区别?_搜狐汽车...
- CTF(pwn)-格式化字符串漏洞讲解(一)
- cannot resolve symbol ‘springframework‘
- OAuth2 实现单点登录 SSO
- 160 - 10 Andrénalin.3
- 计算机休眠怎么唤醒,电脑休眠后,就无法唤醒了?怎么办?快速教你解决这个问题...
- 郑大远程计算机应用基础第09,郑大远程教育《计算机应用基础》第09章在线测试...
- jpeglib画质的代码分析
- 《shell脚本学习指南》学习笔记之入门 一
- 贝叶斯优化核极限学习机KELM用于回归预测
- Web3到底是什么?
- 统一诊断服务(UDS)否定响应也可以被抑制
- Linux文件-/etc/login.defs
- supermap java,SuperMap iObjects Java 10i 产品介绍
- 苹果主屏幕按钮怎么设置_苹果手机屏幕变大怎么恢复
- 苏宁零售云,“动物凶猛”