LambdaMART简介——基于Ranklib源码(二 Regression Tree训练)

上一节中介绍了 λ λ 的计算,lambdaMART就以计算的每个doc的 λ λ 值作为label,训练Regression Tree,并在最后对叶子节点上的样本 lambda lambda 均值还原成 γ γ ,乘以learningRate加到此前的Regression Trees上,更新score,重新对query下的doc按score排序,再次计算deltaNDCG以及 λ λ ,如此迭代下去直至树的数目达到参数设定或者在validation集上不再持续变好(一般实践来说不在模型训练时设置validation集合,因为validation集合一般比训练集合小很多,很容易收敛,达不到效果,不如训练时一步到位,然后另起test集合做结果评估)。

其实Regression Tree的训练很简单,最主要的就是决定如何分裂节点。lambdaMART采用最朴素的最小二乘法,也就是最小化平方误差和来分裂节点:即对于某个选定的feature,选定一个值val,所有<=val的样本分到左子节点,>val的分到右子节点。然后分别对左右两个节点计算平方误差和,并加在一起作为这次分裂的代价。遍历所有feature以及所有可能的分裂点val(每个feature按值排序,每个不同的值都是可能的分裂点),在这些分裂中找到代价最小的。

举个栗子,假设样本只有上一节中计算出 λ λ 的那10个:

 1 qId=1830 features and lambdas
 2 qId=1830    1:0.003 2:0.000 3:0.000 4:0.000 5:0.003 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(1):-0.495
 3 qId=1830    1:0.026 2:0.125 3:0.000 4:0.000 5:0.027 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(2):-0.206
 4 qId=1830    1:0.001 2:0.000 3:0.000 4:0.000 5:0.001 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(3):-0.104
 5 qId=1830    1:0.189 2:0.375 3:0.333 4:1.000 5:0.196 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(4):0.231
 6 qId=1830    1:0.078 2:0.500 3:0.667 4:0.000 5:0.086 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(5):0.231
 7 qId=1830    1:0.075 2:0.125 3:0.333 4:0.000 5:0.078 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(6):-0.033
 8 qId=1830    1:0.079 2:0.250 3:0.667 4:0.000 5:0.085 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(7):0.240
 9 qId=1830    1:0.148 2:0.000 3:0.000 4:0.000 5:0.148 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(8):0.247
10 qId=1830    1:0.059 2:0.000 3:0.000 4:0.000 5:0.059 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(9):-0.051
11 qId=1830    1:0.071 2:0.125 3:0.333 4:0.000 5:0.074 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(10):-0.061

上表中除了第一列是qId,最后一列是lambda外,其余都是feature,比如我们选择feature(1)的0.059做分裂点,则左子节点<=0.059的doc有: 1, 2, 3, 9;而>0.059的被安排到右子节点,doc有4, 5, 6, 7, 8, 10。由此左右两个子节点的lambda均值分别为:

λ L  ¯ =λ 1 +λ 2 +λ 3 +λ 9 4 =−0.495−0.206−0.104−0.0514 =−0.214 λL¯=λ1+λ2+λ3+λ94=−0.495−0.206−0.104−0.0514=−0.214

λ R  ¯ =λ 4 +λ 5 +λ 6 +λ 7 +λ 8 +λ 10 6 =0.231+0.231−0.033+0.240+0.247−0.0616 =0.143 λR¯=λ4+λ5+λ6+λ7+λ8+λ106=0.231+0.231−0.033+0.240+0.247−0.0616=0.143

继续计算左右子节点的平方误差和:

s L =∑ i∈L (λ i −λ L  ¯ ) 2 =(−0.495+0.214) 2 +(−0.206+0.214) 2 +(−0.104+0.214) 2 +(−0.051+0.214) 2 =0.118 sL=∑i∈L(λi−λL¯)2=(−0.495+0.214)2+(−0.206+0.214)2+(−0.104+0.214)2+(−0.051+0.214)2=0.118

s R =∑ i∈R (λ i −λ R  ¯ ) 2 =(0.231−0.143) 2 +(0.231−0.143) 2 +(−0.033−0.143) 2 +(0.240−0.143) 2 +(0.247−0.143) 2 +(0.016−0.143) 2 =0.083 sR=∑i∈R(λi−λR¯)2=(0.231−0.143)2+(0.231−0.143)2+(−0.033−0.143)2+(0.240−0.143)2+(0.247−0.143)2+(0.016−0.143)2=0.083

因此将feature(1)的0.059的均方差(分裂代价)是:

Cost 0.059@feature(1) =s L +s R =0.118+0.083=0.201 Cost0.059@feature(1)=sL+sR=0.118+0.083=0.201

我们可以像上面那样遍历所有feature的不同值,尝试分裂,计算Cost,最终选择所有可能分裂中最小Cost的那一个作为分裂点。然后将 s L  sL 和 s R  sR 分别作为左右子节点的属性存储起来,并把分裂的样本也分别存储到左右子节点中,然后维护一个队列,始终按平方误差和 s 降序插入新分裂出的节点,每次从该队列头部拿出一个节点(并基于这个节点上的样本)进行分裂(即最大均方差优先分裂),直到树的分裂次数达到参数设定(训练时传入的leaf值,叶子节点的个数与分裂次数等价)。这样我们就训练出了一棵Regression Tree。

上面讲述了一棵树的标准分裂过程,需要多提一点的是,树的分裂还有一个参数设定:叶子节点上的最少样本数,比如我们设定为3,则在feature(1)处,0.001和0.003两个值都不能作为分裂点,因为用它们做分裂点,左子树的样本数分别是1和2,均<3。叶子节点的最少样本数越小,模型则拟合得越好,当然也容易过拟合(over-fitting);反之如果设置得越大,模型则可能欠拟合(under-fitting),实践中可以使用cross validation的办法来寻找最佳的参数设定。

LambdaMART简介——基于Ranklib源码(二 Regression Tree训练)相关推荐

  1. LambdaMART简介——基于Ranklib源码(一 lambda计算)

     LambdaMART简介--基于Ranklib源码(一 lambda计算) 时间:2014-08-09 21:01:49      阅读:168      评论:0      收藏:0      ...

  2. java lambdamart库,LambdaMART简介——基于Ranklib源码(一 lambda计算)

    学习Machine Learning,阅读文献,看各种数学公式的推导,其实是一件很枯燥的事情.有的时候即使理解了数学推导过程,也仍然会一知半解,离自己写程序实现,似乎还有一道鸿沟.所幸的是,现在很多主 ...

  3. Java_io体系之BufferedWriter、BufferedReader简介、走进源码及示例——16

    Java_io体系之BufferedWriter.BufferedReader简介.走进源码及示例--16 一:BufferedWriter 1.类功能简介: BufferedWriter.缓存字符输 ...

  4. Java_io体系之RandomAccessFile简介、走进源码及示例——20

    Java_io体系之RandomAccessFile简介.走进源码及示例--20 RandomAccessFile 1.       类功能简介: 文件随机访问流.关心几个特点: 1.他实现的接口不再 ...

  5. Java_io体系之CharArrayReader、CharArrayWriter简介、走进源码及示例——13

    转载自   Java_io体系之CharArrayReader.CharArrayWriter简介.走进源码及示例--13 一:CharArrayReader 1.类功能简介: 字符数组输入流car  ...

  6. 基于Vue源码中e2e测试实践

    您好,如果喜欢我的文章,可以关注我的公众号「量子前端」,将不定期关注推送前端好文~ 基于Vue源码中e2e测试实践 前言 技术选型&对Vue的参考 Puppeteer测试流程 在Concis中 ...

  7. 基于Pytorch源码对SGD、momentum、Nesterov学习

    目前神经网络的监督学习过程通常为: 数据加载(load)进神经网络 经过网络参数对数据的计算,得出预测值(predict) 根据预测值与标注值(label)之间的差距,产生损失(loss) 通过反向传 ...

  8. spring cloud alibaba中台架构源码二次开发+系统集成、集中式应用权限管理

    基于Spring Cloud Alibaba 分布式微服务高并发数据平台化(中台)思想+多租户saas设计的企业开发架构,支持源码二次开发.支持其他业务系统集成.集中式应用权限管理.支持拓展其他任意子 ...

  9. 丁威: 优秀程序员必备技能之如何高效阅读源码(二更)

    @[toc](丁威: 优秀程序员必备技能之如何高效阅读源码(二更)) 消息中间件 我能熟练使用这个框架/软件/技术就行了, 为什么要看源码?" "平时不用看源码, 看源码太费时间, ...

最新文章

  1. 企业级虚拟机管理——虚拟机安装自动化
  2. android 电话拦截短信验证码,全警出击,只为寻找一位正在买新手机的女士……...
  3. 【转】ARM汇编伪指令介绍
  4. 中值滤波原理及c语言的实现,关于中值滤波算法,以及C语言实现(转)
  5. 华为od python_华为运维开发-华为OD工资待遇怎么样 - 华为技术有限公司 - 职友集...
  6. matlab经验分布函数 教程,经验分布函数.ppt
  7. celery异步发送邮箱
  8. 计算机中解决不匹配,电脑显示屏显示不匹配.怎么办
  9. 怒怼外媒,为中国正名,这个《流浪地球》捧红的犹太小哥太励志了
  10. 人民日报+俞敏洪:教育好自己的孩子,是你最重要的事业(两文)
  11. android模拟器安装
  12. 【MQTT服务器】(一)EMQX平台搭建
  13. Excel使用---excel2016___一般操作(搬,侵删)
  14. html table边框加粗,table加边框记录
  15. fla文件中切记不能使用TLF文本
  16. 基于JAVA二次元文化网站计算机毕业设计源码+系统+lw文档+部署
  17. 贼好用的六款 Linux 远程连接工具介绍
  18. 节拍脉冲发生器的设计
  19. Linux - UAC USB声卡
  20. 最小安装CentOS 7.6 Linux系统(无UI界面纯命令行,虚拟机教学)

热门文章

  1. android 默认光标大小设置,如何默认光标位置设置的EditText
  2. flutter怎么手动刷新_flutter局部刷新的实现示例
  3. php动态引入js文件路径问题,JavaScript_动态加载外部css或js文件,原理解析:第一步:使用dom创 - phpStudy...
  4. java使用getinputstream_java解析数据接口获取json对象
  5. 同济大学计算机学院徐老师,第十八届同济大学程序设计竞赛暨高校网络友谊赛圆满落幕...
  6. linux驱动由浅入深系列链接
  7. anaconda 运行路径
  8. python如何调用文件进行换位加密_python 换位密码算法的实例详解
  9. 32. Leetcode 141. 环形链表 (链表-双指针-环形链表)
  10. 24. Leetcode 61. 旋转链表 (链表-基础操作类-旋转链表)