前言:

本节使用数据依旧是之前生成的三种球类数据,刚进入这篇文章的小伙伴可以回头看下。链接如下:

机器学习入门之k近邻算法_俺从头开始的博客-CSDN博客

信息营地:

决策树:

百度百科讲决策树:“决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。

决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。”

本质上来讲,决策树还是一个分类模型,所以它的中心工作是利用一种度量手段来区分各个类别的数据。

信息熵:

一位二十世纪的天才——克劳德·香农,提出了一种名叫“熵”的度量标准,至今依旧被广泛应用于信息领域。

熵公式:

其中,为分类的概率。

信息增益:

与信息熵一起诞生的产物。

计算公式:

一个系统越是有序,信息熵就越低,一个系统越是混乱,信息熵就越高,所以信息熵被认为是一个系统有序程度的度量。

动手实践:

主要就是敲代码实现的过程了,展示如下:

from math import log
import operator
import matplotlib.pyplot as plt
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):numEntries = len(dataSet)  # 计算实例总数labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1]  # 键值是最后一列数值if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1 # 将不存在的类别加入字典shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEnt -= prob * log(prob, 2)  # 香农公式return shannonEnt# 划分数据集
def splitDataSet(dataSet, axis, value):  # dataSet:待划分的数据集,axis:划分数据集特征,value:需要返回的特征值retDataSet = []  # 防止破坏原始数据for featVec in dataSet:if featVec[axis] == value:  # 符合要求reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec)  # 抽取return retDataSet# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0; bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)  # 创建唯一的分类标签列表newEntropy = 0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropy  # 计算每种方式的信息熵if(infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = i  # 计算最好的信息增益return bestFeature# 取最大值标签
def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys():classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.items(), \key=operator.itemgetter(1), reverse=True)  # 根据字典的值降序排列return sortedClassCount[0][0]# 创建树
def createTree(dataSet, labels):classList = [example[-1] for example in dataSet]if classList.count(classList[0]) == len(classList):return classList[0]  # classList只剩下一种值if len(dataSet[0]) == 1:  # dataSet中属性使用完毕,但没有分配完毕return majorityCnt(classList)  # 取数量最多作为分类bestFeat = chooseBestFeatureToSplit(dataSet)labels2 = labels.copy()bestFeatLabel = labels2[bestFeat]myTree = {bestFeatLabel: {}}del(labels2[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels2[:]  # 剩余属性列表myTree[bestFeatLabel][value] = createTree(splitDataSet \(dataSet, bestFeat, value), subLabels)return myTree#导入数据
def sentDataSet(filename):with open(filename, 'r', encoding='utf-8') as file:arrayOLines = file.readlines() #列表型numberOfLines = len(arrayOLines)dataSet = numpy.zeros((numberOfLines, 4))index = 0for line in arrayOLines:line = line.strip()listFromLine = line.split('\t\t')dataSet[index, :] = listFromLine[0: 5]labels = ['圆周长', '重量', '材料', '花纹']return dataSet, labelsmyDat, labels = sentDataSet('./data1')
myTree = createTree(myDat, labels)
print(myTree)

结果如下:

数据集采用了上一节中已生成的数据 ,所以还是分类球类的问题。

主要遇到的问题是,连续的数据没有进行处理,以至于每一组数据都成了一个单独的类别。连续数组的处理将在下一节提到。

导入数据主要就是将之前生成的数据集导入进来,方便后续操作。这一块代码可能有些简陋,毕竟博主水平还有限,不像之前的代码可以照着书敲【手动狗头】。不过嘛,代码毕竟是调通了,皆大欢喜,放心食用。

当然,这个问题也不能一直放着。于是乎,博主决定去掉数据集里的连续型变量,只用后两列数据进行分类,结果如下:

至于代码嘛,倒也不用大改,切片取出数据时修改一下参数即可。

如下:

# 导入数据
def sentDataSet(filename):with open(filename, 'r', encoding='utf-8') as file:arrayOLines = file.readlines()  # 列表型dataSet = []for line in arrayOLines:line = line.strip()listFromLine = line.split('\t\t')dataSet.append(listFromLine[2:5])  # 修改切片labels = ['材料', '花纹']  # 修改这里return dataSet, labels

修改部分已注释。

可视化树:

这一块是希望能过够将分类好的树可视化的输出,也就是直观的看到这棵树。

代码如下:

# 用Matplotlib注解绘制树形图# 定义文本框和箭头格式
decisionNode = dict(boxstyle="square", fc="0.8")  # boxstyle文本框样式、fc=”0.8” 是颜色深度
leafNode = dict(boxstyle="round4", fc="0.8")  # 叶子节点
arrow_args = dict(arrowstyle="<-")  # 定义箭头# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 此函数执行绘制功能# createPlot.ax1是表示: ax1是函数createPlot的一个属性createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt,textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)# 获取叶节点的数目和树的层数
def getNumLeafs(myTree):numLeafs = 0  # 初始化firstStr = list(myTree.keys())[0]  # 获得第一个key值(根节点)secondDict = myTree[firstStr]  # 获得value值for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':  # 测试节点的数据类型是否为字典numLeafs += getNumLeafs(secondDict[key])  # 递归调用函数else:numLeafs += 1return numLeafs# 获取树的深度
def getTreeDepth(myTree):maxDepth = 0  # 初始化firstStr = list(myTree.keys())[0]  # 获得第一个key值(根节点)secondDict = myTree[firstStr]  # 获得value值for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':  # 测试节点的数据类型是否为字典thisDepth = 1 + getTreeDepth(secondDict[key])  # 递归调用else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepth# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)# 画树
def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)  # 获取树高depth = getTreeDepth(myTree)  # 获取树深度firstStr = list(myTree.keys())[0]  # 这个节点的文本标签cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,plotTree.yOff)  # plotTree.totalW, plotTree.yOff全局变量,追踪已经绘制的节点,以及放置下一个节点的恰当位置plotMidText(cntrPt, parentPt, nodeTxt)  # 标记子节点属性plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # 减少y偏移for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':plotTree(secondDict[key], cntrPt, str(key))else:plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# 绘制决策树
def createPlot(inTree):fig = plt.figure(1, facecolor='white')  # 创建一个新图形fig.clf()  # 清空绘图区font = {'family': 'MicroSoft YaHei'}plt.rc("font", **font)axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalW;plotTree.yOff = 1.0;plotTree(inTree, (0.5, 1.0), '')plt.show()Createplot = createPlot(myTree)

运行结果:

材料这个分类分的稀烂,但这主要是因为这个类别本身就不适合分类。但由于去掉了前两列的数据,防止生成树过于单调,我最后还是决定将它加上。

代码问题上注意下面这一列:

本身就是字体的选择,但需要你的电脑上有这个可以编译的字体,这里推荐使用

font = {'family': 'SimHei'},一般情况下没问题。

下一节将会对树进行优化,也有连续性变量的处理方式,欢迎追订哦!

机器学习之构造决策树相关推荐

  1. 机器学习之使用sklearn构造决策树模型

    一.任务基础 导入所需要的库 import matplotlib.pyplot as plt import pandas as pd%matplotlib inline 加载sklearn内置数据集 ...

  2. 机器学习实战之决策树(一)构造决策树

    决策树(一)构造决策树 1.简介 1.1 优缺点 1.2 流程 1.3 决策树的构造 1.4 海洋生物数据 2.信息增益 2.1 几个概念 2.2 计算给定数据集的熵 3 划分数据集 3.1 按照给定 ...

  3. 机器学习知识点(七)决策树学习算法Java实现

    为理解机器学习第四章节决策树学习算法,通过网上找到的一份现成代码,主要实现了最优划分属性选择和决策树构造,其中最优划分属性选择采用信息增益准则.决策树构造采用递归实现,代码如下: package sk ...

  4. 秒懂机器学习---当机器学习遇上决策树....

    秒懂机器学习---当机器学习遇上决策树.... 一.总结 一句话总结: 多多看图,图的直观效果很好,很多时候文字实在表达不清 1.决策树(Decision Tree)中的各个节点表示什么意思? 每一个 ...

  5. 机器学习系列文章-决策树

    决策树 由于我们是使用sklearn对决策树代码进行实现,所以并不是很关心其原理部分.但我仍需要对其进行一定的了解.通过查询资料,去学习了下决策树的原理,这里对其原理进行简要介绍. 注:这里决策树的原 ...

  6. 《机器学习实战》—— 决策树

    目录 一.决策树的构造 1. 信息增益 2. 划分数据集 3. 递归构建决策树 二.在 Python 中使用 Matplotlib 注解绘制树形图 1. Matplotlib 2. 构造注解树 三.测 ...

  7. 机器学习算法 04 —— 决策树(ID3、C4.5、CART,剪枝,特征提取,回归决策树)

    文章目录 系列文章 决策树 1 决策树算法简介 2 决策树分类的原理 2.1 信息熵 2.2 决策树划分依据-信息增益(ID3) 2.3 决策树划分依据-信息增益率(C4.5) 2.4 决策树划分依据 ...

  8. python数据分析/机器学习 笔记之决策树(泰坦尼克号旅客生存预测)

    最近在学习用python数据分析,不可避免的接触到了机器学习的一些算法,所以在这里简单整理一些学习的笔记和心得与大家分享! 首先机器学习分为:监督学习和非监督学习,前者有参照物,后者为参照物:主要分为 ...

  9. 决策树(decision tree)(一)——构造决策树方法

    决策树(decision tree)(一)--构造决策树方法 说明:这篇博客是看周志华老师的<机器学习>(西瓜书)的笔记总结,虽然自己写了很多总结性文字包括一些算法细节,但博客中仍有部分文 ...

最新文章

  1. 【Python 第8课】while
  2. 人工智能这条小船何时才能变成航母?
  3. ssm 上传图片到mysql_ssm(Spring+Spring MVC+MyBatis)+Web Uploader开发图片文件上传实例,支持批量上传,拖拽上传,复制粘贴上传...
  4. CentOS-创建yum本地源
  5. zippo油和zorro油的区别_陶瓷轴承润滑油的性能质量最指标油的区别?_搜狐汽车...
  6. CTF(pwn)-格式化字符串漏洞讲解(一)
  7. cannot resolve symbol ‘springframework‘
  8. OAuth2 实现单点登录 SSO
  9. 160 - 10 Andrénalin.3
  10. 计算机休眠怎么唤醒,电脑休眠后,就无法唤醒了?怎么办?快速教你解决这个问题...
  11. 郑大远程计算机应用基础第09,郑大远程教育《计算机应用基础》第09章在线测试...
  12. jpeglib画质的代码分析
  13. 《shell脚本学习指南》学习笔记之入门 一
  14. 贝叶斯优化核极限学习机KELM用于回归预测
  15. Web3到底是什么?
  16. 统一诊断服务(UDS)否定响应也可以被抑制
  17. Linux文件-/etc/login.defs
  18. supermap java,SuperMap iObjects Java 10i 产品介绍
  19. 苹果主屏幕按钮怎么设置_苹果手机屏幕变大怎么恢复
  20. 苏宁零售云,“动物凶猛”

热门文章

  1. 解决应用程序无法正常启动(0xc000007b)。请单击“确定”关闭应用程序。
  2. 【附源码】计算机毕业设计java智能仓储设备管理系统设计与实现
  3. Visual Studio 2010 简体中文旗舰、专业版(MSDN原版下载)
  4. Spark的性能优化案例分析(下)
  5. Tera-Store高速数据采集存储系统
  6. 中科蓝汛 ----POWER 10S复位系统的坑
  7. 云空间服务,助力用户数据存储与协同
  8. 关于新高考中综合素质评价的思考
  9. 多线程Synchronized锁的使用与线程之间的通讯
  10. Nginx+Lua 实现灰度发布详细步骤