决策树在生长过程中有可能长得过于茂盛,对训练集学习的很好,但对新的数据集的预测效果不好,即过拟合,此时生成的模型泛化能力较差。因此,我们需要对决策树进行剪枝,使得生成的模型具有较强的泛化能力。

为了检测剪枝前后模型的分类效果是否提升,我们需要将全部数据集划分为三个子集:训练集、验证集(剪枝集)、测试集。

训练集:用于生成决策树。

验证集:用于检验剪枝前后模型分类效果变动情况。

测试集:当模型完全确定后,用于检验模型在新样本数据上的分类效果。

决策树剪枝的主要方法包括两大类:后剪枝、预剪枝。

1.后剪枝

指先允许决策树自由生长,然后对其进行修剪。主要剪枝思想是将“子树”替换为叶子节点,即取消某个内部节点将其变为叶子节点,观察模型的分类效果是否有提升。

(1)后剪枝的优点:

后剪枝比预剪枝会保留更多的节点,不容易出现“欠拟合”问题。

(2)后剪枝的缺点:

因为是先生成树,再对树进行剪枝,因此,时间开销比预剪枝大。

(3)常用的是,先构建一棵决策树,然后根据特征重要性,筛选出重要的特征,再构建新的决策树。

具体操作,参见我的另一篇文章:Cara:随机森林Python实战​zhuanlan.zhihu.com

(4)一些后剪枝算法包括:REP(Reduced-Error Pruning)、PEP(Pesimistic-Error Pruning)、CCP(cost complexity pruning)。

①REP(Reduced-Error Pruning):自下而上剪枝,即将叶子节点合并到上位节点中去。

现有一个生成好的决策树如下,共1~10个节点。

Step 1: 尝试将叶子节点4删掉,则叶节点8、9、10合并为一个节点(命名为节点11),节点11的类别我们用节点8、9、10包含的样本中数量最多的那个类来确定(大多数原则来确定),然后,节点11与节点5一起成为节点2的叶子节点。然后,测试剪枝后的决策树在验证集上的表现,若分类效果(可以用分类的准确率度量)更好或者效果没有变差,则将节点4删掉(如下图),若表现不好则保留原树的形状。剪枝后效果如下:

STEP2:同理,尝试将节点3删掉,则叶子节点6和7会成为合并为一个节点(命名为节点12),节点12会成为节点1的叶子节点,然后,测试剪枝后的决策树在验证集上的表现。若分类效果更好或者效果没有变差,则将节点3删掉。

以此类推,形成最简决策树。

评价:REP是最简单的后剪枝方法之一,不过由于需要使用验证集,如果本身数据量较少,剪枝效果也不会很好。

②PEP(Pesimistic Error Pruning):没看懂,不敢说话……

③CCP(Cost Complexity Pruning):原理没看很明白……下面put一个github上比较流行的实操方法。

CCP思想已经被写入python的sklearn库下的tree模块的DecisionTreeClassifier类,在对类进行实例化时可以根据如下提示进行相关设置,重点中的重点在于如何选择一个较优的ccp_alpha值。

小例子:使用sklearn库中自带的数据集breast_cancer进行决策树剪枝演示。

#导入工程

from sklearn.tree import DecisionTreeClassifier

from sklearn.model_selection import train_test_split

from sklearn.datasets import load_breast_cancer

import matplotlib as plt

from sklearn.metrics import accuracy_score

#数据拆分为特征集和标签列

X,y = load_breast_cancer(return_X_y=True)

#数据拆分为训练集和测试集

X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=8)

#实例化决策树,参数全部用默认值

clf = DecisionTreeClassifier(random_state=0)#默认是基于Gini值生成决策树

#拟合样本数据

clf.fit(X_train,y_train)

#对测试集进行预测

pred = clf.predict(X_test)

#对预测结果进行打分

print(accuracy_score(y_test,pred))

#结果:0.9090909090909091

#绘制出生成的决策树

plt.rcParams['savefig.dpi'] = 200 #图片像素

plt.rcParams['figure.dpi'] = 200 #分辨率

# 默认的像素:[6.0,4.0],分辨率为100,图片尺寸为 600*400

from sklearn import tree

tree.plot_tree(clf,filled=True)

不剪枝时,采用默认参数生成的决策树,共8层,因此,树的深度为8:

#以下代码主要目的是找到较优的ccp_alpha,然后就可以在DecisionTreeClassifier类实例化的时候设置ccp_alpha参数为找到的ccp_alpha值。

path = clf.cost_complexity_pruning_path(X_train, y_train)

ccp_alphas, impurities = path.ccp_alphas, path.impurities

print(ccp_alphas)

'''

[0. 0.00232752 0.00375587 0.00375587 0.00391236 0.00438185

0.00629975 0.00730308 0.00946792 0.01198161 0.01788007 0.04055896

0.31844825]

'''

clfs = []

for ccp_alpha in ccp_alphas:

clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)

clf.fit(X_train, y_train)

clfs.append(clf)

print("Number of nodes in the last tree is: {} with ccp_alpha: {}".format(

clfs[-1].tree_.node_count, ccp_alphas[-1]))

#返回:Number of nodes in the last tree is: 1 with ccp_alpha: 0.3184482456829032

#绘制不同ccp_alpha取值下,clf在训练样本和测试样本上的精确度

train_scores = [clf.score(X_train, y_train) for clf in clfs]

test_scores = [clf.score(X_test, y_test) for clf in clfs]

from matplotlib import pyplot

plt.rcParams['savefig.dpi'] = 80 #图片像素

plt.rcParams['figure.dpi'] = 200 #分辨率

# 默认的像素:[6.0,4.0],分辨率为100,图片尺寸为 600*400

fig, ax = pyplot.subplots()

ax.set_xlabel("alpha")

ax.set_ylabel("accuracy")

ax.set_title("Accuracy vs alpha for training and testing sets")

ax.plot(ccp_alphas, train_scores, marker='o', label="train",

drawstyle="steps-post")

ax.plot(ccp_alphas, test_scores, marker='o', label="test",

drawstyle="steps-post")

ax.legend()

pyplot.show()

不同ccp_alpha取值,对应的clf的测试集正确率和训练集正确率如同所示:

根据该图可以发现,当alpha介于0.05~0.3之间时,对应的决策树分类器在测试集的正确率和训练集的正确率均比较高,正确率稳定在0.90~0.95之间。

#重新生成决策树,根据ccp_alpha进行剪枝

clf2 = DecisionTreeClassifier(random_state=0, ccp_alpha=0.006)

clf2.fit(X_train,y_train)

pred2=clf2.predict(X_test)

print(accuracy_score(y_test, pred2))

#结果:0.916083916083916

#绘制剪枝后的决策树

tree.plot_tree(clf2,filled=True)

剪枝后的树如下,共6层:

2.预剪枝

指在决策树生成之前,就设置一些参数来限制决策树的生长,如设定树的最大深度、叶子节点最小数量等。

在python中,sklearn库下的tree模块,可以加载DecisionTreeClassifier类,该类在实例化时可以设置以下参数进行预剪枝:

预剪枝由于过多的人为设定,可能导致生成的模型存在“欠拟合”问题,即生成的模型对数据的反映不够,模型对训练集的预测效果不佳,对测试集的预测效果也不佳。

(1)设定树的最大深度

该方法会在起初设置一个计数器,这里命名为counter,并设置初始值为0。在决策树生长过程中,会重点关注两个问题:①数据集是否pure了?即y标签是否只有1个类别。②决策树是否达到设定的最大深度。

如果数据集pure,则将数据集设为叶子节点。

如果数据集不pure,就会接着问树是否达到最大深度。

如果数据集不pure且树没有达到最大深度,就会从多个属性中根据某种指标(信息增益、Gini指标)选择出最佳划分节点,并对counter加1。然后再判断划分后的子集分别是否pure,以此类推。

如果数据集不pure,但是树达到最大深度,这时,就会强制将数据集设为叶子节点,代价就是生成的树存在“欠拟合”问题。

(2)设定叶子节点最小数量

参考:https://www.youtube.com/watch?v=SLOyyFHbiqo​www.youtube.com【机器学习】算法原理详细推导与实现(七):决策树算法 - TTyb - 博客园​www.cnblogs.com

决策树剪枝python实现_决策树剪枝问题python代码相关推荐

  1. 机器学习中决策树的随机森林_决策树和随机森林在机器学习中的使用

    机器学习中决策树的随机森林 机器学习 (Machine Learning) Machine learning is an application of artificial intelligence ...

  2. pythoncookbook和流畅的python对比_为什么你学Python效率比别人慢?因为你没有这套完整的学习资料...

    以下资源免费获取方式! 关注!转发!私信"资料"即可免费领取! 入门书籍 1.<Python基础教程>(Beginning Python From Novice to ...

  3. 零基础学python 视频_全网最全Python视频教程真正零基础学习Python视频教程 490集...

    Python Web开发-进阶提升 490集超强Python视频教程 真正零基础学习Python视频教程 [课程简介] 这是一门Python Web开发进阶课程,手把手教你用Python开发完整的商业 ...

  4. 3 x 10的python表达式_这道数学题用PYTHON编程语言怎么写? 编程语言python是用

    我觉着,这个应该这样解决比较符合计算机解题思路. 下面的回答的,思考的东西太多. # -*- coding: utf-8 -*- __author__ = 'lpe234' __date__ = '2 ...

  5. 为什么要学python语言_我们为什么要学习Python语言?

    原标题:我们为什么要学习Python语言? 聊到我们为什么要学习Python语言?小编不禁又想起大佬潘石屹准备开启Python学习旅程时所发布的微博. 我们为什么要学习Python语言? 在农业社会时 ...

  6. 下载python步骤_下载及安装Python详细步骤

    安装python分三个步骤: *下载python *安装python *检查是否安装成功 1.下载python (1)python下载地址 (2)选择下载的版本 (3)点开download后,找到下载 ...

  7. ubuntu更改默认python版本_更改Ubuntu默认python版本的方法

    1.查看基本信息 # 列出所有已安装python ls /usr/bin/python* #查看默认的 Python 版本信息: python --version 2.基于用户修改 默认Python ...

  8. python编辑器_推荐一款Python编辑器,集Pycharm和Sublime优点于一身的王者

    编程里面的编辑器就像是武林大会里面的高手,每一年都有新秀,黑马出现!比如有练习霸道的天罡之气的榜首Pycharm,力量雄厚霸道战斗力极强,但是对斗气消耗很大,占内存大而且启动速度有点慢!还有练习灵巧的 ...

  9. 人工智能只能用python吗_为什么人工智能用Python?

    主要原因: 1.人工智能适应Python的编程语言. 2.人工智能需要利用Python的高层语言,实现可移植性.面向对象.可扩展性.可嵌入型等功能,来实现人机交流. Python:是一种面向对象的解释 ...

最新文章

  1. (二)Cacti监控
  2. PL/SQL developer连接oracle出现“ORA-12154:TNS:could not resolve the connect identifier specified”问题的解决
  3. JavaScript中的数组循环方法
  4. ubuntu 16.04 安装教程
  5. restful规范和APIView
  6. 【招聘(上海)】美团酒店招聘 .NET 高级开发
  7. centos7搭建easy-mock服务
  8. 收藏!目标检测优质综述论文总结!
  9. Hive数据导出的三种方式
  10. ubuntu服务器+apache2绑定域名(以腾讯云域名为例)
  11. paip.python错误解决5
  12. python模拟ssh登录
  13. QQ拼音输入法词库和搜狗输入法词库[相互导入](使用Excel公式)
  14. Blender程序化建模教程【Python】
  15. torch.flatten
  16. 【Android开发】Android基本UI组件
  17. 【Python 字符视频】Python 实现将抖音视频转换成字符视频
  18. SQL 一条SQL语句 统计 各班总人数、男女各总人数 、该班级男女 比例
  19. ams1117-3.3v三端稳压芯片低压差线性稳压器
  20. flask部署阿里云服务器,公网ip访问不了(一些问题及解答)

热门文章

  1. 玉米社:SEM竞价推广转化成本高?做好细节转化率蹭蹭往上涨
  2. 在godot的canvas_item着色器中构建逆投影矩阵和逆视图矩阵
  3. 教你百度网盘文件转阿里云
  4. EXCEL如何批量调整图片大小?
  5. Node.js 实现登录校验 + 选项卡(改进版)
  6. ckplayer html5 添加广告,ewebeditor下利用ckplayer增加html5 (mp4)全平台的支持
  7. 央视《每周质量报告》:揭秘假宽带真相
  8. surface pro win10 重装系统并解决屏幕亮度闪烁和降频的问题
  9. Linux 启动项管理
  10. html css做一个简历表,HTML table制做我的简历