GBDT on Spark on AngelGBDT(Gradient Boosting Decision Tree):梯度提升决策树 是一种集成使用多个弱分类器(决策树)来提升分类效果的机器学习算法,在很多分类和回归的场景中,都有不错的效果。

1. 算法介绍

如图1所示,这是是对一群消费者的消费力进行预测的例子。简单来说,处理流程为:

在第一棵树中,根节点选取的特征是年龄,年龄小于30的被分为左子节点,年龄大于30的被分为右叶子节点,右叶子节点的预测值为1;

第一棵树的左一节点继续分裂,分裂特征是月薪,小于10K划分为左叶子节点,预测值为5;工资大于10k的划分右叶子节点,预测值为10

建立完第一棵树之后,C、D和E的预测值被更新为1,A为5,B为10

根据新的预测值,开始建立第二棵树,第二棵树的根节点的性别,女性预测值为0.5,男性预测值为1.5

建立完第二棵树之后,将第二棵树的预测值加到每个消费者已有的预测值上,比如A的预测值为两棵树的预测值之和:5+0.5=5.5

通过这种方式,不断地优化预测准确率。

2. 分布式训练

GBDT的训练方法中,核心是一种叫梯度直方图的数据结构,需要为每一个特征建立一阶梯度直方图和二阶梯度直方图。梯度直方图的大小与三个因素有关:特征数量、分裂点数量、分类数量和树节点的数量。

已有的分布式GBDT系统使用数据并行的方法,将训练数据按行进行切分,每个计算节点使用分配的数据集建立梯度直方图,通过网络汇总这些梯度直方图后,计算得出最佳的分裂点。但是当训练数据维度高、分类多、树深度大的时候,梯度直方图的大小较大,数据并行的训练方法有几个缺点:

每个计算节点都需要存储一份完整的梯度直方图,存储的开销大。在存储空间有限时,限制了GBDT的适用性。

计算节点之间需要通过网络传输本地的梯度直方图,网络通信的开销大。

为了解决数据并行的训练方式的缺点,Angel实现了特征并行的训练方式。

3. 特征并行GBDT

与数据并行的训练方式不同,Angel按列切分训练数据,我们把这种分布式训练方式叫做特征并行,训练的流程如图2所示:

数据集转换: 由于原始的数据集一般是按行存储于分布式文件系统,我们读取训练数据后做全局的数据转换,每个计算节点分配一个特征子集。

建立梯度直方图: 每个计算节点使用特征子集建立梯度直方图,得益于特征并行的方式,不同计算节点为不同特征建立梯度直方图。

寻找最佳分裂点: 基于本地梯度直方图,每个计算节点计算出本地特征子集的最佳分裂点(分裂特征+分裂特征值);计算节点之间通过网络汇总得到全局的最佳分裂点。

计算分裂结果: 由于每个计算节点只负责一个特征子集,训练数据的分裂结果(左子节点/右子节点)只有一个计算节点能够确定,此计算节点讲训练数据的分裂结果(经过二进制编码)广播给其他计算节点。

分裂树节点: 根据训练数据的分裂结果,更新树结构,如果没有达到停止条件,跳转到第2步继续训练。

与数据并行相比,特征并行使得每个计算节点只需要存储一部分的梯度直方图,减少存储开销,使得可以增大分裂点数量和树深度来提升模型精度。另一方面,特征并行不需要通过网络汇总梯度直方图,在高维场景下更为高效,传输分裂结果的网络开销可以通过二进制编码来降低。

4. 运行 & 性能

输入格式ml.feature.index.range:特征向量的维度

ml.data.type:支持”libsvm”的数据格式,具体参考:Angel数据格式

参数算法参数

ml.gbdt.task.type:任务类型,分类或者回归

ml.gbdt.loss.func:代价函数,支持二分类(binary:logistic)、多分类(multi:logistic)和均方根误差(rmse)

ml.gbdt.eval.metric:模型指标,支持rmse、error、log-loss、cross-entropy、precision和auc

ml.num.class:分类数量,仅对分类任务有用

ml.gbdt.feature.sample.ratio:特征采样比例(0到1之间)

ml.gbdt.tree.num:树的数量

ml.gbdt.tree.depth:树的最大高度

ml.gbdt.split.num:每个特征的分裂点的数量

ml.learn.rate:学习速率

ml.gbdt.min.node.instance:叶子节点上数据的最少数量

ml.gbdt.min.split.gain:分裂需要的最小增益

ml.gbdt.reg.lambda:正则化系数

ml.gbdt.multi.class.strategy:多分类任务的策略,一轮一棵树(one-tree)或者一轮多棵树(multi-tree)

输入输出参数

angel.train.data.path:训练数据的输入路径

angel.validate.data.path:验证数据的输入路径

angel.predict.data.path:预测数据的输入路径

angel.predict.out.path:预测结果的保存路径

angel.save.model.path:训练完成后,模型的保存路径

angel.load.model.path:预测开始前,模型的加载路径

训练任务启动命令示例

使用spark提交任务

./spark-submit \ —master yarn-cluster \ —conf spark.ps.jars=$SONA_ANGEL_JARS \ —conf spark.ps.cores=1 \ —conf spark.ps.memory=10g \ —conf spark.ps.log.level=INFO \ —queue $queue \ —jars $SONA_SPARK_JARS \ —name “GBDT on Spark-on-Angel” \ —driver-memory 5g \ —num-executors 10 \ —executor-cores 1 \ —executor-memory 10g \ —class com.tencent.angel.spark.ml.tree.gbdt.trainer.GBDTTrainer \ spark-on-angel-mllib-${ANGEL_VERSION}.jar \ ml.gbdt.task.type:classification \ angel.train.data.path:XXX angel.validate.data.path:XXX angel.save.model.path:XXX \ ml.gbdt.loss.func:binary:logistic ml.gbdt.eval.metric:error,log-loss \ ml.learn.rate:0.1 ml.gbdt.split.num:10 ml.gbdt.tree.num:20 ml.gbdt.tree.depth:7 ml.num.class:2 \ ml.feature.index.range:47237 ml.gbdt.feature.sample.ratio:1.0 ml.gbdt.multi.class.strategy:one-tree ml.gbdt.min.node.instance:100

预测任务启动命令示例

使用spark提交任务

./spark-submit \ —master yarn-cluster \ —conf spark.ps.jars=$SONA_ANGEL_JARS \ —conf spark.ps.cores=1 \ —conf spark.ps.memory=10g \ —conf spark.ps.log.level=INFO \ —queue $queue \ —jars $SONA_SPARK_JARS \ —name “GBDT on Spark-on-Angel” \ —driver-memory 5g \ —num-executors 10 \ —executor-cores 1 \ —executor-memory 10g \ —class com.tencent.angel.spark.ml.tree.gbdt.predictor.GBDTPredictor \ spark-on-angel-mllib-${ANGEL_VERSION}.jar \ angel.load.model.path:XXX angel.predict.data.path:XXX angel.predict.out.path:XXX \

5. 性能

评测腾讯的内部的数据集来比较Angel和XGBoost的性能。

训练数据

| 数据集 | 数据集大小 | 数据数量 | 特征数量 | 任务 ||:———:|:—————:|:————:|:————:|:———-:|| UserGender | 145GB | 1.2亿 | 33万 | 二分类 |

实验环境

实验所使用的集群是腾讯的线上Gaia集群(Yarn),单台机器的配置是: *

*CPU:2680*2

*内存:256GB

*网络:10G*2

*磁盘:4T*12(SATA)

参数配置

Angel和XGBoost使用如下的参数配置:

*树的数量:20

*树的最大高度:8

*梯度直方图大小:10

*学习速度:0.1(XGboost)、0.1(Angel)

*工作节点数据:50

*每个工作节点内存:20GB

实验结果

| 系统 | 数据集 | 每棵树时间| 测试集误差 | |:———:|:—————-:|:————:|:—————:| | XGBoost| UserGender | 438s | 0.15 | | Angel | UserGender | 79s | 0.15 |

spark写出分布式的训练算法_Spark on Angel相关推荐

  1. spark写出分布式的训练算法_利用 Spark 和 scikit-learn 将你的模型训练加快 100 倍...

    在 Ibotta,我们训练了许多机器学习模型.这些模型为我们的推荐系统.搜索引擎.定价优化引擎.数据质量等提供动力.它们在与我们的移动应用程序交互时为数百万用户做出预测. 当我们使用 Spark 进行 ...

  2. spark写出分布式的训练算法_Spark0.9分布式运行MLlib的线性回归算法

    1 什么是线性回归 线性回归是另一个传统的有监督机器学习算法.在这个问题中,每个实体与一个实数值的标签 (而不是一个像在二元分类的0,1标签),和我们想要预测标签尽可能给出数值代表实体特征.MLlib ...

  3. java 不用if_Java 不用for不用if写出九九乘法表算法

    Java 不用for不用if写出九九乘法表算法代码如下: public class ss { public static void main(String[] args) { row(); } sta ...

  4. Spark 写出MySQL报错,java.sql.BatchUpdateException

    spark DataFrame 写出到MySQL时报如下错误: java.sql.BatchUpdateException: Column 'name' specified twice at sun. ...

  5. PHP面试题:请写出常见的排序算法,并用PHP实现冒泡排序,将数组$a = array()按照从小到大的方式进行排序。

    常见的排序算法: 冒泡排序法.快速排序法.简单选择排序法.堆排序法.直接插入排序法.希尔排序法.合并排序法. 冒泡排序法的基本思想是:对待排序记录关键字从后往前(逆序)进行多遍扫描,当发现相邻两个关键 ...

  6. 美国南加州大学骆沁毅:构建高性能的异构分布式训练算法

    计算机体系结构领域国际顶级会议每次往往仅录用几十篇论文,录用率在20%左右,难度极大.国内学者在顶会上开始发表论文,是最近十几年的事情. ASPLOS与HPCA是计算机体系结构领域的旗舰会议.其中AS ...

  7. TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别

    TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别 目录 输出结果 代码设计 输出结果 代码设计 ...

  8. 若S作主串,P作模式串,试分别写出利用BF算法和KMP算法的匹配过程。

    目   录 题目: 百度文库-答案: (1) (2) MOOC标准答案: (1) (2) mooc答案-截图: 数据结构(C语言版)-严蔚敏2007 题目: 设字符串S='aabaabaabaac', ...

  9. 对下图所示的连通网络G,用克鲁斯卡尔(Kruskal)算法求G的最小生成树T,请写出在算法执行过程中,依次加入T的边集TE中的边。说明该算法的基本思想及贪心策略,并简要分析算法的时间复杂度

    对下图所示的连通网络G,用克鲁斯卡尔(Kruskal)算法求G的最小生成树T,请写出在算法执行过程中,依次加入T的边集TE中的 边.说明该算法的基本思想及贪心策略,并简要分析算法的时间复杂度

  10. 写出TREE-PREDECESSOR的伪代码(算法导论第三版12.2-3)

    写出TREE-PREDECESSOR的伪代码(算法导论第三版12.2-3) TREE-PREDECESSOR(x)if x.left != NILreturn TREE-MAXIMUM(x.left) ...

最新文章

  1. 批量下载文献中的参考文献
  2. 本地连接git 服务器方式:以及git连接时报错
  3. 详解 GNU C 标准中的 typeof 关键字
  4. python爬取app中的音频_Python爬取抖音APP,只需要十行代码
  5. Jmeter之BeanShell
  6. 哈尔滨阳光计算机学院是不是黄了,黑龙江这4所野鸡大学,常被误认为是名校,实则害人不浅...
  7. Linux赋予目录或文件任何人都可以读、写、执行的操作
  8. echarts数据可视化_Golang 数据可视化利器 go-echarts 开源啦
  9. SpringBoot项目瘦身指南,大厂如何面试看出你的水平
  10. 天寒宜早睡,梦醒闻雪声,倒计时83
  11. 浅析busybox如何集成到openwrt
  12. SolidWorks 2018 安装教程
  13. PDF文件如何旋转后保存
  14. forms组件与Dango回顾
  15. STM32学习心得十八:通用定时器基本原理及相关实验代码解读
  16. matlab文件批量重命名并编号排序
  17. 文明游戏5的计算机配置,文明5和文明6哪个好玩 文明5最低电脑配置要求
  18. SQL Anywhere(ASA) 数据库“File is shorter than expected -- transaction rolled back”错误修复...
  19. 离散数学-代数系统总结3-同态
  20. java程序员 达达集团_Java课后项目 达达租车系统

热门文章

  1. 【洛谷 2504】聪明的猴子
  2. qiankun加载react子应用报错[import-html-entry] error occurs while executing normal script
  3. 百度低代码框架amis介绍及实例讲解
  4. eclipse cdt 导入c ++ 工程并建立头头文件 索引
  5. html 每一段首行缩进2字符,设置段落首行缩进2字符,html设置段落首行缩进
  6. Android 获取应用「唯一标识符」——DeviceID「兼容android 10(Q)」
  7. replacestate 后退刷新_关于如何禁止浏览器后退及刷新功能
  8. RuoYi-Vue简介
  9. pgsql依赖性追踪
  10. docker部署time machine服务