Masked Auto Encoder总结

文章目录

  • Masked Auto Encoder总结
    • MAE简介
    • Random Mask
      • random mask 逻辑
      • random mask 实现
    • Encoder
      • Encoder网络结构
      • Block源代码
      • Ecoder计算流程
    • Decoder
      • Decoder网络结构
      • Decoder计算流程
    • Loss计算
      • Loss计算流程

MAE简介

MAE是用于CV的自监督学习方法,优点是扩展性强的,方法简单。在MAE方法中会随机mask输入图片的部分patches,然后重构这些缺失的像素。

MAE基于两个核心设计:

  1. 不对称的(asymmetric)编码解码结构,编码器仅仅对可见的patches进行编码,不对mask tokens进行任何处理,解码器将编码器的输出(latent representation)和 mask tokens 作为输入,重构image。
  2. 使用较高的mask比例(eg:75%)进行训练时,MAE展现了很强的迁移性能,且因为方法简单,可扩展性极强。

通过上图可以看到,我们首先将一张大图片切割为很多小的patches,然后随机mask了很多patch;再将没有被mask掉的图片送入encoder提取特征,之后将特征和被mask掉的patch(未经过任何处理)按照顺序拼接在一起,通过decoder进行图片重建。

MAE的encoder主要是基于ViT的,我们在讲解ViT的时候详细的说了ViT的结构(如果对ViT不了解,可以看我上一篇blog),所以MAE的encoder对于我们来讲就显得很容易理解。我们会对encoder做简要的说明,将主要篇幅放在MAE如何进行随机mask、如何还原原来的顺序、decoder的结构以及对于还原图片loss计算上。

Random Mask

首先我们进行MAE中第一步的解释,我们如何将图片随机mask。

进行mask之前还需要将图片变成patch,此时维度变化如下:
[b,c,h,w]−>[b,(h/patchh)∗(w/patchw),(patchh∗patchw∗c)][b,c,h,w]->[b,(h/patch_h)*(w/patch_w),(patch_h*patch_w*c)] [b,c,h,w]>[b,(h/patchh)(w/patchw),(patchhpatchwc)]
这部分代码中直接调用了Ross Wightman在github上的开源项目中的源码 源码链接,图片patch化的代码如下:

from timm.models.vision_transformer import PatchEmbed, Block
def forward_encoder(self, x, mask_ratio):# embed patchesx = self.patch_embed(x)# add pos embed w/o cls token# 这里是加上了位置编码,这部分在ViT文章中解释过,在这里不赘述x = x + self.pos_embed[:, 1:, :]# masking: length -> length * mask_ratiox, mask, ids_restore = self.random_masking(x, mask_ratio)

获得patch化的x之后,接下来就是随即掩码的过程。

random mask 逻辑

首先我们需要清楚,对于输入x,他的维度是[b,patch_num,dim][b,patch\_num,dim][b,patch_num,dim],我们掩码x,实际的掩码对象是dim维度的数据。由于一个batch中有patch_numpatch\_numpatch_num个dim维数据,所以我们要掩码 patch_num∗75patch\_num * 75%patch_num75 这么多的数据。

为了达到这一目的,我们首先随机化一个[b,patch_num][b,patch\_num][b,patch_num]维的向量,然后对其第一维(patch_num)进行排序,并且对处于前25%的部分对应的dim维数据保持原样,剩下的进行掩码。这样就完成了随机掩码75%的操作。

这里要注意,因为其用到的是argsort()函数,所以返回的是排序的下标。由于在模型中我们需要记住打乱的数据原来的位置,此信息在这段代码中由ids_shuffleids_restore两个变量来保存。

保存的逻辑如下:

我们可以发现,ids_shuffleids_restore是的下标和值存在明显的对应关系,可以通过这两个变量记录之前patch的实际位置。

之后从ids_shuffle中前提取出len_keep长度的信息保留,并且使用torch.gather()函数将被保留的信息挑出来。这里torch.gather()函数是一个巧妙的函数,具体的介绍可以通过这个链接查看 gather函数介绍。这里我们只需要知道,它可以按照我们保留的num_patches下标,将需要的dim维向量整合出来,并保存在x_masked变量中。

同样的,代码最后的mask变量(维度为[b,patch_num][b,patch\_num][b,patch_num])也变得好理解了,我们用一个batch来举例,其中的变量是如下形式的:[1,0,1,1,1,...1,0,1][1,0,1,1,1,...1,0,1][1,0,1,1,1,...1,0,1]。其中1表示这个位置对应的dim维patch被mask了,反之没有。

random mask 实现

    def random_masking(self, x, mask_ratio):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restore

Encoder

前文说过,MAE中的Encoder是一个ViT,所以这里我们直接从代码中解读Encoder。

Encoder网络结构

从Encoder的网络结构可以看出,他基本就是一个ViT,我们可以看到非常熟悉的cls_token与pos_embed变量。

但是我们还注意到了,这个ViT的Transformer是通过Block实现的。这个Block同样也是来自Ross Wightman的代码,Block的源代码在后面贴出。我们可以看到Block是一个标准的多头注意力模型。

class MaskedAutoencoderViT(nn.Module):""" Masked Autoencoder with VisionTransformer backbone"""def __init__(self, img_size=224, patch_size=16, in_chans=3,embed_dim=1024,          # encoder的隐藏层维度depth=24,               # encoder中transformer的深度num_heads=16,decoder_embed_dim=512,   decoder_depth=8, decoder_num_heads=16,mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):super().__init__()# --------------------------------------------------------------------------# MAE encoder specificsself.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)num_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embeddingself.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)# --------------------------------------------------------------------------

Block源代码

可以看出这是一个标准的多头注意力模型。

class Block(nn.Module):def __init__(self,dim,num_heads,mlp_ratio=4.,qkv_bias=False,drop=0.,attn_drop=0.,init_values=None,drop_path=0.,act_layer=nn.GELU,norm_layer=nn.LayerNorm):super().__init__()self.norm1 = norm_layer(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()# NOTE: drop path for stochastic depth, we shall see if this is better than dropout hereself.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()def forward(self, x):x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))return x

Ecoder计算流程

可以看出Encoder的计算流程,相比ViT,只是多出了一个random_masking

最终x的维度为[b,keep_length+1,dim][b,keep\_length + 1,dim][b,keep_length+1,dim],之所以第二维+1是因为拼接上了cls_tokencls\_tokencls_token

    def forward_encoder(self, x, mask_ratio):# embed patchesx = self.patch_embed(x)# add pos embed w/o cls tokenx = x + self.pos_embed[:, 1:, :]# masking: length -> length * mask_ratiox, mask, ids_restore = self.random_masking(x, mask_ratio)# append cls tokencls_token = self.cls_token + self.pos_embed[:, :1, :]cls_tokens = cls_token.expand(x.shape[0], -1, -1)x = torch.cat((cls_tokens, x), dim=1)# apply Transformer blocksfor blk in self.blocks:x = blk(x)x = self.norm(x)return x, mask, ids_restore

Decoder

Decoder网络结构

由于encoder输出和decoder输入的维度不同,所以decoder先通过一个线性层映射,将维度转换成decoder的输入维度。

这里的mask_token是个重要的部分,暂时他的维度是[1,1,decoder_embed_dim][1,1,decoder\_embed\_dim][1,1,decoder_embed_dim],他表示的是一张图片中被掩码的patch。在之后的forward函数中,mask_token会被扩增为[b,patch_num∗75%,decoder_embed_dim][b,patch\_num * 75\%,decoder\_embed\_dim][b,patch_num75%,decoder_embed_dim],表示所有batch中被掩码掉的patch。

之后就是我们熟悉的位置编码、transformer层、norm层以及维度转换的Linear层。

class MaskedAutoencoderViT(nn.Module):""" Masked Autoencoder with VisionTransformer backbone"""def __init__(self, img_size=224, patch_size=16, in_chans=3,embed_dim=1024, depth=24, num_heads=16,decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):super().__init__()# --------------------------------------------------------------------------# MAE encoder specifics# 这里是原来的encoder代码# --------------------------------------------------------------------------# --------------------------------------------------------------------------# MAE decoder specificsself.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embeddingself.decoder_blocks = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)for i in range(decoder_depth)])self.decoder_norm = norm_layer(decoder_embed_dim)self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch# --------------------------------------------------------------------------

Decoder计算流程

我们需要说明,decoder的输入是encoder的输出。所以我们首先需要线性层进行维度转换,下一步是很重要的一步——将被mask掉的值拼接回去,这一步的操作会结合代码详细讲解。

首先我们将 Decoder网络结构 中定义的mask_token进行repeat操作,使其最后的数量与被mask掉的块数量一致。具体讲解在代码中以注释的形式出现。

def forward_decoder(self, x, ids_restore):# embed tokensx = self.decoder_embed(x)# append mask tokens to sequence'''这里mask_tokens最后的维度是[b,patch_num*75%,decoder_embed_dim]、第一维是batchsize是很好理解的,这就是x.shape[0]的含义第二维有点复杂,首先明确ids_restore.shape[1]是patch_num,x.shape[1]是keep_length + 1,这个1需要提醒,是在x上面加入的cls_token,所以 ids_restore.shape[1] + 1 - x.shape[1] 就是patch_num + 1 - (keep_length + 1) = patch_num - keep_length。而keep_length就是patch_num * 25%的结果(这一点如果忘掉的话可以看Random Mask部分回忆一下),所以第二维就是被mask部分的长度。第三维很好理解,在repeat函数中参数为1,也就是保留之前的 decoder_embed_dim 维度'''mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)'''这里是将未被mask的x和刚刚构建出来的被mask部分拼接起来,但是在这里还没有恢复他们的正确位置,只是把mask_tokens简单的拼接在了x后面,并且此处的x是去掉了cls_token的'''x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token'''这里是利用gather函数将上一条代码的简单拼接改为了正确的patch位置。'''x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle'''这里将输入x的cls_token又重新拼接,因为我们并没有改变batch维度的顺序(原来在位置i的图片信息现在还在位置i),所以将参与了训练的cls_token直接拼接是不会有问题的'''x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token# add pos embedx = x + self.decoder_pos_embed# apply Transformer blocksfor blk in self.decoder_blocks:x = blk(x)x = self.decoder_norm(x)# predictor projectionx = self.decoder_pred(x)# remove cls tokenx = x[:, 1:, :]return x

Loss计算

图片重建部分的loss并没有调用pytorch提供的函数,但并不意味着这部分困难。相反,这里loss的计算逻辑很简单,只是计算被mask掉部分的预测值与实际值之间的欧氏距离。

Loss计算流程

    def forward_loss(self, imgs, pred, mask):"""imgs: [N, 3, H, W]pred: [N, L, p*p*3]mask: [N, L], 0 is keep, 1 is remove, """# 将原始图片patch化target = self.patchify(imgs)# 将target归一化if self.norm_pix_loss:mean = target.mean(dim=-1, keepdim=True)var = target.var(dim=-1, keepdim=True)target = (target - mean) / (var + 1.e-6)**.5# 计算每一个patch的lossloss = (pred - target) ** 2loss = loss.mean(dim=-1)  # [N, L], mean loss per patch# 为了计算实际预测的误差,应该只计算被mask的部分的loss,这时候mask矩阵就派上用场了,他将没有被mask的部分乘以0,除去了这部分loss的影响,只保留了mask部分的lossloss = (loss * mask).sum() / mask.sum()  # mean loss on removed patchesreturn loss

atch的loss
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch

    # 为了计算实际预测的误差,应该只计算被mask的部分的loss,这时候mask矩阵就派上用场了,他将没有被mask的部分乘以0,除去了这部分loss的影响,只保留了mask部分的lossloss = (loss * mask).sum() / mask.sum()  # mean loss on removed patchesreturn loss

Masked Auto Encoder总结相关推荐

  1. Auto Encoder用于异常检测

    对基于深度神经网络的Auto Encoder用于异常检测的一些思考 from:https://my.oschina.net/u/1778239/blog/1861724 一.前言 现实中,大部分数据都 ...

  2. 自编码器(Auto Encoder)原理及其python实现

    目录 一.原理 二.为什么要使用自编码器 三.代码实现 1.原始自编码器 2.多层(堆叠)自编码器 3.卷积自编码器 4.正则自编码器 4.1稀疏自编码器 四.降噪自编码器 五. 逐层贪婪训练堆叠自编 ...

  3. 2021李宏毅机器学习课程笔记——Auto Encoder

    注:这个是笔者用于期末复习的一个简单笔记,因此难以做到全面详细,有疑问欢迎大家在评论区讨论 https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-d ...

  4. 堆叠降噪自动编码器 Stacked Denoising Auto Encoder(SDAE)

    原文链接 自动编码器(Auto-Encoder,AE) 自动编码器(Auto-Encoder,AE)自编码器(autoencoder)是神经网络的一种,经过训练后能尝试将输入复制到输出.自编码器内部有 ...

  5. Auto Encoder(AE),Denoising Auto Encoder(DAE), Variational Auto Encoder(VAE) 区别

    文章主要内容参考李宏毅老师的机器学习课程:https://www.bilibili.com/video/BV1Wv411h7kN?p=70 Auto Encoder: 是什么?有什么用? Auto E ...

  6. 【人工智能概论】 变分自编码器(Variational Auto Encoder , VAE)

    [人工智能概论] 变分自编码器(Variational Auto Encoder , VAE) 文章目录 [人工智能概论] 变分自编码器(Variational Auto Encoder , VAE) ...

  7. 【深度学习】李宏毅2021/2022春深度学习课程笔记 - Auto Encoder 自编码器 + PyTorch实战

    文章目录 一.Basic Idea of Auto Encoder 1.1 Auto Encoder 结构 1.2 Auto Encoder 降维 1.3 Why Auto Encoder 1.4 D ...

  8. 机器学习笔记:auto encoder

    1 autoencoder 介绍 这是一个无监督学习问题,旨在从原始数据x中学习一个低维的特征向量(没有任何标签) encoder 最早是用线性函数+非线性单元构成(比如Linear+nonlinea ...

  9. Auto Encoder再学习

    一:AutoEncoder基本概念 将输入的比较高维度信息,不管是语音,文字,图像经过encoder转成一个中间状态的向量(也叫做latent code),这是一个低维度的数据,再通过decoder ...

最新文章

  1. 2022-2028年中国综艺节目市场深度调研及投资前景预测报告
  2. 新装 Win7 系统装完驱动精灵,一打开到检测界面就卡死——原因与解决方案
  3. 手机浏览器UserAgnet大全
  4. 神经网络python实例分类_Python使用神经网络进行简单文本分类
  5. hibernate配置
  6. PHP微信SDK——Zebra-Wechat
  7. 超级干货!31 条2020 年最新版 ZooKeeper面试题,先收藏再看!| 博文精选
  8. java程序回滚之后在哪看_Java在触发事务回滚之后为什么会再一次回到Servlet开始的地方重新走一次流程?...
  9. 夏季学期软工综合实践小记(二)
  10. 戴机械手表有哪些事情就不能做了?
  11. SQL 2014 AlwaysOn 搭建
  12. 好好学习努力工作,要工作也要生活—2016总结,2017规划
  13. Kali Rolling更换登录界面的背景
  14. 调用支付宝第三方支付接口详解(沙箱环境)
  15. cvr存储服务器的优势,CVR是什么
  16. [报表篇] (11)设置印刷尺寸
  17. 03 计算机视觉-opencv图像形态学处理
  18. 此安装不支持该项目类型
  19. 用数组+链表实现哈希表
  20. 超微服务器怎么开虚拟化,超微6016TT-IBXF服务器Supermicro开启虚拟化支持

热门文章

  1. 保姆级——Java使用腾讯云实现手机号验证码登录
  2. GNINX下配置WHMCS伪静态教程
  3. 推石磨机器人_[我爱发明]豆花西施 机器人推石磨豆花机(发明人邹属民)
  4. 扔物线Git小册笔记
  5. RxJava的简单学习(学习自扔物线)
  6. Android面试-LaunchMode及Task工作模式(扔物线笔记)
  7. panda学习190103
  8. 从Noob开始学习python/pyqt5(1)环境安装,工程搭建与打包exe
  9. html网页中的锚点(命名锚记)的使用介绍
  10. PHP foreach循环语句