近期对比学习在NLP\text{NLP}NLP领域取得了不错的成绩,例如句嵌入方法SimCSE[1]\text{SimCSE}^{[1]}SimCSE[1]和短文本聚类方法SCCL[2]\text{SCCL}^{[2]}SCCL[2]。为了能更好的理解近期的进展,期望通过一系列相关的文章来循序渐进的介绍其中的技术和概念。本文就作为该系列的第一篇文章吧~

一、SimCLR\text{SimCLR}SimCLR简介

  • 原始论文:A Simple Framework for Contrastive Learning of Visual Representations
  • 核心思想:
      1. 将一个样本xxx数据增强为两个不同个样本x~i\tilde{x}_ix~i​和x~j\tilde{x}_jx~j​;
      2. 拉近样本x~i\tilde{x}_ix~i​和x~j\tilde{x}_jx~j​的距离,并拉远它们和其他样本的距离;

二、SimCLR\text{SimCLR}SimCLR框架

图1. SimCLR框架

SimCLR\text{SimCLR}SimCLR是一个对比学习的框架,其结构如图1所示,主要包含四个组件:

1. 数据增强模块

该模块会为一个样本随机生成两个增强样本x~i\tilde{x}_ix~i​和x~j\tilde{x}_jx~j​,这两个样本组成了一个正样本对(x~i,x~j)(\tilde{x}_i,\tilde{x}_j)(x~i​,x~j​)。

  • 论文主要是针对图像的。因此,采用的数据增强方式包括:裁剪、颜色失真、高斯模糊;

2. 编码器

编码器f(⋅)f(\cdot)f(⋅)的作用是将增强样本转换为向量表示,hi=f(x~i)\textbf{h}_i=f(\tilde{x}_i)hi​=f(x~i​);

  • 论文选择ResNet\text{ResNet}ResNet作为编码器,hi=f(x~i)=ResNet(x~i)\textbf{h}_i=f(\tilde{x}_i)=\text{ResNet}(\tilde{x}_i)hi​=f(x~i​)=ResNet(x~i​);

3. 投影头(Projection head)

投影头g(⋅)g(\cdot)g(⋅)是一个小型神经网络,其作用是将样本的向量表示映射至可以对比的空间中(也就是适合Loss计算的表示空间);

  • 论文使用单层全连接神经网络作为投影头,即zi=g(hi)=W(2)σ(W(1)hi)z_i=g(\textbf{h}_i)=W^{(2)}\sigma(W^{(1)}\textbf{h}_i)zi​=g(hi​)=W(2)σ(W(1)hi​),σ\sigmaσ是ReLU\text{ReLU}ReLU激活函数;

4. 对比损失函数

对比损失函数l\mathcal{l}l,其作用是:在一个包含正样本对(x~i,x~j)(\tilde{x}_i,\tilde{x}_j)(x~i​,x~j​)的集合{x~k}\{\tilde{x}_k\}{x~k​},给定样本x~i\tilde{x}_ix~i​,从{x~k}k≠i\{\tilde{x}_k\}_{k\neq i}{x~k​}k​=i​中确定出x~j\tilde{x}_jx~j​;

三、框架的实现

上面描述了SimCLR\text{SimCLR}SimCLR框架,本小节则是该框架的一个具体实现。

1. 损失函数NT-Xent\text{NT-Xent}NT-Xent

  • 随机采样NNN个样本作为minibatch\text{minibatch}minibatch,并通过数据增强生成2N2N2N个样本。这里将正样本对以外2(N−1)2(N-1)2(N−1)个样本当做负样本;
  • 向量相似度计算方式为:sim(u,v)=u⊤v/∥u∥∥v∥\text{sim}(u,v)=u^\top v/\Vert u\Vert\Vert v\Vertsim(u,v)=u⊤v/∥u∥∥v∥;
  • 正样本对(i,j)(i,j)(i,j)的损失函数

li,j=−logexp(sim(zi,zj)/τ)∑k=12N1k≠iexp(sim(zi,zk)/τ)\mathcal{l}_{i,j} = -\text{log}\frac{\text{exp(sim(}z_i,z_j)/\tau)}{\sum_{k=1}^{2N}1_{k\neq i}\text{exp(sim(}z_i,z_k)/\tau)} li,j​=−log∑k=12N​1k​=i​exp(sim(zi​,zk​)/τ)exp(sim(zi​,zj​)/τ)​
​ 其中,1[k≠i]∈{0,1}1_{[k\neq i]}\in\{0,1\}1[k​=i]​∈{0,1}是指示函数,τ\tauτ是温度(temperature)参数;

  • 同一个minibatch\text{minibatch}minibatch中所有正样本对的损失之和为最终的loss,称这个loss为NT-Xent\text{NT-Xent}NT-Xent。

2. 完整的算法描述

输入:batch size NNN,常量τ\tauτ,结构f,g,Tf,g,\mathcal{T}f,g,T;

for 采样的minibatch {xk}k=1N\{x_k\}_{k=1}^N{xk​}k=1N​ do

​  for all k∈{1,…,N}k\in\{1,\dots,N\}k∈{1,…,N} do

​ 随机选择两种数据增强函数t∼T,t′∼Tt\sim\mathcal{T},t'\sim\mathcal{T}t∼T,t′∼T

​​ # 第一个数据增强

​​ x~2k−1=t(xk)\tilde{x}_{2k-1}=t(x_k)x~2k−1​=t(xk​)

​​ h2k−1=f(x~2k−1)h_{2k-1}=f(\tilde{x}_{2k-1})h2k−1​=f(x~2k−1​) # 表示

​​ z2k−1=g(h2k−1)z_{2k-1}=g(h_{2k-1})z2k−1​=g(h2k−1​) # 投影

​​ # 第二个数据增强

​​ x~2k=t′(xk)\tilde{x}_{2k}=t'(x_k)x~2k​=t′(xk​)

​​ h2k=f(x~2k−1)h_{2k}=f(\tilde{x}_{2k-1})h2k​=f(x~2k−1​) # 表示

​ ​ z2k=g(h2k−1)z_{2k}=g(h_{2k-1})z2k​=g(h2k−1​) # 投影

​  end for

for all i∈{1,…,2N}and j∈{1,…,2N}i\in\{1,\dots,2N\}\text{ and } j\in\{1,\dots,2N\}i∈{1,…,2N} and j∈{1,…,2N} do

​  si,j=zizj/(∣∣zi∣∣∣∣zj∣∣)s_{i,j}=z_iz_j/(||z_i||||z_j||)si,j​=zi​zj​/(∣∣zi​∣∣∣∣zj​∣∣)

​  end for

​  定义l(i,j)\mathcal{l}(i,j)l(i,j)为li,j=−logexp(sim(zi,zj)/τ)∑k=12N1k≠iexp(sim(zi,zk)/τ)\mathcal{l}_{i,j} = -\text{log}\frac{\text{exp(sim(}z_i,z_j)/\tau)}{\sum_{k=1}^{2N}1_{k\neq i}\text{exp(sim(}z_i,z_k)/\tau)}li,j​=−log∑k=12N​1k​=i​exp(sim(zi​,zk​)/τ)exp(sim(zi​,zj​)/τ)​

​  L=12N∑k=1N[l(2k−1,2k)+l(2k,2k−1)]\mathcal{L}=\frac{1}{2N}\sum_{k=1}^N[\mathcal{l}(2k-1,2k)+\mathcal{l}(2k,2k-1)]L=2N1​∑k=1N​[l(2k−1,2k)+l(2k,2k−1)]

​  通过最小化L\mathcal{L}L来更新网络fff和ggg

end for

return 返回编码网络f(⋅)f(\cdot)f(⋅),并丢弃g(⋅)g(\cdot)g(⋅)

3. 训练细节

  • 为了不使用memory bank,将batch size从256增大至8192;
  • 由于SGD\text{SGD}SGD在大batch size上不稳定,使用LARS进行训练;

四、分析

  • 数据增强操作的组合对于学习好的向量表示至关重要

    上图是不同种数据增强方式间组合带来的影响,对角线表示单个一种数据增强方法。可以发现,对角线的颜色都比较深,也就是说单一的数据增强方式效果并不好。两两组合的数据增强方式效果更佳。

  • 相较于有监督学习,数据增强对对比学习更加有效

    上表时数据增强程度对有监督学习(Supervised)和对比学习(SimCLR)的影响。可以发现,数据增强对“对比学习”影响更大。

  • 模型越大、对比学习效果越好

上图中红色的点是对比学习的效果,随着模型规模的增大,效果也越来越好;

  • 非线性投影头能改善向量表示的质量

    上图中,非线性投影头优于线性投影头,线性投影头优于不进行投影;

  • 合适的温度参数能够帮助模型学习到更难的负样本

观察上表,l2 norm\text{l2 norm}l2 norm是有效的,而是适当大小的τ\tauτ也有助于模型的表现;

  • 大batch size和长的训练时间也有益于对比学习

    观察上图,大的batch size和较大的epoch有助于模型的表现;

【自然语言处理】【对比学习】搞nlp还不懂对比学习,不会吧?快来了解下SimCLR相关推荐

  1. 【自然语言处理】【对比学习】SimCSE:基于对比学习的句向量表示

    相关博客: [自然语言处理][对比学习]SimCSE:基于对比学习的句向量表示 [自然语言处理]BERT-Whitening [自然语言处理][Pytorch]从头实现SimCSE [自然语言处理][ ...

  2. 从ACL2021看对比学习在NLP中的应用

    本文首发于微信公众号"夕小瑶的卖萌屋" 文 | 花小花Posy 源 | 夕小瑶的卖萌屋 最近关注对比学习,所以ACL21的论文列表出来后,小花就搜罗了一波,好奇NLPers们都用对 ...

  3. 对比学习在NLP和多模态领域的应用

    © 作者|杨锦霞 研究方向 | 多模态 引言 对比学习的主要思想是相似的样本的表示相近,而不相似的远离.对比学习可以应用于监督和无监督的场景下,并且目前在CV.NLP等领域中取得了较好的性能.本文先对 ...

  4. 【自然语言处理】一文概述2017年深度学习NLP重大进展与趋势

    选自 tryolabs 机器之心编译 参与:路雪.黄小天.蒋思源 作者通过本文概述了 2017 年深度学习技术在 NLP 领域带来的进步,以及未来的发展趋势,并与大家分享了这一年中作者最喜欢的研究.2 ...

  5. 《Python自然语言处理-雅兰·萨纳卡(Jalaj Thanaki)》学习笔记:06 高级特征工程和NLP算法

    06 高级特征工程和NLP算法 6.1 词嵌入 6.2 word2vec基础 6.2.1 分布语义 6.2.2 定义word2vec 6.2.3 无监督分布语义模型中的必需品 6.3 word2vec ...

  6. 还不懂你现在学习的编程语言能做什么?还不懂如何进阶?过来看图

    前言说七说八 本篇文章的配图标注.内容并不代表仅有:本篇仅以个人经验及当前大学(大专.本科)相关课程作对比,列出比较常规的语言发展走向及相关技术:再次重申,本图及本文所涉及的技术发展走向并不代表着仅有 ...

  7. 《Python自然语言处理-雅兰·萨纳卡(Jalaj Thanaki)》学习笔记:05 特征工程和NLP算法

    05 特征工程和NLP算法 5.1 理解特征工程 5.1.1 特征工程的定义 5.1.2 特征工程的目的 5.1.3 一些挑战 5.2 NLP中的基础特征 5.2.1 句法解析和句法解析器 5.2.2 ...

  8. 《Python自然语言处理-雅兰·萨纳卡(Jalaj Thanaki)》学习笔记:11 如何提高你的NLP技能

    11 如何提高你的NLP技能 11.1 开始新的NLP职业生涯 11.2 备忘列表 11.3 确定你的领域 11.4 通过敏捷的工作来实现成功 11.5 NLP和数据科学方面一些有用的博客 11.6 ...

  9. 【深度学习NLP】初识深度学习(DL)与自然语言(NLP)

    一.自然语言(NLP)处理概述 1.什么是自然语言(NLP) 自然语言就是人类所了解到的语言,与计算机语言相比本质上两者是同义的. 2.自然语言处理(NLP)的基础概念 (1)横跨了计算机科学.语言学 ...

  10. 微信高级研究员解析深度学习在NLP中的发展和应用 | 公开课笔记

    作者 | 张金超(微信模式识别中心的高级研究员) 整理 | Just 出品 | 人工智能头条(公众号ID:AI_Thinker) 近年来,深度学习方法极大的推动了自然语言处理领域的发展.几乎在所有的 ...

最新文章

  1. 深度剖析 浮点型 在内存中的存储【C语言】
  2. 第二次作业+105032014001
  3. 使用字典编码每个字再编码每句话不知对nlp是否有帮助(深度大脑)
  4. 对软连接进行cp,rm
  5. cout、cerr、clog
  6. 系统通知并发问题_玩转Java高并发?请先说明下并发下的惊群效应
  7. 计算机显示文本自定义130%,实训课题目
  8. 使用账户和密码在FTP客户端连接FTP服务器,出现vsftpd:500 OOPS: vsftpd: refusing to run with writable root inside chroot
  9. 机器学习算法基础8-Nagel-Schreckenberg交通流模型-公路堵车概率模型
  10. 使用FZip创建压缩文件保存到桌面
  11. 学习OpenCV(2)OpenCV初探-2
  12. VIN码识别,车架号识别,移动端VIN码识别独家支持云识别
  13. SUMIFS函数 、MATCH及INDEX函数
  14. 【转】图像视觉开源代码
  15. 全面了解 360 评估
  16. 嵌入式学习笔记-2022.2.22
  17. 抖音落地页一键复制微信号跳转微信的方法
  18. Bootstrap实战练习---Web全栈课程体系(表格+巨幕)
  19. 因数(factor)
  20. #includeiomanip

热门文章

  1. 关于zabbix中vm.memery.size监控项后的参数
  2. DAY9:尚学堂高琪JAVA(98)
  3. Python网络爬虫与信息提取学习
  4. starbound服务器配置文件怎么写,【mod向】简单修改文件迅速刷到任何想要物品以及修改随机生成物品入手时数据可以带入任何服务器【修改向】...
  5. Node.js简介及安装
  6. 支付宝转账提现相关问题
  7. qte5编译dub.json
  8. pdf转换成jpg python_Python Wand将PDF转换为JPG background
  9. EKS使用AWS EFS CSI
  10. JavaWeb - 工作窃取算法 Work-Stealing