本文转载于“ 一个拉风的名字”的“Regression Tree 回归树”

1. 引言

AI时代,机器学习算法成为了研究、应用的热点。当前,最火的两类算法莫过于神经网络算法(CNN、RNN、LSTM等)与树形算法(随机森林、GBDT、XGBoost等),树形算法的基础就是决策树。决策树因其易理解、易构建、速度快的特性,被广泛应用于统计学、数据挖掘、机器学习领域。因此,对决策树的学习,是机器学习之路必不可少的一步。

根据处理数据类型的不同,决策树又分为两类:分类决策树与回归决策树,前者可用于处理离散型数据,后者可用于处理连续型数据,下面的英文引用自维基百科。

Classification tree analysis is when the predicted outcome is the class to which the data belongs.
Regression tree analysis is when the predicted outcome can be considered a real number (e.g. the price of a house, or a patient’s length of stay in a hospital).

网络上有关于分类决策树的介绍可谓数不胜数,但是对回归决策树(回归树)的介绍却少之又少。李航教授的统计学习方法 对回归树有一个简单介绍,可惜篇幅较短,没有给出一个具体实例;Google搜索回归树,有一篇介绍回归树的博客(点击),该博客所举的实例有误,其过程事实上是基于残差的GBDT。

基于以上原因,本文简单介绍了回归树(Regression Tree),简单描述了CART算法,给出了回归树的算法描述,辅以简单实例以加深理解,最后是总结部分。

2. 回归树

决策树实际上是将空间用超平面进行划分的一种方法,每次分割的时候,都将当前的空间一分为二, 这样使得每一个叶子节点都是在空间中的一个不相交的区域,在进行决策的时候,会根据输入样本每一维feature的值,一步一步往下,最后使得样本落入N个区域中的一个(假设有N个叶子节点),如下图所示。


三种比较常见的分类决策树分支划分方式包括:ID3, C4.5, CART。

分类与回归树(classificationandregressiontree, CART)模型由Breiman等人在1984年提出,是应用广泛的决策树学习方法。CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归。下面的英文引用自维基百科

The term Classification And Regression Tree (CART) analysis is an umbrella term used to refer to both of the above procedures, first introduced by Breiman et al. Trees used for regression and trees used for classification have some similarities - but also some differences, such as the procedure used to determine where to split.

下面介绍回归树。

2.1 原理概述

既然是决策树,那么必然会存在以下两个核心问题:如何选择划分点?如何决定叶节点的输出值?

一个回归树对应着输入空间(即特征空间)的一个划分以及在划分单元上的输出值。分类树中,我们采用信息论中的方法,通过计算选择最佳划分点。而在回归树中,采用的是启发式的方法。假如我们有n个特征,每个特征有
si(i∈(1,n))s_i(i \in (1,n))si​(i∈(1,n)) 个取值,那我们遍历所有特征,尝试该特征所有取值,对空间进行划分,直到取到特征j的取值s,使得损失函数最小,这样就得到了一个划分点。描述该过程的公式如下:(如果看不到图请点击永久地址)

假设将输入空间划分为M个单元:R1,R2,...,RmR_1,R_2,...,R_mR1​,R2​,...,Rm​那么每个区域的输出值就是:cm=ave(yi∣xi∈Rm)c_m=ave(y_i|x_i \in R_m)cm​=ave(yi​∣xi​∈Rm​)也就是该区域内所有点y值的平均数。

举个例子。如下图所示,假如我们想要对楼内居民的年龄进行回归,将楼划分为3个区域R1,R2,R3R_1, R_2, R_3R1​,R2​,R3​(红线),那么R1R_1R1​的输出就是第一列四个居民年龄的平均值,R2R_2R2​的输出就是第二列四个居民年龄的平均值,R3R_3R3​的输出就是第三、四列八个居民年龄的平均值。

2.2 算法描述

截图来自李航教授的统计学习方法

2.3 一个简单实例

为了便于理解,下面举一个简单实例。训练数据见下表,目标是得到一棵最小二乘回归树。

x 1 2 3 4 5 6 7 8 9 10
y 5.56 5.7 5.91 6.4 6.8 7.05 8.9 8.7 9 9.05

1.选择最优切分变量j与最优切分点s

在本数据集中,只有一个变量,因此最优切分变量自然是x。

接下来我们考虑9个切分点[1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5][1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5][1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5],[1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5] 你可能会问,为什么会带小数点呢?类比于篮球比赛的博彩,倘若两队比分是96:95,而盘口是“让1分 A队胜B队”,那A队让1分之后,到底是A队赢还是B队赢了?所以我们经常可以看到“让0.5分 A队胜B队”这样的盘口。在这个实例中,也是这个道理。

损失函数定义为平方损失函数Loss(y,f(x))=(f(x)−y)2Loss(y, f(x))=(f(x)-y)^2Loss(y,f(x))=(f(x)−y)2,将上述9个切分点一依此代入下面的公式,其中 cm=ave(yi∣xi∈Rm)c_m=ave(y_i|x_i \in R_m)cm​=ave(yi​∣xi​∈Rm​)(如果看不到图请点击永久地址)

例如,取 s=1.5s=1.5s=1.5。此时R1={1},R2={2,3,4,5,6,7,8,9,10}R_1=\{1\} , R_2=\{2,3,4,5,6,7,8,9,10\}R1​={1},R2​={2,3,4,5,6,7,8,9,10},这两个区域的输出值分别为:c1=5.56,c2=19(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)=7.50c_1=5.56, c_2= \frac{1}{9}(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)=7.50c1​=5.56,c2​=91​(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)=7.50,得到下表:

s 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5
c1 5.56 5.63 5.72 5.89 6.07 6.24 6.62 6.88 7.11
c2 7.5 7.73 7.99 8.25 8.5 4 8.91 8.92 9.03 9.05

把c1c_1c1​,c2c_2c2​的值代入到上式,如:m(1.5)=0+15.72=15.72m(1.5)=0+15.72=15.72m(1.5)=0+15.72=15.72。同理,可获得下表:

s 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5
m(s) 15.72 12.07 8.36 5.78 3.91 1.93 8.01 11.73 15.74

显然取 s=6.5s=6.5s=6.5时,m(s)m(s)m(s)最小。因此,第一个划分变量j=xj=xj=x,s=6.5s=6.5s=6.5

  1. 用选定的(j,s)划分区域,并决定输出值
    两个区域分别是:R1={1,2,3,4,5,6}R_1=\{1,2,3,4,5,6\}R1​={1,2,3,4,5,6}, R2={7,8,9,10}R_2=\{7,8,9,10\}R2​={7,8,9,10},输出值cm=ave(yi∣xi∈Rm),c1=6.24,c2=8.91c_m=ave(y_i|x_i \in R_m),c_1=6.24,c_2=8.91cm​=ave(yi​∣xi​∈Rm​),c1​=6.24,c2​=8.91
  2. 对两个子区域继续调用步骤1、步骤2
    对R1R_1R1​继续进行划分:
x 1 2 3 4 5 6
y 5.56 5.7 5.91 6.4 6.8 7.05

取切分点[1.5,2.5,3.5,4.5,5.5][1.5,2.5,3.5,4.5,5.5][1.5,2.5,3.5,4.5,5.5],则各区域的输出值ccc如下表

s 1.5 2.5 3.5 4.5 5.5
c1 5.56 5.63 5.72 5.89 6.07
c2 6.37 6.54 6.75 6.93 7.05

计算m(s):

s 1.5 2.5 3.5 4.5 5.5
m(s) 1.3087 0.754 0.2771 0.4368 1.0644

s=3.5时m(s)最小。
之后的过程不再赘述。

  1. 生成回归树
    假设在生成3个区域之后停止划分,那么最终生成的回归树形式如下:
    T={5.72x≤3.56.753.5⩽x≤6.58.916.5<xT=\left\{\begin{matrix}5.72 & x\leq 3.5\\ 6.75 &3.5\leqslant x\leq 6.5\\ 8.91 & 6.5<x\end{matrix}\right.T=⎩⎨⎧​5.726.758.91​x≤3.53.5⩽x≤6.56.5<x​

2.4 回归树VS线性回归
不多说了,直接看图甩代码

3. 总结
实际上,回归树总体流程类似于分类树,分枝时穷举每一个特征的每一个阈值,来寻找最优切分特征j和最优切分点s,衡量的方法是平方误差最小化。分枝直到达到预设的终止条件(如叶子个数上限)就停止。

当然,处理具体问题时,单一的回归树肯定是不够用的。可以利用集成学习中的boosting框架,对回归树进行改良升级,得到的新模型就是提升树(Boosting Decision Tree),在进一步,可以得到梯度提升树(Gradient Boosting Decision Tree,GBDT),再进一步可以升级到XGBoost。

转载——Regression Tree 回归树相关推荐

  1. Regression Tree 回归树

    1. 引言 AI时代,机器学习算法成为了研究.应用的热点.当前,最火的两类算法莫过于神经网络算法(CNN.RNN.LSTM等)与树形算法(随机森林.GBDT.XGBoost等),树形算法的基础就是决策 ...

  2. 学习——Regression Tree 回归树

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_40604987/arti ...

  3. 机器学习系列之手把手教你实现一个分类回归树

    https://www.ibm.com/developerworks/cn/analytics/library/machine-learning-hands-on5-cart-tree/index.h ...

  4. CART分类与回归树

    一.CART分类与回归树 资料转载: http://dataunion.org/5771.html http://blog.sina.com.cn/s/blog_afe2af380102x020.ht ...

  5. 机器学习实战(八)分类回归树CART(Classification And Regression Tree)

    目录 0. 前言 1. 回归树 2. 模型树 3. 剪枝(pruning) 3.1. 预剪枝 3.2. 后剪枝 4. 实战案例 4.1. 回归树 4.2. 模型树 学习完机器学习实战的分类回归树,简单 ...

  6. 用决策树模型求解回归问题(regression tree)

    How do decision trees for regression work? 决策树模型既可以求解分类问题(对应的就是 classification tree),也即对应的目标值是类别型数据, ...

  7. tree | 分类回归树模型

    专注系列化.高质量的R语言教程 推文索引 | 联系小编 | 付费合集 分类回归树(Classification and Regression Trees,CART)模型分为分类树模型和回归树模型:当因 ...

  8. CART决策树(分类回归树)分析及应用建模

    一.CART决策树模型概述(Classification And Regression Trees)   决策树是使用类似于一棵树的结构来表示类的划分,树的构建可以看成是变量(属性)选择的过程,内部节 ...

  9. 机器学习算法之CART(分类回归树)概要

    分类回归树  classification and regression tree(C&RT)  racoon 优点 (1)可自动忽略对目标变量没有贡献的属性变量,也为判断属性变量的重要性,减 ...

最新文章

  1. 解决1px的border在移动端变粗的问题
  2. SQL How to get the current day month and year
  3. 亲戚再也看不见我一个人食吉野家了
  4. java post流_Java后端HttpClient Post提交文件流 及服务端接收文件流
  5. 请结合计算机硬件论述指令执行的过程,【计算机组成原理】计算机软硬件组成...
  6. android 7.0 按钮崩溃,Android 7.0调用相机崩溃详解及解决办法
  7. python no module named pandas_【原创】大叔经验分享(11)python引入模块报错ImportError: No module named pandas numpy...
  8. JMETER badboy 录制脚本
  9. 【实用工具】GLIBC降级
  10. 华为eNSP BUG——关于OSPF Router ID选择问题
  11. 百度统计挂了,分布式数据库异常引起,数据显示为空!
  12. Android—构建安全的Android客户端请求,避免非法请求
  13. pandas读取与存储操作详解
  14. error while loading shared libraries: libtinfo.so.5
  15. EMD 经验模态分解
  16. 最简单直接粗暴的Mothur分析OTU教程
  17. 拓展:将simulink的仿真图_在matlab画出_复制到word
  18. 双节本世纪仅有三次,特送福利!
  19. matlab——红绿灯颜色及数字识别(一)
  20. 微信分享解决wx not defined

热门文章

  1. 【博客话题】爱上Linux的N+1个理由
  2. 80年代出生人坦白十大尴尬事80一代全搜集
  3. 实现 Virtual DOM 下的一个 VNode 节点
  4. 国内最火5款Java微服务开源项目
  5. web容器 ejb容器_容器实用指南
  6. JavaScript类型强制解释
  7. react开发_我如何在#100DaysOfCode挑战期间找到React开发人员的工作
  8. 我五年来都没来过 我的意志力飞涨。
  9. 远程桌面连接一个域网的计算机,怎样远程控制局域网的另一台电脑(远程桌面)win10...
  10. 小米mysql安装教程_小米 SOAR 开源SQL优化工具安装