python决策树代码实现

实现一个简单的决策树,可以同时处理属性值是连续和离散的情况。
使用sklearn里面的鸢尾花等数据集验证,正确率还不错(90%+)

Github地址:https://github.com/nhjydywd/DecisionTree


使用方式:


import DecisionTreenode = DecisionTree.trainDecisionTree(labels, attrs)result = node.predict(attr)

以下为决策树的代码(DecisionTree.py):

import numpy as npdef trainDecisionTree(np_label, np_attrs):print("Data shape: " + str(np.shape(np_attrs)))# To decide whether an attribute is discreteb_discrete = []TH_DISCRETE = 10for i in range(0,np.shape(np_attrs)[1]):s = set()b_discrete.append(True)col = np_attrs[:,i]for x in col:s.add(x)if(len(s) > TH_DISCRETE):b_discrete[-1] = Falsenode = TreeNode()processTreeNode(node, np_label, np_attrs, b_discrete)return nodedef compareEqual(left, right):return left == right# def compareNotEqual(left, right):
#     return left != rightdef compareLessOrEqual(left, right):return left <= right# def compareBiggerOrEqual(left, right):
#     return left >= rightclass TreeNode:def __init__(self):self.label = Noneself.lChild = Noneself.rChild = Noneself.compareIndexAttr = Noneself.compareValue = Noneself.compareMethod = Nonedef accept(self, attrs):attr = attrs[self.compareIndexAttr]if self.compareMethod(attr, self.compareValue):return Truereturn Falsedef predict(self, attrs):if(self.label != None):return self.labelif self.lChild.accept(attrs):return self.lChild.predict(attrs)else:return self.rChild.predict(attrs)# Impossible!print("TreeNode Error: no child accept!")print("arrts is: " + attrs)exit(-1)def devide(np_label, np_attrs, compareIndexAttr, compareMethod, compareValue):left_label = []left_attrs = []right_label = []right_attrs = []for i in range(0,np.shape(np_attrs)[0]):value = np_attrs[i][compareIndexAttr]label = np_label[i]attr = np_attrs[i]if(compareMethod(value, compareValue)):left_label.append(label)left_attrs.append(attr)else:right_label.append(label)right_attrs.append(attr)left_np_label = np.array(left_label)left_np_attrs = np.array(left_attrs)right_np_label = np.array(right_label)right_np_attrs = np.array(right_attrs)return left_np_label, left_np_attrs, right_np_label, right_np_attrsdef countDistinctValues(np_values):s = dict()for v in np_values:if v in s:s[v] += 1else:s[v] = 1return sdef findDevidePoint(np_label, np_attrs, indexAttr, bDiscrete):if bDiscrete:compareMethod = compareEqualcandidateValue = countDistinctValues(np_attrs[:,indexAttr])else:compareMethod = compareLessOrEqualsorted_a = (np_attrs[np_attrs[:,indexAttr].argsort()])[:,indexAttr]candidateValue = []for i in range(0, len(sorted_a) - 1):v = (sorted_a[i] + sorted_a[i+1]) / 2candidateValue.append(v)minGiniIndex = 1for v in candidateValue:l_label, l_attr, r_label, r_attr = devide(np_label, np_attrs, indexAttr, compareMethod, v)ls_label = [l_label, r_label]theGiniIndex = giniIndex(ls_label)if theGiniIndex < minGiniIndex:minGiniIndex = theGiniIndexcompareValue = vreturn compareMethod, compareValue, minGiniIndexdef processTreeNode(node, np_label, np_attrs, b_discrete):if len(np_label) != len(np_attrs):print("Error: label size != attr size")exit(-1)if len(np_label) <= 0:print("Error: label size <= 0!")exit(-1)if np.shape(np_attrs)[1] != len(b_discrete):print("Error: numbers of attrs != length of b_discrete!")exit(-1)if isArrayElementIdentity(np_label):node.label = np_label[0]returnNUM_END = 5;if len(np_label) <= NUM_END:node.label = getMostElement(np_label)returnif len(np_label) > 1000:print("Current recursion data size: " + str(len(np_label)))# Find the best attribute to divide.minGiniIndex = 1# ls_thread = []for i in range(0, np.shape(np_attrs)[1]):compareMethod, compareValue, giniIndex = findDevidePoint(np_label, np_attrs, i, b_discrete[i])if giniIndex < minGiniIndex:minGiniIndex = giniIndexchooseAttrIndex = ichooseCompareMethod = compareMethodchooseCompareValue = compareValue# Divide the datasetl_label, l_attrs, r_label, r_attrs = devide(np_label,np_attrs,chooseAttrIndex,chooseCompareMethod,chooseCompareValue)# Generate subtreesnode.lChild = TreeNode()node.lChild.compareIndexAttr = chooseAttrIndexnode.lChild.compareMethod = chooseCompareMethodnode.lChild.compareValue = chooseCompareValueif np.shape(l_label)[0] == 0:node.lChild.label = getMostElement(np_label)else:processTreeNode(node.lChild, l_label, l_attrs, b_discrete)node.rChild = TreeNode()if np.shape(r_label)[0] == 0:node.rChild.label = getMostElement(np_label)else:processTreeNode(node.rChild, r_label, r_attrs, b_discrete)def isArrayElementIdentity(np_array):e = np_array[0]for x in np_array:if x != e:return Falsereturn Truedef getMostElement(np_array):dictCount = {}for x in np_array:if x in dictCount.keys():dictCount[x] += 1else:dictCount[x] = 1max = -1result = Nonefor key in dictCount:if dictCount[key] > max:result = keymax = dictCount[key]return resultdef gini(ls_p):result = 1for p in ls_p:result -= p*preturn resultdef giniIndex(ls_devide_np_label):countTotal = 0for np_label in ls_devide_np_label:countTotal += np.shape(np_label)[0]result = 0for np_label in ls_devide_np_label:countValues = countDistinctValues(np_label)ls_p = []for v in countValues:p = countValues[v] / np.shape(np_label)[0]ls_p.append(p)result += gini(ls_p) * np.shape(np_label)[0] / countTotalreturn result

python决策树代码实现相关推荐

  1. python决策树代码解读_建模分析之机器学习算法(附pythonR代码)

    0序 随着移动互联和大数据的拓展越发觉得算法以及模型在设计和开发中的重要性.不管是现在接触比较多的安全产品还是大互联网公司经常提到的人工智能产品(甚至人类2045的的智能拐点时代).都基于算法及建模来 ...

  2. Python决策树代码

    说明 这个是今天上课的代码,记录一下吧,以前都没有这种意识,学过就忘了. 环境 这里用的是anaconda,要用到这里面的代码,还需要下载额外的软件.当你下载好了anaconda之后在"开始 ...

  3. python神经网络代码

    说明 这里有一点点神经网络的代码,写出来记录一下. 环境 Anaconda3-5.2.0 对运行环境有疑问的可以看我的Python决策树代码那块,会比较详细一点 重要内容 这里要补充一点重要内容,就是 ...

  4. 【机器学习入门】(5) 决策树算法实战:sklearn实现决策树,实例应用(沉船幸存者预测)附python完整代码及数据集

    各位同学好,今天和大家分享一下python机器学习中的决策树算法,在上一节中我介绍了决策树算法的基本原理,这一节,我将通过实例应用带大家进一步认识这个算法.文末有完整代码和数据集,需要的自取.那我们开 ...

  5. 决策树代码代码——python源代码,看完你就完全懂了

    决策树 决策树在周志华的西瓜书里面已经介绍的很详细了(西瓜书P73-P79),那也是我看过讲的最清楚的决策树讲解了,我这里就不献丑了,这篇文章主要是分享决策树的代码. 在西瓜书中介绍了三种决策树,分别 ...

  6. 【机器学习入门】(13) 实战:心脏病预测,补充: ROC曲线、精确率--召回率曲线,附python完整代码和数据集

    各位同学好,经过前几章python机器学习的探索,想必大家对各种预测方法也有了一定的认识.今天我们来进行一次实战,心脏病病例预测,本文对一些基础方法就不进行详细解释,有疑问的同学可以看我前几篇机器学习 ...

  7. python基础代码库-Python基础数据处理库-NumPy

    最近更新:2017-07-19 NumPy是Python做数据处理的底层库,是高性能科学计算和数据分析的基础,比如著名的Python机器学习库SKlearn就需要NumPy的支持.掌握NumPy的基础 ...

  8. 什么是CART算法?怎么对CART进行建树?怎么对CART进行减枝叶?CART Python实现代码

    什么是CART算法?怎么对CART进行建树?怎么对CART进行减枝叶?CART Python实现代码 一.什么是决策树? 二.什么是CART树 三.基尼指数 四.基尼指数在这里为什么Gini(D)系数 ...

  9. python决策树 多分类_Python中的决策树分类:您需要了解的一切

    python决策树 多分类 什么是决策树? (What is Decision Tree?) A decision tree is a decision support tool that uses ...

最新文章

  1. oracle性能调优学习0622
  2. 查看数据库、表、索引的物理存储情况
  3. Spring体系结构详解
  4. 动画图解 socket 缓冲区的那些事儿
  5. os-enviroment
  6. Linux中sudo命令设置,Linux下sudo命令的配置与使用方法
  7. PonyAI进军自动驾驶货运,乘用无人车历史性“小马过河”
  8. swing JTable
  9. 【Prufer Sequence +简单排列组合】bzoj 1005: [HNOI2008]明明的烦恼
  10. 如何理解C++中的.h文件和.cpp文件
  11. Javac源码简单分析之解析和填充符号表
  12. ubuntu系统下抓取屏幕
  13. hive中的TextFile转为SequenceFile
  14. Python笔记 | 角谷猜想
  15. 大数据查询工具HBase读写设计与实践
  16. Diagnosing OSGi uses conflicts
  17. html5设计制作作品,16个精美的 HTML5 作品集网站设计案例
  18. PrincipalComponentAnalysis 主成分分析
  19. 学习笔记 | 面对海量数据,为什么无法设计出完美的分布式缓存体系?
  20. Arcgis打包工程文件(map package)

热门文章

  1. IFrame与Frame分析
  2. 纯科普篇!安全防盗电子围栏这些事儿
  3. 事件相关电位(ERP)的简单处理流程(含MATLAB代码)
  4. opencv实战,钢板焊接点寻找1
  5. html 写入 doc 页边距,word文档边距怎么设置 word文档内容两边的间距怎么调?
  6. 固态硬盘的PCIE,SATA,M2,NVMe,AHCI分别都指什么?
  7. python3使用requests登录人人影视网站
  8. iOS-Skill技巧
  9. 如何用纯 CSS 创作从按钮两侧滑入装饰元素的悬停特效
  10. mysql workbench 显示查询结果_MySQLWorkbench如何导出查询结果?(图文)