点上方蓝字计算机视觉联盟获取更多干货

在右上方 ··· 设为星标 ★,与你不见不散

编辑:Sophia
计算机视觉联盟  报道  | 公众号 CVLianMeng

转载于 :Sherlock 知乎链接:https://zhuanlan.zhihu.com/p/108625273

AI博士笔记系列推荐:

博士笔记 | 周志华《机器学习》手推笔记“神经网络”

导读:最近 self-supervised learning 变得非常火,首先是 kaiming 的 MoCo 引发一波热议,然后最近 Yann 在 AAAI 上讲 self-supervised learning 是未来。所以觉得有必要了解一下 SSL,也看了一些 paper 和 blog,最后决定写这篇文章作为一个总结。

什么是 Self-Supervised Learning

首先介绍一下到底什么是 SSL,我们知道一般机器学习分为监督学习,非监督学习和强化学习。而 self-supervised learning 是无监督学习里面的一种,主要是希望能够学习到一种通用的特征表达用于下游任务。其主要的方式就是通过自己监督自己,比如把一段话里面的几个单词去掉,用他的上下文去预测缺失的单词,或者将图片的一些部分去掉,依赖其周围的信息去预测缺失的 patch。

根据我看的文章,现在 self-supervised learning 主要分为两大类:1. Generative Methods;2. Contrastive Methods。下面我们分别简要介绍一下这这两种方法。

Generative Methods

首先我们介绍一下 generative methods。这类方法主要关注 pixel space 的重建误差,大多以 pixel label 的 loss 为主。主要是以 AutoEncoder 为代表,以及后面的变形,比如 VAE 等等。对编码器的基本要求就是尽可能保留原始数据的重要信息,所以如果能通过 decoder 解码回原始图片,则说明 latent code 重建的足够好了。

source: [Towards Data Science]
(https://towardsdatascience.com/generating-images-with-autoencoders-77fd3a8dd368)

这种直接在 pixel level 上计算 loss 是一种很直观的做法,除了这种直接的做法外,还有生成对抗网络的方法,通过判别网络来算 loss。

对于 generative methods,有一些问题,比如:

  1. 基于 pixel 进行重建计算开销非常大;

  2. 要求模型逐像素重建过于苛刻,而用 GAN 的方式构建一个判别器又会让任务复杂和难以优化。

从这个 blog 中我看到一个很好的例子来形容这种 generative methods。对于一张人民币,我们能够很轻易地分辨其真假,说明我们对其已经提取了一个很好的特征表达,这个特征表达足够去刻画人民币的信息, 但是如果你要我画一张一模一样的人民币的图片,我肯定没法画出来。通过这个例子可以明显看出,要提取一个好的特征表达的充分条件是能够重建,但是并不是必要条件,所以有了下面这一类方法。

source:[blog]
(https://ankeshanand.com/blog/2020/01/26/contrative-self-supervised-learning.html)

Contrasive self-supervised learning

除了上面这类方法外,还有一类方法是基于 contrastive 的方法。这类方法并不要求模型能够重建原始输入,而是希望模型能够在特征空间上对不同的输入进行分辨,就像上面美元的例子。

这类方法有如下的特点:1. 在 feature space 上构建距离度量;2. 通过特征不变性,可以得到多种预测结果;3. 使用 Siamese Network;4. 不需要 pixel-level 重建。正因为这类方法不用在 pixel-level 上进行重建,所以优化变得更加容易。当然这类方法也不是没有缺点,因为数据中并没有标签,所以主要的问题就是怎么取构造正样本和负样本。

目前基于 contrastive 的方法已经取得了很好的紧张,在分类任上已经接近监督学习的效果,同时在一些检测、分割的下游任务上甚至超越了监督学习作为 pre-train的方法。

下面是这两类方法的总结图片。

source: [blog]
(https://ankeshanand.com/blog/2020/01/26/contrative-self-supervised-learning.html)

为什么需要 self-supervised learning

上面我们讲了什么是 self-supervised learning,那么为什么我们需要自监督学习呢,以及它能够给我们带来哪些帮助?

在目前深度学习发展的情况下,对于监督学习,我们希望使用更少的标注样本就能够训练一个泛化能力很好的模型,因为数据很容易获取,但是标注成本却是非常昂贵的。而在强化学习中,需要大量的经验对 agent 进行训练,如果能搞减少 agent 的尝试次数,也能够加速训练。除此之外,如果拿到一个好的特征表达,那么也有利于做下游任务的 fintuen和 multi-task 的训练。

最后我们总结一下监督学习和自监督学习的特点,其中 supervised learning 的特点如下:

  1. 对于每一张图片,机器预测一个 category 或者是 bounding box

  2. 训练数据都是人所标注的

  3. 每个样本只能提供非常少的信息(比如 1024 个 categories 只有 10 bits 的信息)

于此对比的是,self-supervised learning 的特点如下:

  1. 对于一张图片,机器可以预任何的部分

  2. 对于视频,可以预测未来的帧

  3. 每个样本可以提供很多的信息

所以通过自监督学习,我们可以做的事情可以远超过监督学习,也难怪 Yann 未来看好 self-supervised learning。目前出现的性能很好的文章主要是基于 contrastive 的方法,所以下面我们介绍几篇基于 contrastive 方法的文章。

Contrastive Predictive Coding

第一篇文章是 Representation Learning with Contrastive Predictive Coding(https://arxiv.org/abs/1807.03748)。这篇文章主要是通过 contrastive 的方式在 speech, images, text 和 在reinforcement learning 中都取得了很好的效果。

从前面我们知道,由一个原始的 input 去建模一个 high-level representation 是很难的,这也是自监督学习想做的事情。其中常用的策略是:future,missing 和 contextual,即预测未来的信息,比如 video 中当前帧预测后面的帧;丢失的信息或者是上下文的信息,比如 NLP 里面的 word2vec 和 BERT。

对于一个目标 x 和他的上下文 c 来说,直接去建模输出 p(x|c) 会损失很多信息,将 target x 和 context c 更合适的建模方式是最大化他们之间的 mutual information,即下面的公式:

优化了他们之间的互信息,即最大化 ,说明 要远大于 ,即在给定 context c 的情况下, 要找到专属于 c 的那个 x,而不是随机采样的 x。

基于这个观察,论文对 density ratio 进行建模,这样可以保留他们之间的互信息

对于这个 density ratio,可以构建左边的函数 f 去表示它,只要基于函数 f 构造下面的损失函数,优化这个损失函数就等价于优化这个 density ratio,下面论文会证明这一点。

而这个损失函数,其实就是一个类似交叉熵的函数,分子是正样本的概率,分母是正负样本的概率求和。

下面我们证明如果能够最优化这个损失函数,则等价于优化了 density ratio,也就优化了互信息。

首先将这个 loss 函数变成概率的形式,最大化这个正样本的概率分布,然后通过 bayesian 公式进行推导,其中 X 是负样本,和 以及 c 都无关。

通过上面的推导,可以看出优化这个损失函数其实就是在优化 density ratio。论文中把 f 定义成一个 log 双线性函数,后面的论文更加简单,直接定义为了 cosine similarity。

有了这个 loss,我们只需要采集正负样本就可以了。对于语音和文本,可以充分利用了不同的 k 时间步长,来采集正样本,而负样本可以从序列随机取样来得到。对于图像任务,可以使用 pixelCNN 的方式将其转化成一个序列类型,用前几个 patch 作为输入,预测下一个 patch。

source: [blog]

(https://ankeshanand.com/blog/2020/01/26/contrative-self-supervised-learning.html)

source:[ Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748)

Deep InfoMax

通过上面的分析和推导,我们有了这样一个通用的框架,那么 deep infomax 这篇文章就非常好理解了,其中正样本就是第 i 张图片的 global feature 和中间 feature map 上个的 local feature,而负样本就是另外一张图片作为输入,非常好理解。

source: [Learning deep representations by mutual information estimation and maximization](https://arxiv.org/abs/1808.06670)

Contrastive MultiView Coding

除了像上面这样去构建正负样本,还可以通过多模态的信息去构造,比如同一张图片的 RGB图 和 深度图。CMC 这篇 paper 就是从这一点出发去选择正样本,而且通过这个方式,每个 anchor 不仅仅只有一个正样本,可以通过多模态得到多个正样本,如下图右边所示。

source: [Contrastive Multiview Coding](http://arxiv.org/abs/1906.05849)

现在我们能够拿到很多正样本,问题是怎么获得大量的负样本,对于 contrastive loss 而言,如何 sample 到很多负样本是关键,mini-batch 里面的负样本太少了,而每次对图片重新提取特征又非常的慢。虽然可以通过 memory bank 将负样本都存下来,但是效果并不好,所以如何节省内存和空间获得大量的负样本仍然没有很好地解决。

MoCo

有了上面这么多工作的铺垫,其实 contrastive ssl 的大框架已经形成了,MoCo 这篇文章也变得很好理解,可以把 target x 看成第 i 张图片的随机 crop,他的正样本通过一个 model ema 来得到,可以理解为过去 epochs 对这张图片的 smooth aggregation。而负样本则从 memory bank 里面拿,同时 memory bank 的 feature 也是通过 model ema 得到,并且通过队列的形式丢掉老的 feature。

source: [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/abs/1911.05722)

MoCo 通过工程的方式,和一些 trick,比如 model ema 和 shuffleBN 来解决之前没法很好 sample 负样本的问题。

SimCLR

最近,hinton 组也放了一篇做 ssl 的 paper,其实都是用的同一套框架,也没有太多的 novelty。虽然摘要里面说可以抛弃 memory bank,不过细看论文,训练的 batchsize 需要到几千,要用32-128 cores 的 TPU,普通人根本用不起。

不过这篇文章系统地做了很多实验,比如探究了一下数据增强的影响,以及的 projection head 的影响等,不过也没有从理论上去解释这些问题,只是做了实验之后获得了一些结论。

Results

source: [A Simple Framework for Contrastive Learning of Visual Representations]
(https://arxiv.org/abs/2002.05709)

最后展示了不同方法的结果,可以看到在性能其实已经逼近监督学习的效果,但是需要 train 4x 的时间,同时网络参数也比较大。

虽然性能没有超过监督学习,不过我认为这仍然给了我们很好的启发,比如训练一个通用的 encoder 来接下游任务,或者是在 cross domain 的时候只需要少量样本去 finetune,这都会给实际落地带来收益。

Reference

contrastive self-supervised learning,https://ankeshanand.com/blog/2020/01/26/contrative-self-supervised-learning.html

deep infomax 和 深度学习中的互信息,https://zhuanlan.zhihu.com/p/46524857

END

声明:本文来源于网络

如有侵权,联系删除

联盟学术交流群

扫码添加联盟小编,可与相关学者研究人员共同交流学习:目前开设有人工智能、机器学习、计算机视觉、自动驾驶(含SLAM)、Python、求职面经、综合交流群扫描添加CV联盟微信拉你进群,备注:CV联盟  

最新热文荐读

GitHub | 计算机视觉最全资料集锦

Github | 标星1W+清华大学计算机系课程攻略!

Github | 吴恩达新书《Machine Learning Yearning》

收藏 | 2020年AI、CV、NLP顶会最全时间表!

收藏 | 博士大佬总结的Pycharm 常用快捷键思维导图!

收藏 | 深度学习专项课程精炼图笔记!

笔记 | 手把手教你使用PyTorch从零实现YOLOv3

笔记 | 如何深入理解计算机视觉?(附思维导图)

笔记 | 深度学习综述思维导图(可下载)

笔记 | 深度神经网络综述思维导图(可下载)

总结 | 2019年人工智能+深度学习笔记思维导图汇总

点个在看支持一下吧

一文带你读懂Self-Supervised Learning(自监督学习)相关推荐

  1. DNN、RNN、CNN.…..一文带你读懂这些绕晕人的名词

    DNN.RNN.CNN.-..一文带你读懂这些绕晕人的名词 https://mp.weixin.qq.com/s/-A9UVk0O0oDMavywRGIKyQ 「撞脸」一直都是娱乐圈一大笑梗. 要是买 ...

  2. 一文带您读懂FCC、CE、CCC认证的区别

    一文带您读懂FCC.CE.CCC认证的区别 参考资料:https://3g.k.sohu.com/t/n411629823 FCC认证,CE认证,CCC认证是产品认证中比较常见的几个认证,前两者经常有 ...

  3. 机器学习中为什么需要梯度下降_机器学习101:一文带你读懂梯度下降

    原标题 | Machine Learning 101: An Intuitive Introduction to Gradient Descent 作者 | Thalles Silva 译者 | 汪鹏 ...

  4. 一文带你读懂HTTP协议的前世今生

    点击上方蓝字关注我们 HTTP,Hypertext Transfer Protocol,超文本协议,是在万维网上传输文件(如文本.图形图像.声音.视频和其他多媒体文件)的规则集.如果web用户打开他们 ...

  5. 用程序员计算机算进制,一文带你读懂计算机进制

    hi,大家好,我是开发者FTD.在我们的学习和工作中少不了与进制打交道,从出生开始上学,最早接触的就是十进制,当大家学习和使用计算机时候,我们又接触到了二进制.八进制以及十六进制.那么大家对进制的认识 ...

  6. 一文带你读懂“经典TRIZ”

    本文承接上文<一文带第读懂TRIZ>,下面开始看第二个问题:什么是"经典TRIZ"? 很多书里都有对TRIZ的产生与发展的描述. 我个人在看了很多的书和文献以后,认为: ...

  7. 简单一文带你读懂Java变量的作用和三要素

    Java变量的作用 不只是java,在其他的编程语言中变量的作用只有一个:存储值(数据) 在java中,变量本质上是一块内存区域,数据存储在java虚拟机(JVM)内存中 变量的三要素 变量的三要素分 ...

  8. 一文带你读懂感知机的前世今生(上)

    一文带你读懂感知机的前世今生 前言 男女不分 什么是神经元 M-P神经元 全或无定律 McCulloch和Pitts 一种高度简化的模型 MP神经元和真值表 MP神经元的几何理解 后记 参考 前言 男 ...

  9. 《一文带你读懂:云原生时代业务监控》

    点击上方蓝字关注我们! 对业务来说,完备的应用健康性和数据指标的监控非常重要,通过采集准确的监控指标.配置合理的告警机制,我们能够提前或者尽早发现问题,并做出响应.解决问题,进而保证产品的稳定性,提升 ...

  10. au加载默认的输入和输出设备失败_一文带你读懂 C/C++ 语言输入输出流与缓存区...

    (给CPP开发者加星标,提升C/C++技能) 作者:技术让梦想更伟大 / 李肖遥 (本文来自作者投稿) 前言 有没有发现,基本上所有的C语言入门书籍,或者是我们的教程里面,第一个C语言程序实体,都是& ...

最新文章

  1. linux修改mysql默认大小写配置,linux下设置mysql不区分大小写
  2. 【推荐】揭秘谷歌电影票房预测模型
  3. iqn怎么查 linux_程序员必备:46个Linux面试常见问题!收藏!
  4. Java--线程同步
  5. 深度学习框架PyTorch一书的学习-第四章-神经网络工具箱nn
  6. linux管理外部工具,linux – 除了iptables之外的数据包管理工具?
  7. 快速使用nexus搭建maven本地私服
  8. 常见数据类型的手机二维码生成与识别格式参考
  9. android平台Camera采集数据ffmpeg进行编码探究
  10. 二十一天学通JavaScript:cookie的安全性
  11. visio 2019 简单流程图教程
  12. mysql软件字体模糊_Windows 10字体模糊发虚! 如何解决?
  13. php随机点名代码怎么做,html座位表随机点名的实例代码
  14. 阿里云视频直播功能升级
  15. 使用 mongorestore恢复数据以及使用 Studio 3T GUI 管理数据库
  16. 什么是SAS硬盘,服务器硬盘sas和sata有什么区别
  17. SpringBoot强制下载文件
  18. pstack无法查看进程堆栈“Could not attach to target”问题
  19. Codeforces Round #470 (Div. 2) A Protect Sheep (基础)输入输出的警示、边界处理
  20. 从一到无穷大 #6 盘满排查过程

热门文章

  1. java集合框架的接口_Java集合框架——Set接口
  2. halcon获取图像中心点_关于Halcon的复杂图形中心点查找
  3. 大数据职业理解_数据分析师真有那么好?其实正在面临3大职业困境
  4. 安卓button设置背景图_这些安卓源码调试技巧,不懂的人月薪绝对不过 30k !
  5. oom 如何避免 高并发_如何设计这样一个高并发系统?
  6. Linux怎么设置ntp授时,linux设置ntp时间同步服务器地址
  7. @RequestBody配合@JsonFormat注解实现字符串自动转换成Date
  8. 【Golang 接口自动化05】使用yml管理自动化用例
  9. 【转】前端的BFC、IFC、GFC和FFC
  10. SpringBoot | 第六章:常用注解介绍及简单使用