看了一些市面上的经典教材,感觉决策树剪枝这一部分讲的都特别晦涩,很不好理解。本文以理论白话+具体案例的形式来讲清楚这个重要知识点,打好决策树这个基础,有助于理解之后我们要讲解的随机森林、gbdt、xgboost、lightgbm等模型。

阅读本文前,可以顺便回顾一下前文:机器学习基础:决策树的可视化

剪枝

如果不对决策树设置任何限制,它可以生成一颗非常庞大的树,决策树的树叶节点所覆盖的训练样本都是“纯”的。这样决策树在训练样本上非常精准,但是在测试集上就没那么好了。

层数越多,叶结点越多,分的越细致,对训练数据分的也越深,越容易过拟合,导致对测试数据预测时反而效果差。要解决这个问题就需要对决策树进行「剪枝」。

剪枝的方案主流的有两种,一种是预剪枝,一种是后剪枝。

所谓的预剪枝,即是在生成树的时候就对树的生长进行限制,防止过度拟合。比如我们可以限制决策树在训练的时候每个节点的数据只有在达到一定数量的情况下才会进行分裂,否则就成为叶子节点保留。或者我们可以限制数据的比例,当节点中某个类别的占比超过阈值的时候,也可以停止生长。

下面我们重点讲后剪枝,因为CART采用的就是用的这个方法。

CART剪枝算法流程

CART树采用的是后剪枝方法,即先从训练集生成一颗完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来泛化性能提升,则将该子树替换为叶结点。

李航老师《统计学习方法》中具体介绍了 CART 剪枝算法的步骤流程。

看起来是不是很复杂?其实核心思想就是对原始的决策时T0,从底部根节点不断向上剪枝,直到根节点为止。在此过程中,就形成了很多子树{T0,T1,...,Tn};然后通过交叉验证法在验证集上对{T0,T1,...,Tn}测试,从中选择最优子树。

怎么度量最优呢?就要先了解一下决策树损失函数:

剪枝前是以 t 为根结点的子树 Tt 的损失函数是:

C(Tt)为训练数据的预测误差,分类树是用基尼系数度量,回归树是均方差度量。|Tt|是子树T的叶子节点的数量。式中唯一的未知变量是正则化参数 α ,其值越大,就意味着剪枝力度越大。当 α 从 0 慢慢增大到 ∞ 时,最优子树会慢慢从最开始的整体树,一点一点剪枝,直到变成单结点树。对于固定的 α,一定存在损失函数Cα(T)最小的子树,我们称之为最优子树,记为 Tα 。

两种剪枝策略对比

后剪枝决策树通常比预剪枝决策树保留了更多的分支;

后剪枝决策树的欠拟合风险很小,泛化性能往往优于预剪枝决策树;

后剪枝决策树训练时间开销比未剪枝决策树和预剪枝决策树都要大的多。其实,只需掌握后剪枝就行了。

CART决策树剪枝(参数解读)

sklearn.tree.DecisionTreeClassifier (criterion=’gini’, splitter=’best’, max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort=False)

max_depth:限制树的最大深度

决策树多生长一层,对样本量的需求会增加一倍,所以限制树深度能够有效地限制过拟合。在高维度低样本量时非常有效;建议从=3开始尝试。

min_samples_leaf:一个节点在分枝后,每个子节点都必须至少包含的训练样本数量

一个节点在分枝后,每个子节点都必须包含至少min_samples_leaf个训练样本,两种取值:(1)整数 (2)浮点型:如果叶节点中含有的样本量变化很大,输入浮点数表示样本量的百分比。如果分支后的子节点不满足参数条件,分枝就不会发生,或者,分枝会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生。

这个参数可以保证每个叶子的最小尺寸,在回归问题中避免低方差,过拟合的叶子节点出现。搭配max_depth使用,在回归树中可以让模型变得更加平滑;建议从=5开始;对于类别不多的分类问题,=1通常就是最佳选择。

min_samples_split:一个节点必须要至少包含的训练样本数量

如果小于这个数量,这个节点才允许被分枝,否则分枝就不会发生。

max_features:分枝时考虑的最大特征个数

即在分支时,超过限制个数的特征都会被舍弃。但是在不知道决策树中的各个特征的重要性的情况下,强行设定这个参数可能会导致模型学习不足。

min_impurity_decrease:子父节点信息增益的最小值

信息增益是父节点的信息熵与子节点信息熵之差,信息增益越大,说明这个分支对模型的贡献越大;相反的,如果信息增益非常小,则说明该分支对模型的建立贡献不大。又由于分支需要的计算量又非常大,所以如果信息增益非常小时,我们就选择放弃该分支。

以上便是剪枝常用到的参数了。

实例

如果不对决策树设置任何限制,生成结果如下:

每个叶子结点gini指数都等于 0 。

iris = load_iris()

clf = tree.DecisionTreeClassifier(random_state=66,min_samples_leaf=15)

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,

feature_names=iris.feature_names,

class_names=iris.target_names,

filled=True, rounded=True,

special_characters=True)

graph = pydotplus.graph_from_dot_data(dot_data)

Image(graph.create_png())

设置叶子节点最少样本数min_samples_leaf=15,这个值限制了叶子节点最少的样本数,如果某叶子节点数目小于样本数,则会和兄弟节点一起被剪枝。

python决策树剪枝_机器学习基础:可视化方式理解决策树剪枝相关推荐

  1. Python深度学习之机器学习基础

    Python深度学习之机器学习基础 一.前言 本文记录 弗朗索瓦·肖莱的<Python深度学习>第四章 机器学习基础有关笔记. 二.笔记 2.1机器学习的四个分支 监督学习 序列生成(se ...

  2. python决策树预测模型_机器学习:决策树(Decision Tree)

    决策树(decision tree)是一种基本的分类与回归方法.在分类问题中,它可以认为是if-then规则的集合,也可以认为是定义在特征空间与类空间上的条件概率分布.在学习时,利用训练数据,根据损失 ...

  3. 第二章_机器学习基础

    文章目录 第二章 机器学习基础 2.1 各种常见算法图示 2.2 监督学习.非监督学习.半监督学习.弱监督学习? 2.3 监督学习有哪些步骤 2.4 多实例学习? 2.5 分类网络和回归的区别? 2. ...

  4. cart算法_机器学习十大算法之一——决策树CART算法

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第23篇文章,我们今天分享的内容是十大数据挖掘算法之一的CART算法. CART算法全称是Classification ...

  5. 机器学习学习吴恩达逻辑回归_机器学习基础:逻辑回归

    机器学习学习吴恩达逻辑回归 In the previous stories, I had given an explanation of the program for implementation ...

  6. 机器学习:分类_机器学习基础:K最近邻居分类

    机器学习:分类 In the previous stories, I had given an explanation of the program for implementation of var ...

  7. python bokeh教程_交互式数据可视化在Python中用Bokeh实现

    Bokeh是一个专门针对Web浏览器的呈现功能的交互式可视化Python库.这是Bokeh与其它可视化库最核心的区别.正如下图所示,它说明了Bokeh如何将数据展示到一个Web浏览器上的流程. 正如你 ...

  8. airbnb机器学习模型_机器学习基础:预测Airbnb价格

    airbnb机器学习模型 Machine learning is easily one of the biggest buzzwords in tech right now. Over the pas ...

  9. 机器学习朴素贝叶斯_机器学习基础朴素贝叶斯分类

    机器学习朴素贝叶斯 In the previous stories, I had given an explanation of the program for implementation of v ...

最新文章

  1. 205页PPT,看5G+AI引领的下一个时代!
  2. 当程序员产崽后...
  3. 存clob为空的值_给Oracle数据库中CLOB字段插入空值
  4. HDU 4930 Fighting the Landlords(扯淡模拟题)
  5. 机器学习从入门到精通50讲(九)-基于 ANTLR 自己实现一个 SQL 解析器
  6. mysql2教程_mySQL 教程 第2章 安装和介绍mySQL
  7. 《Scikit-Learn与TensorFlow机器学习实用指南》第9章 启动并运行TensorFlow
  8. 计算机统考第五次作业操作题,计算机基础第5次作业 第五章 Powerpoint知识题
  9. CCF2015-12-2 消除类游戏
  10. Javascript第三章循环最后一种方法for..in与for区别第二课
  11. mysql有热备吗_mysql备份方法(热备)
  12. css子元素选择父元素的实现
  13. 海关179号出口清单报文CEB603Message描述规范
  14. Android7(N)中webview导致应用内语言切换失效
  15. 浅谈机器人控制与仿真设计----RDS和ROS
  16. 青龙面板关闭青龙二级验证
  17. IT企业面试常见逻辑推理题智力题及详解答案(二)
  18. medusa破解ssh
  19. “XXX程序包不存在”解决方法
  20. 科大讯飞踩过的“坑”,还有多少AI企业要踩?

热门文章

  1. python直接获得文件夹下子目录的文件名
  2. 库存优化中如何判断哪些SKU的库存水位需要改善
  3. ubuntu 18.04 root登录
  4. SparkStreaming使用SQL
  5. 简单理解光会产生折射的原因及折射定律的推导
  6. 农村电子商务的发展思路ppt
  7. matlab难在哪,心理学实验范式?matlab搞不定?那别的不用试。
  8. 蓝桥杯2013JAVA_B省赛真题详解
  9. Oracle ebs 常用标准表
  10. 【python爬虫】猫眼电影TOP100电影封面下载