Gradient Boosting Neural Networks: GrowNet,Preprint, 2021

  • 文章亮点
  • 模型结构
  • 原理
    • 对于Regression task
    • 对于Classification task
    • 对于Learning to task 任务
  • 模型优化方法
  • 资源
  • Refernces

文章亮点

1.借助Gradient boosting的技巧,利用浅层的network来增量式的搭建复杂的网络模型GrowNet。所提网络模型GrowNet可以处理各种机器学习任务(分类、回归等)。
2.提出相应的训练算法来更快、更容易的训练GrowNet。具体包括:利用二阶梯度信息更新网络参数以及全局的纠正步骤。

模型结构


如上图所示,利用浅层网络以gradient boosting的方式增量式地搭建网络: 使用浅层神经网络(例如,含有一个或两个隐藏层)作为弱学习者。我们使用当前迭代倒数第二层的输出来扩充原始输入。然后,通过boosting机制使用当前的残差,将扩充后的特征集(图中的虚线)作为输入来训练下一个弱学习者。该模型的最终输出是所有这些顺序训练模型的加权组合。

原理

假设模型在t−1t-1t−1 step对xix_{i}xi​的输出为:

然后在第ttt step基于贪婪策略通过最小化以下损失来训练第ttt个network—f(t)f(t)f(t):

Gradient Boosting 的技巧就是:将负梯度作为残差的近似,然后训练网络来拟合残差的近似。论文中采用了二阶梯度信息来拟合梯度,并采用MSE loss来计算损失,因此上式可以简化如下:

其中y~i=−gi/hi\tilde y_{i}=-g_{i}/h_{i}y~​i​=−gi​/hi​ , 其中gig_{i}gi​和hih_{i}hi​分别为loss关于ft−1(x)f_{t-1}(x)ft−1​(x)的的一阶和二阶偏导数, 具体形式取决于任务的类型:

对于Regression task

对于Classification task

对于Learning to task 任务

对于给定的的查询, 成对损失paired loss定义如下:

对应的梯度计算如下:

相应的损失和一阶、二阶偏导数计算如下:

注: 一般的Neural network模型model architecture都是fixed/ predefined, 训练是在model的参数空间不断迭代进行寻优(即梯度下降)以最小化Target/Loss function, 而结合了Gradient boosting的Neural network模型model architecture是增量式扩展的not fixed, 训练是在函数空间进行Gradient descent, 具体来说,就是第k step,通过重新训练一个NN (前k-1个NN fixed)来拟合Loss function关于当前model 预测的负梯度。 特别的, 当Loss function为MSE时,该负梯度实际上就是当前model预测与ground truth target之间的残差

模型优化方法


如上所示, 在构建第kkk个network时, 共包含两个steps: (1)独立训练第k个network: individual model training; (2)全局的校正步骤Corrective step;
第1个步骤: 可以看到在训练第k个network时,目标值−gi/hi-g_{i}/h_{i}−gi​/hi​为负二阶梯度,即残差的近似。更新f(t)f(t)f(t)时,前t−1t-1t−1个模型都是固定的, 可以看做为feature extractor, 然后独立地训练f(t)f(t)f(t)。这个可以看做是局部的更新模型
第2个步骤:将这k个network看做一个整体,然后利用原始数据xt,ytx_{t}, y_{t}xt​,yt​通过反向传播来更新所有的网络参数。同时更新步长αk(即每个network对应的权重)\alpha_{k}(即每个network对应的权重)αk​(即每个network对应的权重)。

资源

1.官方代码, https://github.com/sbadirli/GrowNet
2.实际上在这个之前已经有很多工作将Gradient boosting技巧与Deep learning 相结合的工作:
例如:
(1)NIPS-2016-Incremental Boosting Convolutional Neural Networkfor Facial Action Unit Recognition
(2)理论工作, ICML-2017-AdaNet_ Adaptive Structural Learningof Artificial Neural Networks
(3)ICLR-2021-Boost then convolution-Gradient boosting meets graphs neural networks

Refernces

1.Gradient Boosting Neural Networks: GrowNet, Preprint, 2021

论文评析-Gradient Boosting Neural Networks: GrowNet,Preprint, 2021和Gradient boosting原理介绍相关推荐

  1. 论文阅读2018-Deep Convolutional Neural Networks for breast cancer screening 重点:利用迁移学习三个网络常规化进行分类

    论文阅读2018-Deep Convolutional Neural Networks for breast cancer screening 摘要:我们探讨了迁移学习的重要性,并通过实验确定了在训练 ...

  2. G1D7-云计算与虚拟化技术pagerank算法作图GNN@LAB0Intriguing properties of neural networks算法美亚2021个人赛ATP论文@TT

    一.虚拟化技术与云计算 上数据挖掘课,觉得好玩,查一查 https://www.zhihu.com/question/22793847 二.pagerank算法 在做gnn的lab,复习一下~看一下a ...

  3. 对抗样本论文学习:Deep Neural Networks are Easily Fooled

    近日看了一些对抗样本(adversarial examples)方面的论文,在这里对这些论文进行一下整理和总结. 以下仅代表个人理解,本人能力有限难免有错,还请大家给予纠正,不胜感激.欢迎一起讨论进步 ...

  4. 【读点论文】CMT: Convolutional Neural Networks Meet Vision Transformers

    CMT: Convolutional Neural Networks Meet Vision Transformers Abstract 视觉transformer已经成功地应用于图像识别任务,因为它 ...

  5. 论文分享 MetaBalance: High-Performance Neural Networks for Class-Imbalanced Data

    摘要 类不平衡数据,其中一些类包含比其他类多得多的样本,在现实世界的应用程序中无处不在.处理类不平衡的标准技术通常通过对重新加权损失或重新平衡数据进行训练来工作. 不幸的是,针对此类目标训练过度参数化 ...

  6. 经典DL论文研读(part3)--Improving neural networks by preventing co-adaptation of feature detectors

    学习笔记,仅供参考,有错必纠 文章目录 Improving neural networks by preventing co-adaptation of feature detectors Abstr ...

  7. 【论文阅读】Deep Neural Networks for Learning Graph Representations | day14,15

    <Deep Neural Networks for Learning Graph Representations>- (AAAI-16)-2016 文章目录 一.模型 1.1解决了两个问题 ...

  8. [论文阅读笔记]Deep Neural Networks are Easily Fooled:High Confidence Predictions for Unrecognizable Images

    Deep Neural Networks are Easily Fooled:High Confidence Predictions for Unrecognizable Images(CVPR201 ...

  9. 论文笔记:Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering

    前言 初代频域GCN简单粗暴的将diag(g^(λl))diag(\hat{g}{(\lambda_l)})diag(g^​(λl​))变成了卷积核diag(θl)diag(\theta_l)diag ...

  10. 论文阅读:Recurrent Neural Networks for Time Series Forecasting Current Status and Future Directions

    typora-copy-images-to: ./ Recurrent Neural Networks for Time Series Forecasting: Current Status and ...

最新文章

  1. C#--多线程--2
  2. 计算机等级考试属于什么培训,计算机等级是什么
  3. Angular - 如何在页面加载后马上做初始化
  4. vue node --- 前后端联系的知识梳理
  5. Repeater片段
  6. Django之 RESTful规范
  7. Oracle数据库安装图文操作步骤
  8. 位运算解决二进制位上不同数字的个数问题
  9. 使用Google zxing生成二维码
  10. 管理新语:主管不要当传声筒,要检查、核实
  11. php中的魔术常量__FILE__
  12. k3导入账套_金蝶K3财务操作流程
  13. 在使用Assimp库时编译器报错:C2589 “(”:“::”右边的非法标记 AssimpLoadStl
  14. 群控云控SDK开发包(快速开发群控云控微信SCRM客服系统)
  15. 小麦亩产一千八(kela)
  16. php七牛云,php七牛云
  17. 东南大学计算机学院学办董烨,东南大学计算机教学实验中心(国家级)
  18. cuda FORTRAN 统一内存 managed
  19. Android开发中WIFI和GPRS网络的切换
  20. Linux 软链接——ln命令详解

热门文章

  1. 洛谷P1880 石子合并(区间DP)(环形DP)
  2. oracle常用命令(比较常见好用)
  3. ORA-12514: TNS:listener does not currently know of service …
  4. c语言学习-猜数字游戏
  5. Ajax 读取.ashx 返回404
  6. (转)ASP.NET程序中常用代码汇总
  7. Multipart生成的临时文件
  8. make files touse cmd line to protect exe
  9. DOS命令taskkill
  10. java安全管理器视频_安全管理器 (Security Manager)