论文名称: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
论文下载链接:https://arxiv.org/abs/2010.11929
原论文对应源码:https://github.com/google-research/vision_transformer

Pytorch实现代码: pytorch_classification/vision_transformer
Tensorflow2实现代码:tensorflow_classification/vision_transformer
在bilibili上的视频讲解:https://www.bilibili.com/video/BV1Jh411Y7WQ


文章目录

  • 前言
  • 模型详解
    • Vision Transformer模型详解
      • Embedding层结构详解
      • Transformer Encoder详解
      • MLP Head详解
      • 自己绘制的Vision Transformer网络结构
    • Hybrid模型详解
  • ViT模型搭建参数

前言

Transformer最初提出是针对NLP领域的,并且在NLP领域大获成功。这篇论文也是受到其启发,尝试将Transformer应用到CV领域。关于Transformer的部分理论之前的博文中有讲,链接,这里不在赘述。通过这篇文章的实验,给出的最佳模型在ImageNet1K上能够达到88.55%的准确率(先在Google自家的JFT数据集上进行了预训练),说明Transformer在CV领域确实是有效的,而且效果还挺惊人。


模型详解

在这篇文章中,作者主要拿ResNet、ViT(纯Transformer模型)以及Hybrid(卷积和Transformer混合模型)三个模型进行比较,所以本博文除了讲ViT模型外还会简单聊聊Hybrid模型。

Vision Transformer模型详解

下图是原论文中给出的关于Vision Transformer(ViT)的模型框架。简单而言,模型由三个模块组成:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder(图右侧有给出更加详细的结构)
  • MLP Head(最终用于分类的层结构)

Embedding层结构详解

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。


对于图像数据而言,其数据格式为[H, W, C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。如下图所示,首先将一张图片按给定大小分成一堆Patches。以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到(224/16)2=196(224/16)^2=196(224/16)2=196个Patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的向量(后面都直接称为token)。[16, 16, 3] -> [768]

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

在输入Transformer Encoder之前注意需要加上[class]token以及Position Embedding。 在原论文中,作者说参考BERT,在刚刚得到的一堆tokens中插入一个专门用于分类的[class]token,这个[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量,以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]。然后关于Position Embedding就是之前Transformer中讲到的Positional Encoding,这里的Position Embedding采用的是一个可训练的参数(1D Pos. Emb.),是直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]

对于Position Embedding作者也有做一系列对比试验,在源码中默认使用的是1D Pos. Emb.,对比不使用Position Embedding准确率提升了大概3个点,和2D Pos. Emb.比起来没太大差别。


Transformer Encoder详解

Transformer Encoder其实就是重复堆叠Encoder Block L次,下图是我自己绘制的Encoder Block,主要由以下几部分组成:

  • Layer Norm,这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理,之前也有讲过Layer Norm不懂的可以参考链接
  • Multi-Head Attention,这个结构之前在讲Transformer中很详细的讲过,不在赘述,不了解的可以参考链接
  • Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,在但rwightman实现的代码中使用的是DropPath(stochastic depth),可能后者会更好一点。
  • MLP Block,如图右侧所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]


MLP Head详解

上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。注意,在Transformer Encoder后其实还有一个Layer Norm没有画出来,后面有我自己画的ViT的模型可以看到详细结构。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。


自己绘制的Vision Transformer网络结构

为了方便大家理解,我自己根据源代码画了张更详细的图(以ViT-B/16为例):


Hybrid模型详解

在论文4.1章节的Model Variants中有比较详细的讲到Hybrid混合模型,就是将传统CNN特征提取和Transformer进行结合。下图绘制的是以ResNet50作为特征提取器的混合模型,但这里的Resnet与之前讲的Resnet有些不同。首先这里的R50的卷积层采用的StdConv2d不是传统的Conv2d,然后将所有的BatchNorm层替换成GroupNorm层。在原Resnet50网络中,stage1重复堆叠3次,stage2重复堆叠4次,stage3重复堆叠6次,stage4重复堆叠3次,但在这里的R50中,把stage4中的3个Block移至stage3中,所以stage3中共重复堆叠9次。

通过R50 Backbone进行特征提取后,得到的特征矩阵shape是[14, 14, 1024],接着再输入Patch Embedding层,注意Patch Embedding中卷积层Conv2d的kernel_size和stride都变成了1,只是用来调整channel。后面的部分和前面ViT中讲的完全一样,就不在赘述。

下表是论文用来对比ViT,Resnet(和刚刚讲的一样,使用的卷积层和Norm层都进行了修改)以及Hybrid模型的效果。通过对比发现,在训练epoch较少时Hybrid优于ViT,但当epoch增大后ViT优于Hybrid。


ViT模型搭建参数

在论文的Table1中有给出三个模型(Base/ Large/ Huge)的参数,在源码中除了有Patch Size为16x16的外还有32x32的。其中的Layers就是Transformer Encoder中重复堆叠Encoder Block的次数,Hidden Size就是对应通过Embedding层后每个token的dim(向量的长度),MLP size是Transformer Encoder中MLP Block第一个全连接的节点个数(是Hidden Size的四倍),Heads代表Transformer中Multi-Head Attention的heads数。

Model Patch Size Layers Hidden Size D MLP size Heads Params
ViT-Base 16x16 12 768 3072 12 86M
ViT-Large 16x16 24 1024 4096 16 307M
ViT-Huge 14x14 32 1280 5120 16 632M

Vision Transformer详解相关推荐

  1. 深度学习网络模型——Vision Transformer详解 VIT详解

    深度学习网络模型--Vision Transformer详解 VIT详解 通用深度学习网络效果改进调参训练公司自己的数据集,训练步骤记录: 代码实现version-Transformer网络各个流程, ...

  2. Swin Transformer详解: Hierarchical Vision Transformer using Shifted Windows

    这篇文章结合了CNN的归纳偏置,基于局部窗口做注意力,并且逐步融合到深层transformer层中构建表征,来达到扩大感受野,并且极大降低了计算量.是一个特征提取的主干网络,backbone.构建了一 ...

  3. 【NLP】Transformer详解

    [NLP]Transformer详解   Transformer在Google的一篇论文Attention is All You Need被提出,为了方便实现调用Transformer Google还 ...

  4. Transformer详解

    Transformer详解 1. 简介 Transformer是一个面向sequence to sequence任务的模型,在17年的论文<Attention is all you need&g ...

  5. 史上最小白之Transformer详解

    1.前言 博客分为上下两篇,您现在阅读的是下篇史上最小白之Transformer详解,在阅读该篇博客之前最好你能够先明白Encoder-Decoder,Attention机制,self-Attenti ...

  6. Transformer 详解(上) — 编码器【附pytorch代码实现】

    Transformer 详解(上)编码器 Transformer结构 文本嵌入层 位置编码 注意力机制 编码器之多头注意力机制层 编码器之前馈全连接层 规范化层和残差连接 代码实现Transforme ...

  7. 【转载】Transformer详解

    转载于https://blog.csdn.net/Tink1995/article/details/105080033 文章目录 1.前言 2.Transformer 原理 2.1 Transform ...

  8. NLP中的Attention注意力机制+Transformer详解

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 作者: JayLou娄杰 知乎链接:https://zhuanlan.zhihu. ...

  9. 【李宏毅】注意力机制+transformer详解

    Transformer 背景 关于RNN详解可以看这篇博客:循环神经网络 RNN一般被用来处理序列输入的,但是它有一个缺点就是不能并行化,后面一个神经元的输入要依赖与之前神经元的输出. 然后就有人提出 ...

最新文章

  1. c++ 异常处理(1)
  2. linux uname 命令详解
  3. dos 命令与 C++的联合作业,有点意思~
  4. 天津php二次开发培训,天津PHP后台开发培训短期班
  5. 动漫风html源码,CSS3动画制作一个卡通风格的404错误页面代码
  6. 在Java EE 7上骑骆驼–带有Swagger文档的REST服务
  7. 前端学习(1800):前端调试之清除浮动练习1
  8. 限制对web路径的访问
  9. php中将excel写入mysql数据库的示例
  10. 2017计算机c语言大纲,2017年计算机考研大纲
  11. 时间序列信号处理(一)-----变分模态分解(VMD)
  12. mysql dateofweek_日历表-月的周数
  13. 计算机科学与技术高中选课,2019-2021年新高考专业选课要求 大学个专业选科要求解读...
  14. C++实现排列组合问题
  15. abs函数c语言std,c++ 在std :: abs函数上
  16. CC2530+PA(CC2590)开启功率放大模块功能说明
  17. 如何选择适合你的兴趣爱好(二十二),羽毛球
  18. minio数据迁移工具rclone
  19. Kubernetes上安装weblogic monitoring exporter
  20. matlab中cond为啥比bet好,matlab用于超短脉冲中啁啾与色散概念的理解

热门文章

  1. 【青山css】css3阴影效果属性详解及创意玩法
  2. 三年级计算机群鸭戏水教案导入,信息技术教案_群鸭戏水教学设计.docx
  3. JavaScript设计模式(三) - 策略模式
  4. 俄罗斯方块c 语言课程设计流程图,C语言课程设计俄罗斯方块源代码详细分解.doc...
  5. Android流式布局的实现原理
  6. coolshell 谜题通关
  7. php各种加密特征,php三种常用的加密解密算法(介绍)
  8. 25条让人哭笑不得的趣味小
  9. STM32H7 SDMMC+FATFS+USBMSC+FREERTOS 虚拟U盘
  10. onmouseover事件