点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”


作者:Keshav G

编译:ronghuaiyang

导读

这是对比损失函数的一种变体,不再是使用绝对距离,还要考虑batch中其他样本对的整体距离分布来对损失进行加权,大家可以试试。

度量学习的目的是学习一个嵌入空间,在这个空间中,相似样本的嵌入向量被拉近,而不同样本的嵌入向量被推远。Multi Similarity Loss提出了一种直观的更好的方法来实现这一目标,并在公共基准数据集上得到了验证。本文的主要贡献有两个方面:a)在混合算法中引入多重相似性,b)困难样本对挖掘。

多重相似度损失

这种损失涉及携带配对信息的三种类型的相似性。

1. 自相似度:

x1 = anchor, x2 = positive, x3,x4 = negatives

自相似性确保属于正类的实例距离anchor的距离比属于负类的实例距离锚的距离更近。

Sᵢₖ= 样本对的余弦相似度,λ = 相似度margin, α,β = 超参数

MS-Loss包括两个部分:

i) 正样本部分

x1 = Anchor, x2,x3 = positives, λ = margin

这部分只讨论正样本对。λ表示相似度的margin,控制了正样本对的紧密程度,对那些相似度<λ的正样本对进行惩罚。在上面的图中我们可以看到两个样本对(x1, x2)和(x1,x3),正样本对(x1, x2)的损失很低,因为,由于超参数α总是大于零,这一项的值相比(x1,x3)是非常小的。对于(x1,x3)这一对的损失为








α
















α





ii) 负样本部分:

x1 = anchor, x2,x3 = negatives, λ = margin

这部分只处理负样本对,这部分损失确保负样本与anchor的相似性尽可能低。这意味着靠近x1的负样本(即具有高相似性)应该比远离x1的负样本(即具有较低的相似性)受到更大的惩罚。这从损失中是很明显的,损失(x1, x2)为,而损失x1-x3为

2. 负样本相对相似度

在MS损失中分配给负样本对的权值,这是由MS损失对单个样本对的导数推导出来的。

样本对权重wᵢⱼ被定义为这个样本对的损失相对于总损失的贡献

MS-loss只考虑了一个负样本对x1-x2,不仅根据x1-x2之间的自相似度,而且根据其相对相似度,即批处理中存在的所有其他对x1的负样本来分配权重。在上面的式子中, Sᵢⱼ指(x1, x2)的自相似度,Sᵢₖ指x1-x3,x1-x4, x1-x5, x1-x6 x1-x7之间的相似度。

在上图中虽然x1-x2在所有的case中具有相同的Sᵢⱼ,但是其权重wᵢⱼ在不同的case中是不一样的。相同相似年代ᵢⱼ在所有情况下,wᵢⱼ因情况而异。

  • Case 1: 所有其他的负样本相对于x2都距离x1更远。

  • Case 2: 所有的其他负样本相对于x1的距离和x2一样。

  • Case 3: 所有其他的负样本相对于x1的距离比x2更近。

在三个case中,wᵢⱼ的区别是分母项











β















ᵢⱼ






,其中Sᵢₖ= x1-x3,x1-x4, x1-x5 x1-x6 x1-x7之间的余弦相似度,Sᵢⱼ=x-x2之间的余弦相似度。

  • Case 1: wᵢⱼ最大,因为











    β















    ᵢⱼ






    最小,Sᵢₖ<Sᵢⱼ使得指数是个负数。

  • Case 2: wᵢⱼ中等,因为











    β















    ᵢⱼ






    中,指数是0。is in middle, since in denominator term Σ[e^(β(Sᵢₖ- Sᵢⱼ))], Sᵢₖ≃ Sᵢⱼ making it e^(zero-ish term).

  • Case 3: wᵢⱼ最小,因为











    β















    ᵢⱼ






    最大,Sᵢₖ>Sᵢⱼ,使得指数是整数。

3. 正样本相对相似度

在MS损失中分配给正样本对的权重,这是由MS损失相对于单个正样本对的导数推导出来的。

负样本相似度表示单个负样本对与batch中所有其他负样本对的关系。类似地,正样本相对相似度定义了batch中单个正样本对(x1-x2)与所有其它所有正样本对(x1-x3、x1-x4、x1-x5、x1-x6)之间的关系。按照我们在负样本相对相似度下所做的步骤,我们可以很容易地验证上图所示的结果。

困难正负样本的挖掘

多重相似度损失论文的作者在训练中只使用了困难的负样本和正样本,并丢弃了所有其他的样本对,因为它们对效果的提升几乎没有贡献,有时也降低了性能。只选择那些携带最多信息的对也会使算法的计算速度更快。

A = anchor, P = positives, N = negatives

i) 困难负样本挖掘

上面的式子表明只有那些与anchor点相似度大于正样本点最小相似度的负样本才应该包含在训练中。因此,在上面的图表中,我们所选择的是红色的负样本,因为它们都在与anchor的相似性最小的正样本的内部,其余的负样本都被丢弃。

ii) 困难正样本挖掘

上面的式子表明,只有那些与anchor点相似度小于具有最大相似度(最接近anchor点)的负样本点的正样本点才应该被包括在训练中。困难的正样本被涂成蓝色,其余的则被丢弃。

代码理解

class MultiSimilarityLoss(nn.Module):def __init__(self, cfg):super(MultiSimilarityLoss, self).__init__()self.thresh = 0.5self.margin = 0.1self.scale_pos = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POSself.scale_neg = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEGdef forward(self, feats, labels):# feats = features extracted from backbone model for images# labels = ground truth classes corresponding to imagesbatch_size = feats.size(0)sim_mat = torch.matmul(feats, torch.t(feats))         # since feats are l2 normalized vectors, taking
its dot product with transpose of itself will yield a similarity matrix whose i,j (row and column) will correspond to similarity between i'th embedding and j'th embedding of the batch, dim of sim mat = batch_size * batch_size. zeroth row of this matrix correspond to similarity between zeroth embedding of the batch with all other embeddings in the batch.epsilon = 1e-5loss = list()for i in range(batch_size): # i'th embedding is the anchorpos_pair_ = sim_mat[i][labels == labels[i]] # get all positive pair simply by matching ground truth labels of those embedding which share the same label with anchorpos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] # remove the pair which calculates similarity of anchor with itself i.e the pair with similarity one.neg_pair_ = sim_mat[i][labels != labels[i]] # get all negative embeddings which doesn't share the same ground truth label with the anchorneg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]  # mine hard negatives using the method described in the blog, a margin of 0.1 is added to the neg pair similarity to fetch negatives which are just lying on the brink of boundary for hard negative which would have been missed if this term was not present.pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]# mine hard positives using the method described in the blog with a margin of 0.1.if len(neg_pair) < 1 or len(pos_pair) < 1:continue# continue calculating the loss only if both hard pos and hard neg are present.# weighting steppos_loss = 1.0 / self.scale_pos * torch.log(1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))neg_loss = 1.0 / self.scale_neg * torch.log(1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))# losses as described in the equationloss.append(pos_loss + neg_loss)if len(loss) == 0:return torch.zeros([], requires_grad=True)loss = sum(loss) / batch_sizereturn loss

论文 : http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf

代码: https://github.com/MalongTech/research-ms-loss/blob/master/ret_benchmark/losses/multi_similarity_loss.py

后台回复“MSLoss”获取打包好的论文和代码

—END—

英文原文:https://medium.com/@kshavgupta47/multi-similarity-loss-for-deep-metric-learning-ad194691e2d3

请长按或扫描二维码关注本公众号

喜欢的话,请给我个在看吧

Muti-Similarity Loss:考虑了batch中整体距离分布的对比损失函数相关推荐

  1. CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss、Center Loss)简介、使用方法之详细攻略

    CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss.Center Loss)简介.使用方法之详细攻略 目录 T1.Triplet Loss 1.英文原文解释 ...

  2. 一条消息是如何按照二进制协议写入到Batch中的ByteBuffer中的

    double-check模式中会将消息放到Batch中: //将消息放到batch中去 FutureRecordMetadata future = Utils.notNull(batch.tryApp ...

  3. 基于double-check模式尝试将消息放到batch中

    多线程并发获取Deque,但只会有一个人会new.然后就会对获取到的Deque加锁,并尝试将消息放入到获取到的Deque的已有的batch中: //对分区队列加锁后,尝试将消息放到队列中的已有的bat ...

  4. Spring Batch中的块处理

    大数据集的处理是软件世界中最重要的问题之一. Spring Batch是一个轻量级且强大的批处理框架,用于处理数据集. Spring Batch Framework提供了"面向Tasklet ...

  5. Spring Batch中面向TaskletStep的处理

    许多企业应用程序需要批处理才能每天处理数十亿笔交易. 必须处理这些大事务集,而不会出现性能问题. Spring Batch是一个轻量级且强大的批处理框架,用于处理这些大数据集. Spring Batc ...

  6. Protel 2004 的PCB中整体修改元件的技巧

    Protel 2004 的PCB中整体修改元件的技巧   在Protel 2004中统一修改元件的PCB封装类型 (1)在电路原理图中找到其中一个需要修改元件属性的元件.如选择一个电容元件,在该电容元 ...

  7. ppt如何旋转流程图_PPT中smartart图形怎么旋转?PPT中整体旋转或翻转smartart图形的方法介绍...

    powerpoint中smartart图形怎么做整体旋转?SmartArt图形在PPT中用途很大,并且也有很大的作用,但如果使用的比较频繁,就会发现一个问题,如果要进行整体旋转或者是翻转,那这样要怎么 ...

  8. 介绍Spring Batch 中Tasklet 和 Chunks

    介绍Spring Batch 中Tasklet 和 Chunks Spring Batch 提供了两种不同方式实现job: tasklet 和 chunk.本文通过实例实践两种方法. 示例需求说明 给 ...

  9. Seesaw Loss:一种面向长尾目标检测的平衡损失函数

    本文转自知乎,已获作者授权转载,请勿二次转载. 链接:https://zhuanlan.zhihu.com/p/339126633 前言 这篇论文是对 MMDet 团队参加今年7月份举办的 LVIS ...

最新文章

  1. requests(二): json请求中固定键名顺序消除键和值之间的空格
  2. CentOS7 上安装 Zookeeper-3.4.9 服务
  3. 英语口语-文章朗读Week8 Friday
  4. 分数小数互换图_五年级数学分数和小数的互换(15悬赏)
  5. mysql limitorderby
  6. 论《LEFT JOIN条件放ON和WHERE后的区别》
  7. 如何禁用 macOS 更新通知?
  8. 无线安全审计工具FruityWifi初体验
  9. PHP数据结构之实现链式二叉树与遍历
  10. Javashop 7.0 增加小程序支付(二次开发)
  11. 那点房事(难以齿口)
  12. 思科、华为、H3C交换机命名规则全收录
  13. 三网融合融什么融,团购网站团什么团【涂雅速涂】
  14. cdrx8如何批量导出jpg_cdr超级伴侣批量导图v8.0 免费版
  15. Angular2 模型渲染的一个坑
  16. 【BZOJ4864】[BeiJing 2017 Wc]神秘物质 Splay
  17. matlab谐波电流测量,基于MATLAB的谐波电流检测方法的建模与仿真
  18. 常州SEO姜东:怎么找出自己网站的所有连接?
  19. zha男/女的三种境界
  20. (附源码)基于Python的“书怡”在线书店系统的设计与实现 毕业设计082332

热门文章

  1. SRS之SrsConfig类
  2. Android内购+IM
  3. 从你的全世界路过 经典语录
  4. Android 构建简单app 步骤
  5. java实现 腾讯人机验证 + 前端
  6. 在linux下恶臭hellotext中作用的?,《Linux内核与程序设计》实验学习笔记
  7. C语言实现英寸单位与厘米的转换(两种方法)特简单!!!
  8. lt;一 SAP ABAP 将数字转换成本地语言(中文、英文)大写
  9. [深度学习]动手学深度学习笔记-11
  10. unity中Loding.UpdatePreloading占用CPU过高如何解决?