又是Dropout两次!这次它做到了有监督任务的SOTA
©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 追一科技
研究方向 | NLP、神经网络
关注 NLP 新进展的读者,想必对四月份发布的 SimCSE [1] 印象颇深,它通过简单的“Dropout 两次”来构造正样本进行对比学习,达到了无监督语义相似度任务的全面 SOTA。无独有偶,最近的论文《R-Drop: Regularized Dropout for Neural Networks》提出了 R-Drop,它将“Dropout两次”的思想用到了有监督任务中,每个实验结果几乎都取得了明显的提升。此外,笔者在自己的实验还发现,它在半监督任务上也能有不俗的表现。
小小的“Dropout两次”,居然跑出了“五项全能”的感觉,不得不令人惊讶。本文来介绍一下 R-Drop,并分享一下笔者对它背后原理的思考。
论文标题:R-Drop: Regularized Dropout for Neural Networks
论文链接:https://arxiv.org/abs/2106.14448
代码链接:https://github.com/dropreg/R-Drop
SimCSE
《中文任务还是 SOTA 吗?我们给 SimCSE 补充了一些实验》[1] 中,我们已经对 SimCSE 进行了介绍。简单来说,SimCSE 是 NLP 的一种对比学习方案,对比学习的标准流程是同一个样本通过不同的数据扩增手段得到的结果视为正样本对,而 batch 内的所有其他样本视为负样本,然后就是通过 loss 来缩小正样本的距离、拉大负样本的距离了。
所以难度主要集中在数据扩增手段上。对于 NLP 来说,我们很难人工构建保证语义不变的数据扩增,所以 SimCSE 干脆不人工进行数据扩增,而是通过“Dropout 两次”的方式来得到同一个输入的不同特征向量,并将它们视为正样本对。奇怪的是,这个简单的“Dropout 两次”构造正样本,看上去是一种“无可奈何”的妥协选择,但消融实验却发现它几乎优于所有其他数据扩增方法,令人惊讶之余又让人感叹“大道至简”。
▲ SimCSE示意图
在实现上,SimCSE 也相当简单,所谓“Dropout 两次”,只需要将样本重复地输入到模型,然后计算相应的 loss 就行了,如上图所示。由于 Dropout 本身的随机性,每个样本的 Dropout 模式都是不一样的,所以只要单纯地重复样本,就可以实现“Dropout 两次”的效果。
R-Drop
从结果上来看,SimCSE 就是希望 Dropout对模型结果不会有太大影响,也就是模型输出对 Dropout 是鲁棒的。所以很明显,“Dropout 两次”这种思想是可以推广到一般任务的,这就是 R-Drop(Regularized Dropout)。
2.1 分类问题
在笔者看来,R-Drop 跟 SimCSE 是高度相关的,甚至 R-Drop 应该是受到了 SimCSE 启发的,不过 R-Drop 论文并没有引用 SimCSE,所以这就比较迷了。
▲ R-Drop示意图
具体来说,以分类问题为例,训练数据为 ,模型为 ,每个样本的 loss 一般是交叉熵:
在“Dropout 两次”的情况下,其实我们可以认为样本已经通过了两个略有不同的模型,我们分别记为 和 。这时候 R-Drop 的 loss 分为两部分,一部分是常规的交叉熵:
另一部分则是两个模型之间的对称 KL 散度,它希望不同 Dropout 的模型输出尽可能一致:
最终 loss 就是两个 loss 的加权和:
也就是说,它在常规交叉熵的基础上,加了一项强化模型鲁棒性正则项。
2.2 一般形式
可能有些读者会问非分类问题应该将 KL 项替换为什么,事实上原论文并没有在非分类问题上进行实验,不过这里可以补充一下。我们可以留意到:
所以,上述 只不过是 KL 散度的反复使用,它的一般形式是:
因此对于非分类问题,我们将 换成适当的度量(而不是 KL 散度)即可。
实验效果
我们先来看看 R-Drop 的实验结果。
R-Drop 的主要超参有三个:batch_size、 和 Dropout 概率。batch_size 一 般取决于我们的算力,对个人来说调整空间不大;原论文的 从 都有,笔者自己的实验中,则取了 ,也没细调。至于 Dropout的概率,跟笔者在《中文任务还是 SOTA 吗?我们给 SimCSE 补充了一些实验》[1] 所选的一样,设为 0.3 效果比较好。
3.1 论文报告
说实话,原论文所报告的 R-Drop 的效果是相当让人惊艳的,这也是笔者不得不要介绍一波 R-Drop 的主要原因。原论文在 NLU、NLG、CV 的分类等多种任务上都对 R-Drop 做了对比实验,大部分实验效果都称得上“明显提升”。
官方实现:https://github.com/dropreg/R-Drop
下面截图一部分实验结果:
▲ R-Drop在机器翻译任务上的效果
▲ R-Drop在GLUE任务上的效果
特别地,在机器翻译任务上,简单的“Transformer + R-Drop”超过了其他更加复杂方法的效果:
▲ 机器翻译任务上不同方法的对比
论文还包括自动摘要、语言模型、图像分类等实验,以及关于超参数的一些消融实验,大家仔细看原论文就好。总的来说,R-Drop 的这份“成绩单”,确实足以让人为之点赞了。
3.2 个人尝试
当然,笔者坚持的观点是“没有在中文测试过的模型是没有灵魂的”,一般情况下笔者都是在中文任务上亲自尝试过之后,才会写作分享。
个人实现:https://github.com/bojone/r-drop
有中文监督任务上,笔者实验了两个文本分类任务(CLUE 榜单的 IFLYTEK 和 TNEWS)。
和一个文本生成任务(CSL 标题生成,参考 Seq2Seq 中 Exposure Bias 现象的浅析与对策):
可以看到,R-Drop 的结果足以 PK 在对抗训练浅谈:意义、方法和思考(附Keras 实现)中介绍的著名正则化手段“对抗训练”和“梯度惩罚”了。
3.3 实现要点
相比于对抗学习等复杂正则化方法,R-Drop 的实现难度可谓是相当低了,这里以 bert4keras 为例,简单介绍一下如何将一个普通的训练脚本改为带 Dropout 的模式。
首先,是数据生成部分,改动如下:
class data_generator(DataGenerator):"""数据生成器"""def __iter__(self, random=False):batch_token_ids, batch_segment_ids, batch_labels = [], [], []for is_end, (text, label) in self.sample(random):token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)# batch_token_ids.append(token_ids)# batch_segment_ids.append(segment_ids)# batch_labels.append([label])for i in range(2):batch_token_ids.append(token_ids)batch_segment_ids.append(segment_ids)batch_labels.append([label])# if len(batch_token_ids) == self.batch_size or is_end:if len(batch_token_ids) == self.batch_size * 2 or is_end:batch_token_ids = sequence_padding(batch_token_ids)batch_segment_ids = sequence_padding(batch_segment_ids)batch_labels = sequence_padding(batch_labels)yield [batch_token_ids, batch_segment_ids], batch_labelsbatch_token_ids, batch_segment_ids, batch_labels = [], [], []
然后,自定义一个新 loss:
from keras.losses import kullback_leibler_divergence as klddef categorical_crossentropy_with_rdrop(y_true, y_pred):"""配合上述生成器的R-Drop Loss其实loss_kl的除以4,是为了在数量上对齐公式描述结果。"""loss_ce = K.categorical_crossentropy(y_true, y_pred) # 原来的lossloss_kl = kld(y_pred[::2], y_pred[1::2]) + kld(y_pred[1::2], y_pred[::2])return K.mean(loss_ce) + K.mean(loss_kl) / 4 * alpha
最后把模型的 Dropout 打开,并用这个 data_generator 和 categorical_crossentropy_with_rdrop 来训练模型就行了。
个人理解
看完了让人赏心悦目的实验结果后,我们来啃一下理论。原论文提供了对 R-Drop 的一个理论分析,大致意思是 R-Drop 会促进参数的同化,从而起到正则化作用。不过个人感觉这个解释并不直观,而且还不够本质。下面笔者试图提供 R-Drop 的另外几个角度的理解。
4.1 一致性
R-Dropout 可以看成是 Dropout 的改进,那 Dropout 有什么问题呢?其实 Dropout 是典型的训练和预测不一致的方法。具体来说,Dropout 在训练阶段往(某些层的)输入加上了乘性噪声,使得模型从 变成了 ,其中 的每个元素有 p 的概率为 0,剩下 1-p 的概率为 1/(1-p),训练目标就是:
这样训练之后,我们应该用哪个模型预测最好呢?不确定,但如果损失函数是 距离的话,那么我们可以推出最佳预测模型应该是:
推导:如果用 损失,此时单个样本的损失是:
注意,现在我们的问题是“模型训练完后应该用什么函数来预测”,所以 是常数,y 才是要优化的变量,这只不过是一个二次函数的最小值问题,容易解得 时损失函数最小。
我们假定这个结果能泛化到一般情况。上式告诉我们,带 Dropout 的模型的正确步骤是“模型融合”:
对同一个输入多次传入模型中(模型不关闭 Dropout),然后把多次的预测结果平均值作为最终的预测结果。
但我们一般情况下的预测方式显然不是这样的,而是直接关闭 Dropout 进行确定性的预测,这等价于预测模型由“模型平均”变成了“权重平均”:
这里的 1 指的是全 1 向量。所以,我们训练的是不同 Dropout 的融合模型,预测的时候用的是关闭 Dropout 的单模型,两者未必等价,这就是 Dropout 的训练预测不一致问题。
现在,我们就不难理解 R-Drop 了,它通过增加一个正则项,来强化模型对 Dropout 的鲁棒性,使得不同的 Dropout 下模型的输出基本一致,因此能降低这种不一致性,促进“模型平均”与“权重平均”的相似性,从而使得简单关闭 Dropout 的效果等价于多 Dropout 模型融合的结果,提升模型最终性能。
4.2 连续性
本文开头就提到 R-Drop 与 SimCSE 的相似性,事实上它还跟另外一个方法相当相似,那便是“虚拟对抗训练(Virtual Adversarial Training,VAT)”。(不过 R-Drop 也没引 VAT,难道就只有笔者觉得像吗??)
关于 VAT 的介绍,大家可以参考笔者之前的文章泛化性乱弹:从随机噪声、梯度惩罚到虚拟对抗训练。简单来说,VAT 也是通过一个正则项,使得模型对扰动更加鲁棒,增强模型本身的连续性(小的变化不至于对结果产生大的影响)。它们不同的地方在于加扰动的方式,VAT 只把扰动加入到输入中,并且通过对抗的思想提升扰动的针对性;R-Drop 的扰动则可以施加到模型的每一层中,并且扰动是随机的。
有读者可能想到了,VAT 可是主打半监督训练的,那是不是意味着 R-Drop 也可以做半监督训练?这部分原论文并没有实验,是笔者自己做的实验,答案是确实可以,跟 VAT 类似,R-Drop 新增的 KL 散度项是不需要标签的,因此可以无监督训练,混合起来就是半监督了,效果也还不错。下面是笔者的实验结果:
可以看到,R-Drop 的半监督效果完全不逊色于 VAT,而且它实现比 VAT 简单,速度也比 VAT 快!看来 VAT 有望退休了~ 直觉上来看,虽然 R-Drop 的扰动是随机的,但是 R-Drop 的扰动更多,所以它造成的扰动也会放大,也可能比得上 VAT 经过对抗优化的扰动,所以 R-Drop 能够不逊色于 VAT。
4.3 非目标类
一个比较直接的疑问是,如果我的模型够复杂,单靠交叉熵这一项,不能使得模型对 Dropout 鲁棒吗?KL 散度那一项造成了什么直接的区别?
事实上,还真的不能。要注意的是,交叉熵的训练目标主要是:让目标类的得分大于非目标类的得分,这样模型就能正确地把目标类预测出来了(参考将“softmax+交叉熵”推广到多标签分类问题)。也就是说,如果只有交叉熵这一项,模型的训练结果顶多是:
不同的 Dropout 下,目标类的得分都大于非目标类的得分。
而不能做到:
不同的 Dropout 下,每个类的得分一致。
所以也就没有解决训练预测不一致的问题。从公式上来看,交叉熵(2)只跟目标类别有关,不关心非目标类的分布,假如目标类为第一个类别,那么预测结果是 [0.5, 0.2, 0.3] 或 [0.5, 0.3, 0.2],对它来说都没区别。但对于 KL 散度项(3)来说就不一样了,每个类的得分都要参与计算,[0.5, 0.2, 0.3] 或 [0.5, 0.3, 0.2] 是有非零损失的。
本文小结
本文介绍了 R-Drop,它将“Dropout 两次”的思想用到了有监督任务中,每个实验结果几乎都取得了明显的提升。此外,笔者在自己的实验还发现,它在半监督任务上也能有不俗的表现。最后,分享了笔者对 R-Drop 的三个角度的思考。
参考文献
[1] https://kexue.fm/archives/8348
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
???? 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
???? 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
????
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。
又是Dropout两次!这次它做到了有监督任务的SOTA相关推荐
- 【论文复现】SimCSE对比学习: 文本增广是什么牛马,我只需要简单Dropout两下
文本增广是什么牛马,我只需要简单Dropout两下 Sentence Embeddings与对比学习 SimCSE 无监督Dropout 有监督对比学习 如何评判Sentence Embeddings ...
- 丹琦女神新作:对比学习,简单到只需要Dropout两下
文 | 花小花Posy 上周把 <对比学习有多火?文本聚类都被刷爆了...>分享到卖萌屋的群里后,遭到了群友们一波嫌弃安利. 小伙伴们表示,插入替换的数据增强方式已经Out了,SimCSE ...
- FPT:又是借鉴Transformer!这次多方向融合特征金字塔 | ECCV 2020
点击上方"CVer",选择加"星标"置顶 重磅干货,第一时间送达 本文转载自:晓飞的算法工程笔记 论文提出用于特征金字塔的高效特征交互方法FPT,包含3种精心设 ...
- 又是一天,这次是网页显示的问题
功能需求是在程序中加载HTML代码并进行显示,为了实现这个功能,昨天徘徊了一个晚上,拿不定主意到底该用SDK还是MFC来实现,用SDK能实现的话,可以节约不少时间,因为现在我对MFC一窍不通.但SDK ...
- 又是用户数据泄露!这次轮到了“卡饭论坛”
国内最大的计算机安全论坛之一的"卡饭论坛"承认用户数据泄露. 2015年10月28日晚十点半,"卡饭论坛"的会员公告栏突然贴出管理员"theone&q ...
- Dropout视角下的MLM和MAE:一些新的启发
©PaperWeekly 原创 · 作者 |苏剑林 单位 |追一科技 研究方向 |NLP.神经网络 大家都知道,BERT 的 MLM(Masked Language Model)任务在预训练和微调时的 ...
- ChildTuning:试试把Dropout加到梯度上去?
©PaperWeekly 原创 · 作者 |苏剑林 单位 |追一科技 研究方向 |NLP.神经网络 Dropout 是经典的防止过拟合的思路了,想必很多读者已经了解过它.有意思的是,最近 Dropou ...
- 克服Dropout缺陷,简单又有效的正则方法:R-Drop
本文转自微软研究院AI头条. 编者按:深度神经网络是深度学习的基础,但其在训练模型时会出现过拟合的问题,而简单易用的 Dropout 正则化技术可以防止这种问题的发生.然而 Dropout 的操作在一 ...
- 【NLP】对比学习——文本匹配(二)
文章目录 d. SimCSE:2021.04 Unsupervised Supervised e. R-Drop(Supervised):2021.06 f. ESimCSE(Unsupervised ...
最新文章
- Redis数据库 安装
- python计角_最小角python算法
- 《重新认识你自己》八:与真实的自我相处
- 超详细Pycharm部署项目视频教程
- jQuery选择器中的特殊符号和关键字
- Android里监视数据库的变化
- 【报错笔记】Eclipse导入Maven项目时pom.xml报错,项目上有红感叹号。
- spark | scala | 线性代数库Breeze学习
- canvas特效代码详解(2)
- Spring Boot中使用MyBatis注解配置详解
- 例子:10秒后同意按钮可点击
- Android中设置EditText显示或隐藏密码
- PRML学习总结(1)——Introduction
- 诺基亚如何利用计算机上网,诺基亚E63的WIFI上网功能全教程
- 模电、数电、电力电子、开关电源基础知识总结
- DateTime.Now函数详解 所有用法
- html 设计尺寸,多少像素才合适 网页设计标准尺寸大讲解
- 高亮蓝光油溶性ZnCdS/ZnS量子点(硫化锌镉/硫化锌)
- 漫画 |《帝都程序猿十二时辰》
- 【软件测试】离开“浪浪山“测试人迎来的春天......
热门文章
- php5.4 zend安装教程,linux下php5.4安装Zend Guard Loader扩展
- matlab 计算误码率,关于误码率的问题 急!!!!!
- 添加vlan后无法上网_KTV多SSID绑定VLAN实用案例,值得一看的干货
- java中所有函数都是虚函数_关于Java:虚拟函数与纯虚函数之间的区别是什么?...
- c++ map 析构函数_说说C++的虚析构函数
- 数据结构34:二叉树前序遍历、中序遍历和后序遍历
- 图像相似度测量与模板匹配总结
- Android开发切换host应用
- Java版世界时钟示例
- VSS2005 添加文件夹方法!