作者丨小小理工男@知乎

来源丨https://zhuanlan.zhihu.com/p/266311690

编辑丨极市平台

这里将介绍一篇我认为是比较新颖的一篇文章 ——《An Image Is Worth 16X16 Words: Transformers for Image Recognition at Scale》[1]。因为还是 ICLR 2021 under review,所以作者目前还是匿名的,但是看其实验用到的TPU,能够大概猜出应该是Google爸爸的文章(看着实验的配置,不得不感慨钞能力的力量)。

1. Story

近年来,Transformer已经成了NLP领域的标准配置,但是CV领域还是CNN(如ResNet, DenseNet等)占据了绝大多数的SOTA结果。

最近CV界也有很多文章将transformer迁移到CV领域,这些文章总的来说可以分为两个大类:

  1. 将self-attention机制与常见的CNN架构结合;

  2. 用self-attention机制完全替代CNN。

本文采用的也是第2种思路。虽然已经有很多工作用self-attention完全替代CNN,且在理论上效率比较高,但是它们用了特殊的attention机制,无法从硬件层面加速,所以目前CV领域的SOTA结果还是被CNN架构所占据。

文章不同于以往工作的地方,就是尽可能地将NLP领域的transformer不作修改地搬到CV领域来。但是NLP处理的语言数据是序列化的,而CV中处理的图像数据是三维的(长、宽和channels)。

所以我们需要一个方式将图像这种三维数据转化为序列化的数据。文章中,图像被切割成一个个patch,这些patch按照一定的顺序排列,就成了序列化的数据。(具体将在下面讲述)

在实验中,作者发现,在中等规模的数据集上(例如ImageNet),transformer模型的表现不如ResNets;而当数据集的规模扩大,transformer模型的效果接近或者超过了目前的一些SOTA结果。作者认为是大规模的训练可以鼓励transformer学到CNN结构所拥有的translation equivariance 和locality.

translation equivariance解释:

https://aboveintelligent.com/ml-cnn-translation-equivariance-and-invariance-da12e8ab7049

2. Model

Vision Transformer (ViT)结构示意图

模型的结构其实比较简单,可以分成以下几个部分来理解:

a. 将图像转化为序列化数据

作者采用了了一个比较简单的方式。如下图所示。首先将图像分割成一个个patch,然后将每个patch reshape成一个向量,得到所谓的flattened patch。

具体地,如果图片是 








维的,用






大小的patch去分割图片可以得到 




个patch,那么每个patch的shape就是 








,转化为向量后就是 








维的向量,将 




个patch reshape后的向量concat在一起就得到了一个 












的二维矩阵,相当于NLP中输入transformer的词向量。

分割图像得到patch

从上面的过程可以看出,当patch的大小变化时(即 




变化时),每个patch reshape后得到的 








维向量的长度也会变化。为了避免模型结构受到patch size的影响,作者对上述过程得到的flattened patches向量做了Linear Projection(如下图所示),将不同长度的flattened patch向量转化为固定长度的向量(记做 




维向量)。

对flattened patches做linear projection

综上,原本








维的图片被转化为了 




个 




维的向量(或者一个 






维的二维矩阵)。

b. Position embedding

positiion embedding示意图

由于transformer模型本身是没有位置信息的,和NLP中一样,我们需要用position embedding将位置信息加到模型中去。

如上图所示1,编号有0-9的紫色框表示各个位置的position embedding,而紫色框旁边的粉色框则是经过linear projection之后的flattened patch向量。文中采用将position embedding(即图中紫色框)和patch embedding(即图中粉色框)相加的方式结合position信息。

c. Learnable embedding

如果大家仔细看上图,就会发现带星号的粉色框(即0号紫色框右边的那个)不是通过某个patch产生的。这个是一个learnable embedding(记作 













),其作用类似于BERT中的[class] token。在BERT中,[class] token经过encoder后对应的结果作为整个句子的表示;类似地,这里 













经过encoder后对应的结果也作为整个图的表示。

至于为什么BERT或者这篇文章的ViT要多加一个token呢?因为如果人为地指定一个embedding(例如本文中某个patch经过Linear Projection得到的embedding)经过encoder得到的结果作为整体的表示,则不可避免地会使得整体表示偏向于这个指定embedding的信息(例如图像的表示偏重于反映某个patch的信息)。而这个新增的token没有语义信息(即在句子中与任何的词无关,在图像中与任何的patch无关),所以不会造成上述问题,能够比较公允地反映全图的信息。

d. Transformer encoder

Transformer Encoder结构和NLP中transformer结构基本上相同,所以这里只给出其结构图,和公式化的计算过程,也是顺便用公式表达了之前所说的几个部分内容。

Transformer Encoder的结构如下图所示:

Transformer Encoder结构图

对于Encoder的第 




层,记其输入为 











,输出为 







,则计算过程为:

其中MSA为Multi-Head Self-Attention(即Transformer Encoder结构图中的绿色框),MLP为Multi-Layer Perceptron(即Transformer Encoder结构图中的蓝色框),LN为Layer Norm(即Transformer Encoder结构图中的黄色框)。

Encoder第一层的输入 







是通过下面的公式得到的:

其中 
















即未Linear Projection后的patch embedding(都是 








维),右乘 










维的矩阵




表示Linear Projection,得到的 


















都是 




维向量;这 




个 




维向量和同样是 




维向量的 













concat就得到了 










维矩阵。加上 






个 




维position embedding拼成的










维矩阵 











,即得到了encoder的原始输入 







3. 混合结构

文中还提出了一个比较有趣的解决方案,将transformer和CNN结合,即将ResNet的中间层的feature map作为transformer的输入。

和之前所说的将图片分成patch然后reshape成sequence不同的是,在这种方案中,作者直接将ResNet某一层的feature map reshape成sequence,再通过Linear Projection变为Transformer输入的维度,然后直接输入进Transformer中。

4. Fine-tuning过程中高分辨率图像的处理

在Fine-tuning到下游任务时,当图像的分辨率增大时(即图像的长和宽增大时),如果保持patch大小不变,得到的patch个数将增加(记分辨率增大后新的patch个数为 







)。但是由于在pretrain时,position embedding的个数和pretrain时分割得到的patch个数(即上文中的 




)相同。则多出来的 









个positioin embedding在pretrain中是未定义或者无意义的。

为了解决这个问题,文章中提出用2D插值的方法,基于原图中的位置信息,将pretrain中的 




个position embedding插值成 







个。这样在得到 







个position embedding的同时也保证了position embedding的语义信息。

5. 实验

实验部分由于涉及到的细节较多就不具体介绍了,大家如果感兴趣可以参看原文。(不得不说Google的实验能力和钞能力不是一般人能比的...)

主要的实验结论在story中就已经介绍过了,这里复制粘贴一下:在中等规模的数据集上(例如ImageNet),transformer模型的表现不如ResNets;而当数据集的规模扩大,transformer模型的效果接近或者超过了目前的一些SOTA结果。

比较有趣的是,作者还做了很多其他的分析来解释transfomer的合理性。大家如果感兴趣也可以参看原文,这里放几张文章中的图。

参考

1.https://openreview.net/forum?id=YicbFdNTTy

觉得有用麻烦给个在看啦~  

用Transformer完全替代CNN?相关推荐

  1. CV领域,Transformer在未来有可能替代CNN吗?

    Transformer在CV领域得到广泛关注,从Vision Transformer到层出不穷的变种,不断地刷新了各项任务地榜单.在CV领域的应用,Transformer在未来有可能替代CNN吗? 在 ...

  2. 医学图像领域,是时候用视觉Transformer替代CNN了吗?

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 本文转载自:集智书童 Is it Time to Replace CNNs with Transformer ...

  3. ICLR 6-6-6!自注意力可以替代CNN,能表达任何卷积滤波层

    像素层面上,CNN能做的,自注意力(self-attention)也都能做. 统御NLP界的注意力机制,不仅被迁移到了计算机视觉中,最新的研究还证明了: CNN卷积层可抽取的特征,自注意力层同样可以. ...

  4. ICLR 6-6-6!自注意力可以替代CNN,能表达任何卷积滤波层丨代码已开源

    鱼羊 十三 发自 凹非寺 量子位 报道 | 公众号 QbitAI 像素层面上,CNN能做的,自注意力(self-attention)也都能做. 统御NLP界的注意力机制,不仅被迁移到了计算机视觉中,最 ...

  5. 用Transformer完全代替CNN:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

    原文地址:https://zhuanlan.zhihu.com/p/266311690 论文地址:https://arxiv.org/pdf/2010.11929.pdf 代码地址:https://g ...

  6. NLP/CV模型跨界进行到底,视觉Transformer要赶超CNN?

    机器之心报道 机器之心编辑部 在计算机视觉领域中,卷积神经网络(CNN)一直占据主流地位.不过,不断有研究者尝试将 NLP 领域的 Transformer 进行跨界研究,有的还实现了相当不错的结果.近 ...

  7. 虎年到,新年用Python与人工智能一起写春节对联 python+pytorch+Transformer+BiLSTM+ATTN+CNN

    艾薇巴迪大家好,虎年春节就要到了,首先在此祝大家新春快乐.虎年大吉. 用Python与人工智能一起写春联 前言 1.分析 2.配置对联项目 2.1.配置下载 2.2.数据预处理 2.3.训练 2.4. ...

  8. Transformer在CV领域有可能替代CNN吗?还有哪些应用前景?

    来源丨知乎问答 编辑丨极市平台 本文转自知乎问答,所有回答均已获得作者授权. 问题背景:目前已经有基于Transformer在三大图像问题上的应用:分类(ViT),检测(DETR)和分割(SETR), ...

  9. 如何看待Transformer在CV上的应用前景,未来有可能替代CNN吗?

    链接:https://www.zhihu.com/question/437495132 编辑:深度学习与计算机视觉 声明:仅做学术分享,侵删 目前已经有基于Transformer在三大图像问题上的应用 ...

最新文章

  1. 2019年上半年收集到的人工智能机器学习方向干货文章
  2. 使用Restful、Guns、SpringBoot实现前后端分离
  3. 一束激光冒充人声:110米外黑掉智能音箱,手机电脑平板也中招
  4. 如何知道刚刚插入数据库那条数据的id
  5. 安卓手机系统可删除的自带软件大集合
  6. 项目分享 | 好牛X的开源项目,看完忍不住分享(高手作品分享)
  7. P0INP = 0Xfd;P1DIR |= 0X01;
  8. poj3268(Silver Cow Party)最短路
  9. paip.log4j兼容linux windows 路径设置
  10. ADOBE PS镜像某个图层
  11. python3处理普通文件【open内置函数】
  12. 计算机设备图标怎么删除,电脑设备和驱动器中没用的图标怎么删除
  13. 《思考力---引爆无限潜能》书摘(二)
  14. Android:圆形头像
  15. 交换机的主要技术指标
  16. JavaScript诞生二十年,作者Brendan Eich自述10天内开发出JS语言
  17. Word中的初号、小初、一号等是什么意思
  18. 2022届秋招,从被拒到SP+ 谈谈YK菌在2021年的经历与收获
  19. spring tool suit 安装 Lombok 步骤
  20. NLP ——Doc2vec

热门文章

  1. 资料分享:送你一本《数据结构(C语言版)》电子书!
  2. 利用BP神经网络教计算机进行非线函数拟合(代码部分单层)
  3. 整理了 65 个 Matplotlib 案例,这能不收藏?
  4. 中关村开源创新大赛-达闼赛道如火如荼进行中
  5. 一位合格软件工程师应该具备怎样的工程化、交付能力?
  6. 最高3000元/人 , 助你成为C站红人 !
  7. 98年“后浪”科学家,首次挑战图片翻转不变性假设,一作拿下CVPR最佳论文提名​...
  8. 集五福,我用Python
  9. Github标星24k,127篇经典论文下载,这份深度学习论文阅读路线图不容错过
  10. 只做好CTR预估远不够,淘宝融合CTR、GMV、收入等多目标有绝招