Gradient-Boosted Trees(GBT或者GBDT) 和 RandomForests 都属于集成学习的范畴,相比于单个模型有限的表达能力,组合多个base model后表达能力更加丰富。

关于集成学习的理论知识,包括GBT和Random Forests的一些比较好的参考资料:

周志华教授的"Ensemble Methods: Foundations and Algorithms",系统的介绍了集成学习的理论及方法

Random Forests

Greedy Function Approximation: A GradientBoosting Machine

Stochastic GradientBoosting,Spark GBT实现所参考的算法

GBT和Random Forests二者的区别:

二者的理论思想在spark.mllib源码阅读-bagging方法中从模型的方差和偏差的角度做了一些简要的介绍,在Spark官方文档上也有一段关于二者优劣的描述:

1、GBT比RandomForests的训练成本更高,原因在于GBT中各个模型以序列串行的方式进行训练,通常我们说的并行GBT是指base model的并行训练,各个base model之间是无法做到并行的。而Random Forests

中各个子模型之间可以做到并行化。

2、Random Forests的base model越多越有助于降低过拟合,而GBT中base model越多会提高过拟合的程度。

3、二者训练的时间成本不同,因此调参的成本不同。有限的时间内Random Forests可以实验更多的参数组合。

4、通常来看,Random Forests的base model会得到一棵规模适中的树,而GBT为了降低在basemodel数量多时引发的过拟合,会限制其base model的规模。

下面来看看Spark中GBT的实现过程,主要包括3部分:GBT模型、GBT参数配置、GBT训练算法:

GradientBoostedTrees:
GBT的实现过程由GradientBoostedTrees类驱动并向用户暴露模型的训练方法。GradientBoostedTrees的2个关键方法是train和run,在run中,根据用户定义的模型配置类boostingStrategy来调用ml包下的GradientBoostedTrees类进行模型的训练,最后根据训练得到的参数来新建一个GradientBoostedTreesModel:

def train(
           input: RDD[LabeledPoint],
           boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
  new GradientBoostedTrees(boostingStrategy, seed = 0).run(input)
  }
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
  val algo = boostingStrategy.treeStrategy.algo
  //import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT}
  val (trees, treeWeights) = NewGBT.run(input.map { point =>
    NewLabeledPoint(point.label, point.features.asML)
  }, boostingStrategy, seed.toLong)
  new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
GradientBoostedTreesModel:
GradientBoostedTreesModel用来保存训练后的模型,其继承自TreeEnsembleModel。各个Base model保存在trees数组中,每个base model的权重在treeWeights数组中,

其父类TreeEnsembleModel实现的predict方法即是对各个base model的预测值加权treeWeights 得到最终的预测值。

class GradientBoostedTreesModel @Since("1.2.0") (
    @Since("1.2.0") override val algo: Algo, //模型算法:分类 or 回归
    @Since("1.2.0") override val trees: Array[DecisionTreeModel], //base model的数组
    @Since("1.2.0") override val treeWeights: Array[Double]) //每个base model的权重
  extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
BoostingStrategy
GBT的配置信息类,可配置的信息包括

treeStrategy:base tree的配置信息

Loss:损失函数,默认参数为2分类问题用LogLoss, 回归问题用SquaredError

numIterations:GBT的迭代次数,默认值为100

learningRate:学习速率,默认值为0.1

validationTol:通过验证集判断训练终止的条件:验证集上历史最小的残差 - 验证集当前残差 < validationTol*max(验证集当前残差, 0.01) 即提前终止训练

在训练GBT时,base tree的参数设置也很重要,base tree的参数由Strategy类维护,Strategy的默认值如下,在训练GBT时,务必要重新设置Strategy的值,这里我对可以设定的值都做了备注,方便初次使用的同学进行调参:

@Since("1.0.0") @BeanProperty var algo: Algo,//算法的类别:分类还是回归 {Classification、Regression}

@Since("1.0.0") @BeanProperty var impurity: Impurity,//计算信息增益的准则 分类{基尼指数、信息增益} 回归{impurity.Variance}

@Since("1.0.0") @BeanProperty var maxDepth: Int, //树的最大深度

@Since("1.2.0") @BeanProperty var numClasses: Int = 2,//类别数

@Since("1.0.0") @BeanProperty var maxBins: Int = 32,//连续特征离散化的分箱数

@Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,//计算分裂点的算法,待定

@Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),//存储每个分类特征的值数目

@Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,//子结点拥有的最小样本实例数,一个终止条件

@Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,//最小的信息增益值,这个应该是用来控制迭代终止的

@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,//聚合使用的内存大小。待定

@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,//用于训练数据的抽样率

@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,//待定

@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10 //checkpoint

模型的损失函数在BoostingStrategy类中自动设置,在二分类模型中损失函数被定义为LogLoss(对数损失函数)、在回归问题中损失函数被定义为SquaredError(平方损失函数)。在Spark2.1.0版本中还没有实现对多分类GBT的损失函数及多分类GBT模型。对于自定义损失函数,需要继承org.apache.spark.mllib.tree.loss.Loss这个类,并覆写gradient和computeError方法。

GradientBoostedTrees:
GradientBoostedTrees类是Spark训练GBT模型参数的类,模型的训练主要分为2步:1、将分类问题转化为回归问题,在GradientBoostedTrees的run方法中完成:

def run(
    input: RDD[LabeledPoint],
    boostingStrategy: OldBoostingStrategy,
    seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
  val algo = boostingStrategy.treeStrategy.algo
  //都转化为回归问题
  algo match {
    case OldAlgo.Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
    case OldAlgo.Classification =>
      // Map labels to -1, +1 so binary classification can be treated as regression.
      val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
      GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, seed)
    case _ => throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
  }
}
2、问题统一转化为回归问题后,调用GradientBoostedTrees的boost进行参数的训练,看一下整个训练过程的核心代码(在源码的基础上有删减):

// Initialize gradient boosting parameters
val numIterations = boostingStrategy.numIterations //总的迭代次数,决定了生成
val baseLearners = new Array[DecisionTreeRegressionModel](numIterations) //保存每次迭代的base模型的数组
val baseLearnerWeights = new Array[Double](numIterations)//模型权重?
val loss = boostingStrategy.loss //定义的损失函数
val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol
treeStrategy.algo = OldAlgo.Regression //org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
treeStrategy.impurity = OldVariance
treeStrategy.assertValid()
// Cache input
val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
  input.persist(StorageLevel.MEMORY_AND_DISK)
  true
} else {
  false
}
// Prepare periodic checkpointers 定期Checkpointer
val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
  treeStrategy.getCheckpointInterval, input.sparkContext)
val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
  treeStrategy.getCheckpointInterval, input.sparkContext)
 
val firstTree = new DecisionTreeRegressor().setSeed(seed)
//实际是用随机森林训练的一棵树,GBT中树的深度通常较小
//RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val firstTreeModel = firstTree.train(input, treeStrategy)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
//(预测值,误差值)
//如改成多分类的话应该是(list<pred>, list<Error>) 即每棵树的预测值和误差值
var predError: RDD[(Double, Double)] =
  computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
predErrorCheckpointer.update(predError)
var validatePredError: RDD[(Double, Double)] =
  computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
if (validate) validatePredErrorCheckpointer.update(validatePredError)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1
var m = 1
var doneLearning = false
while (m < numIterations && !doneLearning) {
  // Update data with pseudo-residuals
  //predError (预测值,误差值) 预测值是前m-1轮的预测值之和,误差值为lable-预测值
  //如改成多分类的话 此时该样本的loss即可以用logitloss来表示,并对f1~fk都可以算出一个梯度,f1~fk便可以计算出当前轮的残差,供下一轮迭代学习。
  val data = predError.zip(input).map { case ((pred, _), point) =>
    LabeledPoint(-loss.gradient(pred, point.label), point.features)//
  }
  val dt = new DecisionTreeRegressor().setSeed(seed + m)
  val model = dt.train(data, treeStrategy)//训练下一个base model
  // Update partial model
  baseLearners(m) = model
  // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
  //       Technically, the weight should be optimized for the particular loss.
  //       However, the behavior should be reasonable, though not optimal.
  // 这里learningRate是一个固定值,没有使用shrinkage技术
  baseLearnerWeights(m) = learningRate // learningRate同时作为model的权重
  predError = updatePredictionError(
    input, predError, baseLearnerWeights(m), baseLearners(m), loss)
  predErrorCheckpointer.update(predError)
  if (validate) {//验证集,验证是否提前终止训练
    // Stop training early if
    // 1. Reduction in error is less than the validationTol or
    // 2. If the error increases, that is if the model is overfit.
    // We want the model returned corresponding to the best validation error.
    validatePredError = updatePredictionError(
      validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
    validatePredErrorCheckpointer.update(validatePredError)
    val currentValidateError = validatePredError.values.mean()
    if (bestValidateError - currentValidateError < validationTol * Math.max(
      currentValidateError, 0.01)) {
      doneLearning = true
    } else if (currentValidateError < bestValidateError) {
      bestValidateError = currentValidateError
      bestM = m + 1
    }
  }
  m += 1
}
GBT的训练是一个串行的过程,base treemodel在前一轮迭代残差的基础上逐棵生成。每次生成一棵树之后需要更新整个数据集的残差,再进行下一轮的训练。在数据集规模较大,并且迭代轮次比较多时,训练比较耗时,这在一定程度上增加了模型调参的成本。

截至Spark2.0.0,Spark的GBT模型比较初级,在分类问题上目前只支持2分类问题,梯度下降的过程控制也比较简单,难于适应一些精度要求高的的机器学习任务,因此目前版本下的Spark来做GBT模型并不是一个好的选择。相比较而言,XGBOOST是一个更好的选择,当然,有条件的情况下顺着Spark GBT的思路做一些改进也能达到接近的效果。
————————————————
版权声明:本文为CSDN博主「大愚若智_」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/zbc1090549839/article/details/70240989

spark.mllib:GradientBoostedTrees相关推荐

  1. spark.mllib:NaiveBayes

    朴素贝叶斯模型简述: 贝叶斯模型通过使用后验概率和类的概率分布来估计先验概率,具体的以公式表达为 P(Y)可以使用训练样本的类分布进行估计.如果X是单特征也很好估计,但如果X={x1,x2,..,xn ...

  2. spark.mllib:bagging方法

    在训练集成分类器时,关键的一步在于如何从全训练样本集中构建子样本集提供给子分类器进行训练.目前主流的两种子样本集构造方式是bagging方法和boosting方法.bagging方法的思想是从全样本集 ...

  3. spark.mllib:回归算法

    Spark实现了三类线性回归方法: 1.LinearRegression:普通线性回归模型 2.LassoRegression:加L1正则化的线性回归 3.RidgeRegression:加L2正则化 ...

  4. spark.mllib:Optimizer

    Spark中的求解器,根据输入的训练数据及设定的迭代次数.正则化项.参数收敛精度等进行迭代求解模型的参数.Spark内部实现来两类求解器,基于随机梯度下降(miniBatch选取样本)的Gradien ...

  5. Spark MLlib实现的广告点击预测–Gradient-Boosted Trees

    关键字:spark.mllib.Gradient-Boosted Trees.广告点击预测 本文尝试使用Spark提供的机器学习算法 Gradient-Boosted Trees来预测一个用户是否会点 ...

  6. spark mllib实现 广告点击率预测

    本文尝试使用Spark提供的机器学习算法 Gradient-Boosted Trees来预测一个用户是否会点击广告. 训练和测试数据使用Kaggle Avazu CTR 比赛的样例数据,下载地址:ht ...

  7. spark.mllib源码阅读:GradientBoostedTrees

    Gradient-Boosted Trees(GBT或者GBDT) 和 RandomForests 都属于集成学习的范畴,相比于单个模型有限的表达能力,组合多个base model后表达能力更加丰富. ...

  8. 机器学习_机器不学习:从Spark MLlib到美图机器学习框架实践

    / 机器学习简介 / 在深入介绍 Spark MLlib 之前先了解机器学习,根据维基百科的介绍,机器学习有下面几种定义: 机器学习是一门人工智能的科学,该领域的主要研究对象是人工智能,特别是如何在经 ...

  9. Spark大数据分析与实战:基于Spark MLlib 实现音乐推荐

    Spark大数据分析与实战:基于Spark MLlib 实现音乐推荐 基于Spark MLlib 实现音乐推荐 一.实验背景: 熟悉 Audioscrobbler 数据集 基于该数据集选择合适的 ML ...

最新文章

  1. 【转】Android OTA 升级之一:编译升级包
  2. TensorFlow-Bitcoin-Robot:Tensorflow 比特币交易机器人
  3. AtCoder AGC032F One Third (组合计数、DP、概率期望、微积分)
  4. 452. 用最少数量的箭引爆气球(贪心算法+思路+详解)07
  5. python写的小程序怎么封装_微信小程序源码分享之封装request的方法
  6. 165体重_身高165的女性体重多少比较合适?过胖和过瘦都不太好
  7. 联通3G用户破千万 建成全球规模最大WCDMA网络
  8. 正定二次函数的共轭梯度法matlab实现
  9. 游戏策划笔记:工作感受感官引导
  10. CRA图像 Clean Random Access (CRA) Pictures
  11. uniapp 分享到微信、QQ、朋友圈
  12. 杭州php程序员工资一般多少,杭州Android基础一期大黑马强哥,完美收官~~欧巴,卡几嘛...
  13. 观看:使用治具和工厂管理样本数据
  14. 技嘉服务器主板型号,服务器配件 全面认识技嘉服务器主板
  15. P1287 盒子与球题解【python】
  16. C#:Krypton控件使用方法详解(第九讲) ——kryptonRadioButton
  17. SMBIOS信息概述 -- DMI
  18. 个人简介(北京求职中)
  19. 图片提取文字怎么做?两种途径可以一试
  20. 基于DMA通道的连续ADC扫描读取

热门文章

  1. 电信服务器维修人员职责,维修人员岗位职责
  2. final const java_Java中的final关键字 与 C#中的const, readonly关键字
  3. 系统安全及应用(账户安全控制,系统引导和登录,弱口令检测和登录控制,PAM认证,端口扫描,用户切换和提权)
  4. Apache网页优化概述
  5. Linux文件系统与日志分析(inode、inode节点耗尽故障处理、文件备份和恢复、日志文件管理)
  6. 传统公司部署OpenStack(t版)简易介绍(一)——环境部署
  7. git status 不能显示中文
  8. android如何不用系统签名,更新Android系统应用程序,带/不带平台签名
  9. 对于局部变量_2020年对于JDK ,大家觉得哪个版本好用?
  10. spark 广播变量大数据_Spark基础知识(三)--- Spark的广播变量和累加器