决策树剪枝python实现_决策树剪枝问题python代码
决策树在生长过程中有可能长得过于茂盛,对训练集学习的很好,但对新的数据集的预测效果不好,即过拟合,此时生成的模型泛化能力较差。因此,我们需要对决策树进行剪枝,使得生成的模型具有较强的泛化能力。
为了检测剪枝前后模型的分类效果是否提升,我们需要将全部数据集划分为三个子集:训练集、验证集(剪枝集)、测试集。
训练集:用于生成决策树。
验证集:用于检验剪枝前后模型分类效果变动情况。
测试集:当模型完全确定后,用于检验模型在新样本数据上的分类效果。
决策树剪枝的主要方法包括两大类:后剪枝、预剪枝。
1.后剪枝
指先允许决策树自由生长,然后对其进行修剪。主要剪枝思想是将“子树”替换为叶子节点,即取消某个内部节点将其变为叶子节点,观察模型的分类效果是否有提升。
(1)后剪枝的优点:
后剪枝比预剪枝会保留更多的节点,不容易出现“欠拟合”问题。
(2)后剪枝的缺点:
因为是先生成树,再对树进行剪枝,因此,时间开销比预剪枝大。
(3)常用的是,先构建一棵决策树,然后根据特征重要性,筛选出重要的特征,再构建新的决策树。
具体操作,参见我的另一篇文章:Cara:随机森林Python实战zhuanlan.zhihu.com
(4)一些后剪枝算法包括:REP(Reduced-Error Pruning)、PEP(Pesimistic-Error Pruning)、CCP(cost complexity pruning)。
①REP(Reduced-Error Pruning):自下而上剪枝,即将叶子节点合并到上位节点中去。
现有一个生成好的决策树如下,共1~10个节点。
Step 1: 尝试将叶子节点4删掉,则叶节点8、9、10合并为一个节点(命名为节点11),节点11的类别我们用节点8、9、10包含的样本中数量最多的那个类来确定(大多数原则来确定),然后,节点11与节点5一起成为节点2的叶子节点。然后,测试剪枝后的决策树在验证集上的表现,若分类效果(可以用分类的准确率度量)更好或者效果没有变差,则将节点4删掉(如下图),若表现不好则保留原树的形状。剪枝后效果如下:
STEP2:同理,尝试将节点3删掉,则叶子节点6和7会成为合并为一个节点(命名为节点12),节点12会成为节点1的叶子节点,然后,测试剪枝后的决策树在验证集上的表现。若分类效果更好或者效果没有变差,则将节点3删掉。
以此类推,形成最简决策树。
评价:REP是最简单的后剪枝方法之一,不过由于需要使用验证集,如果本身数据量较少,剪枝效果也不会很好。
②PEP(Pesimistic Error Pruning):没看懂,不敢说话……
③CCP(Cost Complexity Pruning):原理没看很明白……下面put一个github上比较流行的实操方法。
CCP思想已经被写入python的sklearn库下的tree模块的DecisionTreeClassifier类,在对类进行实例化时可以根据如下提示进行相关设置,重点中的重点在于如何选择一个较优的ccp_alpha值。
小例子:使用sklearn库中自带的数据集breast_cancer进行决策树剪枝演示。
#导入工程
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
import matplotlib as plt
from sklearn.metrics import accuracy_score
#数据拆分为特征集和标签列
X,y = load_breast_cancer(return_X_y=True)
#数据拆分为训练集和测试集
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=8)
#实例化决策树,参数全部用默认值
clf = DecisionTreeClassifier(random_state=0)#默认是基于Gini值生成决策树
#拟合样本数据
clf.fit(X_train,y_train)
#对测试集进行预测
pred = clf.predict(X_test)
#对预测结果进行打分
print(accuracy_score(y_test,pred))
#结果:0.9090909090909091
#绘制出生成的决策树
plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
# 默认的像素:[6.0,4.0],分辨率为100,图片尺寸为 600*400
from sklearn import tree
tree.plot_tree(clf,filled=True)
不剪枝时,采用默认参数生成的决策树,共8层,因此,树的深度为8:
#以下代码主要目的是找到较优的ccp_alpha,然后就可以在DecisionTreeClassifier类实例化的时候设置ccp_alpha参数为找到的ccp_alpha值。
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
print(ccp_alphas)
'''
[0. 0.00232752 0.00375587 0.00375587 0.00391236 0.00438185
0.00629975 0.00730308 0.00946792 0.01198161 0.01788007 0.04055896
0.31844825]
'''
clfs = []
for ccp_alpha in ccp_alphas:
clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
clf.fit(X_train, y_train)
clfs.append(clf)
print("Number of nodes in the last tree is: {} with ccp_alpha: {}".format(
clfs[-1].tree_.node_count, ccp_alphas[-1]))
#返回:Number of nodes in the last tree is: 1 with ccp_alpha: 0.3184482456829032
#绘制不同ccp_alpha取值下,clf在训练样本和测试样本上的精确度
train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]
from matplotlib import pyplot
plt.rcParams['savefig.dpi'] = 80 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
# 默认的像素:[6.0,4.0],分辨率为100,图片尺寸为 600*400
fig, ax = pyplot.subplots()
ax.set_xlabel("alpha")
ax.set_ylabel("accuracy")
ax.set_title("Accuracy vs alpha for training and testing sets")
ax.plot(ccp_alphas, train_scores, marker='o', label="train",
drawstyle="steps-post")
ax.plot(ccp_alphas, test_scores, marker='o', label="test",
drawstyle="steps-post")
ax.legend()
pyplot.show()
不同ccp_alpha取值,对应的clf的测试集正确率和训练集正确率如同所示:
根据该图可以发现,当alpha介于0.05~0.3之间时,对应的决策树分类器在测试集的正确率和训练集的正确率均比较高,正确率稳定在0.90~0.95之间。
#重新生成决策树,根据ccp_alpha进行剪枝
clf2 = DecisionTreeClassifier(random_state=0, ccp_alpha=0.006)
clf2.fit(X_train,y_train)
pred2=clf2.predict(X_test)
print(accuracy_score(y_test, pred2))
#结果:0.916083916083916
#绘制剪枝后的决策树
tree.plot_tree(clf2,filled=True)
剪枝后的树如下,共6层:
2.预剪枝
指在决策树生成之前,就设置一些参数来限制决策树的生长,如设定树的最大深度、叶子节点最小数量等。
在python中,sklearn库下的tree模块,可以加载DecisionTreeClassifier类,该类在实例化时可以设置以下参数进行预剪枝:
预剪枝由于过多的人为设定,可能导致生成的模型存在“欠拟合”问题,即生成的模型对数据的反映不够,模型对训练集的预测效果不佳,对测试集的预测效果也不佳。
(1)设定树的最大深度
该方法会在起初设置一个计数器,这里命名为counter,并设置初始值为0。在决策树生长过程中,会重点关注两个问题:①数据集是否pure了?即y标签是否只有1个类别。②决策树是否达到设定的最大深度。
如果数据集pure,则将数据集设为叶子节点。
如果数据集不pure,就会接着问树是否达到最大深度。
如果数据集不pure且树没有达到最大深度,就会从多个属性中根据某种指标(信息增益、Gini指标)选择出最佳划分节点,并对counter加1。然后再判断划分后的子集分别是否pure,以此类推。
如果数据集不pure,但是树达到最大深度,这时,就会强制将数据集设为叶子节点,代价就是生成的树存在“欠拟合”问题。
(2)设定叶子节点最小数量
参考:https://www.youtube.com/watch?v=SLOyyFHbiqowww.youtube.com【机器学习】算法原理详细推导与实现(七):决策树算法 - TTyb - 博客园www.cnblogs.com
决策树剪枝python实现_决策树剪枝问题python代码相关推荐
- 机器学习中决策树的随机森林_决策树和随机森林在机器学习中的使用
机器学习中决策树的随机森林 机器学习 (Machine Learning) Machine learning is an application of artificial intelligence ...
- pythoncookbook和流畅的python对比_为什么你学Python效率比别人慢?因为你没有这套完整的学习资料...
以下资源免费获取方式! 关注!转发!私信"资料"即可免费领取! 入门书籍 1.<Python基础教程>(Beginning Python From Novice to ...
- 零基础学python 视频_全网最全Python视频教程真正零基础学习Python视频教程 490集...
Python Web开发-进阶提升 490集超强Python视频教程 真正零基础学习Python视频教程 [课程简介] 这是一门Python Web开发进阶课程,手把手教你用Python开发完整的商业 ...
- 3 x 10的python表达式_这道数学题用PYTHON编程语言怎么写? 编程语言python是用
我觉着,这个应该这样解决比较符合计算机解题思路. 下面的回答的,思考的东西太多. # -*- coding: utf-8 -*- __author__ = 'lpe234' __date__ = '2 ...
- 为什么要学python语言_我们为什么要学习Python语言?
原标题:我们为什么要学习Python语言? 聊到我们为什么要学习Python语言?小编不禁又想起大佬潘石屹准备开启Python学习旅程时所发布的微博. 我们为什么要学习Python语言? 在农业社会时 ...
- 下载python步骤_下载及安装Python详细步骤
安装python分三个步骤: *下载python *安装python *检查是否安装成功 1.下载python (1)python下载地址 (2)选择下载的版本 (3)点开download后,找到下载 ...
- ubuntu更改默认python版本_更改Ubuntu默认python版本的方法
1.查看基本信息 # 列出所有已安装python ls /usr/bin/python* #查看默认的 Python 版本信息: python --version 2.基于用户修改 默认Python ...
- python编辑器_推荐一款Python编辑器,集Pycharm和Sublime优点于一身的王者
编程里面的编辑器就像是武林大会里面的高手,每一年都有新秀,黑马出现!比如有练习霸道的天罡之气的榜首Pycharm,力量雄厚霸道战斗力极强,但是对斗气消耗很大,占内存大而且启动速度有点慢!还有练习灵巧的 ...
- 人工智能只能用python吗_为什么人工智能用Python?
主要原因: 1.人工智能适应Python的编程语言. 2.人工智能需要利用Python的高层语言,实现可移植性.面向对象.可扩展性.可嵌入型等功能,来实现人机交流. Python:是一种面向对象的解释 ...
最新文章
- (二)Cacti监控
- PL/SQL developer连接oracle出现“ORA-12154:TNS:could not resolve the connect identifier specified”问题的解决
- JavaScript中的数组循环方法
- ubuntu 16.04 安装教程
- restful规范和APIView
- 【招聘(上海)】美团酒店招聘 .NET 高级开发
- centos7搭建easy-mock服务
- 收藏!目标检测优质综述论文总结!
- Hive数据导出的三种方式
- ubuntu服务器+apache2绑定域名(以腾讯云域名为例)
- paip.python错误解决5
- python模拟ssh登录
- QQ拼音输入法词库和搜狗输入法词库[相互导入](使用Excel公式)
- Blender程序化建模教程【Python】
- torch.flatten
- 【Android开发】Android基本UI组件
- 【Python 字符视频】Python 实现将抖音视频转换成字符视频
- SQL 一条SQL语句 统计 各班总人数、男女各总人数 、该班级男女 比例
- ams1117-3.3v三端稳压芯片低压差线性稳压器
- flask部署阿里云服务器,公网ip访问不了(一些问题及解答)
热门文章
- 玉米社:SEM竞价推广转化成本高?做好细节转化率蹭蹭往上涨
- 在godot的canvas_item着色器中构建逆投影矩阵和逆视图矩阵
- 教你百度网盘文件转阿里云
- EXCEL如何批量调整图片大小?
- Node.js 实现登录校验 + 选项卡(改进版)
- ckplayer html5 添加广告,ewebeditor下利用ckplayer增加html5 (mp4)全平台的支持
- 央视《每周质量报告》:揭秘假宽带真相
- surface pro win10 重装系统并解决屏幕亮度闪烁和降频的问题
- Linux 启动项管理
- html css做一个简历表,HTML table制做我的简历