文 | 王希梅,高敬涵,龙明盛,王建民
源 | THUML

本文介绍ICML2021的中稿论文:Self-Tuning for Data-Efficient Deep Learning,就“如何减少对标记数据的需求”这一重要问题给出了我们的思考。

论文标题:
Self-Tuning for Data-Efficient Deep Learning

论文链接:
http://ise.thss.tsinghua.edu.cn/~mlong/doc/Self-Tuning-for-Data-Efficient-Deep-Learning-icml21.pdf

GitHub链接:
https://github.com/thuml/Self-Tuning

引言

大规模标记数据集推动深度学习获得了广泛应用,然而,在现实场景中收集足量的标记数据往往耗时耗力。为了减少对标记数据的需求,半监督学习和迁移学习的研究者们从两个不同的视角给出了自己的思考:半监督学习(Semi-supervised Learning, SSL)侧重于同时探索标记数据和无标记数据,通过挖掘无标记数据的内在结构增强模型的泛化能力,而迁移学习(Transfer Learning, TL)旨在将预训练模型微调到目标数据中,也就是我们耳熟能详的预训练-微调范式。

半监督学习的最新进展,例如UDA,FixMatch等方法,证明了自训练(Self-Training)的巨大潜力。通过弱增广样本为强增广样本生成伪标记(pseudo-label),FixMatch就可以在Cifar10、SVHN、STL-10数据集上取得了令人耳目一新的效果。然而,细心的读者会发现,上述数据集都是类别数较少的简单数据集(都是10类),当类别数增加到100时,FixMatch这种从头开始训练(train from scratch)的自训练方法的表现就差强人意了。进一步地,我们在CUB200上将类别数从10逐渐增加到200时,发现FixMatch的准确率随着伪标签的准确率的下降而快速下降。这说明,随着类别数的增加,伪标签的质量逐渐下降,而自训练的模型也被错误的伪标签所误导,从而难以在测试数据集上取得可观的效果。这一现象,被前人总结为自训练的确认偏差(confirmation bias)问题,说明Self-training虽然是良药,偶尔却有毒。

迁移学习在计算机视觉和自然语言处理中被广泛使用,预训练-微调(fine-tuning)的范式也比传统的领域适应(domain adaptation)约束更少,更具落地价值。然而,现有的迁移学习方法专注于从不同角度挖掘预训练模型和标记数据,却对更为容易获取的无标记数据熟视无睹。以迁移学习的最新方法Co-Tuning为例,它通过学习源领域类别和目标领域类别的映射关系,实现了预训练模型参数的完全迁移。然而,因为仅仅将预训练模型迁移到标记数据中,Co-Tuning容易过拟合到有限的标记数据上,测试准确率随着标记数据比例的减少而迅速下降,我们将这一现象总结为模型漂移(model shift)问题。

为了摆脱迁移学习和半监督学习的困境,我们提出了一种称为数据高效深度学习(data-efficient deep learning)的新场景, 通过同时挖掘预训练模型和无标记数据的知识,使他们的潜力得到充分释放。这可能是迁移学习在工业界最为现实的落地场景:当我们试图获得目标领域的一个优秀模型时,源领域的预训练模型和目标领域的无标记数据几乎唾手可得。同时,为了解决前述的确认偏差和模型漂移问题,我们提出了一种称为Self-Tuning的新方法,将标记数据和无标记数据的探索与预训练模型的迁移融为一体,以及一种通用的伪标签组对比机制(Pseudo Group Contrast),从而减轻对伪标签的依赖,提高对伪标签的容忍度。在多个标准数据集的实验表明,Self-Tuning远远优于半监督学习和迁移学习的同类方法。例如,在标签比例为15%的Stanford-Cars数据集上,Self-Tuning的测试精度比fine-tuning几乎提高了一倍

如何解决确认偏差问题?

为了找出自训练的确认偏差(confirmation bias)问题的根源,我们首先分析了伪标签(pseudo-label)广泛采用的交叉熵损失函数(Cross-Entropy, CE):

其中,是输入生成的伪标签, 而是模型对于样本。通常地,大多数自训练方法都会针对confidence做一个阈值过滤,只有大于阈值 (比如FixMatch中设置了0.95的阈值)的样本的预测标签才会被视为合格的伪标签加入模型训练。然而,如图2所示,由于交叉熵损失函数专注于学习不同类别的分类面,如果某些伪标签存在错误,通过交叉熵损失函数训练的模型就会轻易地被错误的伪标签所误导。

为了解决交叉熵损失函数的类别鉴别(class discrimination)特性对自训练带来的挑战,最近取得突破进展的基于样本鉴别(sample discrimination)思想的对比学习损失函数吸引了我们的注意。给定由输入生成的查询样本,在不同数据增广下生成的副本,以及个不同输入生成的负样本,则通过内积度量相似性的对比学习(Constrastive Learning, CL)损失函数可以定义为

可以看出,对比学习旨在最大化同一样本在两个不同数据增广下的表征相似性,而最小化不同样本间的表征相似性,从而实现样本鉴别,挖掘数据中隐藏的流形结构。这种设计与伪标签无关,天然地不受错误的伪标签的影响。然而,标准的对比学习损失函数未能将标签和伪标签嵌入到模型训练中,从而使有用的鉴别信息束之高阁。

为了解决这一挑战,我们提出了一种通用的伪标签组对比机制(Pseudo Group Contrast, PGC)。对于任何一个查询样本,它的伪标签用表示。PGC将具有相同伪标签()的样本都视为正样本,而具有不同伪标签()的样本则组成了负样本,从而最大化查询样本与具有相同伪标签的正样本的表征相似性,实现伪标签的组对比。

那么,为什么PGC机制就可以提高对错误的伪标签的容忍度呢?我们认为,这是因为PGC采用了具有竞争机制的softmax函数,同一伪标签下的正样本会互相竞争。如果正样本的伪标签是错误的,这些伪标签的样本也会在竞争中落败,因为那些具有正确伪标签的正例样本的表征与查询样本的表征更相似。这样的话,模型在训练过程中会更多地受到正确的伪标签的影响,而不是像交叉熵损失函数那样直接地受到错误的伪标签的误导。我们在CUB数据集上的分析实验也证明了这一点:1. 在模型训练伊始,Self-Tuning和FixMatch具有相似的伪标签准确率,但是随着模型逐渐趋于收敛,Self-Tuning的测试集准确率明显高于FixMatch。2. 在具有不同类别数的CUB数据集上,Self-Tuning的测试准确率始终高于伪标签准确率,而FixMatch的测试准确率被伪标签准确率给限制住了。

如何解决模型漂移问题?

如前所述,当我们只在有限的标记数据集上微调预训练模型时,模型漂移问题往往难以避免。为了解决这个问题,近期发表的一篇名为SimCLRv2的论文提出可以综合利用预训练模型、标记数据和无标记数据的信息。他们给出了一个有趣的解决方案:首先在标记数据集()上微调预训练模型(),继而在无标记数据集()上进行知识蒸馏。然而,通过这一从到再到的“序列化”方式,微调后的模型依然倾向于向有限的标记数据偏移。我们认为,应该将标记和未标记数据的探索与预训练模型的迁移统一起来。

与SimCLRv2的“序列化”方式不同,我们提出了一种“一体化”的形式来解决模型漂移问题。首先,与半监督学习从零开始训练模型的通用实践不同,Self-Tuning的模型起点是一个相对准确的大规模预训练模型,通过更准确的初始化模型来提供一个更好的隐式正则。同时,预训练模型的知识将并行地流入标记数据和无标记数据中,标记数据和无标记数据产生的梯度也会同时更新模型参数。这种“一体化”的形式有利于同时探索标记数据的判别信息和无标记数的内在结构,大大缓解模型漂移的挑战。

另一方面,在对比学习中,负样本的规模越大,模型的效果往往越好。与MoCo类似,我们也通过引入队列的方式将负样本规模与批量大小(batch-size)解耦,使得负样本规模可以远大于批量大小。另一方面,队列的方式可以保证每次对比时,每个伪类下的负样本数目恒定,不受每个minibatch随机采样的影响。与标准的对比学习不同的是,由于伪标签的引入,PGC需要维护C个队列,其中C是类别数。在每次模型迭代中,对于无标记样本,将根据他们的伪标签渐进地替换对应队列里面最早的样本。而对于标记数据,因为他们天然地拥有准确的标签,则可以根据他们的标签来更新对应的队列。值得注意的是,我们在标记数据和无标记数据间共享了这些队列。这一设计的好处在于:将标记数据中宝贵的准确标签嵌入到共享队列中,从而提高了无标记数据的候选样本的伪标签准确性。

实验

在实验部分,我们在5种数据集、3种标记数据比例和4种预训练模型下,测试了Self-Tuning的效果,同时与5种主流迁移学习方法、6种主流半监督学习方法以及他们的至强组合进行了充分的对比。

迁移学习的Benchmark

我们首先在迁移学习的常用数据集CUB-200-2011, Stanford Cas和FGVC Aircraft下进行实验,将标记数据的比例依次设置为15%,30%和50%,采用ResNet-50作为预训练模型。结果显示,Self-Tuning大幅领先于现有方法,例如,在标签比例为15%的Stanford-Cars数据集上,Self-Tuning的测试精度比fine-tuning几乎提高了一倍

半监督学习的Benchmark

在半监督学习的主流数据集CIFAR-100、CIFAR-10、SVHN和STL-10中,我们采用了类别数最多、最困难的CIFAR-100数据集。由于在ImageNet上预训练的WRN-28-8模型尚未公开,我们采用了参数少得多的EfficientNet-B2模型。实验结果表明,预训练模型的引入对于半监督学习有如虎添翼的效果。同时,由于采用了对伪标签依赖更小的PGC损失函数,Self-Tuning充分挖掘了预训练模型、标记数据和无标记数据的所有信息,在各种实验设定下均取得了state-of-the-art的测试准确率

无监督预训练模型

为了证明Self-Tuning可以拓展到无监督预训练模型中,我们做了MoCov2迁移到CUB-200的实验。无论是每类4个样本还是每类25个样本的实验设定,Self-Tuning相较于迁移学习和半监督学习的方法都有明显提升

命名实体识别

为了证明Self-Tuning可以拓展到自然语言处理(NLP)的任务中,我们在一个英语命名实体识别数据集CoNLL 2003上进行了实验。按照Co-Tuning的实验设定,我们采用掩蔽语言建模的BERT作为预训练模型。以命名实体的F1得分作为度量指标的话,fine-tuning的F1得分为90.81,BSS、L2-SP和Co-Tuning分别达到90.85、91.02和91.27,而Self-Tuning取得了明显更高的94.53的F1得分,初步证明了Self-Tuning在NLP领域的强大潜力。更加详尽的NLP实验,会在未来的期刊版本中进行拓展。

消融实验

在消融实验部分,我们从两个不同的角度进行了对比。首先是损失函数,PGC损失函数比Cross-Entropy和Contrastive Learning的损失函数有明显提升。其次是信息的探索方式,无论是去掉标记数据还是无标记数据上的PGC损失函数,抑或在标记数据和无标记数据间设置单独的负样本队列,都不及Self-Tuning所提的“一体化”信息探索。

展望

在深度学习社区中,如何减少对标记数据的需求是一个至关重要的问题。考虑到迁移学习和半监督学习的普通实践中只关注预训练模型或无标记数据的不足,本文提出了一种新的数据高效的深度学习机制,可以充分发挥预训练模型和无标记数据的优势。这一机制可能是迁移学习在工业界最为现实的落地场景,值得我们继续大力研究。另一方面,我们提出的Self-Tuning方法简单通用,是迁移学习、半监督学习和对比学习等领域的核心思想的集大成者,可以提高对伪标签的容忍度。对于其他需要用到伪标签的场景,应该也有一定的借鉴价值。

后台回复关键词【入群

加入卖萌屋NLP/IR/Rec与求职讨论群

后台回复关键词【顶会

获取ACL、CIKM等各大顶会论文集!

ICML2021 | Self-Tuning: 如何减少对标记数据的需求?相关推荐

  1. AI人工智能标记数据的技术:类型、方法、质量控制、应用

    AI人工智能 标记数据 在人工智能(Artificial Intelligence,简称AI)领域中,标记数据是非常重要的一环.它是指对原始数据进行标记和注释,以便机器学习算法可以理解和利用这些数据. ...

  2. 干货 | 只有100个标记数据,如何精确分类400万用户评论?

    来源:新智元 本文共2200字,建议阅读6分钟. 本文介绍了面向NLP任务的迁移学习新模型ULMFit,只需使用极少量的标记数据,文本分类精度就能和数千倍的标记数据训练量达到同等水平. [ 导读 ]在 ...

  3. 放弃手工标记数据吧!斯坦福大学开源弱监督框架

    https://www.toutiao.com/a6668443801553469965/ 手工标记大量数据始终是开发机器学习的一大瓶颈.斯坦福AI Lab的研究人员探讨了一种通过编程方式生成训练数据 ...

  4. VTK:标记数据映射器用法实战

    VTK:标记数据映射器用法实战 程序输出 程序完整源代码 程序输出 程序完整源代码 #include <vtkActor.h> #include <vtkActor2D.h> ...

  5. labelme标记数据后,批量处理json文件,生成标签

    1.安装labelme的过程省略,可参考别人 2.打开anaconda prompt 3.激活安装有labelme的虚拟环境 4.运用labelme命令打开labelme开始标记数据 5.处理json ...

  6. 全球Flickr地理标记数据,含经纬度

    数据内容:全球Flickr地理标记数据,含经纬度 数据来源:本数据来源于Flickr地理标记图片信息,flickr是全球最受欢迎和使用最多的公开图片网络社交平台.Flickr图片社交媒体为用户提供图片 ...

  7. 【Axure教程】中继器表格寻找和标记数据

    在系统表格中,我们想在表格中快速找到对应的数据,通常我们会用条件筛选来完成,但是用筛选的方式,其他数据就看不到了,少了两种条件之间的对比.所以如果需要数据对比的情况下,我们更多的是用标记数据的方式,将 ...

  8. 鼓励政府带头采购云服务,减少自建数据中心

    鼓励政府带头采购云服务,减少自建数据中心 返回列表   日期:2014-06-26 核心提示: 云计算发展与政策论坛的官方数据显示,已有10家国内企业获可信云认证通过,包括阿里云.中国电信.新浪.中国 ...

  9. oracle 数据块 修复,案例:Oracle坏块 使用RMAN工具的命令clear标记数据块为corrupt 修复坏块...

    天萃荷净 运维DBA巡检发现数据文件中存在坏块,使用RMAN工具的命令clear标记数据坏块,使用bbed修复坏块 在rman中有隐藏的命令clear,可以标记数据块为corrupt,从而实现数据库坏 ...

最新文章

  1. hmac-sha1加密算法C源码示例
  2. poj1236(强连通分量)
  3. 【项目实战】基于随机森林算法的房屋价格预测模型
  4. 5.2.3 OS之I/O设备的分配与回收(DCT-COCT-CHCT-SDT)
  5. boost::log::parse_formatter用法的测试程序
  6. 【QuotationTool】主要数据结构
  7. pip安装deb_技术|如何在 Ubuntu 上安装 pip
  8. vue借助axios实现网络通信
  9. MATLAB 添加自有的工具包
  10. 语句乎?表达式乎?(Python/C)
  11. Dungeon Master 地下城大师(BFS进阶)
  12. 【hdu2815-Mod Tree】高次同余方程-拓展BadyStepGaintStep
  13. python链表的实例_python数据结构链表之单向链表(实例讲解)
  14. 小米wifi 苹果驱动安装教程macOS Mojave 10.14,Sierra 10.12测试通过
  15. C语言网络编程——基础
  16. 关于胶囊检测的思考-代码实现
  17. 如何选择工业中CCD相机与CMOS相机
  18. iOS Safari阅读模式分析过程
  19. 资深项目经理推荐的五款项目管理工具
  20. 【c语言】求一个3行4列矩阵的外框的元素值之和

热门文章

  1. 【原】jQuery编写插件
  2. ios截屏 u3d导出Xcode工程截屏
  3. MVC阻止用户注入JavaScript代码或者Html标记
  4. linux查看磁盘占用
  5. 【Github教程】史上最全github用法:github入门到精通
  6. [转载]MVVM、MVVMLight、MVVMLight Toolkit之我见
  7. 《信息检索导论》第七章总结
  8. Linux中断子系统-通用框架处理
  9. centos 安装idea 非可视化_太厉害了!目前 Redis 可视化工具最全的横向评测
  10. react 组件封装原则_我理解的React:React 到底是什么?