作为学术菜鸡的我跪着看完了kaiming大佬的论文,先po一个大佬主页:Kaiming He
在讲Masked Autoencoders Are Scalable Vision Learners这个之前,由于笔者对Transformer没有太深理解,因此会穿插一些transformer以及ViT的知识,那么接下来就废话不多说进入正题吧。
Masked Autoencoders Are Scalable Vision Learners

ViT

在讲MAE之前,为了更好的理解其思想,这里先简单的介绍下ViT。
ViT文章 和 ViT代码

ViT architecture

class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.Linear(patch_dim, dim),  # dim = 1024)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # torch.Size([1, 65, 1024])self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # torch.Size([1, 1, 1024])self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img):x = self.to_patch_embedding(img)b, n, _ = x.shape  # torch.Size([1, 64, 1024])cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)x = self.transformer(x)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)return self.mlp_head(x)

先来看看einops如何实现patch的维度变化:PyTorch 70.einops:优雅地操作张量维度

# 3x256x256图片分为64个3x32x32的patch
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width)
'''
torch.Size([1, 3, 256, 256])
b = 1, c = 3, h = 8, p1 = 32, w = 8, p2 = 32
torch.Size([1, 64, 3072])
'''

经过Rearrange后就可以把原始图片分成多个patch,接着用nn.linear对其embedding成64x1024,接下来对其进行Positional Encoding(可学习的位置编码,为什么需要位置编码呢?详见Transformer Architecture: The Positional Encoding)和class_token(Vision Transformer)。然后送入transformer得到encoded embedding进行分类任务。

MAE architecture

有了ViT的前置知识后再来看MAE,其结构如下图所示:

其中ViT作为encoder,其输入的patches是没有经过mask的,注意这里虽然使用的那些patches是没有mask的,但是这些没有mask的只占所有patches的一小部分,这也就是为什么MAE能够只用较少的内存和计算消耗就能训练大的encoders。
接着decoder部分把所有patches都计算进去(包括编码后的patches和mask的patches),并加入了位置编码信息。这些mask的patches即要还原出来的图像。并使用mean squared error (MSE) 来计算重构图片和真实图片间的误差。

代码

此处代码是来自github别人的复现:Unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16, normlize_target: bool = True, log_writer=None, lr_scheduler=None, start_steps=None,lr_schedule_values=None, wd_schedule_values=None):model.train()metric_logger = utils.MetricLogger(delimiter="  ")metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))header = 'Epoch: [{}]'.format(epoch)print_freq = 10loss_func = nn.MSELoss()for step, (batch, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):# assign learning rate & weight decay for each stepit = start_steps + step  # global training iterationif lr_schedule_values is not None or wd_schedule_values is not None:for i, param_group in enumerate(optimizer.param_groups):if lr_schedule_values is not None:param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]if wd_schedule_values is not None and param_group["weight_decay"] > 0:param_group["weight_decay"] = wd_schedule_values[it]images, bool_masked_pos = batchimages = images.to(device, non_blocking=True)bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)# import pdb; pdb.set_trace()with torch.no_grad():# calculate the predict labelmean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None]std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None]unnorm_images = images * std + mean  # in [0, 1]if normlize_target:images_squeeze = rearrange(unnorm_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=patch_size, p2=patch_size)images_norm = (images_squeeze - images_squeeze.mean(dim=-2, keepdim=True)) / (images_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)# we find that the mean is about 0.48 and standard deviation is about 0.08.images_patch = rearrange(images_norm, 'b n p c -> b n (p c)')else:images_patch = rearrange(unnorm_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)B, _, C = images_patch.shapelabels = images_patch[bool_masked_pos].reshape(B, -1, C)with torch.cuda.amp.autocast():outputs = model(images, bool_masked_pos)loss = loss_func(input=outputs, target=labels)loss_value = loss.item()if not math.isfinite(loss_value):print("Loss is {}, stopping training".format(loss_value))sys.exit(1)optimizer.zero_grad()# this attribute is added by timm on one optimizer (adahessian)is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_ordergrad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,parameters=model.parameters(), create_graph=is_second_order)loss_scale_value = loss_scaler.state_dict()["scale"]torch.cuda.synchronize()metric_logger.update(loss=loss_value)metric_logger.update(loss_scale=loss_scale_value)min_lr = 10.max_lr = 0.for group in optimizer.param_groups:min_lr = min(min_lr, group["lr"])max_lr = max(max_lr, group["lr"])metric_logger.update(lr=max_lr)metric_logger.update(min_lr=min_lr)weight_decay_value = Nonefor group in optimizer.param_groups:if group["weight_decay"] > 0:weight_decay_value = group["weight_decay"]metric_logger.update(weight_decay=weight_decay_value)metric_logger.update(grad_norm=grad_norm)if log_writer is not None:log_writer.update(loss=loss_value, head="loss")log_writer.update(loss_scale=loss_scale_value, head="opt")log_writer.update(lr=max_lr, head="opt")log_writer.update(min_lr=min_lr, head="opt")log_writer.update(weight_decay=weight_decay_value, head="opt")log_writer.update(grad_norm=grad_norm, head="opt")log_writer.set_step()if lr_scheduler is not None:lr_scheduler.step_update(start_steps + step)# gather the stats from all processesmetric_logger.synchronize_between_processes()print("Averaged stats:", metric_logger)return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

实验效果如下图所示,可以发现mask掉大部分的图片经过decoder后能还原出原始图像,但是随着mask rate的提高,其重构的图像还是能还原出学到的东西,只不过数量变少了。这些都是符合语义信息的(蘑菇还是蘑菇),说明模型已经学习到了图像中的物体归纳性特征,已经具有很强的泛化能力。

参考文献和资料:

1.Masked Autoencoders Are Scalable Vision Learners
2.ViT文章
3.ViT代码
4.PyTorch 70.einops:优雅地操作张量维度
5.Transformer Architecture: The Positional Encoding
6.Vision Transformer
7.Unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

Masked Autoencoders Are Scalable Vision Learners相关推荐

  1. 李沐精读论文:MAE 《Masked Autoencoders Are Scalable Vision Learners》

    论文:Masked Autoencoders Are Scalable Vision Learners 别再无聊地吹捧了,一起来动手实现 MAE(Masked Autoencoders Are Sca ...

  2. Masked Autoencoders Are Scalable Vision Learners 论文研读

    NLP CV Transformer ViT BERT MAE 文章目录 1. 标题 2. 摘要 3. 模型架构 4. 结论 1. 标题   Masked Autoencoders Are Scala ...

  3. 【读点论文】Masked Autoencoders Are Scalable Vision Learners 完型填空应用到视觉处理上

    Masked Autoencoders Are Scalable Vision Learners 本文表明,掩蔽自动编码器(MAE)是计算机视觉的可扩展自监督学习器. 本文的MAE方法很简单:通过屏蔽 ...

  4. (七十六):Masked Autoencoders Are Scalable Vision Learners

    (七十六):Masked Autoencoders Are Scalable Vision Learners Abstract 1. Introduction 2. Related Work 3. M ...

  5. Masked Autoencoders Are Scalable Vision Learners(MAE)

    VIT论文解读:Vision Transformer(ViT)_NLP_wendi的博客-CSDN博客 论文链接:Masked Autoencoders Are Scalable Vision Lea ...

  6. Masked Autoencoders Are Scalable Vision Learners 论文导读

    Facebook 人工智能研究 (FAIR) 团队发表的论文 Masked Autoencoders Are Scalable Vision Learners 已成为计算机视觉社区的热门话题.这也是K ...

  7. MAE 论文《Masked Autoencoders Are Scalable Vision Learners》

    <Masked Autoencoders Are Scalable Vision Learners>带掩码的自编码器是一个可拓展的视觉学习器,听名字就明白一二,应该是在编码器部分加上了 m ...

  8. 【论文和代码阅读】Masked Autoencoders Are Scalable Learners (MAE)

    写在最前面 先贴一下MAE的论文链接 https://arxiv.org/pdf/2111.06377.pdfhttps://arxiv.org/pdf/2111.06377.pdf紧随其后的是代码复 ...

  9. 论文阅读VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training

    VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training 文章目 ...

  10. PyTorch笔记 - MAE(Masked Autoencoders) PyTorch源码

    欢迎关注我的CSDN:https://blog.csdn.net/caroline_wendy 本文地址:https://blog.csdn.net/caroline_wendy/article/de ...

最新文章

  1. 生活|全民AI时代:干洗店老板、高中生齐上阵
  2. 神经网络理论基础及Python实现
  3. Android Gradle Plugin 源码解析之 externalNativeBuild
  4. 设计模式学习笔记——装饰(Decorator)模式
  5. 工作资讯003---甘特图
  6. Javascript:Ajax讲解
  7. Invalid injected android support version ‘202.7660.26.42.7322048‘, expected to be of the form ‘w.x.y
  8. 理解Monitor监视器锁原理
  9. Alpha版本冲刺(七)
  10. pbl和sbl_ROKSO、SBL、XBL、PBL、DBL 是什么意思?
  11. linux的系统监视器图片_用Nvidia Jetson Nano 2GB和Python构建一个价值60美元的人脸识别系统 - 人工智能遇见磐创...
  12. Ubuntu 查看Nvidia显卡驱动信息
  13. Struts验证框架与一些技巧
  14. raspberry pi_我如何从Mac Mini迁移到Raspberry Pi
  15. 微信小程序请求封装token
  16. 这些城市 你5w就可以买一套全款房
  17. 手机QQ公众号亿级消息实时群发架构
  18. 对话框AlertDialog的基本使用(新手)
  19. python tkinter canvas 画心形
  20. php 摘要算法,MD5摘要算法 - lvk618的个人空间 - OSCHINA - 中文开源技术交流社区

热门文章

  1. OpenCasCade与NURBS——B样条曲线
  2. vscode中使用beautify插件格式化vue文件(自定义快捷键)
  3. 如何将eclipse项目和svn关联(从服务器取项目)
  4. java计算机毕业设计楼宇管理系统源码+数据库+lw文档+系统
  5. Redis(十):sentinel.conf 配置文件说明
  6. 跟着Leo机器学习:sklearn之Nearest Neighbors
  7. JavaScript实现图片瀑布流
  8. springboot easypoi excel导出功能
  9. 《Hibernate上课笔记》------class6------Hibernate实现一对多关联映射
  10. 条件概率公式图解推导