1 前言

1.1 Learning to Rank 简介

Learning to Rank (LTR) , 也被叫做排序学习, 是搜索中的重要技术, 其目的是根据候选文档和查询语句的相关性对候选文档进行排序, 或者选取topk文档. 比如在搜索引擎中, 需要根据用户问题选取最相关的搜索结果展示到首页. 下图是搜索引擎的搜索结果

1.2 LTR算法分类

根据损失函数可把LTR分为三种: 1. Pointwise, 该类型算法将LTR任务作为回归任务来训练, 即尝试训练一个为文档和查询语句的打分器, 然后根据打分进行排序. 2. Pairwise, 该类型算法的损失函数考虑了两个候选文档, 学习目标是把相关性高的文档排在前面, triplet loss 就属于Pairwise, 它的损失函数是$$ loss = max(0, score_{neg}-score_{pos}+margin)$$, 可以看出该损失函数一次考虑两个候选文档. 3. Listwise, 该类型算法的损失函数会考虑多个候选文档, 这是本文的重点, 下面会详细介绍.

1.3 本文主要内容

本文主要介绍了本人在学习研究过程中发明的一种新的Listwise损失函数, 以及该损失函数的使用效果. 如果读者对LTR任务及其算法还不够熟悉, 建议先去学习LTR相关知识, 同时本人博客自然语言处理中的负样本挖掘 (分类与排序任务中如何选择负样本) 也和本文关系较大, 可以先进行阅读.

2 预备知识

2.1 数学符号定义

$q$代表用户搜索问题, 比如"如何成为宇航员", $D$代表候选文档集合,$d^+$代表和$q$相关的文档,$d^-$代表和$q$不相关的文档, $d^+_i$代表第$i$个和$q$相关的文档, LTR的目标就是根据$q$找到最相关的文档$d$

2.2 学习目标

本次学习目标是训练一个打分器 scorer, 它可以衡量q和d的相关性, scorer(q, d)就是相关性分数,分值越大越相关. 当前主流方法下, scorer一般选用深度神经网络模型.

2.3训练数据分类

损失函数不同, 构造训练数据的方法也会不同:

-Pointwise, 可以构造回归数据集, 相关的数据设为1, 不相关设为0.
-Pairwise, 可构造triplet类型的数据集, 形如($q,d^+, d^-$) -Listwise, 可构造这种类型的训练集: ($q,d^+1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-{n+m}$), 一个正例还是多个正例也会影响到损失函数的构造, 本文提出的损失函数是针对多正例多负例的情况.

3 基于均值不等式的Listwise损失函数

3.1 损失函数推导过程

在上一小结我们可以知道,训练集是如下形式 ($q,d^+1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-{n+m}$), 对于一个q, 有m个相关的文档和n个不相关的文档, 那么我们一共可以获取m+n个分值:$(score_1,score_2,...,score_n,...,score_{n+m})$, 我们希望打分器对相关文档打分趋近于正无穷, 对不相关文档打分趋近于负无穷.

对m+n个分值做一个softmax得到$p_1,p_2,...,p_n,...,p_{n+m}$, 此时$p_i$可以看作是第i个候选文档与q相关的概率, 显然我们希望$p_1,p_2,...,p_m$越大越好, $p_{n+1},...,p_{m+n}$越小越好, 即趋近于0. 因此我们暂时的优化目标是$sum_{i=1}^{n}{p_i} rightarrow 1$.

但是这个优化目标是不合理的, 假设$p_1=1$, 其他值全为0, 虽然满足了上面的要求, 但这并不是我们想要的. 因为我们不仅希望$sum_{i=1}^{n}{p_i} rightarrow 1$, 还希望相关候选文档的每一个p值都要足够大, 即我们希望m个候选文档都与q相关的概率是最大的, 所以我们真正的优化目标是: $$max(prod_{i=1}^{n}{p_i} ) , sum_{i=1}^{n}{p_i} = 1$$

当前情况下, 损失函数已经可以通过代码实现了, 但是我们还可以做一些化简工作, $prod_{i=1}^{n}{p_i}$是存在最大值的, 根据均值不等式可得: $$prod_{i=1}^{n}{p_i} leq (frac{sum_{i=1}^{n}{p_i}}{n})^n$$

对两边取对数: $$sum_{i=1}^{n}{log(p_i)} leq -nlog(n)$$

这样是不是感觉清爽多了, 然后我们把它转换成损失函数的形式: $$ loss = -nlog(n) - sum_{i=1}^{n}{log(p_i)}$$

所以我们的训练目标就是$min{(loss)}$

3.2 使用pytorch实现该损失函数

在获取到最终的损失函数后, 我们还需要用代码来实现, 实现代码如下:

# A simple example for my listwise loss function
# Assuming that n=3, m=4
# In[1]
# scores
scores = torch.tensor([[3,4.3,5.3,0.5,0.25,0.25,1]])
print(scores)
print(scores.shape)
'''
tensor([[0.3000, 0.3000, 0.3000, 0.0250, 0.0250, 0.0250, 0.0250]])
torch.Size([1, 7])
'''
# In[2]
# log softmax
log_prob = torch.nn.functional.log_softmax(scores,dim=1)
print(log_prob)
'''
tensor([[-2.7073, -1.4073, -0.4073, -5.2073, -5.4573, -5.4573, -4.7073]])
'''
# In[3]
# compute loss
n = 3.
mask = torch.tensor([[1,1,1,0,0,0,0]]) # number of 1 is n
loss = -1*n*torch.log(torch.tensor([[n]])) - torch.sum(log_prob*mask,dim=1,keepdim=True)
print(loss)
loss = loss.mean()
print(loss)
'''
tensor([[1.2261]])
tensor(1.2261)
'''

该示例代码仅展现了batch_size为1的情况, 在batch_size大于1时, 每一条数据都有不同的m和n, 为了能一起送入模型计算分值, 需要灵活的使用mask. 本人在实际使用该损失函数时,一共使用了两种mask, 分别mask每条数据所有候选文档和每条数据的相关文档, 供大家参考使用.

3.3 效果评估和使用经验

由于评测数据使用的是内部数据, 代码和数据都无法公开, 因此只能对使用效果做简单总结: 1. 效果优于PointwisePairwise, 但差距不是特别大 2. 相比Pairwise收敛速度极快, 训练一轮基本就可以达到最佳效果

下面是个人使用经验: 1. 该损失函数比较占用显存, 实际的batch_size是batch_size*(m+n), 建议显存在12G以上 2. 负例数量越多,效果越好, 收敛也越快 3. 用pytorch实现log_softmax时, 不要自己实现, 直接使用torch中的log_softmax函数, 它的效率更高些. 4. 只有一个正例, 还可以考虑转为分类问题,使用交叉熵做优化, 效果同样较好

### 4 总结 该损失函数还是比较简单的, 只需要简单的数学知识就可以自行推导, 在实际使用中也取得了较好的效果, 希望也能够帮助到大家. 如果大家有更好的做法欢迎告诉我.

文章可以转载, 但请注明出处:

  • 本人简书社区主页
  • 本人博客园社区主页
  • 本人知乎主页
  • 本人Medium社区主页

k均值的损失函数_一种基于均值不等式的Listwise损失函数相关推荐

  1. 肺结节目标检测_一种基于CT图像的肺结节检测方法及系统与流程

    本发明属于医学图像分析和计算机辅助诊断等技术领域,更具体地,涉及一种基于CT图像的肺结节检测方法及系统. 背景技术: 肺癌是导致患癌死亡的最危险的疾病之一,其发病率占所有癌症的三分之二,且5年存活率为 ...

  2. 基于信息熵确立权重的topsis法_一种基于改进多目标粒子群算法的受端电网储能优化配置方法与流程...

    本发明涉及受端电网中储能的规划问题,具体涉及一种基于改进多目标粒子群算法的受端电网储能优化配置方法. 背景技术: 随着煤炭等非可再生.高污染的能源总量日益减少,我国的电能结构正由火力发电向低碳化的清洁 ...

  3. 心电图计算心率公式_一种基于心电信号的心率计算方法与流程

    本发明涉及医学电子信息领域,具体涉及一种基于心电信号的心率计算方法. 背景技术: 心电图是临床最常用的检查之一,应用广泛,包括帮助诊断心律失常.心肌缺血.心肌梗死等.心电图记录的是随心动周期变化的体表 ...

  4. python随机森林筛选变量_一种基于随机森林的改进特征筛选算法

    刘云翔 陈斌 周子宜 摘  要: 肝癌是一种我国高发的消化系统恶性肿瘤,患者死亡率高,威胁极大.而其预后情况通常只能通过医生的专业知识和经验积累来粗略判断,准确率较差.因此文中在分析随机森林算法的基本 ...

  5. adc0832对光电二极管进行数据采集_一种基于光电二极管的麦克风跟踪检测电路的制作方法...

    本实用新型涉及一种基于光电二极管的麦克风跟踪检测电路,属于应用电子技术领域. 背景技术: 随着互联网的发展,语音交互应用正在日益变多,近几年视频直播.网络直播.K歌软件都发展得很快,也推高了麦克风的销 ...

  6. aes子密钥生成c语言_一种基于流密码算法的子密钥生成方法与流程

    本发明涉及一种用于分组加解密算法的子密钥的生成方法. 背景技术: 随着信息技术的发展,信息安全性的问题却愈来愈显得突出,保证信息安全的一个重要技术就是密码学.密码学在信息安全技术中扮演着基础的角色,是 ...

  7. dncnn图像去噪_一种基于DnCNNs改进的图像降噪方法与流程

    本发明涉及图像处理技术领域,具体涉及一种基于dncnns改进的图像降噪方法. 背景技术: 随着科技进步,新的图像技术在逐渐推广,在日常生活中人们对于图像的要求也越来越高,针对阴天或夜晚等弱光条件下拍摄 ...

  8. 度量相似性数学建模_一种基于粒子群位置更新思想灰狼优化算法的K-Means文本分类方法与流程...

    技术特征: 1.一种基于粒子群位置更新思想灰狼优化算法的k-means文本分类方法,其特征在于:包括以下步骤: s1:对文本数据进行预处理,得到预处理后文本数据: s2:采用余弦角度为相似性度量,分别 ...

  9. orb特征 稠密特征_一种基于ORB-SLAM2的双目三维稠密建图方法技术

    本发明专利技术公开了一种基于ORB‑SLAM2的双目稠密建图方法,涉及机器人同步定位与地图创建领域,该方法主要由跟踪线程.局部地图线程.闭环检测线程和稠密建图线程组成.其中稠密建图线程包含以下步骤:1 ...

最新文章

  1. web框架总结(django、flask)
  2. 基础数学:通俗解释,啥叫随机变量?
  3. java在程序中加入音频_在任意Java程序中播放音频
  4. tracert 路由跟踪程序
  5. Boost:bind绑定的回归测试
  6. 【详细注释】1051 Pop Sequence (25 分)
  7. 仅仅有人物没背景的图片怎么弄_五分钟写作课 人物篇 人物的出场是个关键时刻...
  8. vim 寄存器中的 ^@,^M,^J
  9. fping安装包linux,Linux安装fping和hping
  10. delete语句与reference约束冲突怎么解决_一条简单的更新语句,MySQL是如何加锁的?...
  11. Gaze Estimation学习笔记(1)-Appearance-Based Gaze Estimation in the Wild
  12. Python基础-基本语法
  13. 8051单片机的C语言程序设计
  14. CMYK与RGB参数转换公式及转换方法
  15. ringbuffer java例子_Java RingBuffer.publish方法代碼示例
  16. mac自带邮件设置QQ企业邮箱发邮件
  17. 华硕主板的网络唤醒(Wake-on-LAN)
  18. 《一次与IP MTU、TCP MSS导致SSL协商失败的案例》—那些年踩过的坑(二)
  19. Latex添加中文支持和A4纸张设置
  20. git reflog 恢复已删除分支

热门文章

  1. MySQL分页查询小技巧
  2. vue登录如何存储cookie_vue保持用户登录状态(各种token存储方式)
  3. spring mvc 接收页面数据
  4. sqlserver note
  5. link引入html5,CSS引入方式 | link和@import的区别 — 生僻的前端考点
  6. 【剑指offer - C++/Java】7、斐波那契数列
  7. 【OS学习笔记】三十八 保护模式十:中断和异常的处理与抢占式多任务对应的汇编代码----微型内核汇代码
  8. WPF--TextBlock的ToolTip附加属性
  9. oracle数据库查看用户相关语句
  10. 作为前端应当了解的Web缓存知识