目录

1 GBDT算法核心思想

2 GBDT算法的数学原理

3 GBDT算法数学原理举例

补充知识点:梯度提升树中梯度的理解

4 使用sklearn实现GBDT算法

5 案例:产品定价模型

5.1 模型搭建

5.1.1 读取数据

5.1.2 分类型文本变量的处理

5.1.3 提取特征变量和目标变量

5.1.4 划分训练集的测试集

5.1.5 模型训练及搭建

5.2 模型预测及评估

6 模型参数介绍

参考书籍


1 GBDT算法核心思想

GBDT是Gradient Boosting Decision Tree(梯度提升树)的缩写。

GBDT算法也是一种非常实用的Boosting算法,它与AdaBoost算法的区别在于:AdaBoost算法根据分类效果调整权重并不断迭代,最终生成强学习器;GBDT算法则将损失函数的负梯度作为残差的近似值,不断使用残差迭代和拟合回归树,最终生成强学习器。

简单来说,AdaBoost算法是调整权重,而GBDT算法则是拟合残差。

通过一个简单案例理解GBDT算法的核心思想。

下表中有4个样本客户的数据,特征变量X1为年龄,X2为月收入(元),目标变量y是实际信用卡额度(元)。现在要利用GBDT算法根据样本数据构造模型,用于预测信用卡额度。

假设建立的第1棵决策树如下图所示。

A、C被划分到左节点,A的实际信用卡额度为8000,而预测值为10000,因此,A的残差为8000-10000=-2000,同理,C的残差为25000-20000=5000。B、D被划分到右节点,B的残差为30000-35000=-5000,D的残差为40000-35000=5000。

接下来就是GBDT算法的核心思想:构造第2棵决策树来拟合第1棵树产生的残差,注意这里拟合的是残差,构造的拟合残差的决策树如下图所示。

在这棵树中,A、B被划分到左节点,A的实际残差为-2000,而预测的残差为-3000,那么此时A的新残差,即残差的残差为-2000-(-3000)=1000,同理,B的新残差为-5000-(-5000)=0。C、D被划分到右节点,C的新残差为5000-5000=0,D的新残差为5000-5000=0。继续用第2棵树产生的新残差去拟合第3棵树,并不断重复此步骤,使残差变小。

因此,最终的模型就是如下图所示的集成在一起的多个模型,这也充分体现了集成算法的集成思想。

2 GBDT算法的数学原理

简单介绍GBDT算法的数学原理。

迭代模型为:

fm-1(x)是第m-1次迭代模型,即上一次的迭代模型;

Tm(x)是本次待搭建的决策树,其实也是拟合上一个模型残差值的决策树;

fm(x)是本次迭代后产生的新模型。

对GBDT算法来说,只需要简单地拟合当前模型的残差,算法步骤如下。

步骤1:初始化f0(x)=0

步骤2:当m=1,2,…,M,计算残差rmi=yi-fm-1(x);拟合残差,得到决策树Tm(x);更新fm(x)=fm-1(x)+Tm(x)。

步骤3:当误差或迭代次数达到指定要求时,得到回归问题提升树,如下所示。

3 GBDT算法数学原理举例

结合具体的回归问题,详细讲解GBDT算法的数学计算步骤。

给定的训练数据见下表,其中x为特征变量,y为目标变量,因为y为连续值,所以这是一个回归预测问题。

1.构造第一个模型

首先初始化f0(x)=0,然后构造第1个回归决策树模型f1(x),其实也就是T1(x)。

回归决策树模型的划分标准。回归决策树模型与分类决策树模型最大的不同就是其划分标准不是信息熵或基尼系数,而是均方误差MSE,其计算公式如下。

(1)寻找合适的初始切分点

构造第1棵决策树时,我们需要判断在x=1.5、x=2.5、x=3.5、x=4.5这4个位置中的哪个位置“切一刀”,才能使整体的均方误差MSE最小。

先在x=1.5处“切一刀”划分类别,即设置阈值v=1.5,意味着弱学习器f1(x)如下。

其实这就是如下图所示的一棵深度为1的回归决策树。

回归决策树中某个节点的预测值是该节点中所有数据的均值,因此,在上图右边的节点中,所有满足x>1.5的值都被预测为(0+2+2+4)/4=2。

此时的残差yi-f(xi)见下表。

此时的均方误差MSE为:

仿照上面的方法,计算出4种阈值下的均方误差MSE,见下表。

由上表可知,当v=2.5时均方误差MSE取得最小值,此时第1棵决策树f1(x)如下图所示。

(2)查看此时的残差

前面已经求得了第1棵决策树f1(x),因此可以获得其对应的残差,见下表。

获得残差之后,接着就需要构建新的决策树来拟合残差,此时整个系统的均方误差MSE为0.53。

2.拟合残差获得第2个模型

现在需要根据下表中的x(特征变量)和残差(目标变量)拟合出决策树T2(x)。

注意:此时第二棵树的实际值y为上一棵树的残差值。

(1)寻找合适的初始切分点

使用与前面相同的计算方法计算出4种阈值下的均方误差MSE,见下表。

由上表可知,当v=4.5时均方误差MSE取得最小值,此时第2棵决策树T2(x)如下图所示。

(2)查看此时的残差

此时残差的预测值T(x)及残差的残差(残差-T(x))见下表

注意上表中“残差的残差”即系统的残差,此时该拟合残差的决策树的均方误差MSE为0.088,这也是整个系统的均方误差MSE。

(3)集成模型

此时的集成模型f2(x)如下。

或者:

有了新模型的残差后,便可以继续构造新的决策树来拟合残差,直到系统的均方误差MSE达到指定要求或者迭代次数达到指定条件时,便停止迭代,形成最终模型。最终模型如下图所示。

因为GBDT算法是不停地拟合新模型的残差,所以随着新的迭代,整个系统的残差会越来越小,或者更精确地说,系统的均方误差MSE会越来越小,从而使得模型更加准确。

补充知识点:梯度提升树中梯度的理解

之前定义的残差为y-f(xi),而实际应用中,GBDT(Gradient Boosting DecisionTree)梯度提升树使用损失函数的负梯度在当前模型的值作为残差近似值。负梯度的定义如下,其中L(y,f(xi))为损失函数。

其实这个负梯度在特定损失函数的情况下,就是之前定义的残差y-f(xi),令损失函数为:

此时对损失函数求负梯度:

此时负梯度就等于残差y-f(xi),也就是说,当损失函数是平方函数时,负梯度就是残差。

不过当损失函数不是平方函数时,负梯度只是残差的近似值,并不完全等于残差。

4 使用sklearn实现GBDT算法

GBDT算法既能做分类分析,又能做回归分析。

对应的模型分别为GBDT分类模型(GradientBoostingClassifier)和GBDT回归模型(GradientBoostingRegressor)。

GBDT分类模型的弱学习器是分类决策树模型,GBDT回归模型的弱学习器则是回归决策树模型。

代码如下:

# GBDT分类模型
from sklearn.ensemble import GradientBoostingRegressor
X = [[1,2],[3,4],[5,6],[7,8],[9,10]]
y = [1,2,3,4,5]
model = GradientBoostingRegressor(random_state=123)
model.fit(X,y)
model.predict([[5,5]])# 输出
# array([2.54911351])

5 案例:产品定价模型

5.1 模型搭建

5.1.1 读取数据

首先读取1000种图书的数据,这里为了方便演示,只选取了4个特征变量,包括图书的页数、类别、彩印和纸张,目标变量是图书的价格。

查看图书类别以及纸张分类情况:

5.1.2 分类型文本变量的处理

因为“类别”和“纸张”两列是分类型文本变量,所以可以用LabelEncoder()函数进行数值化处理,便于后续进行模型拟合。

可以看到,“类别”列中的“技术类”被转换为数字1,“教辅类”被转换为数字2,“办公类”被转换为数字0。

此时df的前5行见下表。

5.1.3 提取特征变量和目标变量

5.1.4 划分训练集的测试集

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2,random_state=123)

5.1.5 模型训练及搭建

5.2 模型预测及评估

模型搭建完毕后,就可以对测试集数据进行预测,代码如下。

通过如下代码汇总预测值和实际值,以便进行对比。

用模型自带的score()函数查看模型的预测效果,代码如下。

获得的模型准确度评分score为0.874,说明模型的预测效果不错。

这个评分其实就是模型的R-squared值(即统计学中的R2,一元线性回归中有讲解)。

也可以通过如下代码查看模型的R-squared值,来评估模型的拟合程度。

为了更科学合理地进行产品定价,可以通过如下代码查看各个特征变量的特征重要性,以便筛选出对价格影响最大的特征变量。

6 模型参数介绍

GBDT回归模型参数,更多查看可以在Jupyter Notebook中输入并运行如下代码:

from sklearn.ensemble import GradientBoostingRegressor
GradientBoostingRegressor ?

参考书籍

《Python大数据分析与机器学习商业案例实战》

GBDT模型及案例(Python)相关推荐

  1. 用通俗易懂的方式讲解:决策树模型及案例(Python 代码)

    文章目录 1 决策树模型简介 2 Gini系数(CART决策树) 3 信息熵.信息增益 4 决策树模型代码实现 4.1 分类决策树模型(DecisionTreeClassifier) 4.2 回归决策 ...

  2. 逻辑回归模型及案例(Python)

    1 简介 逻辑回归也被称为广义线性回归模型,它与线性回归模型的形式基本上相同,最大的区别就在于它们的因变量不同,如果是连续的,就是多重线性回归:如果是二项分布,就是Logistic回归. Logist ...

  3. 用通俗易懂的方式讲解:逻辑回归模型及案例(Python 代码)

    目录 1 简介 2 优缺点 3 适用场景 加入方式 4 案例:客户流失预警模型 4.1 读取数据 4.2 划分特征变量和目标变量 4.3 模型搭建与使用 4.3.1 划分训练集与测试集 4.3.2 模 ...

  4. 逻辑回归 + GBDT模型融合实战!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:吴忠强,东北大学,Datawhale成员 一.GBDT+LR简介 ...

  5. opencv 训练人脸对比_【项目案例python与人脸识别】基于OpenCV开源计算机视觉库的人脸识别之python实现...

    " 本项目是一个基于OpenCV开源库使用python语言程序实现人脸检测的项目,该项目将从[项目基础知识](即人脸识别的基本原理).[项目实践](人脸识别所需要的具体步骤及其python程 ...

  6. 【机器学习基础】逻辑回归 + GBDT模型融合实战!

    作者:吴忠强,东北大学,Datawhale成员 一.GBDT+LR简介 协同过滤和矩阵分解存在的劣势就是仅利用了用户与物品相互行为信息进行推荐, 忽视了用户自身特征, 物品自身特征以及上下文信息等,导 ...

  7. Seq2Seq模型应用案例

    Seq2Seq模型应用案例( : Seq2Seq是Encoder-Decoder(编码器与解码器)模型,输入是一个序列,输出也是一个序列,适用于输入序列与输出序列长度不等长的场景,如机器翻译.人机对话 ...

  8. 联邦学习算法介绍-FedAvg详细案例-Python代码获取

    联邦学习算法介绍-FedAvg详细案例-Python代码获取 一.联邦学习系统框架 二.联邦平均算法(FedAvg) 三.联邦随梯度下降算法 (FedSGD) 四.差分隐私随联邦梯度下降算法 (DP- ...

  9. R语言使用caret包构建gbdt模型(随机梯度提升树、Stochastic Gradient Boosting )构建回归模型、通过method参数指定算法名称

    R语言使用caret包构建gbdt模型(随机梯度提升树.Stochastic Gradient Boosting )构建回归模型.通过method参数指定算法名称 目录

最新文章

  1. 关于logrotate工具的日志切割
  2. springboot集成测试时@RunWith和@SpringBootTest爆红不能测试
  3. Chrome 隐藏 SSL 证书信息 禁止禁用 DRM
  4. postgis启动_PostgreSQL的安装和启动方法大全
  5. f1 score 代码_腾讯广告算法大赛冠军代码解读:稠密特征工程
  6. 男厕改女厕能多敷衍......
  7. php定义数据表类,phpwind中的数据库操作类
  8. 【原创】modb 功能设计之“支持多消费者单生产者”
  9. solr之搭建企业搜索平台,配置文件详细solrconfig.xml
  10. linux 终端 画圆,Linux Bash Shell快速入门
  11. Netpas:不一样的SD-WAN+ 保障网络通讯品质
  12. java编程给三个数字排序_JAVA程序.输入3个数字,有IF语句,从小到大排序
  13. 白嫖 Moss 斯坦福文件查重
  14. c语言factors函数的意思,factors是什么意思_factors的翻译_音标_读音_用法_例句_爱词霸在线词典...
  15. 香港坚固金业的黑幕,属于非法投资平台。
  16. 关于Android的post,get、cookie网络获取的一些坑
  17. AR红包大战一触即发,2017年会成AR营销元年吗?
  18. android studio linux 字体,Android Studio代码字体模糊解决方法
  19. 工作多年,对程序员“未来”的一些看法
  20. [图像] 金字塔模型

热门文章

  1. SQL Loader的使用详解
  2. (SQL入门经典+SQL必知必会+视频)笔记之一
  3. MySQL 正则表达式(REGEXP)与 like
  4. 橘子平台origin安装闪退?
  5. 仿生蛇形机器人03、Dynamixel MX-64AR舵机串联两个修改Demo(例程)进行调节
  6. 原生js做无限弹窗(娱乐)
  7. STM32(HAL库)流水灯配置及代码
  8. 欧文分校的计算机科学,有关美国加州大学欧文分校计算机科学专业.pdf
  9. 网络安全技术——ACL技术(访问控制列表)
  10. wordPress数据结构 数据库中的表、字段、类型及说明