FixMatch 是对现有 SSL 方法的简化. FixMatch 首先对弱增强的未标记图像生成伪标签, 接着, 对同一图像进行强增强后, 再计算其预测分布, 最后计算强增强的预测与伪标签之间的交叉熵损失.

论文地址: FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
代码地址: https://github.com/google-research/fixmatch
会议: NeurIPS 2020
任务: 分类

FixMatch

FixMatch 是 SSL 两种方法的组合: 一致性正则化和伪标签. 它的新颖之处在于这两种方法的组合以及在执行一致性正则化时使用单独的弱增强和强增强.

FixMatch 简要示意图如下:

将弱增强图像输入模型, 当某一预测类别概率高于阈值(虚线)时, 预测将转换为 one-hot 伪标签. 然后, 计算模型对同一图像的强增强的预测. 计算强增强的预测与伪标签之间的交叉熵损失.

文中符号系统如下:

  • X=((xb,pb);b∈(1,…,B))\mathcal{X}=((x_b,p_b);b\in(1,\dots,B))X=((xb​,pb​);b∈(1,…,B)) 为一个 batch_size BBB 的带标签示例.
  • U=((ub);b∈(1,…,μB)\mathcal{U}=((u_b);b\in(1,\dots,\mu B)U=((ub​);b∈(1,…,μB) 为一个 batch_size μB\mu BμB 的无标签示例, 其中 μ\muμ 是确定 X\mathcal{X}X 和 U\mathcal{U}U 相对大小的超参数.
  • pm(y∣x)p_m(y\vert x)pm​(y∣x) 为预测类别分布.
  • H(p,q)\mathrm{H}(p,q)H(p,q) 为两个概率ppp, qqq分布之间的交叉熵.
  • A()\mathcal{A}()A(), α()\alpha()α() 分别为不同类型的增强.

一致性正则化及伪标签方法简要介绍如下:

Consistency regularization. 关于一致性正则化, 核心就是基于平滑假设, 模型对于对增强后数据的预测应与原始数据预测的结果一致.
Pseudo-labeling. 即利用模型本身来获取未标记数据的人工标签. 更具体地说, pbp_bpb​ 的伪标签 qbq_bqb​ 可以分别定义为基于锐化的连续分布(软)或基于 arg max⁡\argmaxargmax 操作的独热分布(硬). 在本文里, 人工标签一般指"硬"标签, 并且只保留最大类别概率高于预定阈值的情况. pseudo-labeling 使用如下损失函数:
1μB∑b=1μb(max⁡(qb)≥τ)H(q^b,qb)(1)\frac{1}{\mu B} \sum_{b=1}^{\mu b}(\max(q_b) \geq \tau)\mathrm{H}(\hat{q}_b,q_b) \tag{1} μB1​b=1∑μb​(max(qb​)≥τ)H(q^​b​,qb​)(1)
其中 qb=pm(y∣ub)q_b=p_m(y\vert u_b)qb​=pm​(y∣ub​), q^b=arg max⁡(qb)\hat{q}_b=\argmax(q_b)q^​b​=argmax(qb​), τ\tauτ 为阈值. 鼓励模型的预测是对未标记数据的低熵, 或者说是高置信度.

FixMatch 算法

FixMatch 的损失函数由两个交叉熵损失项组成: 应用于标记数据的监督损失 ℓs\ell_sℓs​ 和无监督损失 ℓu\ell_uℓu​. 具体来说, ℓs\ell_sℓs​ 只是弱增强标记示例上的标准交叉熵损失:
ℓs=1B∑b=1BB(pb,pm(y∣α(xb)))(2)\ell_s=\frac{1}{B} \sum_{b=1}^B \mathrm{B}(p_b,p_m(y\vert \alpha(x_b))) \tag{2} ℓs​=B1​b=1∑B​B(pb​,pm​(y∣α(xb​)))(2)
FixMatch 为每个未标记的示例计算一个人工标签, 然后将其用于标准交叉熵损失. 为了获得人工标签, 首先在给定未标记图像的弱增强版本的情况下计算模型的预测类别分布: qb=pm(y∣α(ub))q_b =p_m(y \vert \alpha(u_b))qb​=pm​(y∣α(ub​)). 然后, 使用 q^b=arg max⁡(qb)\hat{q}_b = \argmax(q_b)q^​b​=argmax(qb​) 作为伪标签, 与 ubu_bub​ 的强增强版本做交叉熵损失:
ℓu=1μB∑b=1μB(max⁡(qb)≥τ)H(q^b,pm(y∣A(ub)))(3)\ell_u=\frac{1}{\mu B} \sum_{b=1}^{\mu B} (\max(q_b)\geq \tau) \mathrm{H}(\hat{q}_b,p_m(y\vert \mathcal{A}(u_b))) \tag{3} ℓu​=μB1​b=1∑μB​(max(qb​)≥τ)H(q^​b​,pm​(y∣A(ub​)))(3)
综上, FixMatch 的损失函数定义为: loss=ℓs+λℓuloss=\ell_s+\lambda\ell_uloss=ℓs​+λℓu​. 完整的算法如下:

  • 1.计算弱增强标签数据集上的交叉熵损失 ℓs\ell_sℓs​.
  • 2.对每一个 μB\mu BμB batch, 计算弱增强无标签数据集上的预测分布及伪标签 qbq_bqb​, q^b\hat{q}_bq^​b​.
  • 3.计算无标签数据交叉熵损失 ℓu\ell_uℓu​
  • 4.得到目标函数总损失 ℓs+λℓu\ell_s+\lambda\ell_uℓs​+λℓu​.

FixMatch 中使用的增强方法

FixMatch 利用了两种增强: “弱"和"强”.

  • 弱增强是一种标准的翻转和移位增强策略. 例如在数据集上以 50% 的概率随机水平翻转图像, 并且在垂直和水平方向上随机平移.
  • 对于"强"增强, 文中尝试了两种基于 AutoAugment 的方法, 然后是 Cutout. AutoAugment 使用强化学习来查找包含来自 Python Imaging Library 的转换的增强策略. 这需要标记数据来学习增强策略, 这使得在可用标记数据有限的 SSL 设置中使用存在问题. 因此, 使用不需要利用标记数据学习增强策略的 AutoAugment 变体, 例如 RandAugment 和 CTAugment. RandAugment 和 CTAugment 都没有使用学习策略, 而是为每个样本随机选择转换. 对于 RandAugment, 控制所有失真严重程度的幅度是从预定义的范围内随机采样的. 具有随机幅度的 RandAugment 也被用于 UDA. 而对于 CTAugment, 单个变换的幅度是即时学习的.

其他

一些其他重要因素会影响 SSL 的性能, 例如: architecture, optimizer, training schedule 等. 经过实验, 文中发现正则化尤为重要. 在所有的模型和实验中, 使用简单的权重衰减正则化. 同时发现使用 Adam 优化器会导致更差的性能, 而使用 SGD 则没有这种情况, 另外, 使用 SGD 和使用 Nesterov 之间没有存在实质性差异. 对于学习率, 使用余弦学习率衰减. 它将学习率设置为 ηcos⁡7πk16K\eta \cos \frac{7\pi k}{16K}ηcos16K7πk​, 其中 η\etaη 是初始学习率, kkk 是当前训练步长, KKK 是总学习率训练步骤. 最后, 使用模型参数的指数移动平均值(EMA)报告最终性能.

FixMatch 可以很容易地使用 SSL 文献中的技术进行扩展. 例如, 来自 ReMixMatch 的增强锚定和分布对齐. 此外, 可以用与模态无关的增强策略, 例如 MixUp 或对抗性扰动代替 FixMatch 中的强增强. 对抗性扰动在 VAT, Adversarial Dropout 中已经应用. MixUp 也在 MixMatch, ICT 中成功应用.

[半监督学习] FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence相关推荐

  1. 图解半监督学习FixMatch,只用10张标注图片训练CIFAR10

    2020-05-25 11:20:08 作者:amitness 编译:ronghuaiyang 导读 仅使用10张带有标签的图像,它在CIFAR-10上的中位精度为78%,最大精度为84%,来看看是怎 ...

  2. 图构造总结-Graph‑based semi‑supervised learning via improving the quality of the graph dynamically

    前言 本博文主要对论文中提到的图构造方法进行梳理,论文自己提出的模型并未介绍,感兴趣的可以阅读原文 摘要 基于图的半监督学习GSSL主要包含两个过程:图的构建和标签推测.传统的GSSL中这两个过程是完 ...

  3. [半监督学习] Adversarial Dropout for Supervised and Semi-Supervised Learning

    引入了对抗性 dropout(AdD), 可最大限度地提高具有 dropouts 的网络输出之间的差异. 识别出的对抗性 dropout 用于在训练过程中自动重新配置神经网络, 是 Virtual A ...

  4. 【李宏毅机器学习】Semi-supervised Learning 半监督学习(p24) 学习笔记

    文章目录 Semi-supervised Learning Introduction Supervised Learning Semi-supervised Learning Why semi-sup ...

  5. 李弘毅机器学习笔记:第十五章—半监督学习

    李弘毅机器学习笔记:第十五章-半监督学习 监督学习和半监督学习 半监督学习的好处 监督生成模型和半监督生成模型 监督生成模型 半监督生成模型 假设一:Low-density Separation Se ...

  6. 半监督学习(Semi-Supervised Learning, SSL)-简述及论文整理

    本文参考An Overview of Deep Semi-Supervised Learning,An overview of proxy-label approaches for semi-supe ...

  7. NeurIPS 2020 | FixMatch:通过图像增强就能实现半监督学习

    前言 算法.算力.数据是深度学习的三架马车.深度学习是数据驱动式方法,目前的从业基本者都有一个共识就是:数据是非常重要的且不可或缺的.在实际环境中对数据标注又是一个耗时和昂贵的过程.但是受束于资源的限 ...

  8. 长文总结半监督学习(Semi-Supervised Learning)

    ©PaperWeekly 原创 · 作者|燕皖 单位|渊亭科技 研究方向|计算机视觉.CNN 在现实生活中,无标签的数据易于获取,而有标签的数据收集起来通常很困难,标注也耗时和耗力.在这种情况下,半监 ...

  9. 更少的标签,更好的学习,谷歌半监督学习算法FixMatch

    点击我爱计算机视觉标星,更快获取CVML新技术 本文向大家推荐谷歌前段时间发布的论文 FixMatch: Simplifying Semi-Supervised Learning with Consi ...

最新文章

  1. iphone XCode调试技巧之EXC_BAD_ACCESS中BUG解决
  2. Spring Boot 整合 Elasticsearch,实现 function score query 权重分查询
  3. ecdf函数_关于ecdf函数的使用问题
  4. 人工智能:模型与算法 之 启发式搜索
  5. PowerDesigner 正向工程 和 逆向工程 说明
  6. socket编程资料-网络收集
  7. activity-启动动画的设定(下面弹出出现,弹入下面消失)
  8. EasyExcel快速上手~读取
  9. 区域显示触发_Nature Communications:地幔数据显示可氧化的火山气体的减少可能触发了大氧化事件...
  10. 特斯拉上海超级工厂已在建设动力系统厂房
  11. windows下,linux下elasticsearch安装插件head插件的步骤
  12. 使用Nginx的proxy_cache缓存功能取代Squid(转)
  13. 50 行代码,实现中英文翻译
  14. OverFeat心得
  15. 创业维艰--书摘+乱七八糟
  16. win10锁屏时间太短就关闭屏幕
  17. 【Matlab元胞自动机】元胞自动机地铁火灾疏散模型【含源码 246期】
  18. 学计算机专业1050显卡够不够,gtx1050显卡性能怎么样
  19. 干货!基于元消歧的偏多标记学习
  20. 查找chrome浏览器历史记录

热门文章

  1. Office 2013 Preview专业增强版下载
  2. stm32点亮三个led灯
  3. css 3d闪烁动画,CSS3实现闪烁动画效果的方法
  4. DARPA“人工智能探索”工作进展
  5. 零基础入门python-零基础 Python 入门
  6. 帝国cms 未审核 showinfo.php,帝国CMS自动审核发布信息脚本
  7. 绘制“校园通”系统的上下文图
  8. VBA,userform的控件 controls,如何禁用一些image/commandbutton 处理例子
  9. php 打开word乱码,如何解决php word 乱码问题
  10. 计算机毕业设计django基于python图书馆借阅系统(源码+系统+mysql数据库+Lw文档)