机器学习--决策树python实现案例
注:文中相关截图或内容取自《统计学习方法》李航编
简介:
决策树(decision tree)是基本的分类与回归方法。对分类和回归的理解,通俗的讲就是最终结果是离散的为分类任务,结果是连续的是回归任务。决策树中每个非叶节点用一个(多个)特征进行选择向下探索的分支(每个节点相当于“switch 特征i:”的语句或者多个if 、else if的组合),最后探索至叶节点将实例判定为其分类。
如图为关于是否批准贷款的简单决策树:
决策树中节点的分支数根据该节点所选取的特征的值域确定。
信息增益
1.熵
在节点的特征选择上,需要现了解熵、条件熵和信息增益的概念。当前节点的划分特征选择信息增益最大的特征。
其中pi=训练集中该特征值为xi的样本的数量 / 总数。熵越大,随机变量的不确定性越大。
2.条件熵
pij = 在类yi的样本集中该特征值为xi的样本的数量 / 类yi的样本数量
3.信息增益
生成决策树的步骤
- 载入数据,抽象数据特征(数据集D,特征A)
- 计算熵
- 计算每个特征的信息增益,选择信息增益最大的特征作为划分子树集的依据
递归构建决策树
其中(1)(2)(4)为构建叶节点的三种情况,(5)(6)递归构建内节点。给样本分类
案例:
下面通过一个实例来实现这个算法。
项目数据下载及说明,如下链接:
http://archive.ics.uci.edu/ml/datasets/Car+Evaluation
请自行下载数据,以及了解数据的相关内容。
数据样例:
Class Values:
unacc, acc, good, vgood
Attributes:
buying: vhigh, high, med, low.
maint: vhigh, high, med, low.
doors: 2, 3, 4, 5more.
persons: 2, 4, more.
lug_boot: small, med, big.
safety: low, med, high.
样本:
vhigh,vhigh,2,2,small,low,unacc
vhigh,vhigh,2,2,small,med,unacc
vhigh,vhigh,2,2,small,high,unacc
vhigh,vhigh,2,2,med,low,unacc
vhigh,vhigh,2,2,med,med,unacc
vhigh,vhigh,2,2,med,high,unacc
,,,
…
…
代码实现
部分函数
1、载入数据,抽象数据特征
#从文档中读取数据,每条数据转成列表的形式
def readData(path):dataList = []with open(path,'r') as f:dataSet = f.readlines()for d in dataSet:d = d[:-1]d = d.split(',')dataList.append(d)return dataList#映射属性值,方便代码处理
Cls = {'unacc':0, 'acc':1, 'good':2, 'vgood':3} #分类值映射
#特征值映射,共6个特征值,每个特征表示为X[i],X[i][xiv]表示特征Xi的取值。
X = [{'vhigh':0, 'high':1, 'med':2, 'low':3},{'vhigh':0, 'high':1, 'med':2, 'low':3},{'2':0, '3':1, '4':2, '5more':3},{'2':0, '4':1, 'more':2},{'small':0, 'med':1, 'big':2},{'low':0, 'med':1, 'high':2}]
2、 计算熵
def CountEachClass(dataSet):numEachClass = [0]*len(Cls) #列表初始化for d in dataSet:numEachClass[Cls[d[-1]]] +=1return numEachClassdef caculateEntropy(dataSet):NumEachClass = CountEachClass(dataSet)dataNum = len(dataSet)ent = 0for numC in NumEachClass:temp = numC/dataNumif(temp != 0):ent -= temp * log(temp,2)return ent
3、 计算每个特征的信息增益,选择信息增益最大的特征作为划分子树集的依据
def calGain(dataset,xi): #计算信息增益res = 0ent = caculateEntropy(dataset)subDataSet = splitData(dataset,xi)for xivDataSet in subDataSet:if(xivDataSet):res += len(xivDataSet)/len(dataset) * caculateEntropy(xivDataSet)return ent - resdef getMaxGain(dataSet,usedX=[]): #获得最大的信息增益值和对应的特征序号gains = []for xi in range(len(X)):if(xi not in usedX):gains.append(calGain(dataSet,xi))else:gains.append(0)mg = max(gains)mx = gains.index(mg)return mx,mgdef splitData(dataSet,xi):subDataSets = [ [] for i in range(len(X[xi]))] #子数据集列表初始化for d in dataSet:subDataSets[ X[xi][d[xi]] ].append(d)return subDataSets
4、递归构建决策树
def createTree(dataSet,r,usedX=[]): #以字典的结构构建决策树if (len(dataSet)==0):return {} #空树tree = {}numEachClass = CountEachClass(dataSet)c = numEachClass.index(max(numEachClass))tree['class'] = c #该树各分类中数据最多的类,记为该根节点的分类mx,mg = getMaxGain(dataSet,usedX)print("max gain:",mg)if len(usedX) == len(X) or numEachClass[c] == len(dataSet) or mg < r:tree['factureX'] = -1 #不在继续分支,即为叶节点return treeelse:tree['factureX']= mx #记录该根节点用于划分的特征subDataSet = splitData(dataSet, mx) #用该特征的值划分子集,用于构建子树for xiv in range(len(X[mx])):xivDataSet = subDataSet[xiv]newusedX = usedX.copy()newusedX.append(mx)tree[xiv] = createTree(xivDataSet,r,newusedX)return tree
5、给样本分类
def classify(tree,data):xi = tree['factureX'] #根节点用于划分子树的特征if(xi>=0):subtree = tree[X[xi][data[xi]]]if subtree=={}: #节点没有该子树return tree['class'] #以该节点的分类作为数据的分类return classify(subtree,data) #否则遍历子树else: #叶节点return tree['class']
完整代码
from math import log
import numpy as np#从文档中读取数据,每条数据转成列表的形式
def readData(path):dataList = []with open(path,'r') as f:dataSet = f.readlines()for d in dataSet:d = d[:-1]d = d.split(',')dataList.append(d)return dataList#将数据集划分为训练集和测试集
def splitTestData(dataList,testnum):trainData = []testData = []dataNum = len(dataList)pred_ind = np.random.randint(0,dataNum,testnum)for d in pred_ind:testData.append(dataList[d])for d in range(dataNum):if d not in pred_ind:trainData.append(dataList[d])print("dataSetNum:",dataNum,len(trainData),len(testData))return trainData,testData#映射属性值,方便代码处理
Cls = {'unacc':0, 'acc':1, 'good':2, 'vgood':3} #分类值映射
#特征值映射,共6个特征值,每个特征表示为X[i],X[i][xiv]表示特征Xi的取值。
X = [{'vhigh':0, 'high':1, 'med':2, 'low':3},{'vhigh':0, 'high':1, 'med':2, 'low':3},{'2':0, '3':1, '4':2, '5more':3},{'2':0, '4':1, 'more':2},{'small':0, 'med':1, 'big':2},{'low':0, 'med':1, 'high':2}]def CountEachClass(dataSet):numEachClass = [0]*len(Cls) #列表初始化for d in dataSet:numEachClass[Cls[d[-1]]] +=1return numEachClassdef caculateEntropy(dataSet):NumEachClass = CountEachClass(dataSet)dataNum = len(dataSet)ent = 0for numC in NumEachClass:temp = numC/dataNumif(temp != 0):ent -= temp * log(temp,2)return entdef splitData(dataSet,xi):subDataSets = [ [] for i in range(len(X[xi]))] #子数据集列表初始化for d in dataSet:subDataSets[ X[xi][d[xi]] ].append(d)return subDataSetsdef calGain(dataset,xi): #计算信息增益res = 0ent = caculateEntropy(dataset)subDataSet = splitData(dataset,xi)for xivDataSet in subDataSet:if(xivDataSet):res += len(xivDataSet)/len(dataset) * caculateEntropy(xivDataSet)return ent - resdef getMaxGain(dataSet,usedX=[]): #获得最大的信息增益值和对应的特征序号gains = []for xi in range(len(X)):if(xi not in usedX):gains.append(calGain(dataSet,xi))else:gains.append(0)mg = max(gains)mx = gains.index(mg)return mx,mgdef createTree(dataSet,r,usedX=[]): #以字典的结构构建决策树if (len(dataSet)==0):return {} #空树tree = {}numEachClass = CountEachClass(dataSet)c = numEachClass.index(max(numEachClass))tree['class'] = c #该树各分类中数据最多的类,记为该根节点的分类mx,mg = getMaxGain(dataSet,usedX)print("max gain:",mg)if len(usedX) == len(X) or numEachClass[c] == len(dataSet) or mg < r:tree['factureX'] = -1 #不在继续分支,即为叶节点return treeelse:tree['factureX']= mx #记录该根节点用于划分的特征subDataSet = splitData(dataSet, mx) #用该特征的值划分子集,用于构建子树for xiv in range(len(X[mx])):xivDataSet = subDataSet[xiv]newusedX = usedX.copy()newusedX.append(mx)tree[xiv] = createTree(xivDataSet,r,newusedX)return treedef classify(tree,data):xi = tree['factureX'] #根节点用于划分子树的特征if(xi>=0):subtree = tree[X[xi][data[xi]]]if subtree=={}: #节点没有该子树return tree['class'] #以该节点的分类作为数据的分类return classify(subtree,data) #否则遍历子树else: #叶节点return tree['class']#测试:
testNum = 100
err = 0
right = 0dataSet = readData('car.data.txt')
trainDataSet,testDataSet = splitTestData(dataSet,testNum)
tree = createTree(trainDataSet,0.2)for d in testDataSet:c = classify(tree,d)if c ==Cls[d[-1]]:right +=1else:err +=1print("分类:",c)print('实际分类',Cls[d[-1]])print("err:",err,"right:",right)
print("total:",testNum)
print("错误率:",err/testNum)
运行结果:
...
...
分类: 0
实际分类 0
分类: 1
实际分类 1
err: 3 right: 97
total: 100
错误率: 0.03
机器学习--决策树python实现案例相关推荐
- python鸢尾花数据集_鸢尾花经典机器学习分类Python实现案例
作者|Nature 出品|AI机器思维 由Fisher在1936年整理的Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例. 其数据集变量包含4个特征(Sepal.Len ...
- python决策树 value_直播案例 | 决策树、随机森林和 AdaBoost 的 Python 实现
获取案例链接.直播课件.数据集在本公众号内发送"机器学习". 本案例使用 Python 逐步实现了三种基于树的模型:分类回归决策树(CART).随机森林和 AdaBoost .在实 ...
- 机器学习——决策树模型:Python实现
机器学习--决策树模型:Python实现 1 决策树模型的代码实现 1.1 分类决策树模型(DecisionTreeClassifier) 1.2 回归决策树模型(DecisionTreeRegres ...
- 史上最简单的spark教程第二十三章-运行第一个机器学习Java和Python代码案例
[提前声明] 文章由作者:张耀峰 结合自己生产中的使用经验整理,最终形成简单易懂的文章 写作不易,转载请注明,谢谢! 代码案例地址: ?https://github.com/Mydreamandrea ...
- 机器学习实验二---决策树python
机器学习实验二---决策树python 一.了解一下决策树吧 决策树基本流程 信息增益 决策树的优缺点 二.数据处理 三.决策树的构建 计算给定数据集的香农熵 按照给定特征划分数据集 选择最好的数据划 ...
- 机器学习 学习曲线 Python实现学习曲线及案例解析
机器学习 学习曲线 Python实现学习曲线及案例解析 学习曲线 如果数据集的大小为 mmm,则通过下面的流程即可画出学习曲线: 把数据集分成训练数据集和交叉验证数据集. 取训练数据集的 20%20\ ...
- 模式识别与机器学习(Python实现):决策树分男女
模式识别与机器学习(Python实现):决策树分男女 欢迎大家来到安静到无声的<模式识别与人工智能(程序与算法)>,如果对所写内容感兴趣请看模式识别与人工智能(程序与算法)系列讲解 - 总 ...
- 机器学习-决策树之回归树python实战(预测泰坦尼克号幸存情况)(三)
本文用通俗易懂的方式来讲解分类树中的回归树,并以"一维回归的图像绘制"和"泰坦尼克号幸存者预测"两个例子来说明该算法原理. 以下是本文大纲: 1 Decisio ...
- 机器学习决策树DecisionTree以及python代码实现
机器学习决策树DecisionTree以及python代码实现 1.基本算法原理 2.选择最优特征进行划分 2.1信息增益 2.2信息增益率 2.3基尼系数 4.连续值以及缺失值的处理 4.1连续值的 ...
最新文章
- ActivityGroup是如何对嵌入的Activitys进行管理的
- 千米感知误差低于5%,嬴彻发布全球领先的超长距精准3D感知技术
- python读什么英文-django的英文读法是什么
- 【剑指offer】面试题49:丑数
- html推箱子怎么清除走过的,第九讲:HTML5该canvas推箱子原型实现
- 微信小程序简介、发展史、小程序的优点、申请账号、开发工具、初识wxml文件和wxss文件
- java方法中与参数怎么调用_与Java方法调用中的类型参数有关的问题
- C 语言实例 - 计算 int, float, double 和 char 字节大小
- java类静态初始化_Java静态代码块和类初始化、实例初始化过程
- J2EE Architecture(6)
- [POJ3580]SuperMemo
- 6.Linux性能诊断 --- 远程通信gRPC,kafka,docker
- spring:注解配置AOP
- The process cannot access the file '' because it is being used by another process.....
- 苹果电脑系统如果删除驱动
- SwiftUI调用UIKit
- unity——UI拖拽实现拼图
- 润乾报表Api导出word只读
- 企业邮箱给国外发邮件注册哪个好?如何群发邮件?
- 安卓+ios系统--手机端页面自适应手机屏幕大小,禁止手动放大和缩小VUE