摘要

虽然Transformer架构已经是NLP领域的一个标准,但是应用transformer到CV领域效果还是很有限的。在视觉领域,自注意力要么和卷积神经网络一起使用,要么是将卷积神经网络里的卷积替换成自注意力。但是仍然保持整体结构不变。本文证明了这种对于卷积神经网络的依赖是完全不必要的,一个纯的transformer直接作用于一系列图像块的时候也可以在图像分类任务上表现很好。尤其是当我们在大规模的数据集上进行预训练,然后迁移到中小型数据集上使用时,Vision transformer能获得和最好的卷积神经网络相媲美的结果。这里我们将CIFAR-100,Imagenet、VATB当作小数据集。

1 引言

自注意力机制网络,尤其是Transformer已经是自然语言处理领域的必选模型了。现在比较主流的方式是先去一个大规模的数据集上去做预训练,然后再在一些特定领域的小数据集上进行微调(fine-tune)。多亏了Transformer的计算高效性和可扩展性,现在已经可以训练超过1000亿参数的模型了。随着模型和数据集的增长,我们还没有看到任何性能饱和的现象。

这里介绍下将Transformer应用到视觉问题上的一些难处。先回顾下transformer,假设我们有如下的Transformer的Encoder以及一些输入元素(在自然语言处理中,这就是一个句子里面的一个一个的单词):

Transformer里面最主要的就是自注意力操作,自注意力操作就是每个元素都要和每个元素去做互动,然后计算得到一个Attention图,接下来用这个Attention图去进行加权平均,最后得到输出:

因为在做自注意力时我们是两两相互的,这个计算复杂度是和序列的长度成乘方关系O(n2)O(n^2)O(n2),目前一般在自然语言处理领域,硬件能够支持的序列长度也就是几百或者上千。在BERT里,序列长度是512。

在视觉领域,首先要解决的任务是将2d图片转化为1d的序列。最直观的方式是把像素图拉直,但是复杂度太高 224×224=50176224\times224=50176224×224=50176

回到引言第二段,在视觉领域,卷积神经网络仍然占主导地位。但是由于到Transformer在NLP领域的成功,现在许多工作尝试把CNN结构和自注意力结构结合在一起,还有一些工作将整个卷积神经都替换掉了,全部使用自注意力。这些方法其实都是在降低序列长度。完全用自注意力替代卷积的这一类工作虽然理论上非常高效,但事实上由于他们的自注意力操作都是一些比较特殊的自注意力操作,没有应用到现有的硬件结构进行加速,所以很难训练一个较大的模型。因此,在大规模的图像识别上,传统的ResNet结构网络还是效果最好的。

作者被Transformer在NLP领域的可扩展性启发,他们希望直接将一个标准的Transformer应用于图片,尽量少做修改(不针对视觉任务进行特定的改变)。Vision Transformer将一个图片打成了很多16×1616\times 1616×16的patch,此时宽度和高都是224/16=14224/16=14224/16=14,最后序列长度就变成了196。Vision Transformer将每个patch作为一个元素,通过fc layer可以得到linear embedding,这些会被当作输入传递给transformer。我们可以把这些patch作为NLP里面的单词,训练Vision Transformer使用的是有监督训练。

当在中型大小的数据集上(比如ImageNet)进行训练时,如果不加比较强的约束(strong regularization),ViT和同等大小的残差网络相比,是要弱几个点的。这个看起来不太好的结果其实是可以预期的,因为Transformer和卷积神经网络相比,缺少一些卷积神经网络有的归纳偏置(inductive biases,这指一些先验知识,即我们提前做好的假设,对于卷积神经网络来说,我们常说的有两个inductive bias,一个叫做locality,由于卷积神经网络以滑动窗口的形式一点点在图片上进行卷积,所以假设图片上相邻的区域会有相邻的特征,另外一个inductive bias叫做平移等变性,translation equivariance,f(g(x))=g(f(x))f(g(x))=g(f(x))f(g(x))=g(f(x)),可以把ggg理解为平移,fff理解为卷积),由于拥有这些inductive bias,卷积神经网络具有许多先验信息,所以可以应用相对少的数据来训练一个较好的模型,但是对于Transformer来说,没有这些先验信息,对于vision领域的知识全部需要自己学习。

为了验证这个假设,作者在更大的数据集上进行了预训练(14M指代ImageNet 22K数据集,300M指代Google的JFT 300M数据集。大规模的预训练表明优于归纳偏置。Vision Transformer只要在有足够数据进行预训练的情况下就能在下游任务上获得较好的迁移学习效果。在ImageNet 21k或者JFT-300M上进行训练时,ViT就能获得和现在最好的残差网络相近,或者说更好的结果。具体而言,在ImageNet上实现了88.55%,在ImageNet-ReaL上实现了90.72%,在CIFAR-100上实现了94.55%,在VTAB上实现了77.63%(这个数据集融合了19个数据集,主要用于测试鲁棒性)。

2 相关工作

24.39

3 方法

模型设计是尽可能贴近原始的transformer,这样做的好处是可以直接把NLP那边已经成功的Transformer架构直接拿过来用,不用自己再去魔改模型。而且因为Transformer已经在NLP领域火了这么久,现在有一些写得非常高效的实现,同样Vision Transformer可以直接拿过来用。

3.1 Vision transformer

模型总览图如图一所示:


标准的Transformer需要一系列1D序列作为输入,所以为了符合Transformer的结构,我们将图片x∈RH×W×Cx\in\mathbb{R}^{H\times W\times C}xRH×W×C变成一系列展平的2D patches xp∈RN×(P2×C)x_p\in\mathbb{R}^{N\times (P^2\times C)}xpRN×(P2×C),这里(H,W)(H,W)(H,W)是原始图片的分辨率,CCC是通道数,(P,P)(P,P)(P,P)是每个图片patch的分辨率,N=HW/P2N=HW/P^2N=HW/P2是patch的数量,这就是最终传入Transformer的有效序列长度。Transformer从头到尾都是使用DDD作为向量长度(768),为了和Transformer的维度相匹配,所以我们的图像patch维度也设定为768(具体做法是使用了一个可以训练的linear projection,即全连接层)。从这个全连接层出来的向量我们称之为patch embedding。

为了进行最后的分类,作者借鉴了BERT里面的 [class][class][class] token,这个token是一个可以学习的特征,且和图像的特征具有相同的维度,token初始表示为z00=xclassz_0^0=x_{class}z00=xclass,经过多层Transformer处理后,我们将token表示为zL0z_L^0zL0,此时我们将这个token当成整个Transformer的输出,也就是当作整个图片的特征。在pre-training以及fine-tuning阶段,一个分类头都连接到了zL0z_L^0zL0。这个分类头是一个MLP,这个MLP在pre-training阶段有一个hidden layer,在fine-tuning阶段有一个linear layer。

位置编码信息被添加到patch embedding中来保留位置信息。本文使用标准的可以学习的1D position embedding,也就是BERT里面使用的位置编码。作者也尝试了其他编码形式(因为我们是针对图像任务),例如一个2D-aware的位置编码。实验显示最后的结果相差不大。

作者用公式描述了整体的过程:

z0=[xclass;xp1E;xp2E;…;xpNE]+Epos,E∈R(N+1)×Dz_0=[x_{class};x_p^1E;x_p^2E;\dots;x_p^NE]+E_{pos},\quad\quad E\in\mathbb{R}^{(N+1)\times D}z0=[xclass;xp1E;xp2E;;xpNE]+Epos,ER(N+1)×D

这里xp1,xp2x_p^1,x_p^2xp1,xp2等其实就是这些图像块中的patch,一共有NNN个patch,每个patch先和linear projection(这里表示为EEE)进行转换,从而得到patch embedding,在得到这些linear embedding后,我们在前面拼接一个class embedding,利用它得到最后的输出。在得到了所有的tokens后,我们需要对这些token进行位置编码,我们将位置编码信息EposE_{pos}Epos直接加入矩阵中,此时z0z_0z0就是transformer的输入了。接下来是一个循环:

zℓ′=MSA(LN(zℓ−1))+zℓ−1ℓ=1...Lzℓ=MLP(LN(zℓ′))+zℓ′ℓ=1...Lz'_{\ell}=\text{MSA}(\text{LN}(z_{\ell-1}))+z_{\ell-1}\quad\quad\quad \ell=1...L\\ z_{\ell}=\text{MLP}(\text{LN}(z'_{\ell}))+z'_{\ell}\quad\quad\quad\quad\quad\ell=1...Lz=MSA(LN(z1))+z1=1...Lz=MLP(LN(z))+z=1...L

对于每个Transformer block来说,里面都有两个操作,一个是MLP,另外一个是MSA(多头自注意力),在进行这两个操作前,我们要先经过layer norm(LN),然后每一层出来的结果都要去再经过一次残差连接。zℓ′z_{\ell}'z就是多头自注意力的结果,zℓz_{\ell}z就是每个Transformer block整体出来的结果。在l层循环后,我们将zL0z_{L}^0zL0,也就是最后一层的第一个token拿出来,当作整体图像的一个特征,从而去做最后的这个分类任务。

归纳偏置:Vision Transformer相比CNN而言少了许多图像特有的归纳偏置,比如CNN里面存在的locality和translation equivariance。在ViT中,只有MLP层是局部且平移等变性的。但是自注意力层是全局的。

混合结构:我们原来有一张图片,然后我们用Res50等结构去处理得到特征图(14×1414\times1414×14),这时这个特征图也是196个元素,然后我们用新得到的这196个元素去和全连接层进行操作。得到新的patch embedding,其实这就是两种不同的对图片预处理的方式。

3.2

52分钟左右

其主要包括以下模块:

图片预处理:

作者将x∈RH×W×Cx\in\mathbb{R}^{H\times W\times C}xRH×W×C的图片,变成一个xp∈RN×(P2⋅C)x_p\in\mathbb{R}^{N\times (P^2\cdot C)}xpRN×(P2C)的sequence of flattened 2D patches。这可以看做是一个2D块序列,序列中一共有N=HW/P2N=HW/P^2N=HW/P2个展平的2D块,每个块的维度是(P2⋅C)(P^2\cdot C)(P2C),其中PPP是块大小,CCC是通道数。

作者进行这步的意图是:因为Transformer希望输入是一个二维的矩阵(N,D)(N,D)(N,D),其中NNN是序列长度,DDD是序列中每个向量的维度(常用256)。所以这里我们也要设法将H×W×CH\times W\times CH×W×C的三维图片转化为(N,D)(N,D)(N,D)的二维输入。

对应代码是:

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)

具体使用了einops库,具体可以参考这篇文章

现在得到的向量维度是:xp∈N×(P2⋅C)x_p\in N\times (P^2\cdot C)xpN×(P2C),要转化成(N,D)(N,D)(N,D)的二维输入,我们还需要进一步叫做Patch Embedding的步骤。

Patch Embedding

这步要做的是对每个向量都做一个线性变换(即全连接层),压缩后的维度为DDD,我们称其为Patch Embedding。

z0=[xclass;xp1E;xp2E;…;xpNE]+Epos,E∈R(P2⋅C)×D,Epos∈R(N+1)×Dz_0=[x_{class};\ x_p^1E; x_p^2E;\dots;\ x_p^NE]+E_{pos},\quad\quad E\in\mathbb{R}^{(P^2\cdot C)\times D},\ E_{pos}\in\mathbb{R}^{(N+1)\times D}z0=[xclass;xp1E;xp2E;;xpNE]+Epos,ER(P2C)×D,EposR(N+1)×D

全连接层就是上式中的EEE,它的输入维度大小是(P2⋅C)(P^2\cdot C)(P2C),输出维度大小是DDD

注意上面式子中存在一个xclassx_{class}xclass,假设将整个图片切成9个块,但是最终输入到transformer中的是10个向量,这是人为增加的一个向量。

为什么要追加这个向量?

如果没有这个向量,假设N=9N=9N=9个向量输入transformer encoder,输出9个编码向量,然后呢?对于分类任务而言,我们应该用哪个输出向量进行后续分类呢?

所以我们干脆使用一个向量xclass(vector,dim=D)x_{class}(vector,dim=D)xclass(vector,dim=D),这个向量是可学习的嵌入向量,它和那9个向量一起输入transformer encoder,输出1+9个编码向量,然后使用第0个编码向量,即xclassx_{class}xclass的输出进行分类预测即可。

这么做的原因可以理解为:ViT只用到了transformer的encoder,而并没有用到decoder,而xclassx_{class}xclass的作用有点类似于解码器中的Query的作用,相对应的Key,Value就是其他9个编码向量的输出。

xclassx_{class}xclass是一个可学习的嵌入向量,它的意义说通俗一点为:寻找其他9个输入向量对应的image的类别。

代码为:

 dim=1024
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))# forward前向代码
# 变成(b,64,1024)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 跟前面的分块进行concat
# 额外追加token,变成b,65,1024
x = torch.cat((cls_tokens, x), dim=1)

Positional Encoding

按照transformer的位置编码方式,在本文中同样使用了位置编码。引入了一个positional encoder EposE_{pos}Epos 来加入位置信息,同样在这里引入了pos embedding,这是一个可训练的变量。

在ViT中,我们没有使用原版Transformer的sincos编码,而是直接设置为可学习的Positional Encoding,这两个的效果差不多。我们对训练好的pos embedding进行可视化,如下图所示。


可以发现,位置越接近,往往具有更相似的位置编码。此外,还出现了行列结构;同一行/列中的patch具有相似的位置编码。

代码表示如下:

# num_patches=64,dim=1024,+1是因为多了一个cls开启解码标志
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

Transformer的前向过程:

z0=[xclass;xp1E;xp2E;…;xpNE]+Epos,E∈R(P2⋅C)×D,Epos∈R(N+1)×Dzl′=MSA(LN(zl−1))+zl−1,l=1…,Lzl=MLP(LN(zl′))+zl′,l=1…Ly=LN(zL0)z_0=[x_{class};\ x_p^1E; x_p^2E;\dots;\ x_p^NE]+E_{pos},\quad\quad\quad\quad E\in\mathbb{R}^{(P^2\cdot C)\times D},\ E_{pos}\in\mathbb{R}^{(N+1)\times D}\\ z_l'=\text{MSA}(\text{LN}(z_{l-1}))+z_{l-1},\quad\quad\quad\quad\quad\quad\quad\ l=1\dots,L\\z_l=\text{MLP}(\text{LN}(z_l'))+z_l',\quad\quad\quad\quad\quad\quad\quad\quad\quad l=1\dots L\\y=\text{LN}(z_L^0)\quad\quad\quad\quad\quad\quad\quad\quad\quad\quadz0=[xclass;xp1E;xp2E;;xpNE]+Epos,ER(P2C)×D,EposR(N+1)×Dzl=MSA(LN(zl1))+zl1,l=1,Lzl=MLP(LN(zl))+zl,l=1Ly=LN(zL0)

其中,第一个式子为上面提及的Patch Embedding 和 Positional Encoding的过程。

第二个式子为Transformer Encoder的Multi-head Self-attention, Add and Norm的过程,重复L次。

第三个式子为Transformer Encoder的Feed Forward Network,Add and Norm的过程,重复L次。

作者采用的是没有任何改动的transformer。

最后一个是MLP的Classification Head,整体的结构只有这些,如下图所示(变量的维度变化过程标注在了图中):

4 实验

这个章节主要对比了ResNet,Vision Transformer以及混合模型的表征学习能力。为了了解训练好每个模型到底需要多少数据,我们在不同大小的数据集上进行预训练,然后在很多benchmark上进行测试。当考虑到预训练的代价,即预训练的时间长短时,ViT表现的非常好,能在大多数数据集上取得最好的结果。同时需要更少的时间去训练。最后作者还做了一个小小的自监督实验,结果还可以,说自监督的ViT还是比较有潜力的。

4.1 设置

数据集:作者使用了ILSVRC-2012 ImageNet数据集,同时使用了大家最普遍使用的这1000个类(称为ImageNet-1k),和更大规模的数据集(ImageNet-21k),作者还使用了JFT数据集(Google自己的数据集,包含了三亿张图片)。

5 结论

An Image is worth 16x16 words:transformers for image recognition at scale相关推荐

  1. An Image is worth 16*16 words: Transformers for image recognition at scale.

    An Image is worth 16*16 words: Transformers for image recognition at scale. Abstract 虽然Transformer架构 ...

  2. AN IMAGE IS WORTH 16X16 WORDS :TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE(VIT)

    最近看transformer用于CV比较热门,特意去进行了解,这里用分类的一篇文章进行讲解. NLP中的transformer和代码讲解参考我另一篇文章. 论文链接:AN IMAGE IS WORTH ...

  3. AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE——ViT全文翻译

    一文读懂ViT:ViT 快速理解 Vision in Transformer 文章目录 全文翻译- Vision in Transformer- 相关说明 基本信息介绍 ABSTRACT 1 INTR ...

  4. 【读点论文】AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE(ViT)像处理自然语言那样处理图片

    AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE abstract 虽然Transformer体系结 ...

  5. 李沐精读论文:ViT 《An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale》

    视频:ViT论文逐段精读[论文精读]_哔哩哔哩_bilibili 代码:论文源码 使用pytorch搭建Vision Transformer(vit)模型 vision_transforme · WZ ...

  6. 论文解读:ViT | AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

    发表时间:2021 论文地址:https://openreview.net/pdf?id=YicbFdNTTy 项目地址:https://github.com/lucidrains/vit-pytor ...

  7. 用Transformer完全代替CNN:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

    原文地址:https://zhuanlan.zhihu.com/p/266311690 论文地址:https://arxiv.org/pdf/2010.11929.pdf 代码地址:https://g ...

  8. 【Transformer】ViT:An image is worth 16x16: transformers for image recognition at scale

    文章目录 一.背景和动机 二.方法 三.效果 四.Vision Transformer 学习到图像的哪些特征了 五.代码 代码链接:https://github.com/lucidrains/vit- ...

  9. 重读经典:《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》

    ViT论文逐段精读[论文精读] 这次李沐博士邀请了亚马逊计算机视觉专家朱毅博士来精读 Vision Transformer(ViT),强烈推荐大家去看本次的论文精读视频.朱毅博士讲解的很详细,几乎是逐 ...

  10. VIT: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale(arXiv 2021)

    最前面是论文翻译,中间是背景+问题+方法步骤+实验过程,最后是文中的部分专业名词介绍(水平线分开,能力有限,部分翻译可能不太准确) 摘要: 尽管Tansformer结构已经成为自然语言处理的事实标准, ...

最新文章

  1. 『科技』2019全球最有前景AI公司TOP100
  2. win10安装pytorch
  3. BZOJ3160:万径人踪灭
  4. VALVE SURVEY RESULTS
  5. Python3网络爬虫(四): 登录
  6. IPSEC非单播流量处理
  7. Educational Codeforces Round 112 (Rated for Div. 2)(A-D)
  8. 通过网络地址进行真机调试
  9. DNS协议及客户端实现
  10. PHP面相对象中的重载与重写
  11. Python的类与对象
  12. python手册中文版-python手册中文
  13. springboot整合websocket实现微信小程序聊天
  14. npm加速器、github加速器
  15. 视频压缩软件APP有哪些?让我来告诉你答案
  16. 北漂三年多 我选择离开,眼神更加坚定!
  17. 一款让人耳目一新的事件驱动型RTOS
  18. 用有数据的单元格内容向下填充空白单元格
  19. 信捷PLc的C语言大小排序筛选,信捷PLC顺序控制怎么写
  20. Coderwars使用

热门文章

  1. Linux设置node的process.env.NODE_ENV
  2. PHP实现单向链表解决约瑟夫环问题
  3. 静态的通讯录(C语言)
  4. Flink SQL 系列 | 5 个 TableEnvironment 我该用哪个?
  5. 老板要我开发一个简单的工作流引擎 !
  6. 职场不是家,不会方法,如何混职场
  7. c++ opengl 三维图形中显示文字_opengl基本流程
  8. Qt利用avilib实现录屏功能_如何找到电脑录屏功能?4种方法教你一键打开,不会用来学一学...
  9. html如何连接外部网页,怎么链接一个外部的css文件?
  10. ATMV1函数版v1