目录

  • 前言
  • 一睹为快,眼见为实
  • 代码解读

前言

github:https://github.com/facebookresearch/mae
论文:https://arxiv.org/abs/2111.06377
解读:何凯明最新一作MAE解读系列1

MAE实践主要从几个方面展开:

  1. 复现代码;复现代码过程,去粗存精,保留软件环境说明,如pytorch+cuda版本等;从jupyter来可视化过程。
  2. 解读代码。解读代码过程,双管齐下,参照论文framework来解读代码。

“知行合一”,“纸上得来终觉浅,觉知此事要躬行”

一睹为快,眼见为实

在jupyter中先快速可视化MAE的工作效果,以下展示使用MAE重构遮挡75%比例的图像效果,不用GPU即可完成操作。






















github






















download code+model









Demo: Jupyter Visualize







Debug









Show







Bug: __init__() got an unexpected keyword argument 'qk_scale' #58

https://github.com/facebookresearch/mae/issues/58#

Jupyter通过命令行保存成python格式

Jupyter通过命令行保存成python格式,代码如下(示例):

try:   !jupyter nbconvert --to python mae_visualize.ipynb# python即转化为.py,script即转化为.html# file_name.ipynb即当前module的文件名
except:pass

Run MAE on the image

# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with pixel reconstruction:')
run_one_image(img, model_mae)

MAE with pixel reconstruction:


以上部分,可视化了MAE可视化的效果。主要由三个部分实现的,第一,随机采样:通过高比例如75%的遮挡像素块,留下可见的25%子块;第二,编码器将这1/4的子块进行表征学习,注意,没有任何的令牌输入到编码器中;第三,解码器将编码后的表征与3/4遮挡子块对应的令牌一起输入到解码器中,解码器的输出为归一化的像素值,他们将会重塑成有序的图像,即重构后的图片。
以上部分是预训练的过程。在图像推断或者识别中,去掉解码器,只保留编码器部分,并且输入到编码器的是一张完整的图片,不再遮挡掩码。
我们解读MAE的目的就是预训练后的模型,作为有效的骨架,来为下游任务提供服务。目前,原文中实现的下游任务有目标检测(Faster RCNN)和图像分割(UpperNet),我们的目标是将之应用到2D pose中。
所以我们的思路如下:首先爬到fine-tuned的分类模型,作为我们的backbone;第二,在backbone后设计一个light-weighted Head,输出heapmap以适应pose估计的任务。

代码解读

Mask Autoencoder结构代码如下:

def __init__(self, img_size=224, patch_size=16, in_chans=3,embed_dim=1024, depth=24, num_heads=16,decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):super().__init__()# --------------------------------------------------------------------------# MAE encoder specificsself.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)num_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embeddingself.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)# --------------------------------------------------------------------------# --------------------------------------------------------------------------# MAE decoder specificsself.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embeddingself.decoder_blocks = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(decoder_depth)])self.decoder_norm = norm_layer(decoder_embed_dim)self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch# --------------------------------------------------------------------------self.norm_pix_loss = norm_pix_lossself.initialize_weights()

MAE结构,主要包含两个部分构成,即encoder+decoder定义。encoder结构,基本上就是ViT-Large模型,包含24个transformer block;decoder 是8个transformer block的轻量级网络。

有关PatchEmbed、Block和pos_embed的原理与用法,会专门设置一个主题来讲原始的transformer。

patchify函数:将N×3×224×224图像变成图像块,每个图像块大小为16,并将16×16图像拉伸为1×256序列,这样输出为:N×196×768

    def patchify(self, imgs):"""imgs: (N, 3, H, W)x: (N, L, patch_size**2 *3)"""p = self.patch_embed.patch_size[0]   #p=16assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0h = w = imgs.shape[2] // p   # 224/16=14x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))x = torch.einsum('nchpwq->nhwpqc', x)x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))   #14*14=196,16^2=256return x

unpatchify函数:patchify的逆过程,将序列恢复到图像。即序列N×196×768,恢复到RGB图像N×3×224×224。

    def unpatchify(self, x):"""x: (N, L, patch_size**2 *3)imgs: (N, 3, H, W)"""p = self.patch_embed.patch_size[0]   #16h = w = int(x.shape[1]**.5)  #14assert h * w == x.shape[1]x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))x = torch.einsum('nhwpqc->nchpwq', x)imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))return imgs

random_masking函数:执行随机采样。随机采样的过程是,首先通过均匀分布产生一个随机的数组,然后将之进行高低排序,选择小数值即1/4部分保留,余下的去掉,即洗牌操作,可实现图像块的随机遮挡。

    def random_masking(self, x, mask_ratio):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dim=N×196×768len_keep = int(L * (1 - mask_ratio))  #保留的图像块个数为49noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restore

有关torch.gather,可以参考torch官方文档说明。dim=1即按照列向量组合index,去取值mask。

forward_encoder函数:一张完整的图片输入到encoder中,首先执行常见的patch_embedding,然后添加2D_sin_cos位置信息;再执行随机掩码,这里随机掩码后的输出x与原始输入x的维度是否一致,从代码中看,masked后的输出添加了class token和pos_embed,因此两者的维度应该一致。然后将class token与masked后的x合并起来,输入到一个堆叠的block栈中。最后需要进行Layer Norm归一化。

    def forward_encoder(self, x, mask_ratio):# embed patchesx = self.patch_embed(x)# add pos embed w/o cls tokenx = x + self.pos_embed[:, 1:, :]# masking: length -> length * mask_ratiox, mask, ids_restore = self.random_masking(x, mask_ratio)# append cls tokencls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1)# apply Transformer blocksfor blk in self.blocks:x = blk(x)x = self.norm(x)return x, mask, ids_restore

forward_decoder函数:输入到解码器中,包含两个部分,即encoded学习到的representation,和剩下3/4部分的图像块对应的掩码id。mask_token是个虚构的标记,文章中说是为了模仿NLP中的class_token而设计的一个虚构标记;然后将encoded的representation和mask_token通过tensor合并到一起,之后再添加class_token。这波操作,真的是看的人云里雾里。mask_token和class_token具体是怎样的,只能在debug中去打印出来,看下shape或者数据。然后再添加decoder的位置信息,输入到一个浅层的transformer栈中,最后除了layer norm外,还通过以全连接层映射到输出,并去掉class_token。不禁提问,为什么这么做?

    def forward_decoder(self, x, ids_restore):# embed tokensx = self.decoder_embed(x)# append mask tokens to sequencemask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls tokenx_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshufflex = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token# add pos embedx = x + self.decoder_pos_embed# apply Transformer blocksfor blk in self.decoder_blocks:x = blk(x)x = self.decoder_norm(x)# predictor projectionx = self.decoder_pred(x)# remove cls tokenx = x[:, 1:, :]return x

forward_loss函数:输出是重构的向量,原始的图片可以通过patchify将图像数据转换成重构向量一样的维度。这里,原始图片即目标,需要像素值的归一化,首先取均值再求标准差。损失函数即均方误差,并且只在3/4部分上求误差。

    def forward_loss(self, imgs, pred, mask):"""imgs: [N, 3, H, W]pred: [N, L, p*p*3]mask: [N, L], 0 is keep, 1 is remove, """target = self.patchify(imgs)if self.norm_pix_loss:mean = target.mean(dim=-1, keepdim=True)var = target.var(dim=-1, keepdim=True)target = (target - mean) / (var + 1.e-6)**.5loss = (pred - target) ** 2loss = loss.mean(dim=-1)  # [N, L], mean loss per patchloss = (loss * mask).sum() / mask.sum()  # mean loss on removed patchesreturn loss

何凯明最新一作MAE解读系列2之代码实践相关推荐

  1. 何凯明最新一作MAE(mask掉图片的部分信息也能重建识别)

    导读 凯明出品,必属精品.没有花里胡哨的修饰,MAE就是那么简单的强大,即结构简单但可扩展性能强大.MAE通过设计一个非对称的编码解码器,在预训练阶段,通过高比例的掩码原图,将可见部分输入到编码器中: ...

  2. 如何看待何恺明最新一作论文Masked Autoencoders?

    来源 | 知乎问题 地址 | https://www.zhihu.com/question/498364155 编辑 | 机器学习算法与自然语言处理 原问题:如何看待何恺明最新一作论文Masked A ...

  3. 大道至简,何恺明最新一作火了:让计算机觉视觉通向大模型!

    何恺明,清华大学本科,港中文博士 来源 | 知乎,MLNLP编辑 https://www.zhihu.com/question/498364155 原问题:如何看待何恺明最新一作论文Masked Au ...

  4. 何恺明最新一作论文:无监督胜有监督,迁移学习无压力,刷新7项检测分割任务...

    鱼羊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 何恺明的一作论文,又刷新了7项分割检测任务. 这一次,涉及的是无监督表征学习.这一方法广泛应用在NLP领域,但尚未在计算机视觉中引起注意. ...

  5. 【深度学习】preprint版本 | 何凯明大神新作MAE | CVPR2022最佳论文候选

    文章转自:微信公众号[机器学习炼丹术] 笔记作者:炼丹兄(已授权转载) 联系方式:微信cyx645016617 论文题目:"Masked Autoencoders Are Scalable ...

  6. 何恺明最新工作:简单实用的自监督学习方案MAE,ImageNet-1K 87.8%

    作者丨happy 编辑丨极市平台 本文首发于极市平台,转载须经授权并注明来源 论文链接:https://arxiv.org/pdf/2111.06377.pdf 恺明出品,必属精品!这篇文章延续了其一 ...

  7. 如何从数学角度解释何恺明新作Masked Autoencoders (MAE)?

    何恺明最新一作论文 Masked Autoencoders(MAE)为自监督学习方法带来了一场革命,自提出以来,在 AI 领域内得到了极大的关注.MAE 不仅在图像预训练上达到了 SOTA 性能,更是 ...

  8. Lossless Codec---APE代码解读系列(二)

    APE file 一些概念 APE代码解读系列(一) APE代码解读系列(三) 1. 先要了解APE compression level APE主要有5level, 分别是: CompressionL ...

  9. 大概是全网最详细的何恺明团队顶作MoCo系列解读...(完结篇)

    ​作者丨科技猛兽 编辑丨极市平台 本文原创首发于极市平台,转载请获得授权并标明出处. 大概是全网最详细的何恺明团队顶作 MoCo 系列解读!(上) 本文目录 1 MoCo v2 1.1 MoCo v2 ...

  10. 【深度学习】大概是全网最详细的何恺明团队顶作MoCo系列解读...(完结篇)

    作者丨科技猛兽 编辑丨极市平台 导读 kaiming 的 MoCo让自监督学习成为深度学习热门之一, Yann Lecun也在 AAAI 上讲 Self-Supervised Learning 是未来 ...

最新文章

  1. 控制服务器信息不存在或已删除,错误1075:依存服务不存在, 或已标记为删除的解决方法...
  2. win10网络邻居看到linux,在Deepin 20系统中网络共享Windows无法访问的另类解决方法...
  3. 剑指offer-数组中出现次数超过一半的数字
  4. RDD 与 DataFrame原理-区别-操作详解
  5. Oracle 11g新特性:Result Cache
  6. python 3d大数据可视化软件_最受欢迎的大数据可视化软件
  7. .NET Core 3.0之深入源码理解Startup的注册及运行
  8. 2018091-2 博客作业
  9. ORACLE复杂查询之连接查询
  10. 最实用的Git命令总结:新建本地分支、远程分支、关联和取消关联分支、清除本地和远程分支、合并分支、版本还原、tag命令、中文乱码解决方案、如何fork一个分支和修改后发起合并请求
  11. java的位置_Java中数据存放的位置
  12. elcentro matlab,EL-Centro地震波积分计算与基线调整.docx
  13. Pensieve Multi_agent代码详解以及A3C强化学习代码详解
  14. amd linux显卡驱动,AMDAMD ATI Radeon Mobility FireGL 9.10显卡驱动官方正式版下载,适用于linux-驱动精灵...
  15. 计算机矩阵入门(eigen)0XC000041D
  16. c语言编程十进制转八进制算法,C语言十进制如何转八进制?
  17. canvas实现粒子特效
  18. c语言break后要分号吗,C语言程序每行结尾处都必须加分号(;)作为结束符号。
  19. 关于VMware上的VAAI特性详解
  20. 生物基础知识---CDS,基因,Matlab生物信息工具箱

热门文章

  1. Android群英传知识点回顾——第七章:Android动画机制与使用技巧
  2. OPENCV 实现png绘制,alpha通道叠加。
  3. 淘宝天猫商城的推广方法大总结
  4. 国内外各大免费搜索引擎、导航网址提交入口(转载)
  5. 几款基于ODE的机器人仿真软件
  6. ecshop ectouch 不支持html,ECShop上传的商品图片在ECTouch不能显示,怎么解决
  7. 赵伟功老师 管理系统提升专家
  8. 矩阵论知识整理(未完成,同步更新)
  9. mysql 存储微信表情
  10. mysql处理微信表情