决策树

kNN算法可以完成很多分类任务,但是它最大的缺点就是给出数据的内在含义,决策树的主要优势就在于数据形式非常容易理解

决策树的构造

  • 优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
  • 缺点:可能会产生过度匹配问题。

适用数据类型:数值型和标称型。

创建分支的伪代码函数createBranch()

Check if every item in the dataset is in the same class:If so return the class labelElsefind the best feature to split the datasplit the datasetcreate a branch nodefor each splitcall createBranch() and add the result to the branch nodereturn branch node

示例数据

海洋生物数据

不浮出水面是否可以生存 是否有脚蹼 属于鱼类

信息增益 Information gain

划分数据集的大原则是:将无序的数据变得更加有序。

组织杂乱无章数据的一种方法就是使用 信息论 度量信息。

在划分数据集之前之后信息发生的变化称为 信息增益.

知道如何计算信息增益,就可以计算某个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

冯 诺依曼 建议使用 这术语

信息增益是熵(数据无序度)的减少,大家肯定对于将熵用于度量数据无序度的减少更容易理解。

集合信息的度量称为香农熵 或者 简称 熵(entropy)。(更多熵知识请移步至 What Is Information Entropy)

熵定义为信息的期望值

信息定义

如果待分类的事务可能划分在多个分类之中,则符号Xi的信息定义为

其中p(Xi)是选择该分类的概率。

为了计算熵,我们需要计算所有分类别所有可能值包含的信息期望值,通过下面的公式得到

trees.py

计算给定数据集的香农熵

def calcShannonEnt(dataSet):#实例总数numEntries = len(dataSet)labelCounts = {}#the the number of unique elements and their occurance#统计目标变量的值出现的次数for featVec in dataSet: #每个实例的最后一项是目标变量currentLabel = featVec[-1]if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0#利用上面的公式计算出香农熵for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEnt -= prob * log(prob,2) #log base 2return shannonEnt

创建数据集

def createDataSet():dataSet = [[1, 1, 'yes'],[1, 1, 'yes'],[1, 0, 'no'],[0, 1, 'no'],[0, 1, 'no']]labels = ['no surfacing','flippers']#change to discrete valuesreturn dataSet, labels

运行

testTree.py

# -*- coding: utf-8 -*-
import treesdataSet, labels = trees.createDataSet()print dataSet
#[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
print labels
#['no surfacing', 'flippers']#计算熵
print trees.calcShannonEnt(dataSet)
#0.970950594455#改变多一个数据
dataSet[0][-1] = 'maybe'print dataSet
#[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
print trees.calcShannonEnt(dataSet)
#1.37095059445

熵越大,则混合的数据越多

延伸:另一个度量集合无序程度的方法是基尼不纯度,简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组的概率。

划分数据集

#axis表示第n列
#返回剔除第n列数据的数据集
def splitDataSet(dataSet, axis, value):retDataSet = []for featVec in dataSet:if featVec[axis] == value:#剔除第n列数据reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec)return retDataSet

运行

testTree.py

print dataSet
#[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
#划分数据集#当第0列时,值为0 的实例
print trees.splitDataSet(dataSet, 0, 0)
#[[1, 'no'], [1, 'no']]#当第0列时,值为1 的实例
print trees.splitDataSet(dataSet, 0, 1)
#[[1, 'yes'], [1, 'yes'], [0, 'no']]
#append 和 extend 区别
>>> a=[1,2,3]
>>> b=[4,5,6]
>>> a.append(b)
>>> a
[1, 2, 3, [4, 5, 6]]
>>> a=[1,2,3]
>>> a.extend(b)
>>> a
[1, 2, 3, 4, 5, 6]
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):#有多少个特征数量,最后一个目标变量numFeatures = len(dataSet[0]) - 1#计算基准 香农熵 目标变量的熵baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0; bestFeature = -1#迭代特征,i是列数for i in range(numFeatures):        #该特征(一列)下 所有值#使用 列表推倒 (List Comprehension)featList = [example[i] for example in dataSet] #特征值去重uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:#返回剔除第i列数据的数据集subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))#新的香农熵#有点不清楚这公式newEntropy += prob * calcShannonEnt(subDataSet)     #计算增益infoGain = baseEntropy - newEntropy#选择最大增益,增益越大,区分越大if (infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature

运行

print trees.chooseBestFeatureToSplit(dataSet)
#0
#这运行结果告诉我们,第0特征是最好的用于划分数据集的特征#chooseBestFeatureToSplit(dataSet)的一些中间变量的值
#baseEntropy: 0.970950594455    #第0列
#value: 0
#value: 1
#newEntropy: 0.550977500433#第1列
#value: 0
#value: 1
#newEntropy: 0.8

递归构建决策树

返回出现次数最多的分类名称

def majorityCnt(classList):classCount={}for vote in classList:if vote not in classCount.keys(): classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]

创建树的函数代码

def createTree(dataSet,labels):#目标变量的值classList = [example[-1] for example in dataSet]#stop splitting when all of the classes are equalif classList.count(classList[0]) == len(classList): return classList[0]#stop splitting when there are no more features in dataSetif len(dataSet[0]) == 1: return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel:{}}del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:#copy all of labels, so trees don't mess up existing labelssubLabels = labels[:]       myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)return myTree

运行结果

print "---createTree---"print trees.createTree(dataSet, labels)"""
---createTree---
classList:
['yes', 'yes', 'no', 'no', 'no']
baseEntropy: 0.970950594455
value: 0
value: 1
newEntropy: 0.550977500433
value: 0
value: 1
newEntropy: 0.8
---
classList:
['no', 'no']
---
classList:
['yes', 'yes', 'no']
baseEntropy: 0.918295834054
value: 0
value: 1
newEntropy: 0.0
---
classList:
['no']
---
classList:
['yes', 'yes']---最终运行结果---
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
"""

在Python中使用Matplotlib注解绘制树形图

Matplotlib注解annotate

testCreatePlot.py

import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")#绘制节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )def createPlot():fig = plt.figure(1, facecolor='white')fig.clf()createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)plt.show()print createPlot#<function createPlot at 0x0000000007636F98>createPlot()print createPlot.ax1#AxesSubplot(0.125,0.11;0.775x0.77)

注意:红色的坐标是后来加上去的,不是上面程序生成的。

构造注解树

获取叶节点的数目和树的层数

testPlotTree.py

def getNumLeafs(myTree):numLeafs = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodesnumLeafs += getNumLeafs(secondDict[key])else:   numLeafs +=1return numLeafsdef getTreeDepth(myTree):maxDepth = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodesthisDepth = 1 + getTreeDepth(secondDict[key])else:   thisDepth = 1if thisDepth > maxDepth: maxDepth = thisDepthreturn maxDepthdef retrieveTree(i):listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]return listOfTrees[i]myTree = retrieveTree(0)
print "myTree: "
print myTreeprint "getNumLeafs(myTree): "
print getNumLeafs(myTree)print "getTreeDepth(myTree): "
print getTreeDepth(myTree)# myTree:
# {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
# getNumLeafs(myTree):
# 3
# getTreeDepth(myTree):
# 2

treePlotter.py

#在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on#this determines the x width of this treenumLeafs = getNumLeafs(myTree)  depth = getTreeDepth(myTree)#the text label for this node should be thisfirstStr = myTree.keys()[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalDfor key in secondDict.keys():#test to see if the nodes are dictonaires, if not they are leaf nodesif type(secondDict[key]).__name__=='dict':#recursion递归调用plotTree(secondDict[key],cntrPt,str(key))else:#it's a leaf node print the leaf node绘制叶节点plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])#**axprops 表示 no ticks 不绘制坐标点createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;#绘制跟节点plotTree(inTree, (0.5,1.0), '')plt.show()

运用

testPlotTree2.py

# -*- coding: utf-8 -*- import treePlottermyTree = treePlotter.retrieveTree(0)
print myTree
#{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}#开始绘制决策树
treePlotter.createPlot(myTree)

测试算法:使用决策树执行分类

def classify(inputTree,featLabels,testVec):firstStr = inputTree.keys()[0]secondDict = inputTree[firstStr]featIndex = featLabels.index(firstStr)key = testVec[featIndex]valueOfFeat = secondDict[key]if isinstance(valueOfFeat, dict): classLabel = classify(valueOfFeat, featLabels, testVec)else: classLabel = valueOfFeatreturn classLabel

运行

testClassify.py

import treePlotter
import treesdataSet, labels = trees.createDataSet()
myTree = treePlotter.retrieveTree(0)print myTreeprint trees.classify(myTree, labels, [1, 0])
#noprint trees.classify(myTree, labels, [1, 1])
#yes

决策树存储到本地

为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。

def storeTree(inputTree,filename):import picklefw = open(filename,'w')pickle.dump(inputTree,fw)fw.close()def grabTree(filename):import picklefr = open(filename)return pickle.load(fr)

运用

testStoreTree.py

import trees
import treePlottermyTree = treePlotter.retrieveTree(0)
#存储到'classifierStorage.txt'文件
trees.storeTree(myTree, 'classifierStorage.txt')#再读取
print trees.grabTree('classifierStorage.txt')
#{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

示例:使用决策树预测隐形眼镜类型

数据集lenses.txt

age  prescript药方 astigmatic散光的 tearRate撕裂率young  myope   no  reduced no lenses
young   myope   no  normal  soft
young   myope   yes reduced no lenses
young   myope   yes normal  hard
young   hyper   no  reduced no lenses
young   hyper   no  normal  soft
young   hyper   yes reduced no lenses
young   hyper   yes normal  hard
pre myope   no  reduced no lenses
pre myope   no  normal  soft
pre myope   yes reduced no lenses
pre myope   yes normal  hard
pre hyper   no  reduced no lenses
pre hyper   no  normal  soft
pre hyper   yes reduced no lenses
pre hyper   yes normal  no lenses
presbyopic  myope   no  reduced no lenses
presbyopic  myope   no  normal  no lenses
presbyopic  myope   yes reduced no lenses
presbyopic  myope   yes normal  hard
presbyopic  hyper   no  reduced no lenses
presbyopic  hyper   no  normal  soft
presbyopic  hyper   yes reduced no lenses
presbyopic  hyper   yes normal  no lenses
  • pre 之前
  • presbyopic 远视眼的
  • myope 近视眼
  • hyper 超级

运用

testLenses.py

import trees
import treePlotterfr=open('lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.createTree(lenses,lensesLabels)print "lensesTree: "
print lensesTree
#{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}treePlotter.createPlot(lensesTree)

生成的决策树图

总结

上面决策树非常好地匹配了实验数据,然而这些匹配选项可能太多了。我们将这种问题称之为过度匹配Overfitting

为了减少过度匹配问题,可以裁剪决策树,去掉一些不必要的叶子节点。

如果叶子节点只能增加少许信息,则可以删除该节点,将它并入到其他叶子节点中。

上述阐述的是ID3算法,它是一个瑕不掩瑜的算法。

ID3算法无法直接处理数值型int,double的数据,尽管我们可以通过量化的方法将数值型转换为标称型数值,但是如果存在太多的特征划分,ID3算法仍然会面临其他问题。

《机器学习实战》笔记(03):决策树相关推荐

  1. 机器学习实战笔记:决策树(Decision Tree)

    PS 该部分内容所设计到的程序源码已经存在我的github上,地址奉上: https://github.com/AdventureSJ/ML-Notes/tree/master/DecisionTre ...

  2. 天池龙珠训练营-机器学习学习笔记-03 LightGBM 分类

    天池龙珠训练营-机器学习学习笔记-03 LightGBM 分类 本学习笔记为阿里云天池龙珠计划机器学习训练营的学习内容,学习链接为:训练营 一 原理简介: 它是一款基于GBDT(梯度提升决策树)算法的 ...

  3. 机器学习实战笔记(Python实现)-04-Logistic回归

    转自:机器学习实战笔记(Python实现)-04-Logistic回归 转自:简单多元线性回归(梯度下降算法与矩阵法) 转自:人工神经网络(从原理到代码) Step 01 感知器 梯度下降

  4. 机器学习实战3.4决策树项目案例03:使用Sklearn预测隐形眼镜类型

    搜索微信公众号:'AI-ming3526'或者'计算机视觉这件小事' 获取更多人工智能.机器学习干货 csdn:https://blog.csdn.net/baidu_31657889/ github ...

  5. 机器学习实战笔记(Python实现)-01-机器学习实战

    今天发布一篇图片博客,看一下效果如何,如果效果,以后的博客尽量发图片上来. 机器学习实战 本博客来自于CSDN:http://blog.csdn.net/niuwei22007/article/det ...

  6. 机器学习实战笔记(Python实现)-03-朴素贝叶斯

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  7. python实现线性回归预测_机器学习实战笔记(Python实现)-08-线性回归

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  8. 机器学习读书笔记:决策树

    文章目录 如何形成一颗决策树 划分选择 信息熵 & 信息增益 基尼指数 剪枝 预剪枝 后剪枝 连续值 & 属性缺失处理 连续值 属性缺失 多变量决策树 如何形成一颗决策树 ​ 决策树从 ...

  9. 机器学习实战笔记(Python实现)-02-k近邻算法(kNN)

    k近邻算法(kNN) 本博客来源于CSDN:http://blog.csdn.net/niuwei22007/article/details/49703719 本博客源代码下载地址:CSDN免费下载. ...

  10. 机器学习实战——笔记(第一章)

    机器学习基础 目录 机器学习的简单概述 机器学习的主要任务 算法选择与开发步骤 Python语言的优势 一.机器学习的简单概述 机器学习应用领域十分广泛,如人脸识别.推荐系统.手写数字辨识及垃圾邮件过 ...

最新文章

  1. puppet之service管理
  2. Java数据库查询简介
  3. org.simpleframework.xml.core.Persister
  4. Linux 系统线程数量上限是多少?
  5. oracle进程的作用,浅析Oracle10g新进程 MMON 和 MMNL的作用
  6. ListView控件使用简介(转载)
  7. 建立mysql服务器连接失败_解决项目启动无法创建到数据库服务器的连接错误
  8. 分布式session之redis解决方案实现
  9. vuejs 开发中踩到的坑
  10. android定位!每个程序员都必须掌握的8种数据结构!完整版开放下载
  11. rot移位密码c++实现
  12. 百度人脸识别Java版
  13. 详述 Kafka 基本原理
  14. 转载:各个手机尺寸版本
  15. 如何用Mathpix 和 MathType在WPS快速输入数学公式
  16. 小米公司不管老用户的手机了。小米1
  17. 关于有氧运动的误区 你知道几个
  18. 基于GEE(Google earth engine)的 GIMMS NDVI月合成
  19. 3本实体书、10个csdn定制笔记本丨20211101期开奖
  20. php模拟炒股网站源码,stock 模 拟 炒 股 网 站 源 码(WEB版) WEB(ASP,PHP,...) 251万源代码下载- www.pudn.com...

热门文章

  1. 蜂窝注册表和永久存储
  2. JOJ的2042面试题目的数学推导过程
  3. 电脑用linux命令大全,电脑操作时常用的一些Linux命令
  4. Arcmap格式转arcgis的shp格式
  5. 【转】重新打包DebianISO实现无人应答安装(UEFI+BIOS)
  6. 用于科研的移动机器人平台推荐
  7. 第八节: Quartz.Net五大构件之SimpleThreadPool及其四种配置方案
  8. 计算机卡在无法显示网页,我的电脑上网上银行一直“无法显示网页”
  9. python列表添加数字_Python-识别列表中的连续数字组
  10. JS成员函数声明位置优化