Project Address:

https://github.com/TheOneAC/ML.git

dataset in ML/ML_ation/tree

 决策树

  • 计算复杂度低,中间值缺失不敏感,可理解不相关数据
  • 可能过度匹配(过度分类)
  • 适用:数值型和标称型

决策树伪代码createbranch

检测数据集中子项是否全部属于一类if so return class_tagelse 寻找数据集最佳划分特征划分数据集创建分支节点对每一个子集,递归调用createbranch返回分支节点

递归结束条件:所有属性遍历完,或者数据集属于同一分类

香农熵

def calcShannonEnt(dataSet):numEntries = len(dataSet)labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEnt -= prob * log(prob,2)return shannonEnt

数据及划分与最优选择(熵最小)

def splitDataSet(dataSet, axis, value):retDataSet = []for featVec in dataSet:if featVec[axis] == value:reduceFeatVec = featVec[:axis]reduceFeatVec.extend(featVec[axis + 1:])retDataSet.append(reduceFeatVec)return retDataSetdef chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0])- 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif infoGain > bestInfoGain:baseInfoGain = infoGainbestFeature = ireturn bestFeature

所有标签用尽无法确定类标签时: 多数表决决定子叶分类


def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys(): classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)return sortedClassCount[0][0]

创建树


def createTree(dataSet, labels):classList = [example[-1] for example in dataSet]if classList.count(classList[0]) == len(classList):return classList[0]if len(dataSet[0]) == 1:return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatureLabel = labels[bestFeat]myTree = {bestFeatureLabel:{}}del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeat,value), subLabels)return myTree

测试

def classify(inputTree,featLabels,testVec):firstStr = inputTree.keys()[0]secondDict = inputTree[firstStr]featIndex = featLabels.index(firstStr)for key in secondDict.keys():if testVec[featIndex] == key:if type(secondDict[key]).__name__=='dict':classLabel = classify(secondDict[key],featLabels,testVec)else:classLabel = secondDict[key]return classLabel
>>> import trees
>>> myDat,labels=trees.createDataSet()
>>> labels
['no surfacing', 'flippers']
>>> myTree=treePlotter.retrieveTree (0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> trees.classify(myTree,labels,[1,0])
'no'
>>> trees.classify(myTree,labels,[1,1])
'yes'

 存储与重载

def storeTree(inputTree, filename):import picklefw = open(filename, 'w')pickle.dump(inputTree,fw)fw.close()def grabTree(filename):import picklefr = open(filename)return pickle.load(fr)

 test

#!/usr/bin/python
import treesmyDat,labels = trees.createDataSet()myTree = trees.createTree(myDat, labels)trees.storeTree(myTree,'classifierStorage.txt')print(trees.grabTree('classifierStorage.txt'))

图形化显示树结构

#!/usr/bin/pythonimport matplotlib.pyplot as plt decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = "axes fraction",xytext = centerPt, textcoords = "axes fraction",va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)

创建节点

def createPlot():fig = plt.figure(1, facecolor = "white")fig.clf()createPlot.ax1 = plt.subplot(111, frameon = False)plotNode("a decision node",(0.5, 0.1), (0.1, 0.5), decisionNode)plotNode("a leaf node",(0.8, 0.1), (0.3, 0.8), leafNode)plt.show()

python command line run command as this

import treeplotter
treePlotter.createPlot()
  • result like this
def getNumLeafs(myTree):numLeafs = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumleafs(secondDict[key])else: numLeafs +=1return numLeafsdef getTreeDepth(myTree):maxDepth = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1+ getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth > maxDepth: maxDepth = thisDepthreturn maxDepthdef retrieveTree(i):listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': \{0: 'no', 1: 'yes'}}}},{'no surfacing': {0: 'no', 1: {'flippers': \{0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]return listOfTrees[i]def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = myTree.keys()[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,\plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__=='dict':plotTree(secondDict[key],cntrPt,str(key))else:plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;plotTree(inTree, (0.5,1.0), '')plt.show()

扩展测试 lens.py

Project Address: ` https://github.com/TheOneAC/ML.git`dataset:  `lens.txt in ML/ML_ation/tree`
#!/usr/bin/pythonimport trees
import treePlotterfr = open("lenses.txt")
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age', 'prescript', 'astigmatic', 'tearRate']lensesTree = trees.createTree(lenses,lensesLabels)
print(lensesTree)treePlotter.createPlot(lensesTree)

转载于:https://www.cnblogs.com/zeroArn/p/6691287.html

ML in Action 决策树相关推荐

  1. A.机器学习入门算法(五):基于企鹅数据集的决策树分类预测

    [机器学习入门与实践]入门必看系列,含数据挖掘项目实战:数据融合.特征优化.特征降维.探索性分析等,实战带你掌握机器学习数据挖掘 专栏详细介绍:[机器学习入门与实践]合集入门必看系列,含数据挖掘项目实 ...

  2. Spark之MLlib

    目录 Part VI. Advanced Analytics and Machine Learning Advanced Analytics and Machine Learning Overview ...

  3. 台大林轩田·机器学习技法记要

    台大林轩田·机器学习技法 记要 6/1/2016 7:42:34 PM 第一讲 线性SVM 广义的SVM,其实就是二次规划问题 把SVM问题对应到二次规划的系数 这就是线性SVM,如果想变成非线性,只 ...

  4. 手把手教你使用Flask轻松部署机器学习模型(附代码链接) | CSDN博文精选

    作者 | Abhinav Sagar 翻译 | 申利彬 校对 | 吴金笛 来源 | 数据派THU(ID:DatapiTHU) 本文旨在让您把训练好的机器学习模型通过Flask API 投入到生产环境  ...

  5. 独家 | 手把手教你如何使用Flask轻松部署机器学习模型(附代码链接)

    作者:Abhinav Sagar 翻译:申利彬 校对:吴金笛 本文约2700字,建议阅读7分钟. 本文可以让你把训练好的机器学习模型使用Flask API 投入生产环境. 本文旨在让您把训练好的机器学 ...

  6. Oracle 20c 新特性:XGBoost 机器学习算法和 AutoML 的支持

    墨墨导读:XGBoost是一个高效.可扩展的机器学习算法,用于回归和分类(regression and classification),使得XGBoost Gradient Boosting开源包可用 ...

  7. 机器学习在脊柱的应用现状:从临床的观点

    Current Applications of Machine Learning in Spine: From Clinical View 机器学习在脊柱的应用现状:从临床观点 ----------- ...

  8. 野人与传教士——宽度优先搜索(完整报告,含全部代码)

    题目: 野人与传教士渡河问题:3个野人与3个传教士打算乘一条船到对岸去,该船一次最多能运2个人,在任何时候野人人数超过传教士人数,野人就会把传教士吃掉,如何用这条船把所有人安全的送到对岸?在实现基本程 ...

  9. 自回避随机行走问题 c语言,醉汉随机行走/随机漫步问题(Random Walk Randomized Algorithm Python)...

    世界上有些问题看似是随机的(stochastic),没有规律可循,但很可能是人类还未发现和掌握这类事件的规律,所以说它们是随机发生的. 随机漫步(Random  Walk)是一种解决随机问题的方法,它 ...

最新文章

  1. REVERSE-PRACTICE-BUUCTF-25
  2. matlab数值很小出错,求大神帮忙解决一下,用MATLAB求解动力学数据总是出错~ - 计算模拟 - 小木虫 - 学术 科研 互动社区...
  3. NSString / NSMutableString 字符串处理,常用代码 (实例)
  4. c语言判断回文字符串递归,用递归实现判断一个字符串是否为回文串
  5. 如何让char不要忽略开头的空格_如何使用C语言实现JSON解析库(二)
  6. ho1365_共享力量的四种方法,而不是ho积
  7. 【华为云技术分享】opensuse使用zypper安装软件
  8. oracle在指定列后添加列,oracle添加列到指定位置
  9. php 生成 rtf,PHP 生成Word文档,ODT文档,RTF文档
  10. 5种2D Attention整理(Non-Local、Criss-Cross、SE、CBAM、Dual-Attention)
  11. 2021-2022-2 ACM集训队每周程序设计竞赛(1) - 问题 B: 蹩脚两轮车 - 题解
  12. UnityShader笔记第三课-MVP矩阵原理-M矩阵
  13. python恶搞代码打开对方摄像头_Python 3 利用 Dlib 实现摄像头实时人脸检测和平铺显示...
  14. MATLAB系统仿真其三:Ornstein-Uhlenbeck(OU)噪声
  15. socket.io实现聊天功能——第一章 、群聊
  16. 网络攻防-20169213-刘晶-第五周作业
  17. 安利4款支持Linux的实用绘图软件
  18. Debian10更换软件源
  19. ts promise
  20. iptables匹配iprange

热门文章

  1. 【Python】学习笔记7-异常处理try。。except .. as e ....else
  2. 6.24AppCan移动开发者大会价值30万的展示机会归了谁?
  3. Java异常处理和常用类
  4. 功能强大的滚动播放插件JQ-Slide
  5. 使用firefox遇到的问题
  6. 迈克尔·乔丹,无可复制的篮球之神!
  7. 记一次程序员在办公室里的“撕逼”经历
  8. java中io与nio复制文件性能对比
  9. $(document).ready、body.Onload()和 $(window).load的区别
  10. iOS:(接口适配器3)--iPhone适应不同型号 6/6plus 前