一文理解Ranking Loss/Contrastive Loss/Margin Loss/Triplet Loss/Hinge Loss
翻译自FesianXu, 2020/1/13, 原文链接 https://gombru.github.io/2019/04/03/ranking_loss/

前言

ranking loss在很多不同的领域,任务和神经网络结构(比如siamese net或者Triplet net)中被广泛地应用。其广泛应用但缺乏对其命名标准化导致了其拥有很多其他别名,比如对比损失Contrastive loss,边缘损失Margin loss,铰链损失hinge loss和我们常见的三元组损失Triplet loss等。

本文翻译自https://gombru.github.io/2019/04/03/ranking_loss/,如有谬误,请提出指正,谢谢。

∇ \nabla 联系方式:

e-mail: FesianXu@gmail.com

QQ: 973926198

github: https://github.com/FesianXu

知乎专栏: 计算机视觉/计算机图形理论与应用

微信公众号


ranking loss函数:度量学习

不像其他损失函数,比如交叉熵损失和均方差损失函数,这些损失的设计目的就是学习如何去直接地预测标签,或者回归出一个值,又或者是在给定输入的情况下预测出一组值,这是在传统的分类任务和回归任务中常用的。ranking loss的目的是去预测输入样本之间的相对距离。这个任务经常也被称之为度量学习(metric learning)。

在训练集上使用ranking loss函数是非常灵活的,我们只需要一个可以衡量数据点之间的相似度度量就可以使用这个损失函数了。这个度量可以是二值的(相似/不相似)。比如,在一个人脸验证数据集上,我们可以度量某个两张脸是否属于同一个人(相似)或者不属于同一个人(不相似)。通过使用ranking loss函数,我们可以训练一个CNN网络去对这两张脸是否属于同一个人进行推断。(当然,这个度量也可以是连续的,比如余弦相似度。)

在使用ranking loss的过程中,我们首先从这两张(或者三张,见下文)输入数据中提取出特征,并且得到其各自的嵌入表达(embedded representation,译者:见[1]中关于数据嵌入的理解)。然后,我们定义一个距离度量函数用以度量这些表达之间的相似度,比如说欧式距离。最终,我们训练这个特征提取器,以对于特定的样本对(sample pair)产生特定的相似度度量。

尽管我们并不需要关心这些表达的具体值是多少,只需要关心样本之间的距离是否足够接近或者足够远离,但是这种训练方法已经被证明是可以在不同的任务中都产生出足够强大的表征的。

ranking loss的表达式

正如我们一开始所说的,ranking loss有着很多不同的别名,但是他们的表达式却是在众多设置或者场景中都是相同的并且是简单的。我们主要针对以下两种不同的设置,进行两种类型的ranking loss的辨析

  1. 使用一对的训练数据点(即是两个一组)
  2. 使用三元组的训练数据点(即是三个数据点一组)

这两种设置都是在训练数据样本中进行距离度量比较。

成对样本的ranking loss

Fig 2.1 成对样本ranking loss用以训练人脸认证的例子。在这个设置中,CNN的权重值是共享的。我们称之为Siamese Net。成对样本ranking loss还可以在其他设置或者其他网络中使用。

在这个设置中,由训练样本中采样到的正样本和负样本组成的两种样本对作为训练输入使用。正样本对 ( x a , x p ) (x_a, x_p) (xa,xp)由两部分组成,一个锚点样本 x a x_a xa 和 另一个和之标签相同的样本 x p x_p xp ,这个样本 x p x_p xp与锚点样本在我们需要评价的度量指标上应该是相似的(经常体现在标签一样);负样本对 ( x a , x n ) (x_a,x_n) (xa,xn)由一个锚点样本 x a x_a xa和一个标签不同的样本 x n x_n xn组成, x n x_n xn在度量上应该和 x a x_a xa不同。(体现在标签不一致)

现在,我们的目标就是学习出一个特征表征,这个表征使得正样本对中的度量距离 d d d尽可能的小,而在负样本对中,这个距离应该要大于一个人为设定的超参数——阈值 m m m。成对样本的ranking loss强制样本的表征在正样本对中拥有趋向于0的度量距离,而在负样本对中,这个距离则至少大于一个阈值。用 r a , r p , r n r_a, r_p, r_n ra,rp,rn分别表示这些样本的特征表征,我们可以有以下的式子:
L = { d ( r a , r p ) 正 样 本 对 ( x a , x p ) max ⁡ ( 0 , m − d ( r a , r n ) ) 负 样 本 对 ( x a , x n ) (2.1) L = \begin{cases} \mathrm{d}(r_a, r_p) & 正样本对(x_a, x_p) \\ \max(0, m-\mathrm{d}(r_a, r_n)) & 负样本对(x_a,x_n) \end{cases} \tag{2.1} L={d(ra,rp)max(0,md(ra,rn))(xa,xp)(xa,xn)(2.1)
对于正样本对来说,这个loss随着样本对输入到网络生成的表征之间的距离的减小而减少,增大而增大,直至变成0为止。

对于负样本来说,这个loss只有在所有负样本对的元素之间的表征的距离都大于阈值 m m m的时候才能变成0。当实际负样本对的距离小于阈值的时候,这个loss就是个正值,因此网络的参数能够继续更新优化,以便产生更适合的表征。这个项的loss最大值不会超过 m m m,在 d ( r a , r n ) = 0 \mathrm{d}(r_a,r_n)=0 d(ra,rn)=0的时候取得。这里设置阈值的目的是,当某个负样本对中的表征足够好,体现在其距离足够远的时候,就没有必要在该负样本对中浪费时间去增大这个距离了,因此进一步的训练将会关注在其他更加难分别的样本对中。

假设用 r 0 , r 1 r_0,r_1 r0,r1分别表示样本对两个元素的表征, y y y是一个二值的数值,在输入的是负样本对时为0,正样本对时为1,距离 d d d是欧式距离,我们就能有最终的loss函数表达式:
L ( r 0 , r 1 , y ) = y ∣ ∣ r 0 − r 1 ∣ ∣ + ( 1 − y ) max ⁡ ( 0 , m − ∣ ∣ r 0 − r 1 ∣ ∣ ) (2.2) L(r_0,r_1,y) = y||r_0-r_1||+(1-y)\max(0,m-||r_0-r_1||) \tag{2.2} L(r0,r1,y)=yr0r1+(1y)max(0,mr0r1)(2.2)

三元组样本对的ranking loss

三元组样本对的ranking loss称之为triplet loss。在这个设置中,与二元组不同的是,输入样本对是一个从训练集中采样得到的三元组。这个三元组 ( x a , x p , x n ) (x_a,x_p,x_n) (xa,xp,xn)由一个锚点样本 x a x_a xa,一个正样本 x p x_p xp,一个负样本 x n x_n xn组成。其目标是锚点样本与负样本之间的距离 d ( r a , r n ) \mathrm{d}(r_a,r_n) d(ra,rn) 与锚点样本和正样本之间的距离 d ( r a , r p ) \mathrm{d}(r_a,r_p) d(ra,rp)之差大于一个阈值 m m m,可以表示为:
L ( r a , r p , r n ) = max ⁡ ( 0 , m + d ( r a , r p ) − d ( r a , r n ) ) (2.3) L(r_a,r_p,r_n)=\max(0,m+\mathrm{d}(r_a,r_p)-\mathrm{d}(r_a,r_n)) \tag{2.3} L(ra,rp,rn)=max(0,m+d(ra,rp)d(ra,rn))(2.3)

Fig 2.2 Triplet loss的例子,其中的CNN网络的参数是共享的。

在训练过程中,对于一个可能的三元组,我们的triplet loss可能有三种情况:

  • “简单样本”的三元组(easy triplet): d ( r a , r n ) > d ( r a , r p ) + m \mathrm{d}(r_a,r_n) > \mathrm{d}(r_a,r_p)+m d(ra,rn)>d(ra,rp)+m。在这种情况中,在嵌入空间(译者:指的是以嵌入特征作为表征的欧几里德空间,空间的每个基底都是一个特征维,一般是赋范空间)中,对比起正样本来说,负样本和锚点样本已经有足够的距离了(即是大于 m m m)。此时loss为0,网络参数将不会继续更新。
  • “难样本”的三元组(hard triplet): d ( r a , r n ) < d ( r a , r p ) \mathrm{d}(r_a,r_n) < \mathrm{d}(r_a,r_p) d(ra,rn)<d(ra,rp)。在这种情况中,负样本比起正样本,更接近锚点样本,此时loss为正值(并且比 m m m大),网络可以继续更新。
  • “半难样本”的三元组(semi-hard triplet): d ( r a , r p ) < d ( r a , r n ) < d ( r a , r p ) + m \mathrm{d}(r_a,r_p) < \mathrm{d}(r_a,r_n) < \mathrm{d}(r_a,r_p)+m d(ra,rp)<d(ra,rn)<d(ra,rp)+m。在这种情况下,负样本到锚点样本的距离比起正样本来说,虽然是大于后者,但是并没有大于设定的阈值 m m m,此时loss仍然为正值,但是小于 m m m,此时网络可以继续更新。

Fig 2.3 三元组可能的情况。

负样本的挑选

在训练中使用Triplet loss的一个重要选择就是我们需要对负样本进行挑选,称之为负样本选择(negative selection)或者三元组采集(triplet mining)。选择的策略会对训练效率和最终性能结果有着重要的影响。一个明显的策略就是:简单的三元组应该尽可能被避免采样到,因为其loss为0,对优化并没有任何帮助。

第一个可供选择的策略是离线三元组采集(offline triplet mining),这意味着在训练的一开始或者是在每个世代(epoch)之前,就得对每个三元组进行定义(也即是采样)。第二种策略是在线三元组采集(online triplet mining),这种方案意味着在训练中的每个批次(batch)中,都得对三元组进行动态地采样,这种方法经常具有更高的效率和更好的表现。

然而,最佳的负样本采集方案是高度依赖于任务特性的。但是在本篇文中不会在此深入讨论,因为本文只是对ranking loss的不同别名的综述并且讨论而已。可以参考[2]以对负样本采样进行更深的了解。

ranking loss的别名们~名儿可真多啊

ranking loss家族正如以上介绍的,在不同的应用中都有广泛应用,然而其表达式都是大同小异的,虽然他们在不同的工作中名儿并不一致,这可真让人头疼。在这里,我尝试对为什么采用不同的别名,进行解释:

  • ranking loss:这个名字来自于信息检索领域,在这个应用中,我们期望训练一个模型对项目(items)进行特定的排序。比如文件检索中,对某个检索项目的排序等。
  • Margin loss:这个名字来自于一个事实——我们介绍的这些loss都使用了边界去比较衡量样本之间的嵌入表征距离,见Fig 2.3
  • Contrastive loss:我们介绍的loss都是在计算类别不同的两个(或者多个)数据点的特征嵌入表征。这个名字经常在成对样本的ranking loss中使用。但是我从没有在以三元组为基础的工作中使用这个术语去进行表达。
  • Triplet loss:这个是在三元组采样被使用的时候,经常被使用的名字。
  • Hinge loss:也被称之为max-margin objective。通常在分类任务中训练SVM的时候使用。他有着和SVM目标相似的表达式和目的:都是一直优化直到到达预定的边界为止。

Siamese 网络和 Triplet网络

Siamese网络(Siamese Net)和Triplet网络(Triplet Net)分别是在成对样本和三元组样本 ranking loss采用的情况下训练模型。

Siamese网络

这个网络由两个相同并且共享参数的CNN网络(两个网络都有相同的参数)组成。这些网络中的每一个都处理着一个图像并且产生对于的特征表达。这两个表达之间会进行比较,并且计算他们之间的距离。然后,一个成对样本的ranking loss将会作为损失函数进行训练模型。

我们用 f ( x ) f(x) f(x)表示这个CNN网络,我们有Siamese网络的损失函数如:
L ( x 0 , x 1 , y ) = y ∣ ∣ f ( x 0 ) − f ( x 1 ) ∣ ∣ + ( 1 − y ) max ⁡ ( 0 , m − ∣ ∣ f ( x 0 ) − f ( x 1 ) ∣ ∣ ) (4.1) L(x_0,x_1,y) = y||f(x_0)-f(x_1)||+(1-y)\max(0,m-||f(x_0)-f(x_1)||) \tag{4.1} L(x0,x1,y)=yf(x0)f(x1)+(1y)max(0,mf(x0)f(x1))(4.1)

Triplet网络

这个基本上和Siamese网络的思想相似,但是损失函数采用了Triplet loss,因此整个网络有三个分支,每个分支都是一个相同的,并且共享参数的CNN网络。同样的,我们能有Triplet网络的损失函数表达为:
L ( x a , x p , x n ) = max ⁡ ( 0 , m + ∣ ∣ f ( x a ) − f ( x p ) ∣ ∣ − ∣ ∣ f ( x a ) − f ( x n ) ∣ ∣ ) (4.2) L(x_a,x_p,x_n) = \max(0, m+||f(x_a)-f(x_p)||-||f(x_a)-f(x_n)||) \tag{4.2} L(xa,xp,xn)=max(0,m+f(xa)f(xp)f(xa)f(xn))(4.2)

在多模态检索中使用ranking loss

根据我的研究,在涉及到图片和文本的多模态检索任务中,使用了Triplet ranking loss。训练数据由若干有着相应文本标注的图片组成。任务目的是学习出一个特征空间,模型尝试将图片特征和相对应的文本特征都嵌入到这个特征空间中,使得可以将彼此的特征用于在跨模态检索任务中(举个例子,检索任务可以是给定了图片,去检索出相对应的文字描述,那么既然在这个特征空间里面文本和图片的特征都是相近的,体现在距离近上,那么就可以直接将图片特征作为文本特征啦~当然实际情况没有那么简单)。为了实现这个,我们首先从孤立的文本语料库中,学习到文本嵌入信息(word embeddings),可以使用如同Word2Vec或者GloVe之类的算法实现。随后,我们针对性地训练一个CNN网络,用于在与文本信息的同一个特征空间中,嵌入图片特征信息。

具体来说,实现这个的第一种方法可以是:使用交叉熵损失,训练一个CNN去直接从图片中预测其对应的文本嵌入。结果还不错,但是使用Triplet ranking loss能有更好的结果。

使用Triplet ranking loss的设置如下:我们使用已经学习好了文本嵌入(比如GloVe模型),我们只是需要学习出图片表达。因此锚点样本 a a a是一个图片,正样本 p p p是一个图片对应的文本嵌入,负样本 n n n是其他无关图片样本的对应的文本嵌入。为了选择文本嵌入的负样本,我们探索了不同的在线负样本采集策略。在多模态检索这个问题上使用三元组样本采集而不是成对样本采集,显得更加合乎情理,因为我们可以不建立显式的类别区分(比如没有label信息)就可以达到目的。在给定了不同的图片后,我们可能会有需要简单三元组样本,但是我们必须留意与难样本的采样,因为采集到的难负样本有可能对于当前的锚点样本,也是成立的(虽然标签的确不同,但是可能很相似。)

在该实验设置中,我们只训练了图像特征表征。对于第 i i i个图片样本,我们用 f ( i ) f(i) f(i)表示这个CNN网络提取出的图像表征,然后用 t p , t n t_p,t_n tp,tn分别表示正文本样本和负文本样本的GloVe嵌入特征表达,我们有:
L ( i , t p , t n ) = max ⁡ ( 0 , m + ∣ ∣ f ( i ) − t p ∣ ∣ − ∣ ∣ f ( i ) − t n ∣ ∣ ) (5.1) L(i, t_p, t_n) = \max(0, m+||f(i)-t_p||-||f(i)-t_n||)\tag{5.1} L(i,tp,tn)=max(0,m+f(i)tpf(i)tn)(5.1)
在这种实验设置下,我们对比了triplet ranking loss和交叉熵损失的一些实验的量化结果。我不打算在此对实验细节写过多的笔墨,其实验细节设置和[4,5]一样。基本来说,我们对文本输入进行了一定的查询,输出是对应的图像。当我们在社交网络数据上进行半监督学习的时候,我们对通过文本检索得到的图片进行某种形式的评估。采用了Triplet ranking loss的结果远比采用交叉熵损失的结果好。

深度学习框架中的ranking loss层

Caffe

  • Constrastive loss layer
  • pycaffe triplet ranking loss layer

PyTorch

  • CosineEmbeddingLoss
  • MarginRankingLoss
  • TripletMarginLoss

TensorFlow

  • contrastive_loss
  • triplet_semihard_loss

Reference

[1]. https://blog.csdn.net/LoseInVain/article/details/88373506

[2]. https://omoindrot.github.io/triplet-loss

[3]. https://github.com/adambielski/siamese-triplet

[4]. https://arxiv.org/abs/1901.02004

[5]. https://gombru.github.io/2018/08/01/learning_from_web_data/

一文理解Ranking Loss/Contrastive Loss/Margin Loss/Triplet Loss/Hinge Loss相关推荐

  1. 一文理解Ranking Loss/Margin Loss/Triplet Loss

    点击蓝字  关注我们 作者丨土豆@知乎 来源丨https://zhuanlan.zhihu.com/p/158853633 本文已获授权,未经作者许可,不得二次转载. 前言 Ranking loss在 ...

  2. hinge loss的一种实现方法

    hinge loss的一种实现方法 FesianXu 20220820 at Baidu Search Team hinge loss是一种常用损失[1],常用于度量学习和表征学习.对于一个模型 y ...

  3. 【机器学习】SVM支持向量机在手写体数据集上进行二分类、采⽤ hinge loss 和 cross-entropy loss 的线性分类模型分析和对比、网格搜索

    2022Fall 机器学习 1. 实验要求 考虑两种不同的核函数:i) 线性核函数; ii) ⾼斯核函数 可以直接调⽤现成 SVM 软件包来实现 ⼿动实现采⽤ hinge loss 和 cross-e ...

  4. Hinge Loss

    Hinge Loss 作者:陈雕 链接:https://www.zhihu.com/question/47746939/answer/286432586 来源:知乎 著作权归作者所有.商业转载请联系作 ...

  5. 机器学习中的各种损失函数(Hinge loss,交叉熵,softmax)

    机器学习中的各种损失函数 SVM multiclass loss(Hinge loss) 这是一个合页函数,也叫Hinge function,loss 函数反映的是我们对于当前分类结果的不满意程度.在 ...

  6. 【论文理解】ArcFace: Additive Angular Margin Loss for Deep Face Recognition(InsightFace)

    论文地址:https://arxiv.org/abs/1801.07698 github:https://github.com/deepinsight/insightface 这篇论文基本介绍了近期较 ...

  7. 理解Hinge Loss (折页损失函数、铰链损失函数)

    理解Hinge Loss (折页损失函数.铰链损失函数) 原文:https://blog.csdn.net/fendegao/article/details/79968994 Hinge Loss 是 ...

  8. 从loss的硬截断、软化到Focal Loss

    对于二分类模型,我们总希望模型能够给正样本输出1,负样本输出0,但限于模型的拟合能力等问题,一般来说做不到这一点.而事实上在预测中,我们也是认为大于0.5的就是正样本了,小于0.5的就是负样本.这样就 ...

  9. SoftTriple Loss: Deep Metric Learning Without Triplet Sampling

    1. Abstract 距离度量学习(DML)用于学习嵌入(特征提取),其中来自同一类别的示例比来自不同类别的示例更接近. 可以将其转换为具有三元约束的优化问题. 由于存在大量三元组约束,因此DML的 ...

最新文章

  1. VTK:参数化超椭球用法实战
  2. Ajax跨域提交JSON和JSONP
  3. React-引领未来的用户界面开发框架-读书笔记(五)
  4. Android char数据类型乱码��解决方法
  5. sequelize 增加数据库字段_Node项目使用Sequelize操作数据库(一)(包括模型,增,删、改等)...
  6. Oracle join多表查询
  7. python查函数用法语句_Python-17 (函数的基本使用)
  8. 谈谈C#中的事件注册和注销
  9. 软件测试中测试用例的简单案例
  10. Linux驱动加载总结
  11. DevCpp 如何进行调试
  12. 2020年10款网盘大比拼,总有一款适合你
  13. 在mybatis中怎么书写工具类,也就是创建一个sqlsessionFactory
  14. 解决git fatal:无法找到‘https‘的远程助手
  15. android wifi智能硬件4g,智能硬件 篇五:把WiFi带在身上,告别宽带!华为随行WiFi2畅享版真实体验...
  16. 物理机安装esxi系统
  17. MMC子系统之SDIO卡驱动
  18. 第4章 虚拟机性能监控与故障处理工具
  19. 电脑一个磁盘分为两个磁盘
  20. 无尽对决一直显示正在连接服务器,无尽对决服务器连接不上 | 手游网游页游攻略大全...

热门文章

  1. Jupyter Notebook中%time和%timeit 的使用
  2. 小学计算机应用到英语课教案,利用信息技术提升小学英语课堂教学
  3. VMware与USB3.0不解情缘(此文献给win7系统只有usb3.0口死活与虚拟机连不上的朋友们)
  4. char数组和char指针数组赋值的区别
  5. ubuntu 由于没有公钥,无法验证下列签名
  6. CTO养成记(一)CTO理论知识恶补
  7. RK3588+AI+FPGA图像处理硬件算法加速解决方案
  8. win10如何查看系统盘分区表类型
  9. 1/20 武功秘籍~!
  10. java 判断是几位数_Java 快速判断一个int值是几位数