机器学习实战 —— 决策树(完整代码)
声明: 此笔记是学习《机器学习实战》 —— Peter Harrington 上的实例并结合西瓜书上的理论知识来完成,使用Python3 ,会与书上一些地方不一样。
机器学习实战—— 决策树
Coding: Jungle
样本集合: D D D
第k类样本所占比例L: p k p_k pk
属性a对样本D进行划分产生分支节点个数: V V V
信息熵 : E n t ( D ) = − ∑ k = 1 ∣ y ∣ p k l o g 2 p k Ent(D) = - \sum_{k=1}^{|y|} p_k log_2p_k Ent(D)=−∑k=1∣y∣pklog2pk
信息增益: G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D,a) = Ent(D) - \sum_{v = 1}^{V} \frac{|D^v|}{|D|}Ent(D^v) Gain(D,a)=Ent(D)−∑v=1V∣D∣∣Dv∣Ent(Dv)
数据集
不浮出水面是否可以生存 | 是否有脚蹼 | 是否属于鱼类 | |
---|---|---|---|
1 | 是 | 是 | 是 |
2 | 是 | 是 | 是 |
3 | 是 | 否 | 否 |
4 | 否 | 是 | 否 |
4 | 否 | 是 | 否 |
1. 计算给定数据集的熵
#trees.py
from math import log
def calShannonEnt(dataSet):numEntries = len(dataSet)labelCounts = {}#为所有可能的分类创建字典for featVec in dataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:#计算熵,先求pprob = float(labelCounts[key])/numEntriesshannonEnt -= prob *log(prob,2)return shannonEnt
2. 构建数据集
def creatDataSet():dataSet = [[1,1,'maybe'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]labels = ['no surfacing','flippers']return dataSet,labels
myData,labels = creatDataSet()
print("数据集:{}\n 标签:{}".format(myData,labels))
print("该数据集下的香农熵为:{}".format(calShannonEnt(myData)))
数据集:[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]标签:['no surfacing', 'flippers']
该数据集下的香农熵为:1.3709505944546687
相同数据量下,减少属性值类型及特征值,对比熵的变化
myData[0][-1] = 'yes'
print("数据为:{}\n 该数据集下的香农熵为:{}".format(myData,calShannonEnt(myData)))
数据为:[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]该数据集下的香农熵为:0.9709505944546686
3. 划分数据集
# 根据属性及其属性值划分数据集
def splitDataSet(dataSet, axis, value):'''dataSet : 待划分的数据集axis : 属性及特征value : 属性值及特征的hasattr值'''retDataSet = []for featVet in dataSet:if featVet[axis] == value:reducedFeatVec = featVet[:axis]reducedFeatVec.extend(featVet[axis+1:])retDataSet.append(reducedFeatVec)return retDataSet
print("划分前的数据集:{}\n \n按照“离开水是否能生存”为划分属性,得到下一层待划分的结果为:\n{}--------{}".format(myData,splitDataSet(myData,0,0),splitDataSet(myData,0,1)))
划分前的数据集:[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]按照“离开水是否能生存”为划分属性,得到下一层待划分的结果为:
[[1, 'no'], [1, 'no']]--------[[1, 'yes'], [1, 'yes'], [0, 'no']]
# 选择最好的数据集划分方式,及根绝信息增益选择划分属性
def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1baseEntropy = calShannonEnt(dataSet)bestInfoGain, bestFeature = 0, -1for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy = 0.0# 计算每种划分方式的信息熵for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet) / float(len(dataSet))newEntropy += prob * calShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif(infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature
chooseBestFeatureToSplit(myData)
0
递归构建决策树
# 找到出现次数最多的分类名称
import operatordef majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys():classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]
# 创建树的函数
def creatTree(dataSet, labels):classList = [example[-1] for example in dataSet]# 类别完全相同停止划分if classList.count(classList[0]) == len(classList):return classList[0]if len(dataSet[0]) == 1:return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel: {}}del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:sublabels = labels[:] myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataSet, bestFeat, value), sublabels)return myTree
creatTree(myData,labels)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
画出决策树
# treePlotter.py
import matplotlib.pyplot as plt
from pylab import*decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt, centerPt,parentPt, nodeType):mpl.rcParams['font.sans-serif']=['SimHei'] createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():fig = plt.figure(111, facecolor='white')fig.clf()createPlot.ax1 = plt.subplot(111, frameon=False)plotNode(U'Decision Node', (0.5, 0.1),(0.1, 0.5), decisionNode)plotNode(U'Leaf Node', (0.8, 0.1),(0.3, 0.8), leafNode)plt.show
createPlot()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2QZJ7Upl-1596631929903)(output_19_0.png)]
# treePlotter.py
# 计算叶节点的个数def getNumLeaves(myTree):numLeafs= 0# 截取到树字典中的key值#firstStr = str(myTree.keys())[13:-3]firstStr = eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeaves(secondDict[key])else:numLeafs += 1return numLeafs
# 计算树的深度def getTreeDepth(myTree):maxDepth = 0firstStr = eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1 + getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepth
#测试深度计算和叶节点记述函数
def retrieveTree(i):listOftrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]return listOftrees[i]mytree = retrieveTree(1)
getNumLeaves(mytree)
getTreeDepth(mytree)
3
# treePlotter.py
def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):numLeafs = getTreeDepth(myTree)firstStr = eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]cntrPt = (plotTree.xOff+(1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':plotTree(secondDict[key], cntrPt, str(key))else:plotTree.xOff = plotTree.xOff +1.0/plotTree.totalWplotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))plotTree.yOff = plotTree.yOff +1.0/plotTree.totalDdef createPlot(inTree):fig = plt.figure(1,facecolor='white')fig.clf()axprops = dict(xticks = [],yticks = [])createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)plotTree.totalW = float(getNumLeaves(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5/plotTree.totalWplotTree.yOff = 1.0plotTree(inTree,(0.5,1.0),"")plt.show
myTree = retrieveTree(0)
createPlot(myTree)
测试功能
#trees.py
def classify(inputTree,featLbabels,testVec):firstStr = eval(str(myTree.keys()).replace('dict_keys(','').replace(')',''))[0]secondDict = inputTree[firstStr]featIndex = featLbabels.index(firstStr)for key in secondDict.keys():if testVec[featIndex] == key:if type(secondDict[key]).__name__ == 'dict':classLabel = classify(secondDict[key],featLbabels,testVec)else:classLabel = secondDict[key]return classLabel
#来测试一下
myDat1,labels1 = creatDataSet()
mytree1 = retrieveTree(0)
classify(mytree1,labels1,[0,0])
'no'
西瓜书西瓜的决策树构建
数据集
编号 | 色泽 | 根蒂 | 敲声 | 纹理 | 脐部 | 触感 | 瓜型 |
---|---|---|---|---|---|---|---|
1 | 青绿 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 好瓜 |
2 | 乌黑 | 蜷缩 | 沉闷 | 清晰 | 凹陷 | 硬滑 | 好瓜 |
3 | 乌黑 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 好瓜 |
4 | 青绿 | 蜷缩 | 沉闷 | 清晰 | 凹陷 | 硬滑 | 好瓜 |
5 | 浅白 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 好瓜 |
6 | 青绿 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 软粘 | 好瓜 |
7 | 乌黑 | 稍蜷 | 浊响 | 稍糊 | 稍凹 | 软粘 | 好瓜 |
8 | 乌黑 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 硬滑 | 好瓜 |
9 | 乌黑 | 稍蜷 | 沉闷 | 稍糊 | 稍凹 | 硬滑 | 坏瓜 |
10 | 青绿 | 硬挺 | 清脆 | 清晰 | 平坦 | 软粘 | 坏瓜 |
11 | 浅白 | 硬挺 | 清脆 | 模糊 | 平坦 | 硬滑 | 坏瓜 |
12 | 浅白 | 蜷缩 | 浊响 | 模糊 | 平坦 | 软粘 | 坏瓜 |
13 | 青绿 | 稍蜷 | 浊响 | 稍糊 | 凹陷 | 硬滑 | 坏瓜 |
14 | 浅白 | 稍蜷 | 沉闷 | 稍糊 | 凹陷 | 硬滑 | 坏瓜 |
15 | 乌黑 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 软粘 | 坏瓜 |
16 | 浅白 | 蜷缩 | 浊响 | 模糊 | 平坦 | 硬滑 | 坏瓜 |
17 | 青绿 | 蜷缩 | 沉闷 | 稍糊 | 稍凹 | 硬滑 | 坏瓜 |
#DT_ID3_pumpkin .py
def createDatePumpKin():dataSet = [# 1['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],# 2['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],# 3['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],# 4['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],# 5['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],# 6['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],# 7['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],# 8['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],# ----------------------------------------------------# 9['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],# 10['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],# 11['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],# 12['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],# 13['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],# 14['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],# 15['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],# 16['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],# 17['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']]labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']return dataSet,labels
#这里在前面计算香农熵的函数,好像是会在log函数的第二个参数的类型上报错,进行一下修改,但是主要问题原因需要看底层的代码
from math import log2
def calShannonEnt(dataSet):numEntries = len(dataSet)labelCounts = {}#为所有可能的分类创建字典for featVec in dataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:#计算熵,先求pprob = float(labelCounts[key])/numEntriesshannonEnt -= prob *log2(prob)return shannonEnt
myData_P,labels_P = createDatePumpKin()
print("该数据集下的香农熵为:{}".format(myData_P,calShannonEnt(myData_P)))
该数据集下的香农熵为:0.9975025463691153
print("按照“色泽”为划分属性,得到下一层待划分的结果为:\n-------->{}\n-------->{}\n--------->{}".format(myData_P,splitDataSet(myData_P,0,'浅白'),splitDataSet(myData_P,0,'青绿'),splitDataSet(myData_P,0,'乌黑')))
按照“色泽”为划分属性,得到下一层待划分的结果为:
-------->[['蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], ['硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'], ['蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'], ['稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'], ['蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜']]
-------->[['蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], ['蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'], ['稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'], ['硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'], ['稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'], ['蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']]
--------->[['蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'], ['蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], ['稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'], ['稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'], ['稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'], ['稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜']]
myTree_P=creatTree(myData_P,labels_P)
print(myTree_P)
createPlot(myTree_P)
{'纹理': {'模糊': '坏瓜', '稍糊': {'触感': {'硬滑': '坏瓜', '软粘': '好瓜'}}, '清晰': {'根蒂': {'稍蜷': {'色泽': {'青绿': '好瓜', '乌黑': {'触感': {'硬滑': '好瓜', '软粘': '坏瓜'}}}}, '硬挺': '坏瓜', '蜷缩': '好瓜'}}}}
机器学习实战 —— 决策树(完整代码)相关推荐
- 机器学习实战——决策树(代码)
最近在学习Peter Harrington的<机器学习实战>,代码与书中的略有不同,但可以顺利运行. from math import log import operator# 计算熵 d ...
- apriori算法代码_资源 | 《机器学习实战》及代码(基于Python3)
〇.<机器学习实战> 今天推荐给大家的是<机器学习实战>这本书. 机器学习作为人工智能研究领域中一个极其重要的研究方向(一文章看懂人工智能.机器学习和深度学习),在当下极其热门 ...
- [机器学习数据挖掘]机器学习实战决策树plotTree函数完全解析
[机器学习&数据挖掘]机器学习实战决策树plotTree函数完全解析 http://www.cnblogs.com/fantasy01/p/4595902.html点击打开链接 import ...
- 机器学习实战--决策树ID3的构建、画图与实例:预测隐形眼镜类型
声明 本文参考了<机器学习实战>书中代码,结合该书讲解,并加之自己的理解和阐述 机器学习实战系列博文 机器学习实战--k近邻算法改进约会网站的配对效果 机器学习实战--决策树的构建.画图与 ...
- python神经网络算法pdf_Python与机器学习实战 决策树、集成学习、支持向量机与神经网络算法详解及编程实现.pdf...
作 者 :何宇健 出版发行 : 北京:电子工业出版社 , 2017.06 ISBN号 :978-7-121-31720-0 页 数 : 315 原书定价 : 69.00 主题词 : 软件工具-程序设计 ...
- 机器学习实战-决策树-22
机器学习实战-决策树-叶子分类 import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplo ...
- 《机器学习实战》配套代码下载
<机器学习实战>配套代码资源下载网址:http://www.ituring.com.cn/book/1021(图灵社区),网址里有随书下载,可以下载配套资源.
- 机器学习实战 支持向量机SVM 代码解析
机器学习实战 支持向量机SVM 代码解析 <机器学习实战>用代码实现了算法,理解源代码更有助于我们掌握算法,但是比较适合有一定基础的小伙伴.svm这章代码看起来风轻云淡,实则对于新手来说有 ...
- 机器学习实战——决策树构建过程,信息熵及相关代码
目录 1. 决策树基本概念 2. 决策树类别 3. 构建决策树 3.1 选择最优特征 3.1.1 信息熵 3.1.2 信息增益 3.2 决策树的生成 3.3 剪枝 1. 决策树基本概念 决策树就是一棵 ...
最新文章
- String复习笔记
- C++编程思想:指针,引用,拷贝构造函数,赋值运算符
- Python中如何把一个UTC时间转换为本地时间
- java 如何跟多个字符串比较_Stack Overflow上370万浏览量的一个问题:如何比较Java的字符串...
- 10无法勾选隐藏的项目_Excel超好用的隐藏操作,不可多得
- emlog博客主题价值358元lu1.3模板
- python gil锁_python--GIL锁
- 微型计算机的主要,微型计算机组成,微型计算机主要由什么组成
- SQL Server舍入功能概述– SQL舍入,上限和下限
- Centos7下安装DB2
- 嵌入式软件设计第7次实验报告
- ai文件图片连接丢失怎么处理_未来美学丨点亮你的AI技能点(一)
- python list 查找子列_寻找列表连续的子列
- ARM开发经典学习网站推荐【转】
- 将移动硬盘变为与系统硬盘等同的存在
- 如何让客户接受你的价格比别人更高?
- 移动端Touch (触摸)事件
- 女朋友撒娇让我教她HashMap
- 那些 996 公司的员工怎么样了?
- sap badi s4 MIGO屏幕实施测试