对比学习可以使用梯度累积吗?
©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 追一科技
研究方向 | NLP、神经网络
在之前的文章用时间换取效果:Keras 梯度累积优化器中,我们介绍过“梯度累积”,它是在有限显存下实现大 batch_size 效果的一种技巧。一般来说,梯度累积适用的是 loss 是独立同分布的场景,换言之每个样本单独计算 loss,然后总 loss 是所有单个 loss 的平均或求和。然而,并不是所有任务都满足这个条件的,比如最近比较热门的对比学习,每个样本的 loss 还跟其他样本有关。
那么,在对比学习场景,我们还可以使用梯度累积来达到大 batch_size 的效果吗?本文就来分析这个问题。
简介
一般情况下,对比学习的 loss 可以写为:
这里的 b 是 batch_size;
是事先给定的标签,满足 ,它是一个 one hot 矩阵,每一列只有一个 1,其余都为0;而 是样本 i 和样本 j 的相似度,满足 ,一般情况下还有个温度参数,这里假设温度参数已经整合到 中,从而简化记号。模型参数存在于 中,假设为 。
可以验证,一般情况下:
所以直接将小 batch_size 的对比学习的梯度累积起来,是不等价于大 batch_size 的对比学习的。类似的问题还存在于带 BN(Batch Normalization)的模型中。
梯度
注意,刚才我们说的是常规的简单梯度累积不能等效,但有可能存在稍微复杂一些的累积方案的。为此,我们分析式(1)的梯度:
其中
表示不需要对 求 的梯度,也就是深度学习框架的 stop_gradient 算子。上式表明,如果我们使用基于梯度的优化器,那么使用式(1)作为 loss,跟使用 作为 loss,是完全等价的(因为算出来的梯度一模一样)。
内积
接下来考虑
的计算,一般来说它是向量的内积形式,即 ,参数 在 里边,这时候:
所以 loss 中的
可以替换为 而效果不变:
第二个等号源于将
那一项的求和下标 i,j 互换而不改变求和结果。
流程
式(5)事实上就已经给出了最终的方案,它可以分为两个步骤。第一步就是向量:
的计算,这一步不需要求梯度,纯粹是预测过程,所以 batch_size 可以比较大;第二步就是把
当作“标签”传入到模型中,以 为单个样本的 loss 进行优化模型,这一步需要求梯度,但它已经转化为每个样本的梯度和的形式了,所以这时候就可以用常规的梯度累积了。
假设反向传播的最大 batch_size 是 b,前向传播的最大 batch_size 是 nb,那么通过梯度累积让对比学习达到 batch_size 为 nb 的效果,其格式化的流程如下:
1. 采样一个 batch 的数据
,对应的标签矩阵为 ,初始累积梯度为 g=0;
2. 模型前向计算,得到编码向量
以及对应的概率矩阵 ;
3. 根据式(6)计算标签向量
;
4. 对于
,执行:
5. 用 g 作为最终梯度更新模型,然后重新执行第 1 步。
总的来说,在计算量上比常规的梯度累积多了一次前向计算。当然,如果前向计算的最大 batch_size 都不能满足我们的需求,那么也可以分批前向计算,因为我们只需要把各个
算出来存好,而 可以基于 算出来。
最后还要提醒的是,上述流程只是在优化时等效于大 batch_size 模型,也就是说
的梯度等效于原 loss 的梯度,但是它的值并不等于原 loss 的值,因此不能用 作为 loss 来评价模型,它未必是单调的,也未必是非负的,跟原来的 loss 也不具有严格的相关性。
问题
上述流程有着跟《节省显存的重计算技巧也有了 Keras 版了》[1] 介绍的“重计算”一样的问题,那就是跟 Dropout 并不兼容,这是因为每次更新都涉及到了多次前向计算,每次前向计算都有不一样的 Dropout,这意味着我们计算标签向量
时所用的 跟计算梯度时所用的 并不是同一个,导致计算出来的梯度并非最合理的梯度。
这没有什么好的解决方案,最简单有效的方法就是在模型中去掉 Dropout。这对于 CV 来说没啥大问题,因为 CV 的模型基本也不见 Dropout 了;对于 NLP 来说,第一反应能想到的结果就是 SimCSE 没法用梯度累积,因为 Dropout 是 SimCSE 的基础。
小结
本文分析了对比学习的梯度累积方法,结果显示对比学习也可以用梯度累积的,只不过多了一次前向计算,并且需要在模型中去掉 Dropout。本文同样的思路还可以分析 BN 如何使用梯度累积,有兴趣的读者不妨试试。
参考文献
[1] https://kexue.fm/archives/7367
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
???? 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
???? 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
????
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。
对比学习可以使用梯度累积吗?相关推荐
- [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积
[源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 文章目录 [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 0x00 摘要 0x01 概述 1.1 前 ...
- 语义表征的无监督对比学习:一个新理论框架
点击上方↑↑↑蓝字关注我们~ 「2019 Python开发者日」7折优惠最后3天,请扫码咨询 ↑↑↑ 译者 | Linstancy 责编 | 琥珀 出品 | AI科技大本营(ID:rgznai100) ...
- 细节满满!理解对比学习和SimCSE,就看这6个知识点
©PaperWeekly 原创 · 作者 | 海晨威 研究方向 | 自然语言处理 2020 年的 Moco 和 SimCLR 等,掀起了对比学习在 CV 领域的热潮,2021 ...
- 对比学习(Contrastive Learning)相关进展梳理
©PaperWeekly 原创 · 作者|李磊 学校|西安电子科技大学本科生 研究方向|自然语言处理 最近深度学习两巨头 Bengio 和 LeCun 在 ICLR 2020 上点名 Self-Su ...
- 超细节的对比学习和SimCSE知识点
2020年的Moco和SimCLR等,掀起了对比学习在CV领域的热潮,2021年的SimCSE,则让NLP也乘上了对比学习的东风.下面就尝试用QA的形式挖掘其中一些细节知识点,去更好的理解对比学习和S ...
- 顶会中的对比学习论文-2
文章目录 1 NAACL-2022 DiffCSE:Difference-based Contrastive Learning for SentenceEmbeddings Learning Dial ...
- 对比学习顶会论文系列-3-2
文章目录 一.特定任务中的对比学习 1.2 摘要生成中的对比学习--SimCLS: A Simple Framework for Contrastive Learning of Abstractive ...
- 【CV】对比学习经典之作 SimLR 论文笔记
论文名称:A Simple Framework for Contrastive Learning of Visual Representations 论文下载:https://arxiv.org/ab ...
- MICCAI 2022 | CLFC:基于对比学习的多模态脑肿瘤分割与单模态正常脑图像的特征比较
MICCAI 2022 | CLFC基于对比学习的多模态脑肿瘤分割与单模态正常脑图像的特征比较 Multimodal Brain Tumor Segmentation Using Contrastiv ...
最新文章
- sftp工具都有哪些_色彩校正的工具都有哪些?
- 【SeeMusic】视频编辑 ( 视频 X 坐标 | 视频 Y 坐标 | 视频旋转 | 视频扭曲 )
- 数据结构一:链表(循环链表)
- IDEA工具实现反编译操作
- Linux中强大的输入输出重定向和管道
- 前端学习(3139):react-hello-react之生命周期组件挂载过程
- 2021曲靖高考成绩查询时间,2021年曲靖高考成绩排名及成绩公布时间什么时候出来...
- VirtualBox中虚拟Ubuntu添加新的虚拟硬盘
- python中pass语句的作用是_Python pass语句以及作用详解
- 如何查看服务器数据库修改密码,如何查看服务器数据库密码
- Java复习第三天-静态方法
- ie6不支持png图片的解决办法
- matlab约束转非约束,请问:fmincon非等和等于的约束条件
- 恢复rm删除的文件(ext3
- web安全深度剖析知识点总结
- 【每日算法Day 98】慈善赌神godweiyang教你算骰子点数概率!
- dp P1103 书本整理 洛谷
- java 二级联动实现
- UI框架的使用(NGUI)
- VBA 字典 键值可以是 二维数组
热门文章
- iap升级问题 stm32f103r8_STM32的基于串口的IAP固件升级与加密
- 鸿蒙so系统,鸿蒙手机版JNI实战(JNI开发、SO库生成、SO库使用)
- php 打印变量内存地址_Python合集之Python变量
- cudnn下载_Windows10安装 cuDNN 方法
- 构成子网与构成超网的分析
- 【转】Java内存与垃圾回收调优
- Json.NET Deserialize时如何忽略$id等特殊属性
- nchar,char,varchar与nvarchar区别
- Silverlight中调用WebService-发送邮件测试实例
- Live Messenger 邀请,再次放送