目录

前言

ListNet

Methodology

Probability Construction

ListMLE

Methodology

Probability Construction


前言

在之前的专栏中,我们介绍过RankNet,LambdaRank以及LambdaMART,这些方法都是pair-wise的方法,也就是说它们考虑的是两两之间的排序损失。在本次专栏中,我们要介绍的两种方法是list-wise排序损失,它们是考虑每个query对应的所有items的整体排序损失。在实现过程中,你可能会发现ListNet与传统的分类任务中常用的BCE Loss非常相似,事实上确实如此,但是它们也存在一些差异,我们最后会详细说明。

ListNet

在之前的专栏中,我们介绍过RankNet系列算法,它们是pair-wise的方法。无论是pair-wise还是point-wise,都是将每个item独立看待,忽视了整体的关系。对于每一个query,我们要做的是对其所有的items按照相关性进行排序,要考虑整体的结果,这就是ListNet的主要动机。

Methodology

对于一个query,假设与其对应的items有个,即,对应的相关性得分为,相关性得分是真实的标签,往往是人工标注好的相关性等级,这里我们设置为:相关性得分或等级越大,代表item与当前query越相关。不妨用表示打分系统,比如一个神经网络,那么对于当前query我们得到的打分结果为:。通过最小化模型打分与真实打分之间的“误差”,可以得到一个训练好的打分模型,这里的误差可以通过多种方式衡量,我们下面详细介绍ListNet是如何计算的。为了表述方便,记为真实相关性标签,为预测结果,假设有个query,那么ListNet的损失函数可以构造为:

上标代表样本索引,代表ListNet的在单个query上的损失。

Probability Construction

作者将预测得到的结果和真实标签映射为一个概率分布,这样通过最小化两个概率分布之间的差异,就可以使得预测结果更接近真实结果。文中提出了两种概率构造方法:

Permutation Probability

按照上文介绍的,我们有个items,并且每个item与query的相关性我们是已知的,用映射函数来将得分映射成概率。用表示的一个排列,表示这个排列中第个位置是哪一个item。依据以上定义,在得到个items的得分之后,我们可以计算items任意的排列对应的概率:

这个表达式看起来有些复杂,但是很容易理解,例如,对于3个items来讲,不妨假设得分为,对于排列和排列,它们的概率分别为:

利用相同的方法,可以计算出所有的predict的排列概率,同理,也可以计算出所有groundtruth的排列概率,两个概率分布之间可以利用KL散度最小化差异,从而完成训练。但是这样做的代价太高,长度为的列表,它的全排列有种方式,是无法计算的。为了解决这个问题,作者提出了Top1概率,并给出了一些优异的性质。

Top One Probability

Top1概率指的是,某个item排在第一位的概率。对于索引为的item,它的Top1概率就是:所有以为第一个元素的排列的概率之和。表达式可以构造如下:

上述定义有一个非常好的性质,使得我们可以直接计算Top1概率而不需要计算所有排列的概率,这极大减少了计算量:

上式中的代表的是item的得分,这样,无论是对于真实标签还是预测结果,我们都可以计算每个item的Top1概率。ListNet的优化目标就是使得:预测结果中每个item的Top1概率与真实结果中每个item的Top1概率尽量接近。从而实现预测的排序结果与真实的排序结果更加相似。

损失函数与BCE Loss相似,不同的是每个维度上的数值代表的含义不同:分类任务中的score代表的是输入属于当前维度对应类别的概率;ListNet中的score代表的是当前维度对应的item的Top1概率。如果选用指数函数来作为,那么ListNet的损失可以构造如下:

import torch.nn.functional as Fdef listnet_loss(predict, target):# predict : batch x n_items# target : batch x n_itemstop1_target = F.softmax(target, dim=0)top1_predict = F.softmax(predict, dim=0)return torch.mean(-torch.sum(top1_target * torch.log(top1_predict)))

除了使用上面的交叉熵作为ListNet的损失函数之外,还可以直接使用KL散度或者是JS散度。它们的目的都是为了使得各个items预测的结果与真实的结果更加接近。不妨用表示预测结果的分布和真实结果的分布,那么JS散度可以用以下表达式计算:

其中,代表KL散度,其计算方式如下:

在ListNet中,我们使用softmax计算出Top1概率,预测的Top1概率列表可以和真实的Top1概率列表计算JS散度作为损失函数:

import torch.nn.functional as Fdef kld(p, q):# p : batch x n_items# q : batch x n_itemsreturn (p * torch.log2(p / q)).sum()def jsd(predict, target):# predict : batch x n_items# target : batch x n_itemstop1_true = F.softmax(target, dim=0)top1_pred = F.softmax(predict, dim=0)jsd = 0.5 * kld(top1_true, top1_pred) + 0.5 * kld(top1_pred, top1_true)return jsd

ListMLE

经过对ListNet的介绍,我们可以看出list-wise算法与point-wise以及pair-wise的最大区别就是,list-wise以优化整体的排序结果为目标,而不是仅仅关注绝对打分或者是两两之间的排序结果,从而大多数时候,list-wise方法能够得到相对更好的效果。

Methodology

ListMLE的思路非常容易理解。在ListNet中,我们是在最小化预测顺序与真实顺序之间的差异,为了实现这个目的,我们构造了概率分布,然后最小化了两个概率分布的差异。这里,一个更加直接的方法就是,我们以真实标签顺序为目标,最大化预测结果排序与目标一致的概率即可。也就是说,我们只需要定义出预测结果按照目标顺序来排列的概率就可以了,然后直接使用负对数来优化就可以了。

Probability Construction

在ListNet中介绍了Permutation Probability与Top One Probability,这里,我们再介绍一个非常经典的概率分布模型:Plackett-Luce模型,该模型在Learning to Rank中有着非常广泛的应用。

Plackett-Luce

不妨用来表示目标排序结果,用来表示模型对个items的预测得分。将预测得分按照真实标签的顺序排列,注意,这里是按照目标顺序排列,而不是按照自己本身的大小排序,从而得到:,这个排序的概率可以构造如下:

这就是Plackett-Luce模型构造概率的方法。我们只要最大化概率就可以实现让预测结果的排序尽可能接近真实结果,从儿直接构造负对数损失即可:

根据上面的公式,我们可以很容易的得到ListMLE的损失函数代码,这里给出torch的版本:

import torchdef list_mle(y_pred, y_true, k=None):# y_pred : batch x n_items# y_true : batch x n_items if k is not None:sublist_indices = (y_pred.shape[1] * torch.rand(size=k)).long()y_pred = y_pred[:, sublist_indices] y_true = y_true[:, sublist_indices] _, indices = y_true.sort(descending=True, dim=-1)pred_sorted_by_true = y_pred.gather(dim=1, index=indices)cumsums = pred_sorted_by_true.exp().flip(dims=[1]).cumsum(dim=1).flip(dims=[1])listmle_loss = torch.log(cumsums + 1e-10) - pred_sorted_by_truereturn listmle_loss.sum(dim=1).mean()

再上述代码中,因为取完exp再取log后就恢复原始数值了,因此直接转化成了减去预测值。另外,为了增加训练的随机性和鲁棒性,我们可以只计算一个长度为k的子列表对应的损失,期望子列表排序正确,这在样本量很少时可以增加算法的鲁棒性,如果不提供k,那么就默认计算整个列表的ListMLE损失。

Learning to Rank : ListNet与ListMLE相关推荐

  1. Learning to Rank 中Listwise关于ListNet算法讲解及实现

     [学习排序] Learning to Rank 中Listwise关于ListNet算法讲解及实现             版权声明:本文为博主原创文章,转载请注明CSDN博客源地址!共同学习, ...

  2. Learning to Rank 中Listwise关于ListNet算法讲授及实现

     Learning to Rank 中Listwise关于ListNet算法讲授及实现 前一篇文章"Learning to Rank中Pointwise关于PRank算法源码实现&quo ...

  3. 【学习排序】 Learning to Rank 中Listwise关于ListNet算法讲解及实现

    前一篇文章"Learning to Rank中Pointwise关于PRank算法源码实现"讲述了基于点的学习排序PRank算法的实现.该篇文章主要讲述Listwise Appro ...

  4. Learning to Rank简介

    机器学习有三大问题,分类.回归和排序.分类和回归之前了解了很多的算法,但排序还没有深入的了解过. Learning to Rank有很多种典型的应用.包括: document retrieval ex ...

  5. Learning to Rank:X-wise

    LTR(Learning to Rank)学习排序已经被广泛应用到文本挖掘.搜索推荐系统的很多领域,比如IR中排序返回的相似文档,推荐系统中的候选产品召回.用户排序等,机器翻译中排序候选翻译结果等等. ...

  6. lightGBM用于排序(Learning to Rank )

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx Learning to Rank 简介 去年实习时,因为项目需要,接触了一下Learning ...

  7. 学习排序 Learning to Rank:从 pointwise 和 pairwise 到 listwise,经典模型与优缺点

    Ranking 是信息检索领域的基本问题,也是搜索引擎背后的重要组成模块.本文将对结合机器学习的 ranking 技术--learning2rank--做个系统整理,包括 pointwise.pair ...

  8. Learning to rank 小结

    1.现有排序模型 排序(Ranking)一直是信息检索的核心研究问题,有大量的成熟的方法,主要可以分为以下两类:相关度排序模型和重要性排序模型. 1.1 相关度排序模型(Relevance Ranki ...

  9. Learning to Rank 简介

    非常好的一篇总结Learning to Rank的总结文章! 转载自:http://www.cnblogs.com/kemaswill/archive/2013/06/01/3109497.html ...

最新文章

  1. nginx主配置文件 在那找怎么打开
  2. 【MATLAB】矩阵操作 ( 矩阵下标 | 矩阵下标排列规则 )
  3. 【12】行为型-观察者模式
  4. Spring IoC、AOP、Transaction、MVC 归纳小结
  5. HashMap 怎么 hash?又如何 map?
  6. 解方程 2014NOIP提高组 (数学)
  7. 11 单线程+多任务异步协程 爬虫
  8. java连接zookeeper 找不到zoo.cfg_ZooInspector 连接不到 Zookeeper 的解决方法
  9. ubuntu14操作系统chrome标签和书签乱码解决
  10. 安装和启动tftp-server服务器及可能出现Redirecting to /bin/systemctl restart xinetd.service问题的解决方式...
  11. python批量读取landsat8的波段
  12. 电子商务B2C之未来-刘爽
  13. Vue2组件通信方式
  14. 八皇后问题(回溯算法)
  15. rime 简体中文 linux,Rime (简体中文)
  16. POJ 3290 WFF 'N PROOF 英文少
  17. 学大伟业 Day 6 培训总结
  18. 前端调用高德地图 百度地图
  19. pip安装报错: unable to creat process using ‘“‘的解决方法
  20. 画E-R图·数据库笔记(四)

热门文章

  1. 使用MyEclipse格式化JSP设置
  2. sap erp发展史
  3. 不同Vlan之间的PC相互通信(二)
  4. c# printdialog 打印html,c# – ReportViewer.PrintDialog()在打印到Adobe PDF时抛出异常
  5. 可以加载本地图片和网络资源的轮播图:TuTu
  6. 当前最流行的报表工具
  7. 温度补偿计算公式_温度补偿,matlab 计算
  8. 没有基础,能学PHP开发吗?
  9. SEM优化教程第1讲—SEM是什么?SEM与SEO的区别?
  10. 微信qq邮箱提醒 服务器繁忙,用qq邮箱找回微信密码为什么一直提示服务器繁忙???...