原文地址:https://zhuanlan.zhihu.com/p/266311690

论文地址:https://arxiv.org/pdf/2010.11929.pdf

代码地址:https://github.com/google-research/vision_transformer

用Transformer完全代替CNN

  • 1. Story
  • 2. Model
    • a 将图像转化为序列化数据
    • b Position embedding
    • c Learnable embedding
    • d Transformer encoder
  • 3. 混合结构
  • 4. Fine-tuning过程中高分辨率图像的处理
  • 5. 实验

1. Story

近年来,Transformer已经成了NLP领域的标准配置,但是CV领域还是CNN(如ResNet, DenseNet等)占据了绝大多数的SOTA结果。

最近CV界也有很多文章将transformer迁移到CV领域,这些文章总的来说可以分为两个大类:

  • 将self-attention机制与常见的CNN架构结合;
  • 用self-attention机制完全替代CNN。

本文采用的也是第2种思路。虽然已经有很多工作用self-attention完全替代CNN,且在理论上效率比较高,但是它们用了特殊的attention机制,无法从硬件层面加速,所以目前CV领域的SOTA结果还是被CNN架构所占据。

文章不同于以往工作的地方,就是尽可能地将NLP领域的transformer不作修改地搬到CV领域来。但是NLP处理的语言数据是序列化的,而CV中处理的图像数据是三维的(长、宽和channels)。

所以我们需要一个方式将图像这种三维数据转化为序列化的数据。文章中,图像被切割成一个个patch,这些patch按照一定的顺序排列,就成了序列化的数据。(具体将在下面讲述)

在实验中,作者发现,在中等规模的数据集上(例如ImageNet),transformer模型的表现不如ResNets;而当数据集的规模扩大,transformer模型的效果接近或者超过了目前的一些SOTA结果。作者认为是大规模的训练可以鼓励transformer学到CNN结构所拥有的translation equivariance和locality.

2. Model

Vision Transformer (ViT)结构示意图

模型的结构其实比较简单,可以分成以下几个部分来理解:

a 将图像转化为序列化数据

作者采用了了一个比较简单的方式。如下图所示。首先将图像分割成一个个patch,然后将每个patch reshape成一个向量,得到所谓的flattened patch。

具体地,如果图片是H×W×CH\times W\times CH×W×C维,用P×PP\times PP×P大小的patch去分割图片可以得到N个patch,那么每个patch的shape就是P×P×CP\times P\times CP×P×C,转化为向量就是P2CP^2CP2C维向量,将N个patch reshape后的向量concat在一起就得到了一个N×(P2C)N\times (P^2C)N×(P2C)的二维矩阵,相当于NLP中输入transformer的词向量。

  • 分割图像得到patch


从上面的过程可以看出,当patch的大小变化时(即 P 变化时),每个patch reshape后得到的 P2CP^2CP2C 维向量的长度也会变化。为了避免模型结构受到patch size的影响,作者对上述过程得到的flattened patches向量做了Linear Projection(如下图所示),将不同长度的flattened patch向量转化为固定长度的向量(记做D维向量)。

  • 对flattened patches做linear projection

    综上,原本H×W×CH\times W\times CH×W×C维的图片被转化为N个D维的向量(或者一个N×DN\times DN×D维的二维矩阵)。

b Position embedding

  • Position embedding

由于transformer模型本身是没有位置信息的,和NLP中一样,我们需要用position embedding将位置信息加到模型中去。

如上图所示1,编号有0-9的紫色框表示各个位置的position embedding,而紫色框旁边的粉色框则是经过linear projection之后的flattened patch向量。文中采用将position embedding(即图中紫色框)和patch embedding(即图中粉色框)相加的方式结合position信息。

c Learnable embedding


如果大家仔细看上图,就会发现带星号的粉色框(即0号紫色框右边的那个)不是通过某个patch产生的。这个是一个learnable embedding(记作 XclassX_{class}Xclass ),其作用类似于BERT中的[class] token。在BERT中,[class] token经过encoder后对应的结果作为整个句子的表示;类似地,这里 XclassX_{class}Xclass 经过encoder后对应的结果也作为整个图的表示。

至于为什么BERT或者这篇文章的ViT要多加一个token呢?因为如果人为地指定一个embedding(例如本文中某个patch经过Linear Projection得到的embedding)经过encoder得到的结果作为整体的表示,则不可避免地会使得整体表示偏向于这个指定embedding的信息(例如图像的表示偏重于反映某个patch的信息)。而这个新增的token没有语义信息(即在句子中与任何的词无关,在图像中与任何的patch无关),所以不会造成上述问题,能够比较公允地反映全图的信息。

d Transformer encoder


Transformer Encoder结构和NLP中transformer结构基本上相同,所以这里只给出其结构图,和公式化的计算过程,也是顺便用公式表达了之前所说的几个部分内容。

Transformer Encoder的结构如下图所示:


对于Encoder的第 lll 层,记其输入为zl−1z_{l-1}zl1,输出为zlz_lzl,则计算过程为:

其中MSA为Multi-Head Self-Attention(即Transformer Encoder结构图中的绿色框),MLP为Multi-Layer Perceptron(即Transformer Encoder结构图中的蓝色框),LN为Layer Norm(即Transformer Encoder结构图中的黄色框)。

Encoder第一层的输入z0z_0z0是通过下面的公式得到的:

其中Xp1,...,XpNX_p^1,...,X_p^NXp1,...,XpN即未Linear Projection后的patch embedding(都是p2Cp^2Cp2C维)

3. 混合结构

文中还提出了一个比较有趣的解决方案,将transformer和CNN结合,即将ResNet的中间层的feature map作为transformer的输入。

和之前所说的将图片分成patch然后reshape成sequence不同的是,在这种方案中,作者直接将ResNet某一层的feature map reshape成sequence,再通过Linear Projection变为Transformer输入的维度,然后直接输入进Transformer中。

4. Fine-tuning过程中高分辨率图像的处理

在Fine-tuning到下游任务时,当图像的分辨率增大时(即图像的长和宽增大时),如果保持patch大小不变,得到的patch个数将增加(记分辨率增大后新的patch个数为 N′N^{'}N )。但是由于在pretrain时,position embedding的个数和pretrain时分割得到的patch个数(即上文中的 N )相同。则多出来的 N′−NN^{'}-NNN 个positioin embedding在pretrain中是未定义或者无意义的。

为了解决这个问题,文章中提出用2D插值的方法,基于原图中的位置信息,将pretrain中的 N 个position embedding插值成N′N^{'}N 个。这样在得到 N′N^{'}N 个position embedding的同时也保证了position embedding的语义信息。

5. 实验

实验部分由于涉及到的细节较多就不具体介绍了,大家如果感兴趣可以参看原文。(不得不说Google的实验能力和钞能力不是一般人能比的…)

主要的实验结论在story中就已经介绍过了,这里复制粘贴一下:在中等规模的数据集上(例如ImageNet),transformer模型的表现不如ResNets;而当数据集的规模扩大,transformer模型的效果接近或者超过了目前的一些SOTA结果。

比较有趣的是,作者还做了很多其他的分析来解释transfomer的合理性。大家如果感兴趣也可以参看原文,这里放几张文章中的图。

用Transformer完全代替CNN:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE相关推荐

  1. 【Transformer】ViT:An image is worth 16x16: transformers for image recognition at scale

    文章目录 一.背景和动机 二.方法 三.效果 四.Vision Transformer 学习到图像的哪些特征了 五.代码 代码链接:https://github.com/lucidrains/vit- ...

  2. 用Transformer完全替代CNN?

    作者丨小小理工男@知乎 来源丨https://zhuanlan.zhihu.com/p/266311690 编辑丨极市平台 这里将介绍一篇我认为是比较新颖的一篇文章 --<An Image Is ...

  3. NLP/CV模型跨界进行到底,视觉Transformer要赶超CNN?

    机器之心报道 机器之心编辑部 在计算机视觉领域中,卷积神经网络(CNN)一直占据主流地位.不过,不断有研究者尝试将 NLP 领域的 Transformer 进行跨界研究,有的还实现了相当不错的结果.近 ...

  4. 虎年到,新年用Python与人工智能一起写春节对联 python+pytorch+Transformer+BiLSTM+ATTN+CNN

    艾薇巴迪大家好,虎年春节就要到了,首先在此祝大家新春快乐.虎年大吉. 用Python与人工智能一起写春联 前言 1.分析 2.配置对联项目 2.1.配置下载 2.2.数据预处理 2.3.训练 2.4. ...

  5. CNN+Transformer=SOTA!CNN丢掉的全局信息,Transformer来补

    转自:新智元 在计算机视觉技术发展中,最重要的模型当属卷积神经网络(CNN),它是其他复杂模型的基础. CNN具备三个重要的特性:一定程度的旋转.缩放不变性:共享权值和局部感受野:层次化的结构,捕捉到 ...

  6. 论文阅读(9)---基于Transformer的多模态CNN心电图心律失常分类

    Multi-module Recurrent Convolutional Neural Network with Transformer Encoder for ECG Arrhythmia Clas ...

  7. 披着transformer皮的CNN:SwinTransformer

    新一代backbone 源码https://github.com/microsoft/Swin-Transformer ICCV 2021最佳论文 解决问题: 图像中像素太多,需要更多特征就需要很长的 ...

  8. 谷歌开源BoTNet | CNN与Transformer结合!Bottleneck Transformers for Visual Recognition!CNN+Transformer!

    新思路! https://arxiv.org/abs/2101.11605 无需任何技巧,基于Mask R-CNNN框架,BoTNet在COCO实例分割任务上取得了44.4%的Mask AP与49.7 ...

  9. [Transformer] EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers

    EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers CVPR2022 论文: https: ...

最新文章

  1. 第一次全面揭示世界软件巨人微软致胜的技术奥秘
  2. pom.xml内容没有错,但一直报错红叉 解决办法
  3. viusal studio 调试错误及解决方法(长期更新记录)
  4. DL之CNN:基于CRNN_OCR算法(keras,CNN+RNN)利用数据集(torch,mdb格式)训练来实现新图片上不定长度字符串进行识别—训练过程
  5. Bootstrap 带下拉的导航
  6. Codeblocks中搭建Qt环境遇到个问题
  7. 实战:京东购物车静态界面实现
  8. Python——jieba优秀的中文分词库(基础知识+实例)
  9. 2022双非保研经历
  10. Largest Submatrix (最大全1子矩阵)
  11. RNG战队联名设计 iGame Z390 RNG限量版上线
  12. greenplum-cc-web4.0监控安装
  13. 打卡第四天 学习python读取mat和xslx文件+敢死队+XPOWER
  14. 中兴盒子B860AV1.1-T2版刷公版固件教程
  15. mysql统计一年中每周的数据,week(时间)和week(时间,1)不同
  16. cmd窗口最小化运行
  17. andriod——Fresco+Retrofit+GreenDao
  18. 13 PHP次末跳弹出pemultimate hop popping
  19. Process and Thread
  20. 是男人就下100层【第三层】——高仿交通银行手机客户端界面

热门文章

  1. loj 2542 随机游走 —— 最值反演+树上期望DP+fmt
  2. jzoj100029. 【NOIP2017提高A组模拟7.8】陪审团(贪心,排序)
  3. 吴恩达机器学习笔记(三) —— Regularization正则化
  4. ios业务模块间互相跳转的解耦方案
  5. throw throws 区别
  6. 【Java】 Thinking in Java 4.8 练习9
  7. php 重复区域,如何使用Mysql和PHP从重复区域单击缩略图后检索图像
  8. python整数和浮点数相乘_python中整数除法和浮点数到整数转换之间的区别是什么原因?...
  9. sql server 链接服务器 改访问接口_跨服务器链接数据库?其实很简单!(上)
  10. python无需修改是什么特性_用户编写的python程序无需修改就可以在不同的平台运行,是python的什么特征...