文章目录

  • 写在前面
  • 一、Info Noise-contrastive estimation(Info NCE)
    • 1.1 描述
    • 1.2 实现
  • 二、HCL
    • 2.1 描述
    • 2.2 实现
  • 三、文字解释
  • 四、代码解释
    • 4.1 Info NCE
    • 4.2 HCL

写在前面

  最近在基于对比学习做实验,github有许多实现,虽然直接套用即可,但是细看之下,损失函数部分甚是疑惑,故学习并记录于此。关于对比学习的内容网络上已经有很多内容了,因此不再赘述。本文重在对InfoNCE的两种实现方式的记录。

一、Info Noise-contrastive estimation(Info NCE)

1.1 描述

  InfoNCE在MoCo中被描述为:
Lq=−log⁡exp⁡(q⋅k+/τ)∑i=0Kexp⁡(q⋅ki/τ)(1)\mathcal{L}_{q}=-\log \frac{\exp \left(q \cdot k_{+} / \tau\right)}{\sum_{i=0}^{K} \exp \left(q \cdot k_{i} / \tau\right)} \tag{1}Lq​=−log∑i=0K​exp(q⋅ki​/τ)exp(q⋅k+​/τ)​(1)
其中τ\tauτ是超参。

  • 分子表示:qqq对k+k_+k+​的点积。所谓点积就是描述qqq和k+k_+k+​两个向量之间的距离。
  • 分母表示:qqq对所有kkk的点积。所谓所有就是指正例(positive sample)和负例(negative sample),所以求和号是从i=0i=0i=0到KKK,一共K+1K+1K+1项。

1.2 实现

  MoCo源码的\moco\builder.py中,实现如下:

 # compute logits# Einstein sum is more intuitive# positive logits: Nx1l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)# negative logits: NxKl_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1)# apply temperaturelogits /= self.T# labels: positive key indicatorslabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()...return logits, labels

这里的变量logits的意义我也查了一下:是未进入softmax的概率

这段代码根据注释即可理解:l_pos表示正样本的得分,l_neg表示所有负样本的得分,logits表示将正样本和负样本在列上cat起来之后的值。值得关注的是,labels的数值,是根据logits.shape[0]的大小生成的一组zero。也就是大小为batch_size的一组0。

  接下来看损失函数部分,\main_moco.py

 # define loss function (criterion) and optimizercriterion = nn.CrossEntropyLoss().cuda(args.gpu)...# compute outputoutput, target = model(im_q=images[0], im_k=images[1])loss = criterion(output, target)

这里直接对输出的logits和生成的labels计算交叉熵,然后就是模型的loss。这里就是让我不是很理解的地方。先将疑惑埋在心里~

二、HCL

2.1 描述

  在文章《Contrastive Learning with Hard Negative Samples》中描述到,使用负样本的损失函数为:
Ex∼p,x+∼px+[−log⁡ef(x)Tf(x+)ef(x)Tf(x+)+QN∑i=1Nef(x)Tf(xi−)](2)\mathbb{E}_{x \sim p, x^{+} \sim p_{x}^{+}}\left[-\log \frac{e^{f(x)^{T} f\left(x^{+}\right)}}{e^{f(x)^{T} f\left(x^{+}\right)}+\frac{Q}{N} \sum_{i=1}^{N} e^{f(x)^{T} f\left(x_{i}^{-}\right)}}\right] \tag{2}Ex∼p,x+∼px+​​[−logef(x)Tf(x+)+NQ​∑i=1N​ef(x)Tf(xi−​)ef(x)Tf(x+)​](2)

  • 分子:ef(x)Tf(x+)e^{f(x)^{T} f(x^{+})}ef(x)Tf(x+)表示学到的表示f(x)f(x)f(x)和正样本f(x+)f(x^+)f(x+)的点积。(其实也就是正样本的得分)
  • 分母:第一项表示正样本的得分,第二项表示负样本的得分。

其实本质上适合InfoNCE一个道理,都是mean(-log(正样本的得分/所有样本的得分))

2.2 实现

  但是在这篇文章的实现中,\image\main.py

def criterion(out_1,out_2,tau_plus,batch_size,beta, estimator):# neg scoreout = torch.cat([out_1, out_2], dim=0)neg = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)old_neg = neg.clone()mask = get_negative_mask(batch_size).to(device)neg = neg.masked_select(mask).view(2 * batch_size, -1)# pos scorepos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)pos = torch.cat([pos, pos], dim=0)# negative samples similarity scoringif estimator=='hard':N = batch_size * 2 - 2imp = (beta* neg.log()).exp()reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)Ng = (-tau_plus * N * pos + reweight_neg) / (1 - tau_plus)# constrain (optional)Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))elif estimator=='easy':Ng = neg.sum(dim=-1)else:raise Exception('Invalid estimator selected. Please use any of [hard, easy]')# contrastive lossloss = (- torch.log(pos / (pos + Ng) )).mean()return loss

可以看到最后计算loss的公式是:

 loss = (- torch.log(pos / (pos + Ng) )).mean()

的确与我上文中的理解相同,可是为什么这样的实现,没有用到全0的label呢?

三、文字解释

  既然是同一种方法的两种实现,已经理解了第二种实现(HCL)。那么,问题就出在了:不理解第一种实现的label为何要这样生成? 于是乎,查看交叉熵的计算方式:
loss(x,class)=−log⁡(exp⁡(x[class])∑jexp⁡(x[j]))=−x[class]+log⁡(∑jexp⁡(x[j]))(3)\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)= -x[class] + \log\left(\sum_j \exp(x[j])\right) \tag{3}loss(x,class)=−log(∑j​exp(x[j])exp(x[class])​)=−x[class]+log(j∑​exp(x[j]))(3)

交叉熵的label的作用是:将label作为索引,来取得xxx中的项(x[class]x[class]x[class]),因此,这些项就是label。而倘若label是全0的项,那么其含义为:xxx中的第一列为label(正样本),其他列就是负样本。然后带入公式(3)中计算,即可得到交叉熵下的loss值。

  而对于HCL的实现方式,是直接将InfoNCE拆解开来,使用正样本的得分和负样本的得分来计算。

四、代码解释

  首先,生成pos得分和neg的得分:

注意,这里省略了生成的特征,直接生成了得分,

4.1 Info NCE

4.2 HCL


嗒哒~两者的结果“一模一样”(取值范围导致最后一位不太一样)

对比学习(Contrastive Learning)中的损失函数相关推荐

  1. 从对比学习(Contrastive Learning)到对比聚类(Contrastive Clustering)

    从对比学习(Contrastive Learning)到对比聚类(Contrastive Clustering) 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailug ...

  2. 对比学习Contrastive Learning

    对比学习是一种常用的自监督学习方法. 核心思想:把正样本距离拉近,正样本与负样本距离拉远.(类似度量学习中的margin, 但是对比学习为正负样本分类,无margin概念) 方式:通过一个正样本,以及 ...

  3. 理解对比表示学习(Contrastive Learning)

    目录 一.前言 二.对比学习 三.主要论文(附代码分析) 1. AMDIM ([Bachman](https://arxiv.org/pdf/1906.00910.pdf) *et al.* 2019 ...

  4. 从ACL2021看对比学习在NLP中的应用

    本文首发于微信公众号"夕小瑶的卖萌屋" 文 | 花小花Posy 源 | 夕小瑶的卖萌屋 最近关注对比学习,所以ACL21的论文列表出来后,小花就搜罗了一波,好奇NLPers们都用对 ...

  5. ICCV2021 比MoCo更通用的对比学习范式,中科大MSRA提出对比学习新方法MaskCo

    关注公众号,发现CV技术之美 今日分享 ICCV2021论文『Self-Supervised Visual Representations Learning by Contrastive Mask P ...

  6. 强化学习(Reinforcement Learning)中的Q-Learning、DQN,面试看这篇就够了!

    文章目录 1. 什么是强化学习 2. 强化学习模型 2.1 打折的未来奖励 2.2 Q-Learning算法 2.3 Deep Q Learning(DQN) 2.3.1 神经网络的作用 2.3.2 ...

  7. 理解对比学习(contrasive learning)

    1.什么是对比学习? 对比学习,顾名思义就是在训练中和某些东西进行对比从而学习,在自编码器中,输出与自己进行对比,从而得到一个中间量latent code,我认为这也是一种对比学习. 2.对比学习框架 ...

  8. 图对比学习入门 Contrastive Learning on Graph

    对比学习作为近两年的深度学习界的一大宠儿,受到了广大研究人员的青睐.而图学习因为图可以用于描述生活中广泛出现的非欧式数据,具有广大的应用前景.当图学习遇上了对比学习- 本文从对比学习入手,再介绍图对比 ...

  9. MoCo 动量对比学习——一种维护超大负样本训练的框架

    MoCo 动量对比学习--一种维护超大负样本训练的框架 FesianXu 20210803 at Baidu Search Team 前言 在拥有着海量数据的大型互联网公司中,对比学习变得逐渐流行起来 ...

  10. 对比学习顶会论文系列-3-2

    文章目录 一.特定任务中的对比学习 1.2 摘要生成中的对比学习--SimCLS: A Simple Framework for Contrastive Learning of Abstractive ...

最新文章

  1. android 中间按钮突出,Android 实现 按钮从两边移到中间动画效果
  2. C#学习基本概念之属性使用
  3. os.path.join
  4. 安装 | cmd(命令提示符)窗口下使用conda安装TensorFlow
  5. [.Net 多线程处理系列专题七——对多线程的补充
  6. Linux 网络相关命令
  7. 算法提高 邮票面值设计 搜索 动态规划
  8. 风控必知必会|两大逻辑表五大基础报表
  9. SpringBoot | 第十章:Swagger2的集成和使用
  10. 远程医疗监护系统开发
  11. 每日英语Daily English
  12. 使用Pr编辑字幕-快闪效果
  13. 019 Linux tcpdump 抓包案例入门可真简单啊?
  14. c语言小学生入门自学,小学生C语言编程入门书.pdf
  15. 全国大学生信息安全竞赛writeup--暗号(reverse300)
  16. OPENCV提取图片中的文字
  17. 百度地图让用户“私人定制“:一场语音定制背后的AI能力强势输出
  18. 数学的故事之“共轭”
  19. c语言花卉销售与管理系统,网上花卉销售和管理系统毕业设计论文.doc
  20. error LNK1123: 转换到 COFF 期间失败

热门文章

  1. js下载文件方法记录
  2. Jmeter 压力测试、并发测试、弱网测试
  3. Python 神器!自动识别文字中的省市区并绘图
  4. html5圆圈闪烁,html5 css3圆形波浪百分比加载动画特效
  5. 直击美团“远程面试”现场,面试官竟反问:你真懂数据库事务吗?
  6. 获取SVG文件中use标签xlink:href的值
  7. 怎样可以在线将pdf转换成jpg格式
  8. Scratch3.0创意编程(基础篇):第1课 Scratch画图形
  9. 虚拟机 安装 CUDA 可行性分析操作
  10. JAVA8的ConcurrentHashMap为什么放弃了分段锁,有什么问题吗,如果你来设计,你如何 设计。