【机器学习】基于GBDT的数据回归及python实现

  • 一、Boosting技术
  • 二、GBDT原理与应用
    • 2.1、GBDT思想
    • 2.2、基于GBDT的数据回归
    • 2.3、基于GBDT的数据分类
  • 三、基于GBDT的数据回归的python实现
  • 参考资料

本博文中GBDT(Gradient Boosting Decison Tree,梯度提升决策树)的决策树采用CART(Classification and Regression Tree,分类回归树)。CART主要参考 《python机器学习算法》,在这里就不赘述了,有需要的同学可以看大神的书。本博文首先介绍集成学习中的Boosting技术,然后介绍GBDT的原理及应用,在GBDT的应用部分重点介绍基于GBDT的数据回归技术,并顺便提及基于GBDT的数据分类思想(只需对回归做一点小的改进即可)。

一、Boosting技术

在面对一个复杂的问题时,我们可以训练多个模型,将这些模型的结果集成起来作为最终的结果,这种处理问题的思想被称为集成学习(参考资料【2】),集成学习分为两类:Bagging和Boosting。Bagging技术是指在样本中有放回的随机抽取几组样本数据,用这几组样本数据训练几个模型,最终由这几个模型投票决定分类结果或取均值得到回归结果,随机森林(参考资料【3】)就是一个典型的Bagging模型。Boosting技术同样是训练多个弱性能模型,但不同于Bagging技术中弱性能模型是并行且独立训练的,在Boosting中多个弱性能模型是串行的。Boosting对样本进行不同的赋值,对错误学习的样本的权重设置的较大,在后续的学习中集中处理难学的样本,最终得到一系列的预测结果,每个预测结果有一个权重。

二、GBDT原理与应用

2.1、GBDT思想

GBDT由两部分组成,分别是GB(Gradient Boosting)和DT(Decison Tree)。博文开头已经说明,DT采用CART的回归树。GB指的是沿着梯度方向,构造一系列的弱分类器函数,并以一定权重组合起来,形成最终决策的强分类器(参考资料【4】)。

2.2、基于GBDT的数据回归

假设总共有mmm个串行CART回归树模型,每个CART模型单独拿出来的预测结果为:

fi(X),{i=1,2,⋯ ,m}{f_i}\left( X \right),\left\{ {i = 1,2, \cdots ,m} \right\}fi​(X),{i=1,2,⋯,m}

前iii个模型串行在一起的预测结果为:

Fi(X),{i=1,2,⋯ ,m}{F_i}\left( X \right),\left\{ {i = 1,2, \cdots ,m} \right\}Fi​(X),{i=1,2,⋯,m}

其中Fi(X)=Fi−1(X)+fi(X){F_i}\left( X \right) = {F_{i - 1}}\left( X \right) + {f_i}\left( X \right)Fi​(X)=Fi−1​(X)+fi​(X)

假设样本标签为Y,前i−1i-1i−1个模型串行的预测输出与标签真实值之间的损失函数为:

Loss(Y,Fi−1(X))=1N∑j=1N(yj−Fi−1(xj))2Loss\left( {Y,{F_{i-1}}\left( X \right)} \right) = \frac{1}{N}\sum\limits_{j = 1}^N {{{\left( {{y_j} - {F_{i-1}}\left( {{x_j}} \right)} \right)}^2}}Loss(Y,Fi−1​(X))=N1​j=1∑N​(yj​−Fi−1​(xj​))2

当训练第iii个回归树模型时,我们的目的是前iii个模型串行预测输出与样本标签真实值之间的损失函数Loss(Y,Fi−1(X)+fi(X))Loss\left( {Y,{F_{i-1}}\left( X \right) + {f_{i }}\left( X \right)} \right)Loss(Y,Fi−1​(X)+fi​(X))最小。

为了训练出模型fi(X)f_i\left( X \right)fi​(X),GB方法利用最速下降的近似方法,即利用损失函数的负梯度在当前模型的值,作为回归问题中提升树算法的残差的近似值,去拟合当前的回归树模型。

rij=−[∂Loss(Y,f(xj))∂f(xj)]f(x)=fi−1(x){r_{ij}} = - {\left[ {\frac{{\partial Loss\left( {Y,f\left( {{x_j}} \right)} \right)}}{{\partial f\left( {{x_j}} \right)}}} \right]_{f\left( x \right) = {f_{i - 1}}\left( x \right)}}rij​=−[∂f(xj​)∂Loss(Y,f(xj​))​]f(x)=fi−1​(x)​

rijr_{ij}rij​表示求第iii个回归树模型时,第jjj个样本特征对应的残差值。利用{(x1,ri1),(x2,ri2),⋯ ,(xN,riN)}\left\{ {\left( {{x_1},{r_{i1}}} \right),\left( {{x_2},{r_{i2}}} \right), \cdots ,\left( {{x_N},{r_{iN}}} \right)} \right\}{(x1​,ri1​),(x2​,ri2​),⋯,(xN​,riN​)}去拟合第iii个CART模型。

2.3、基于GBDT的数据分类

以GBDT二分类为例,样本类别为(0,1)(0,1)(0,1),将每一轮模型训练后的预测的结果利用sigmoid函数映射到0到1之间,最后预测值大于等于0.5的样本为类别1,预测值小于0.5的样本为类别0。

三、基于GBDT的数据回归的python实现

python代码与样本地址:https://github.com/shiluqiang/GBDT_regression
基于GBDT的数据回归中训练每一个CART模型,直接采用前面所有模型的串行输出与样本标签的残差为训练目标(参考资料【5】)。
GBDTpython代码如下

import CART_regression_tree
import numpy as npdef load_data(data_file):'''导入训练数据input:  data_file(string):保存训练数据的文件output: data(list):训练数据'''data_X = []data_Y = []f = open(data_file)for line in f.readlines():sample = []lines = line.strip().split("\t")data_Y.append(float(lines[-1]))for i in range(len(lines) - 1):sample.append(float(lines[i]))  # 转换成float格式data_X.append(sample)f.close()    return data_X,data_Yclass GBDT_RT(object):'''GBDT回归算法类'''def __init__(self):self.trees = None ##用于存放GBDT的树self.learn_rate = learn_rate ## 学习率,防止过拟合self.init_value = None ##初始数值self.fn = lambda x: xdef get_init_value(self,y):'''计算初始数值为平均值input:y(list):样本标签列表output:average(float):样本标签的平均值'''average = sum(y)/len(y)return averagedef get_residuals(self,y,y_hat):'''计算样本标签标签与预测列表的残差input:y(list):样本标签列表y_hat(list):预测标签列表output:y_residuals(list):样本标签标签与预测列表的残差'''y_residuals = []for i in range(len(y)):y_residuals.append(y[i] - y_hat[i])return y_residualsdef fit(self,data_X,data_Y,n_estimators,learn_rate,min_sample, min_err):'''训练GBDT模型input:self(object):GBDT_RT类data_X(list):样本特征data_Y(list):样本标签n_estimators(int):GBDT中CART树的个数learn_rate(float):学习率min_sample(int):学习CART时叶节点的最小样本数min_err(float):学习CART时最小方差'''## 初始化预测标签和残差self.init_value = self.get_init_value(data_Y)n = len(data_Y)y_hat = [self.init_value] * n ##初始化预测标签y_residuals = self.get_residuals(data_Y,y_hat)self.trees = []self.learn_rate = learn_rate## 迭代训练GBDTfor j in range(n_estimators):idx = range(n)X_sub = [data_X[i] for i in idx] ## 样本特征列表residuals_sub = [y_residuals[i] for i in idx] ## 标签残差列表tree = CART_regression_tree.CART_RT(X_sub,residuals_sub, min_sample, min_err).fit()res_hat = [] ##残差的预测值for m in range(n):res_hat.append(CART_regression_tree.predict(data_X[m],tree))## 计算此时的预测值等于原预测值加残差预测值y_hat = [y_hat[i] + self.learn_rate * res_hat[i] for i in idx]y_residuals = self.get_residuals(data_Y,y_hat)self.trees.append(tree)def GBDT_predict(self,xi):'''预测一个样本'''return self.fn(self.init_value + sum(self.learn_rate * CART_regression_tree.predict(xi,tree) for tree in self.trees))def GBDT_predicts(self,X):'''预测多个样本'''return [self.GBDT_predict(xi) for xi in X]def error(Y_test,predict_results):'''计算预测误差input:Y_test(list):测试样本标签predict_results(list):测试样本预测值output:error(float):均方误差'''Y = np.mat(Y_test)results = np.mat(predict_results)error = np.square(Y - results).sum() / len(Y_test)return errorif __name__ == '__main__':print ("------------- 1.load data ----------------")X_data,Y_data = load_data("sine.txt") X_train = X_data[0:150]Y_train = Y_data[0:150]X_test = X_data[150:200]Y_test = Y_data[150:200]print('------------2.Parameters Setting-----------')n_estimators = 4learn_rate = 0.5min_sample = 30min_err = 0.3print ("--------------3.build GBDT ---------------")gbdt_rt = GBDT_RT()gbdt_rt.fit(X_train,Y_train,n_estimators,learn_rate,min_sample, min_err)print('-------------4.Predict Result--------------')predict_results = gbdt_rt.GBDT_predicts(X_test)print('--------------5.Predict Error--------------')error = error(Y_test,predict_results)print('Predict error is: ',error)

参考资料

1、《python机器学习算法》
2、https://blog.csdn.net/google19890102/article/details/46507387
3、https://blog.csdn.net/qq547276542/article/details/78304454
4、https://blog.csdn.net/legendavid/article/details/78904353
5、GBDT回归的原理及Python实现

【机器学习】基于GBDT的数据回归及python实现相关推荐

  1. 吴恩达机器学习作业2:逻辑回归(Python实现)

    逻辑回归 在训练的初始阶段,将要构建一个逻辑回归模型来预测,某个学生是否被大学录取.设想你是大学相关部分的管理者,想通过申请学生两次测试的评分,来决定他们是否被录取.现在你拥有之前申请学生的可以用于训 ...

  2. 吴恩达机器学习课程-作业1-线性回归(python实现)

    Machine Learning(Andrew) ex1-Linear Regression 椰汁学习笔记 最近刚学习完吴恩达机器学习的课程,现在开始复习和整理一下课程笔记和作业,我将陆续更新. Li ...

  3. Python+Django+Mysql开发在线美食推荐网 协同过滤推荐算法在美食网站中的运用 基于用户、物品的协同过滤推荐算法 个性化推荐算法、机器学习、分布式大数据、人工智能开发

    Python+Django+Mysql开发在线美食推荐网 协同过滤推荐算法在美食网站中的运用 基于用户.物品的协同过滤推荐算法 个性化推荐算法.机器学习.分布式大数据.人工智能开发 FoodRecom ...

  4. Python+Django+Mysql开发在线购物推荐网 协同过滤推荐算法在购物网站中的运用 个性化推荐算法开发 基于用户、物品的协同过滤推荐算法 机器学习、分布式大数据、人工智能开发

    Python+Django+Mysql开发在线购物推荐网 协同过滤推荐算法在购物网站中的运用 个性化推荐算法开发 基于用户.物品的协同过滤推荐算法 机器学习.分布式大数据.人工智能开发 ShopRec ...

  5. 在线新闻推荐网 Python+Django+Mysql开发技术 基于用户、物品的协同过滤推荐算法 个性化新闻推荐系统 协同过滤推荐算法在新闻网站中的运用 个性化推荐算法、机器学习、分布式大数据、人工智

    在线新闻推荐网 Python+Django+Mysql开发技术 基于用户.物品的协同过滤推荐算法 个性化新闻推荐系统 协同过滤推荐算法在新闻网站中的运用 个性化推荐算法.机器学习.分布式大数据.人工智 ...

  6. 在线图书推荐网 Python+Django+Mysql开发技术 个性化图书推荐系统 协同过滤推荐算法在图书网站中的运用 基于用户、物品的协同过滤推荐算法 个性化推荐算法、机器学习、分布式大数据、人工智

    在线图书推荐网 Python+Django+Mysql开发技术 个性化图书推荐系统 协同过滤推荐算法在图书网站中的运用 基于用户.物品的协同过滤推荐算法 个性化推荐算法.机器学习.分布式大数据.人工智 ...

  7. 在线音乐推荐网 Python+Django+Mysql开发技术 基于用户、物品的协同过滤推荐算法 个性化音乐推荐系统 音乐网站+协同过滤推荐算法 机器学习、分布式大数据、人工智能开发

    在线音乐推荐网 Python+Django+Mysql开发技术 基于用户.物品的协同过滤推荐算法 个性化音乐推荐系统 音乐网站+协同过滤推荐算法 机器学习.分布式大数据.人工智能开发 MusicRec ...

  8. 【数据分析与挖掘】基于LightGBM,XGBoost,逻辑回归的分类预测实战:英雄联盟数据(有数据集和代码)

    机器学习-LightGBM 一.LightGBM的介绍与应用 1.1 LightGBM的介绍 1.2 LightGBM的应用 二.数据集来源 三.基于英雄联盟数据集的LightGBM分类实战 Step ...

  9. 【机器学习基础】数学推导+纯Python实现机器学习算法13:Lasso回归

    Python机器学习算法实现 Author:louwill 第13讲和第14讲我们来关注一下回归模型的两个变种模型.本节我们要介绍的是基于L1正则化的Lasso模型,下一节介绍基于L2正则化的Ridg ...

最新文章

  1. leetcode10 为什么p[j-1] == '*'的时候,不能用递推公式dp[i][j] = dp[i][j-1] || dp[i][j-2] || dp[i-1][j]
  2. 想实现高可用?先搞定负载均衡原理
  3. LeetCode MySQL 580. 统计各专业学生人数
  4. 网络通信 netstat
  5. 使用JavaMail技术发送邮件
  6. 利用sobel算子提取图像的水平特征和竖直特征
  7. 二叉树:一入递归深似海,从此offer是路人
  8. 工具-破解pdf密码
  9. Python3.6支付宝账单爬虫
  10. windows_帮助文档【.CHM电子书】打开显示空白解决办法
  11. 2018年世界计算机超算大赛,在世界大学生超级计算机竞赛(ASC18)总决赛中 青海大学超算团队成功获得ASC竞赛全球一等奖...
  12. 快速上手 Android 蓝牙串口 SPP 开发
  13. ssdt函数索引号_【转】SSDT索引号的获取
  14. kal虚拟机统下安装open-vmware-tools
  15. Jetpack Compose 从入门到入门(六)
  16. Windows上安装Linux
  17. Matlab从视频中提取图像,可以设定每多少秒提取1帧。
  18. Linux 中 which、whereis、locate、find的区别
  19. 【ansys workbench】8.刚体平移和弱弹簧
  20. 逆水寒捏脸服务器维护,《逆水寒》2019年3月28日更新公告

热门文章

  1. 使用csscan评测字符集改变
  2. 如何做好数据安全治理
  3. python微博爬虫实战_Python爬虫实战演练:爬取微博大V的评论数据
  4. Javascript中的对象拷贝(对象复制/克隆)
  5. TypeScript算法专题 - [双链表1] - 双链的概念及其实现
  6. 无模型自适应迭代学习控制原理和matlab代码仿真学习记录
  7. JAVA中整型常量的长度,Java基础入门篇(三)——Java常量、变量,
  8. 2021年SWPUACM暑假集训day5单调栈算法
  9. 2020 年百度之星#183;程序设计大赛 - 初赛一
  10. 混淆矩阵评价指标_机器学习模型评价指标 -- 混淆矩阵