来自:李rumor

关注CV领域的小伙伴一定都记得Hinton团队在年初提出的SimCLR[1],采用自监督的对比学习方法进行encoder的训练,各种碾压之前的模型。所以今年我一直在等某个大招,终于在20年的尾巴看到了一丝希望。

今天要介绍的这篇工作来自斯坦福和Facebook AI,作者在BERT分类任务的精调阶段加入了对比学习的loss,在各个任务上都获得了很稳定的提升:

上图中CE表示交叉熵,SCL表示Supervised Contrastive Learning。实话说结果并不够惊艳,用对抗学习也差不多可以做到,让我惊喜的是在Few-shot上的效果:

N表示训练样本数量。可以看到N=20时QNLI上有10个点之多的提升。

下面就让我们来走近科学,看看SCL是个啥玩意儿叭~

论文题目:Supervised Contrastive Learning for Pre-trained Language Model Fine-tuning
论文链接:https://arxiv.org/abs/2011.01403

对比学习

对比学习的核心思想,就是让模型学习如何将正样本和其他负样本区别开来,抓住样本的本质特征,而不是把每个细节都考虑到。拿人来举例,假如有人让你凭空画一张一美元,你可能只画成这样[2]

而如果给你一张美元照着临摹,可能还能画好看点,比如这样:

所以说我们记住的,不一定是像素级别的特征,而是更高维度的。在训练模型时,也不强求它们把所有信息都编码,只要细致到可以区分数据中的不同样本就可以。

如何实现呢?这个就体现在目标函数上:

在自监督的情况下,对比学习利用数据增强方法,给每个输入样本输入构建另一个view







作为正例,并使用同batch下其他样本







作为负例,达到拉近正例拉开负例的“对比”目的:

描述得更具体一点,就是把N个输入样本增强到2N个,然后进行2N分类(其中有2个正例2N-2个负例)。

P.S. 关于对比学习在图像领域的进展可以参考知乎@Tobias Lee的文章[3]

Supervised Contrastive Learning

上文讲了自监督的对比学习主要是靠一个batch内的样本间相互对比,那有监督的数据如何更好利用呢?

作者就针对分类任务进行了研究。分类的核心思想就是把不同类别的样本划分开来,通常使用交叉熵作为损失函数。作者则提出了一个新的对比学习loss SCL,将同一类的样本互相作为正例,不同类别的作为负例。以此达到拉近类内样本、拉开类间距离的目的:

具体的损失计算方法为(右滑公示):

其中







是正确label,




是归一化后的encoder输出,




是一个控制类间距离的超参数,越低负例就越难分。这个式子的主要目的就是拉近正样本(同类数据)的距离。

实验结果

除了开头展示的直接提升外,作者还进行了很多分析。从SST-2数据集的[CLS] embedding来看,通过CE(左)和SCL(右)损失训练出来的encoder对正负例的区分能力确实有不少差距:

同时在有噪声的训练数据上SCL鲁棒性会更强(T越高噪声越多):

总结

这篇文章目前正在投稿ICLR2021(都在arxiv上挂了还盲审啥。。),总体的改动比较简单,但对比学习的前景还是挺大的,同时加上SCL损失之后不仅对少样本的情况很有帮助,也能提升模型鲁棒性,相比于对抗学习的计算代价明显要小,还是比较实用的,一起立个flag,复现一波?

参考资料

[1]

A Simple Framework for Contrastive Learning of Visual Representations: https://arxiv.org/abs/2002.05709

[2]

Contrastive Self-Supervised Learning: https://ankeshanand.com/blog/2020/01/26/contrative-self-supervised-learning.html

[3]

对比学习(Contrastive Learning)相关进展梳理: https://zhuanlan.zhihu.com/p/141141365

下载一:中文版!学习TensorFlow、PyTorch、机器学习、深度学习和数据结构五件套!后台回复【五件套】
下载二:南大模式识别PPT后台回复【南大模式识别】

说个正事哈

由于微信平台算法改版,公号内容将不再以时间排序展示,如果大家想第一时间看到我们的推送,强烈建议星标我们和给我们多点点【在看】。星标具体步骤为:

(1)点击页面最上方深度学习自然语言处理”,进入公众号主页。

(2)点击右上角的小点点,在弹出页面点击“设为星标”,就可以啦。

感谢支持,比心

投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。

方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。

记得备注呦

推荐两个专辑给大家:

专辑 | 李宏毅人类语言处理2020笔记

专辑 | NLP论文解读

专辑 | 情感分析


整理不易,还望给个在看!

给BERT加一个loss就能稳定提升?斯坦福+Facebook最新力作!相关推荐

  1. [学习日志]使用pytorch 和 bert 实现一个简单的文本分类任务

    项目简介 最近在学习pytorch和Bert,所以做了一个这样完全新手向的入门项目来练习. 由于之前在网上学习发现现存的教程比较少,所以记录一下自己的学习过程,加深印象,也希望能帮到别的学习者吧,能涨 ...

  2. 钉钉机器人关键字自动回复_如何用 GPT2 和 BERT 建立一个可信的 reddit 自动回复机器人?...

    上个月,我尝试构建一个 reddit 评论机器人,通过结合两个预先训练的深度学习模型 GPT-2 和 BERT 生成自然语言回复.在这里我想一步一步地介绍一下我的工作,这样其他人就可以用我所建立的东西 ...

  3. 一加6android p降级包,国内最快!一加6更新Android P稳定版,官方确认一加6T会预装...

    原标题:国内最快!一加6更新Android P稳定版,官方确认一加6T会预装 临近年底,各个行业都热闹非凡,纷纷推出了新品或者做起了活动.而在手机行业内,年底发布新旗舰似乎成为了约定成俗的规矩,近日一 ...

  4. 亚马逊:我们提取了BERT的一个最优子架构,只有Bert-large的16%,CPU推理速度提升7倍...

    选自arXiv 作者:Adrian de Wynter.Daniel J. Perry 机器之心编译 机器之心编辑部 提取 BERT 子架构是一个非常值得探讨的问题,但现有的研究在子架构准确率和选择方 ...

  5. 给你的开源项目加一个绶带吧

    D2 Ribbons 是一套为开发者准备的开源社区绶带资源,你你可以下载图片到你的项目中使用或者直接使用仓库资源链接. 素材地址 github.com/d2-projects- Features 扁平 ...

  6. pthread_cond_wait()加一个while为什么的解释

    等号上面这段是大多数网上给pthread_cond_wait()加一个while为什么的解释:但是有些地方不太明白或者说没有解释清晰: 准备:1:pthread_cond_singal是唤醒至少一个线 ...

  7. 接到一个需求,想在页面上加一个链接有多难?

    点击上方蓝色"程序猿DD",选择"设为星标" 回复"资源"获取独家整理的学习资料! 作者 | 程序师 来源 | www.techug.com ...

  8. cassandra——可以预料的查询,如果你的查询条件有一个是根据索引查询,那其它非索引非主键字段,可以通过加一个ALLOW FILTERING来过滤实现...

    cassandra的索引查询和排序 转自:http://zhaoyanblog.com/archives/499.html cassandra的索引查询和排序 cassandra的查询虽然很弱,但是它 ...

  9. 数值格式化,每隔三位加一个逗号

    数值整数和小数 每隔三位加一个逗号方便阅读 function addCommas(nStr){ nStr += ''; x = nStr.split('.'); x1 = x[0]; x2 = x[1 ...

  10. 给博客园加一个会动的小人-spig.js

    给博客园加一个会动的小人-spig.js 效果大概是这样,感觉十分可爱qvq 那么怎么添加呢? 首先需要开通js/html权限. 然后在页脚html代码中加入以下代码 <script src=& ...

最新文章

  1. SAP Workload Monitor
  2. Lite-HRNet
  3. 编写易于理解代码的六种方式
  4. 【深度学习】CV和NLP通吃!谷歌提出OmniNet:Transformers的全方位表示
  5. vue源码之响应式数据
  6. ORACLE1.21 PLSQL 01
  7. 0-10不断循环的js
  8. java中case语句_Java:switch-case语句
  9. NodeJS服务器退出:完成任务,优雅退出
  10. java8: hashmap性能提升
  11. java teechart怎么用_TeeChart for Java
  12. stvd能编辑c语言吗,STVD自动生成的stm8_interrupt_vector.c中几个疑问
  13. dll修复工具安装教程
  14. 浪潮配置ipim_浪潮服务器管理口IP设置_IPMI设置
  15. 影响虚拟主机访问速度的因素,主要有哪些?
  16. SQL Server医疗信息管理系统数据库【英文版-源码】--(Medical Management System Database)
  17. 海康威视监控推流自建服务器实现网页端无插件1-2秒低延迟实时监控
  18. 【python】获取当前时间(年月日时分秒)
  19. Angular 组件类测试
  20. 完整的SEO团队应该包括哪些人员(细分八要职)

热门文章

  1. 管理信息系统第一次作业
  2. MVVM 实战之计算器
  3. 漫游Kafka实战篇clientAPI
  4. 用 Javascript 验证表单(form)中多选框(checkbox)值
  5. 雷声大雨点小-参加江西省网站内容管理系统培训有感
  6. java day23【函数式接口】
  7. 算法:两条线段求交点
  8. nyoj-488 素数环 +nyoj -32 组合数 (搜索)
  9. mysql 的命令行操作
  10. SonarLint插件的安装与使用