注:文中相关截图或内容取自《统计学习方法》李航编
简介:
决策树(decision tree)是基本的分类与回归方法。对分类和回归的理解,通俗的讲就是最终结果是离散的为分类任务,结果是连续的是回归任务。决策树中每个非叶节点用一个(多个)特征进行选择向下探索的分支(每个节点相当于“switch 特征i:”的语句或者多个if 、else if的组合),最后探索至叶节点将实例判定为其分类。
如图为关于是否批准贷款的简单决策树:

决策树中节点的分支数根据该节点所选取的特征的值域确定。

信息增益

1.熵

在节点的特征选择上,需要现了解熵、条件熵和信息增益的概念。当前节点的划分特征选择信息增益最大的特征。

其中pi=训练集中该特征值为xi的样本的数量 / 总数。熵越大,随机变量的不确定性越大。

2.条件熵


pij = 在类yi的样本集中该特征值为xi的样本的数量 / 类yi的样本数量

3.信息增益

生成决策树的步骤

  1. 载入数据,抽象数据特征(数据集D,特征A)
  2. 计算熵
  3. 计算每个特征的信息增益,选择信息增益最大的特征作为划分子树集的依据
  4. 递归构建决策树

    其中(1)(2)(4)为构建叶节点的三种情况,(5)(6)递归构建内节点。

  5. 给样本分类

案例:

下面通过一个实例来实现这个算法。
项目数据下载及说明,如下链接:
http://archive.ics.uci.edu/ml/datasets/Car+Evaluation
请自行下载数据,以及了解数据的相关内容。

数据样例:

Class Values:
unacc, acc, good, vgood

Attributes:
buying: vhigh, high, med, low.
maint: vhigh, high, med, low.
doors: 2, 3, 4, 5more.
persons: 2, 4, more.
lug_boot: small, med, big.
safety: low, med, high.

样本:

vhigh,vhigh,2,2,small,low,unacc
vhigh,vhigh,2,2,small,med,unacc
vhigh,vhigh,2,2,small,high,unacc
vhigh,vhigh,2,2,med,low,unacc
vhigh,vhigh,2,2,med,med,unacc
vhigh,vhigh,2,2,med,high,unacc
,,,

代码实现

部分函数

1、载入数据,抽象数据特征

#从文档中读取数据,每条数据转成列表的形式
def readData(path):dataList = []with open(path,'r') as f:dataSet = f.readlines()for d in dataSet:d = d[:-1]d = d.split(',')dataList.append(d)return dataList#映射属性值,方便代码处理
Cls = {'unacc':0, 'acc':1, 'good':2, 'vgood':3}   #分类值映射
#特征值映射,共6个特征值,每个特征表示为X[i],X[i][xiv]表示特征Xi的取值。
X = [{'vhigh':0, 'high':1, 'med':2, 'low':3},{'vhigh':0, 'high':1, 'med':2, 'low':3},{'2':0, '3':1, '4':2, '5more':3},{'2':0, '4':1, 'more':2},{'small':0, 'med':1, 'big':2},{'low':0, 'med':1, 'high':2}]

2、 计算熵

def CountEachClass(dataSet):numEachClass = [0]*len(Cls)  #列表初始化for d in dataSet:numEachClass[Cls[d[-1]]] +=1return numEachClassdef caculateEntropy(dataSet):NumEachClass = CountEachClass(dataSet)dataNum = len(dataSet)ent = 0for numC in NumEachClass:temp = numC/dataNumif(temp != 0):ent -= temp * log(temp,2)return ent

3、 计算每个特征的信息增益,选择信息增益最大的特征作为划分子树集的依据

def calGain(dataset,xi):    #计算信息增益res = 0ent = caculateEntropy(dataset)subDataSet = splitData(dataset,xi)for xivDataSet in subDataSet:if(xivDataSet):res += len(xivDataSet)/len(dataset) * caculateEntropy(xivDataSet)return ent - resdef getMaxGain(dataSet,usedX=[]):   #获得最大的信息增益值和对应的特征序号gains = []for xi in range(len(X)):if(xi not in usedX):gains.append(calGain(dataSet,xi))else:gains.append(0)mg = max(gains)mx = gains.index(mg)return mx,mgdef splitData(dataSet,xi):subDataSets = [ [] for i in range(len(X[xi]))]  #子数据集列表初始化for d in dataSet:subDataSets[ X[xi][d[xi]] ].append(d)return subDataSets

4、递归构建决策树

def createTree(dataSet,r,usedX=[]):   #以字典的结构构建决策树if (len(dataSet)==0):return {}     #空树tree = {}numEachClass = CountEachClass(dataSet)c = numEachClass.index(max(numEachClass))tree['class'] = c  #该树各分类中数据最多的类,记为该根节点的分类mx,mg = getMaxGain(dataSet,usedX)print("max gain:",mg)if len(usedX) == len(X) or numEachClass[c] == len(dataSet) or mg < r:tree['factureX'] = -1    #不在继续分支,即为叶节点return treeelse:tree['factureX']= mx  #记录该根节点用于划分的特征subDataSet = splitData(dataSet, mx)  #用该特征的值划分子集,用于构建子树for xiv in range(len(X[mx])):xivDataSet = subDataSet[xiv]newusedX = usedX.copy()newusedX.append(mx)tree[xiv] = createTree(xivDataSet,r,newusedX)return tree

5、给样本分类

def classify(tree,data):xi = tree['factureX']  #根节点用于划分子树的特征if(xi>=0):subtree = tree[X[xi][data[xi]]]if subtree=={}: #节点没有该子树return tree['class']  #以该节点的分类作为数据的分类return classify(subtree,data)  #否则遍历子树else: #叶节点return tree['class']

完整代码

from math import log
import numpy as np#从文档中读取数据,每条数据转成列表的形式
def readData(path):dataList = []with open(path,'r') as f:dataSet = f.readlines()for d in dataSet:d = d[:-1]d = d.split(',')dataList.append(d)return dataList#将数据集划分为训练集和测试集
def splitTestData(dataList,testnum):trainData = []testData = []dataNum = len(dataList)pred_ind = np.random.randint(0,dataNum,testnum)for d in pred_ind:testData.append(dataList[d])for d in range(dataNum):if d not in pred_ind:trainData.append(dataList[d])print("dataSetNum:",dataNum,len(trainData),len(testData))return trainData,testData#映射属性值,方便代码处理
Cls = {'unacc':0, 'acc':1, 'good':2, 'vgood':3}   #分类值映射
#特征值映射,共6个特征值,每个特征表示为X[i],X[i][xiv]表示特征Xi的取值。
X = [{'vhigh':0, 'high':1, 'med':2, 'low':3},{'vhigh':0, 'high':1, 'med':2, 'low':3},{'2':0, '3':1, '4':2, '5more':3},{'2':0, '4':1, 'more':2},{'small':0, 'med':1, 'big':2},{'low':0, 'med':1, 'high':2}]def CountEachClass(dataSet):numEachClass = [0]*len(Cls)  #列表初始化for d in dataSet:numEachClass[Cls[d[-1]]] +=1return numEachClassdef caculateEntropy(dataSet):NumEachClass = CountEachClass(dataSet)dataNum = len(dataSet)ent = 0for numC in NumEachClass:temp = numC/dataNumif(temp != 0):ent -= temp * log(temp,2)return entdef splitData(dataSet,xi):subDataSets = [ [] for i in range(len(X[xi]))]  #子数据集列表初始化for d in dataSet:subDataSets[ X[xi][d[xi]] ].append(d)return subDataSetsdef calGain(dataset,xi):    #计算信息增益res = 0ent = caculateEntropy(dataset)subDataSet = splitData(dataset,xi)for xivDataSet in subDataSet:if(xivDataSet):res += len(xivDataSet)/len(dataset) * caculateEntropy(xivDataSet)return ent - resdef getMaxGain(dataSet,usedX=[]):   #获得最大的信息增益值和对应的特征序号gains = []for xi in range(len(X)):if(xi not in usedX):gains.append(calGain(dataSet,xi))else:gains.append(0)mg = max(gains)mx = gains.index(mg)return mx,mgdef createTree(dataSet,r,usedX=[]):   #以字典的结构构建决策树if (len(dataSet)==0):return {}     #空树tree = {}numEachClass = CountEachClass(dataSet)c = numEachClass.index(max(numEachClass))tree['class'] = c  #该树各分类中数据最多的类,记为该根节点的分类mx,mg = getMaxGain(dataSet,usedX)print("max gain:",mg)if len(usedX) == len(X) or numEachClass[c] == len(dataSet) or mg < r:tree['factureX'] = -1    #不在继续分支,即为叶节点return treeelse:tree['factureX']= mx  #记录该根节点用于划分的特征subDataSet = splitData(dataSet, mx)  #用该特征的值划分子集,用于构建子树for xiv in range(len(X[mx])):xivDataSet = subDataSet[xiv]newusedX = usedX.copy()newusedX.append(mx)tree[xiv] = createTree(xivDataSet,r,newusedX)return treedef classify(tree,data):xi = tree['factureX']  #根节点用于划分子树的特征if(xi>=0):subtree = tree[X[xi][data[xi]]]if subtree=={}: #节点没有该子树return tree['class']  #以该节点的分类作为数据的分类return classify(subtree,data)  #否则遍历子树else: #叶节点return tree['class']#测试:
testNum = 100
err = 0
right = 0dataSet = readData('car.data.txt')
trainDataSet,testDataSet = splitTestData(dataSet,testNum)
tree = createTree(trainDataSet,0.2)for d in testDataSet:c = classify(tree,d)if c ==Cls[d[-1]]:right +=1else:err +=1print("分类:",c)print('实际分类',Cls[d[-1]])print("err:",err,"right:",right)
print("total:",testNum)
print("错误率:",err/testNum)

运行结果:

...
...
分类: 0
实际分类 0
分类: 1
实际分类 1
err: 3 right: 97
total: 100
错误率: 0.03

机器学习--决策树python实现案例相关推荐

  1. python鸢尾花数据集_鸢尾花经典机器学习分类Python实现案例

    作者|Nature 出品|AI机器思维 由Fisher在1936年整理的Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例. 其数据集变量包含4个特征(Sepal.Len ...

  2. python决策树 value_直播案例 | 决策树、随机森林和 AdaBoost 的 Python 实现

    获取案例链接.直播课件.数据集在本公众号内发送"机器学习". 本案例使用 Python 逐步实现了三种基于树的模型:分类回归决策树(CART).随机森林和 AdaBoost .在实 ...

  3. 机器学习——决策树模型:Python实现

    机器学习--决策树模型:Python实现 1 决策树模型的代码实现 1.1 分类决策树模型(DecisionTreeClassifier) 1.2 回归决策树模型(DecisionTreeRegres ...

  4. 史上最简单的spark教程第二十三章-运行第一个机器学习Java和Python代码案例

    [提前声明] 文章由作者:张耀峰 结合自己生产中的使用经验整理,最终形成简单易懂的文章 写作不易,转载请注明,谢谢! 代码案例地址: ?https://github.com/Mydreamandrea ...

  5. 机器学习实验二---决策树python

    机器学习实验二---决策树python 一.了解一下决策树吧 决策树基本流程 信息增益 决策树的优缺点 二.数据处理 三.决策树的构建 计算给定数据集的香农熵 按照给定特征划分数据集 选择最好的数据划 ...

  6. 机器学习 学习曲线 Python实现学习曲线及案例解析

    机器学习 学习曲线 Python实现学习曲线及案例解析 学习曲线 如果数据集的大小为 mmm,则通过下面的流程即可画出学习曲线: 把数据集分成训练数据集和交叉验证数据集. 取训练数据集的 20%20\ ...

  7. 模式识别与机器学习(Python实现):决策树分男女

    模式识别与机器学习(Python实现):决策树分男女 欢迎大家来到安静到无声的<模式识别与人工智能(程序与算法)>,如果对所写内容感兴趣请看模式识别与人工智能(程序与算法)系列讲解 - 总 ...

  8. 机器学习-决策树之回归树python实战(预测泰坦尼克号幸存情况)(三)

    本文用通俗易懂的方式来讲解分类树中的回归树,并以"一维回归的图像绘制"和"泰坦尼克号幸存者预测"两个例子来说明该算法原理. 以下是本文大纲: 1 Decisio ...

  9. 机器学习决策树DecisionTree以及python代码实现

    机器学习决策树DecisionTree以及python代码实现 1.基本算法原理 2.选择最优特征进行划分 2.1信息增益 2.2信息增益率 2.3基尼系数 4.连续值以及缺失值的处理 4.1连续值的 ...

最新文章

  1. ActivityGroup是如何对嵌入的Activitys进行管理的
  2. 千米感知误差低于5%,嬴彻发布全球领先的超长距精准3D感知技术
  3. python读什么英文-django的英文读法是什么
  4. 【剑指offer】面试题49:丑数
  5. html推箱子怎么清除走过的,第九讲:HTML5该canvas推箱子原型实现
  6. 微信小程序简介、发展史、小程序的优点、申请账号、开发工具、初识wxml文件和wxss文件
  7. java方法中与参数怎么调用_与Java方法调用中的类型参数有关的问题
  8. C 语言实例 - 计算 int, float, double 和 char 字节大小
  9. java类静态初始化_Java静态代码块和类初始化、实例初始化过程
  10. J2EE Architecture(6)
  11. [POJ3580]SuperMemo
  12. 6.Linux性能诊断 --- 远程通信gRPC,kafka,docker
  13. spring:注解配置AOP
  14. The process cannot access the file '' because it is being used by another process.....
  15. 苹果电脑系统如果删除驱动
  16. SwiftUI调用UIKit
  17. unity——UI拖拽实现拼图
  18. 润乾报表Api导出word只读
  19. 企业邮箱给国外发邮件注册哪个好?如何群发邮件?
  20. 安卓+ios系统--手机端页面自适应手机屏幕大小,禁止手动放大和缩小VUE

热门文章

  1. 获取一个字符串的拼音码
  2. SQL语句中except是怎样用的?
  3. C++ 实现俄罗斯方块(附详细解析)
  4. 锐捷——Telent登录时使用 用户名及密码登陆路由器
  5. UDP协议、UDP和TCP优缺点对比
  6. 74hc595级联实现原理
  7. 计算机基础word目录操作题,Word综合操作题 计算机基础
  8. 最好的网站宣传方法:网摘精灵
  9. 优维科技实力入选《2023深圳金融业信息技术融合创新案例汇编》
  10. 2017届中兴综合面试