在之前一篇推文一文串起从NLP到CV 预训练技术和范式演进中,由于篇幅有限,仅仅介绍了深度学习中的预训练技术发展,基本思路是顺着CV和NLP双线的预训练技术发展演进。

这里正式开启一个顺着这篇推文的倒叙精读系列。

Masked Autoencoders Are Scalable Vision Learners

正好mae的官方pytorch在两周前开源了

https://github.com/facebookresearch/mae

我们用倒叙的方式,从MAE往后看。开篇用一个非常夸张的实验效果demo图。这个效果实在是太夸张了,人都脑补不出来这样的马赛克程度。

摘要

MAE的方法非常简单,随机MASK住图片里的一些块,然后再去重构这些被MASK住的像素。这整个思想也来自 BERT 的带掩码的语言模型,但不一样的是这一个词(patches) 它就是一个 image 的一个块,然后它预测的是你这块里面的所有的像素。

全文有两个重要的创新点:跑得快+学得难

跑得快:非对称的自编码器架构(autoencoder),其编码器仅作用在可见的这些patch里面, 如果一个 patch 被它丢掉了,那么编码器就不会对它进行编码。这样图像encoder端的工作量就减少了,好处就是可以跑得很快。进一步地,解码器是一个比较轻量的解码器。一层transformer就够用。

学得难:预训练任务能够重构原始的像素级图片。并且,可以搞定75%的这些块全部遮住下的图像复原。这个事情是一个非平凡的,而且有意义的自监督的任务。如果你就简单遮住几块的话,那么就插一下值,你就可以出来了,这样整个模型可能学不到特别有意思的东西。但是你要是遮住高达75%的部分,苦一苦你的模型,说不定他会学到一些更好的一些表征

然后把这两个放在一起,跑得快+学得难,我们就可以让他做一些超越自己当前模型水平的水平的事情(老PUA了)

结果:用更小的数据来自监督预训练,超越了更多数据监督训练的ViT模型。他用来自于VIT这个论文的不加任何技巧的ViT-Huge的模型backbobe结构,

加上他的预训练方法,能够得到 87.8% 的ACC表现。

最后,强调一下迁移学习也很好。当然,预训练模型不迁移学习,那岂不预训练了一个寂寞。

结构

论文一般有两个图最重要,一个是第一页右上角的小图,第二个是第三页横跨双栏的大图。

这是 MAE体的架构图,预训练阶段一共分为四个部分,MASK,encoder,decoder。

MASK

可以看到一张图片进来,首先把你切块切成一个一个的小块,按格子切下来。

其中要被MASK住的这一块就是涂成一个灰色,然后没有MASK住的地方直接拎出来,这个地方75%的地方被MASK住了。注意是随机采样,而不是什么中心采样,网格的采样,局部采样等方式,s这部分在实验里对比过。这里比较符合认知的解释是,可以防止引入类似中心归纳偏好等特定bias,随机是最公平的。

encoder

前面拎起来的像素块即unmask部分,放进一个 encoder 的里面,这里采用了ViT论文中的transformer backbone,得到每一个块它对应的这一些特征。

在这个地方它要把它拉长,把这些被MASK那些块,重新放回到原来的位置,把它拉成一条向量。在预训练的时候,MASK住的东西,其实啥也没有了,作者给了他一个可以学习的共享隐向量+Position  embedding(!!!!这个地方比较难trick,推荐看一下代码实现)没有MASK住的,就是填上那 ViT 它出来的这些特征。组成一个长的隐层向量,输到一个解码器里面。

decoder

解码器会去尝试把里面的像素信息全部重构回来,得到最后的 target(目标的像素值)。要注意的是,解码的过程是没有加速度的,但是解码的模型一般都不大。我们知道编码的Transformer 这些模型计算量都特别大,如果有个几倍的加速,其实也是非常重要的一个事情。

下游任务

如果你想用这个模型来做一个下游任务呢,你就只需要它的编码器就行了,解码器是不需要的,你的图片进来你不需要对它做掩码

你直接切成这些格子块。然后过encoder它就会得到你所有那些块的一个特征的表示,这个就是你的图片的语义表征(representation)

实现细节

encoder

1.patch,图像切块, 图像在tensor中的表示为 (B,C,H,W) reshape 成 (B,N,PxPxC),其中B是Batch大小,N和P分别为 patch 数量 和 patch 大小。

N = H*W/P/P。

2.patch embedding, 1中的图片切块的嵌入表征,他是连续值经过一层全连接得到固定维度大小的值(dim),注意文本是one-hot形式,或者look up table的形式。

从1中的 (B,N,PxPxC) -> (B,N,dim)

3.position embedding,patch编码对应的embeding,这个和NLP中的词表查到的embedding是一样的。

4.部分编码,预训练阶段的Encoder从实现角度再复述一遍:图像切块-没有MASK的部分走patch embedding+position embedding

def forward_encoder(self, x, mask_ratio):# embed patchesx = self.patch_embed(x)# add pos embed w/o cls tokenx = x + self.pos_embed[:, 1:, :]# masking: length -> length * mask_ratiox, mask, ids_restore = self.random_masking(x, mask_ratio)# append cls tokencls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1)# apply Transformer blocksfor blk in self.blocks:x = blk(x)x = self.norm(x)return x, mask, ids_restore

decoder

1.mask部分的对应的隐向量并不来自于encoder的推断,而是直接在这里进行凭空初始化的共享token向量+position embedding。

2.decoder不需要用encoder那么重的模型。你可以理解为Bert的decoder就是个MLP,这里可以用一个特别简单的一层transformer。虽然decoder在数量补齐了复杂度,因为模型简单,压力并不太大。

def forward_decoder(self, x, ids_restore):# embed tokensx = self.decoder_embed(x)# append mask tokens to sequencemask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls tokenx_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshufflex = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token# add pos embedx = x + self.decoder_pos_embed# apply Transformer blocksfor blk in self.decoder_blocks:x = blk(x)x = self.decoder_norm(x)# predictor projectionx = self.decoder_pred(x)# remove cls tokenx = x[:, 1:, :]return x

loss

1.仅用MSE算mask path的像素差值。只算mask patch是因为实验结论,否则有大约0.5%的ACC下降。

2.归一化的像素值作为target比较好。实验结果

总结

预训练阶段:

1.图片切patch

2.patch做embedding (projection方式)

3.加上position embedding (lookup table方式)

4.mask打码(75%)

5.无码部分进encoder

6.有码部分做好可训练的共享语义向量+position embedding

7.按patch的原始顺序拼好mask和unmask的对应语义向量,送decoder

8.取decoder出来的,mask部分对应的像素值算mse loss。

实验部分

1.mask比例

少了多了都不好。所以说恰当的压力才是前进的动力。直观理解就是太简单了学不到东西,太难了也学不会。

2.采样策略

随机采样效果最好,其他的方式多多少少泛化能力都差一点。

block的任务更难,扣掉一大块比例太大也学不好,对于模型来说太难了。扣掉50%差不多了,但是效果比随机还差一点。和上面个实验一样,刚刚好比较好。

3.decoder设计

用深层和更大decoder不太好。其实也可以理解,encoder出来的隐向量的信息已经够复杂了。第二点是苦一苦encoder,这样在下游任务他发挥的更好一点。要是用复杂的encoder,信息和建模能力,都隐藏在decoder恐怕就没有这么好的效果了。

4.重建目标

作者和 BEiT 那种预测token的方式 以及 PCA 的方式。patch 做 PCA 并预测最大的因子,进行了比较。有无归一化也进行了比较。

5.数据增强

保持图片局部完整信息的随机缩放,比其他引入噪声的方式都要好。

往期精彩回顾适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载中国大学慕课《机器学习》(黄海广主讲)机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑机器学习交流qq群955171419,加入微信群请扫码

【深度学习】 MAE|心中无码,便是高清相关推荐

  1. 心中无码,自然高清 | 联合去马赛克与超分辨率研究论文Pytorch复现

    作者 | 知凡,个人公众号:林木蔚然读书会(ID:EspressoOcean),知乎ID:Uno Whoiam 本文授权转载自知乎 本文结构 简单扫盲 什么是去马赛克 什么是超分辨率 <Deep ...

  2. 资源|最好的九张机器学习/深度学习代码速查表,附高清下载

    作者:Kailash Ahirwar 机器之心编译 文末附高清速查表下载 对于初学者来讲,入门机器学习和深度学习非常困难:同时深度学习库也难以理解.通过收集多方资源,我在 Github 上创建了一个速 ...

  3. 深度学习框架Caffe源码解析

    作者:薛云峰(https://github.com/HolidayXue),主要从事视频图像算法的研究, 本文来源微信公众号:深度学习大讲堂.  原文:深度学习框架Caffe源码解析  欢迎技术投稿. ...

  4. Swift开发小技巧--扫描二维码,二维码的描边与锁定,设置扫描范围,二维码的生成(高清,无码,你懂得!)...

    二维码的扫描,二维码的锁定与描边,二维码的扫描范围,二维码的生成(高清,无码,你懂得!),识别相册中的二维码 扫描二维码用到的三个重要对象的关系,如图: 1.懒加载各种类 // MARK: - 懒加载 ...

  5. ArcGIS\QGIS无插件加载(无偏移)MapBox高清影像图

    喜欢就关注我们吧! 首先介绍一下MapBOX. Mapbox 是用于移动和 Web 应用程序的位置数据平台.用户可以使用Mapbox Studio创建一个自定义.交互式的地图,然后可以将这些自定义的地 ...

  6. 使用RNN神经网络自动生成名字 (不使用深度学习框架,源码)

    本文讲解在不使用深度学习框架的情况下,构建一个基本的RNN神经网络来进行名字自动生成.RNN模型请看下面的三张图片.本文主要讲解数据集以及输入模型的数据格式. 数据集和可执行的源码下载地址:https ...

  7. 基于深度学习的二维码检测和识别(含完整代码和数据)

    最近尝试着将深度学习技术引入到二维码检测和识别中,期望能够提升传统二维码的识读性能,能够适用更多复杂背景,并且最终应用到工业生产中,方便生产线上对产品的ID管理. 项目最终实现效果如下所示: 相对来说 ...

  8. 深度学习03-sklearn.LinearRegression 源码学习

    在上次的代码重写中使用了sklearn.LinearRegression 类进行了线性回归之后猜测其使用的是常用的梯度下降+反向传播算法实现,所以今天来学习它的源码实现.但是在看到源码的一瞬间突然有种 ...

  9. 何必心中无码,AI让你眼见为实

    还在为珍贵的照片,被路人抢镜而苦恼吗? 还在为景区人山人海,而拍不到一人一景手足无措吗? 上周,英伟达发布了一个超牛逼的AI修图技术,不需要专业的修图师进行修图,能够很完美的解决以上问题.下面我先来一 ...

最新文章

  1. 图像特征点—SIFT特征点
  2. MySQL · myrocks · myrocks统计信息
  3. 《Java基础入门》课后习题答案 资源分享
  4. mysql 分组 字符串_MySQL查询以字符串字段中的数字字符对行进行分组?
  5. python表示数字6_【第六节】Python数字(Number)
  6. javaEE开发问题整理(1)
  7. C#入门详解(14)
  8. 【研究】Metasploit自动攻击模块
  9. 反射和多态的实现原理详解以及区别
  10. elment-ui的table组件多行合并
  11. linux下载的安装包位置,及下载安装包到本地
  12. 用ssms建sql server数据库和python连接到数据库
  13. 网页|利用touch实现下拉刷新
  14. python fields_Python fields.Nested方法代码示例
  15. Android Studio 3.5.2 入门教程(浓缩版)
  16. 近期做笔试题总结和思考(百度,滴滴,360)
  17. 如何做快手副业?怎么在快手上赚工资?快手发视频怎么赚钱?
  18. QMetaObjectPrivate meta_constractors Q_INVOKABLE
  19. 软文营销登顶销售奇迹的4U定律你知道吗?
  20. 使用anaconda编程c语言,Anaconda的安装与虚拟环境建立

热门文章

  1. 拼多多2021笔试真题集 -- 1. 多多的数字组合
  2. 第063讲: 论一只爬虫的自我修养11:Scrapy框架之初窥门径 | 学习记录(小甲鱼零基础入门学习Python)
  3. IntelliJ IDEA Remote Development 使用体验
  4. 《2040大预言:高科技引擎与社会新秩序》——2.4 在芯片上建造大金字塔
  5. 腾讯地图数据可视化之热力图
  6. 2021款途锐噪音测试软件,试驾2021款大众途锐:这才是原汁原味的德国沃尔夫斯堡的味道...
  7. 200套工作室设计行业响应式Html5模板HTML5+CSS3设计网站模板简洁设计师作品展示响应式模板整洁扁平宽屏CSS3网站模板html5网页静态模板Bootstrap扁平化网站源码css3手机se
  8. 大学计算机未来五年规划,大学生活评价与未来五年计划(8页)-原创力文档
  9. Postgres 数据存储位置
  10. 怎看传智播客学员如此吃香