Vision Transformer详解
论文名称: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
论文下载链接:https://arxiv.org/abs/2010.11929
原论文对应源码:https://github.com/google-research/vision_transformer
Pytorch实现代码: pytorch_classification/vision_transformer
Tensorflow2实现代码:tensorflow_classification/vision_transformer
在bilibili上的视频讲解:https://www.bilibili.com/video/BV1Jh411Y7WQ
文章目录
- 前言
- 模型详解
- Vision Transformer模型详解
- Embedding层结构详解
- Transformer Encoder详解
- MLP Head详解
- 自己绘制的Vision Transformer网络结构
- Hybrid模型详解
- ViT模型搭建参数
前言
Transformer最初提出是针对NLP领域的,并且在NLP领域大获成功。这篇论文也是受到其启发,尝试将Transformer应用到CV领域。关于Transformer的部分理论之前的博文中有讲,链接,这里不在赘述。通过这篇文章的实验,给出的最佳模型在ImageNet1K上能够达到88.55%的准确率(先在Google自家的JFT数据集上进行了预训练),说明Transformer在CV领域确实是有效的,而且效果还挺惊人。
模型详解
在这篇文章中,作者主要拿ResNet、ViT(纯Transformer模型)以及Hybrid(卷积和Transformer混合模型)三个模型进行比较,所以本博文除了讲ViT模型外还会简单聊聊Hybrid模型。
Vision Transformer模型详解
下图是原论文中给出的关于Vision Transformer(ViT)的模型框架。简单而言,模型由三个模块组成:
- Linear Projection of Flattened Patches(Embedding层)
- Transformer Encoder(图右侧有给出更加详细的结构)
- MLP Head(最终用于分类的层结构)
Embedding层结构详解
对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。
对于图像数据而言,其数据格式为[H, W, C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。如下图所示,首先将一张图片按给定大小分成一堆Patches。以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到(224/16)2=196(224/16)^2=196(224/16)2=196个Patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的向量(后面都直接称为token)。[16, 16, 3] -> [768]
在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768]
,然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768]
,此时正好变成了一个二维矩阵,正是Transformer想要的。
在输入Transformer Encoder之前注意需要加上[class]token以及Position Embedding。 在原论文中,作者说参考BERT,在刚刚得到的一堆tokens中插入一个专门用于分类的[class]token,这个[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量,以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]
。然后关于Position Embedding就是之前Transformer中讲到的Positional Encoding,这里的Position Embedding采用的是一个可训练的参数(1D Pos. Emb.
),是直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768]
,那么这里的Position Embedding的shape也是[197, 768]
。
对于Position Embedding作者也有做一系列对比试验,在源码中默认使用的是1D Pos. Emb.
,对比不使用Position Embedding准确率提升了大概3个点,和2D Pos. Emb.
比起来没太大差别。
Transformer Encoder详解
Transformer Encoder其实就是重复堆叠Encoder Block L次,下图是我自己绘制的Encoder Block,主要由以下几部分组成:
- Layer Norm,这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理,之前也有讲过Layer Norm不懂的可以参考链接
- Multi-Head Attention,这个结构之前在讲Transformer中很详细的讲过,不在赘述,不了解的可以参考链接
- Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,在但
rwightman
实现的代码中使用的是DropPath(stochastic depth),可能后者会更好一点。 - MLP Block,如图右侧所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍
[197, 768] -> [197, 3072]
,第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]
MLP Head详解
上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]
输出的还是[197, 768]
。注意,在Transformer Encoder后其实还有一个Layer Norm没有画出来,后面有我自己画的ViT的模型可以看到详细结构。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]
中抽取出[class]token对应的[1, 768]
。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear
+tanh激活函数
+Linear
组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear
即可。
自己绘制的Vision Transformer网络结构
为了方便大家理解,我自己根据源代码画了张更详细的图(以ViT-B/16为例):
Hybrid模型详解
在论文4.1章节的Model Variants
中有比较详细的讲到Hybrid混合模型,就是将传统CNN特征提取和Transformer进行结合。下图绘制的是以ResNet50作为特征提取器的混合模型,但这里的Resnet与之前讲的Resnet有些不同。首先这里的R50的卷积层采用的StdConv2d不是传统的Conv2d,然后将所有的BatchNorm层替换成GroupNorm层。在原Resnet50网络中,stage1重复堆叠3次,stage2重复堆叠4次,stage3重复堆叠6次,stage4重复堆叠3次,但在这里的R50中,把stage4中的3个Block移至stage3中,所以stage3中共重复堆叠9次。
通过R50 Backbone进行特征提取后,得到的特征矩阵shape是[14, 14, 1024]
,接着再输入Patch Embedding层,注意Patch Embedding中卷积层Conv2d的kernel_size和stride都变成了1,只是用来调整channel。后面的部分和前面ViT中讲的完全一样,就不在赘述。
下表是论文用来对比ViT,Resnet(和刚刚讲的一样,使用的卷积层和Norm层都进行了修改)以及Hybrid模型的效果。通过对比发现,在训练epoch较少时Hybrid优于ViT,但当epoch增大后ViT优于Hybrid。
ViT模型搭建参数
在论文的Table1中有给出三个模型(Base/ Large/ Huge)的参数,在源码中除了有Patch Size为16x16
的外还有32x32
的。其中的Layers就是Transformer Encoder中重复堆叠Encoder Block的次数,Hidden Size就是对应通过Embedding层后每个token的dim(向量的长度),MLP size是Transformer Encoder中MLP Block第一个全连接的节点个数(是Hidden Size的四倍),Heads代表Transformer中Multi-Head Attention的heads数。
Model | Patch Size | Layers | Hidden Size D | MLP size | Heads | Params |
---|---|---|---|---|---|---|
ViT-Base | 16x16 | 12 | 768 | 3072 | 12 | 86M |
ViT-Large | 16x16 | 24 | 1024 | 4096 | 16 | 307M |
ViT-Huge | 14x14 | 32 | 1280 | 5120 | 16 | 632M |
Vision Transformer详解相关推荐
- 深度学习网络模型——Vision Transformer详解 VIT详解
深度学习网络模型--Vision Transformer详解 VIT详解 通用深度学习网络效果改进调参训练公司自己的数据集,训练步骤记录: 代码实现version-Transformer网络各个流程, ...
- Swin Transformer详解: Hierarchical Vision Transformer using Shifted Windows
这篇文章结合了CNN的归纳偏置,基于局部窗口做注意力,并且逐步融合到深层transformer层中构建表征,来达到扩大感受野,并且极大降低了计算量.是一个特征提取的主干网络,backbone.构建了一 ...
- 【NLP】Transformer详解
[NLP]Transformer详解 Transformer在Google的一篇论文Attention is All You Need被提出,为了方便实现调用Transformer Google还 ...
- Transformer详解
Transformer详解 1. 简介 Transformer是一个面向sequence to sequence任务的模型,在17年的论文<Attention is all you need&g ...
- 史上最小白之Transformer详解
1.前言 博客分为上下两篇,您现在阅读的是下篇史上最小白之Transformer详解,在阅读该篇博客之前最好你能够先明白Encoder-Decoder,Attention机制,self-Attenti ...
- Transformer 详解(上) — 编码器【附pytorch代码实现】
Transformer 详解(上)编码器 Transformer结构 文本嵌入层 位置编码 注意力机制 编码器之多头注意力机制层 编码器之前馈全连接层 规范化层和残差连接 代码实现Transforme ...
- 【转载】Transformer详解
转载于https://blog.csdn.net/Tink1995/article/details/105080033 文章目录 1.前言 2.Transformer 原理 2.1 Transform ...
- NLP中的Attention注意力机制+Transformer详解
关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 作者: JayLou娄杰 知乎链接:https://zhuanlan.zhihu. ...
- 【李宏毅】注意力机制+transformer详解
Transformer 背景 关于RNN详解可以看这篇博客:循环神经网络 RNN一般被用来处理序列输入的,但是它有一个缺点就是不能并行化,后面一个神经元的输入要依赖与之前神经元的输出. 然后就有人提出 ...
最新文章
- c++ 异常处理(1)
- linux uname 命令详解
- dos 命令与 C++的联合作业,有点意思~
- 天津php二次开发培训,天津PHP后台开发培训短期班
- 动漫风html源码,CSS3动画制作一个卡通风格的404错误页面代码
- 在Java EE 7上骑骆驼–带有Swagger文档的REST服务
- 前端学习(1800):前端调试之清除浮动练习1
- 限制对web路径的访问
- php中将excel写入mysql数据库的示例
- 2017计算机c语言大纲,2017年计算机考研大纲
- 时间序列信号处理(一)-----变分模态分解(VMD)
- mysql dateofweek_日历表-月的周数
- 计算机科学与技术高中选课,2019-2021年新高考专业选课要求 大学个专业选科要求解读...
- C++实现排列组合问题
- abs函数c语言std,c++ 在std :: abs函数上
- CC2530+PA(CC2590)开启功率放大模块功能说明
- 如何选择适合你的兴趣爱好(二十二),羽毛球
- minio数据迁移工具rclone
- Kubernetes上安装weblogic monitoring exporter
- matlab中cond为啥比bet好,matlab用于超短脉冲中啁啾与色散概念的理解
热门文章
- 【青山css】css3阴影效果属性详解及创意玩法
- 三年级计算机群鸭戏水教案导入,信息技术教案_群鸭戏水教学设计.docx
- JavaScript设计模式(三) - 策略模式
- 俄罗斯方块c 语言课程设计流程图,C语言课程设计俄罗斯方块源代码详细分解.doc...
- Android流式布局的实现原理
- coolshell 谜题通关
- php各种加密特征,php三种常用的加密解密算法(介绍)
- 25条让人哭笑不得的趣味小
- STM32H7 SDMMC+FATFS+USBMSC+FREERTOS 虚拟U盘
- onmouseover事件