这是近期的一篇视觉Transformer领域的工作,文章并没有设计更加复杂的token mixer,而是通过简单的池化算子验证视觉Transformer的成功在于整体架构设计,即MetaFormer。

简介

Transformer已经在计算机视觉中展现了巨大的潜力,一个常见的观念是视觉Transformer之所以取得如此不错的效果主要是由于基于self-attention的token mixer模块。但是视觉MLP的近期工作证明将这个token mixer换为spatial MLP依然可以保持相当好的效果。作者并没有在这方面做过多的探究,而是认为这些工作之所以成功的原因是因为他们模型结构采用MetaFormer这样的通用架构(即token mixer+channel MLP(FFN)),至于具体采用哪种token mixer模块,影响并不是那么大。

  • 论文标题

    MetaFormer is Actually What You Need for Vision

  • 论文地址

    https://arxiv.org/abs/2111.11418

  • 论文源码

    https://github.com/sail-sg/poolformer

介绍

如上图第二幅图所示,Transformer的编码器一般包含两个组件,即一个注意力模块和channel MLP等后续组件,前者用于混合token之间的信息,称为token mixer,后者包括channel MLP和残差连接等。

忽略掉token mixer用注意力模块实现的细节,可以将上述架构抽象为上图第一幅图的MetaFormer架构,后来的一些工作将注意力模块换成了简单的spatial MLP,如代表性的MLP-Mixer,也取得了不遑多让的效果。这也促使了研究者不断探索token mixer的形式,比如最近的工作采用傅里叶变换作为token mixer。

但是,这篇论文的作者并没有在这个思路上越扎越深,而是回过头了做了个总结,竟然各种token mixer效果都很好,那说明这些模型成功的原因应该是整个架构的设计啊,也就是MetaFormer可以保证模型总能获得一个不错的效果。那么自然就会产生一个新的问题了,这个token mixer可以简洁到什么地步呢?

这篇论文中,作者将注意力模块换为了一个没有参数的空间池化算子,原论文对此的描述为简单到让人尴尬。

PoolFormer

MetaFormer

首先,作者对视觉Transformer和视觉MLP等架构做了一个归纳,得到了上图最左侧的通用架构MetaFormer。

首先,输入III呗切分为patch并转换为token,得到序列X∈RN×CX \in \mathbb{R}^{N \times C}XRN×C,其中NNN为序列长度,CCC是嵌入的维度。

X=InputEmb⁡(I)X=\operatorname{InputEmb}(I) X=InputEmb(I)

然后,这个embedding tokens会被送入多个MetaFormer blocks中,每个block包含两个子block。

第一个子blcok通过一个token mixer模块来实现不同token之间的信息交互,NormNormNorm表示某种标准化方法,如LN、BN。TokenMixerTokenMixerTokenMixer的形式 就比较多样了,可以是最近的视觉Transformer提出的注意力模块,也可以是视觉MLP采用的spatialMLP结构。需要注意的是,尽管有些token mixer能够混合通道之间的信息,但是token mixer的主要功能还是混合不同token之间的信息。

Y=TokenMixer (Norm⁡(X))+XY=\text { TokenMixer }(\operatorname{Norm}(X))+X Y=TokenMixer(Norm(X))+X

第二个子block主要由两层MLP和非线性激活函数组成,形式如下。W1∈RC×rCW_{1} \in \mathbb{R}^{C \times r C}W1RC×rCW2∈RrC×CW_{2} \in \mathbb{R}^{r C \times C}W2RrC×C是可学习参数。

Z=σ(Norm⁡(Y)W1)W2+YZ=\sigma\left(\operatorname{Norm}(Y) W_{1}\right) W_{2}+Y Z=σ(Norm(Y)W1)W2+Y

将上述过程的token mixer更换,就形成了主流的视觉Transformer和MLP模型。

PoolFormer

此前的很多工作要么改进了自注意力模块,要么设计了更加精致的token mixer模块,很少有人关注整体架构。这篇论文的作者认为,MetaFormer这种通用架构才是Transformer和MLP模型取得成功的主要原因。

为了验证猜想,作者设计了一个没有参数的空间池化算子来作为token mixer模块,这个池化没有任何可学习的参数,只是用于使得每个token平均聚合其附近的tokens的信息。若输入为T∈RC×H×WT \in \mathbb{R}^{C \times H \times W}TRC×H×W,该池化算子数学上表示如下,KKK表示池化核尺寸,减去自身是因为后续有个残差连接会再次加上(为了统一为MetaFormer形式)。

T:,i,j′=1K×K∑p,q=1KT:,i+p−K+12,i+q−K+12−T:,i,jT_{:, i, j}^{\prime}=\frac{1}{K \times K} \sum_{p, q=1}^{K} T_{:, i+p-\frac{K+1}{2}, i+q-\frac{K+1}{2}}-T_{:, i, j} T:,i,j=K×K1p,q=1KT:,i+p2K+1,i+q2K+1T:,i,j

该算子的PyTorch风格代码如下,其中池化核和padding的设置是为了输入和输出尺寸不变。

要知道,self-attention和spatial MLP 的计算复杂度与要混合的token数量成平方倍。 更糟糕的是,spatial MLP 在处理更长的序列时会带来更多的参数。 因此,self-attention 和spatial MLP 通常只能处理数百个token。相比之下,池化的计算复杂度和序列的长度是线性关系,且不需要可学习参数。

因此,作者以池化为token mixer参考CNN结构和最近的层级Transformer结构,设计了如下图所示的网络,即PoolFormer。它由四个stage组成,每个stage下采样加倍,具体配置见下标。

实验

作者在图像分类、目标检测和实例分割以及语义分割等任务上验证了PoolFormer的效果。

下表是在ImageNet上验证图像分类任务的效果。可以看到,PoolFormer-S24 和 PoolFormer-S36 这样的小模型就可以分别达到80.3%80.3\%80.3%81.4%81.4\%81.4%的 Top-1,它们仅仅需要3.6G 和5.2G 的 MACs,超过了几种典型的 视觉Transformer 和 视觉MLP 模型。这说明,即使使用池化这种极其简单的 token mixer,MetaFormer仍然具有很强的性能,说明 这整个架构才是我们在设计视觉模型时所真正需要的。

下表是在COCO上目标检测和实例分割的结果,也很能说明问题。

最后,在ADE20K上进行语义分割的实验,结果如下。

作者还进行了一些消融实验,都在下表中列了出来,着重看最后一部分,作者将四个stage的池化换为了注意力模块或者spatial MLP,发现这些结构混用也不会有什么问题,而且前两个stage池化后两个stage注意力这种设计效果尤其不错,同比之下的ResMLP-B24需要7倍的参数量和8.5倍的MACs才能获得同等精度。

总结

这篇文章中,作者独树一帜提出视觉Transformer及其变种的成功原因主要是架构的设计,并且将token mixer换为了简单的池化获得了相当好的效果。这也反映了视觉Transformer其实还有很多值得研究的地方,这篇论文的代码也已经开源,代码量不大,感兴趣的可以通过源码了解到更多的细节。本文也只是我本人从自身出发对这篇文章进行的解读,想要更详细理解的强烈推荐阅读原论文。最后,如果我的文章对你有所帮助,欢迎一键三连,你的支持是我不懈创作的动力。

PoolFormer解读相关推荐

  1. Python Re 模块超全解读!详细

    内行必看!Python Re 模块超全解读! 2019.08.08 18:59:45字数 953阅读 121 re模块下的函数 compile(pattern):创建模式对象 > import ...

  2. Bert系列(二)——源码解读之模型主体

    本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...

  3. Bert系列(三)——源码解读之Pre-train

    https://www.jianshu.com/p/22e462f01d8c pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现 ...

  4. NLP突破性成果 BERT 模型详细解读 bert参数微调

    https://zhuanlan.zhihu.com/p/46997268 NLP突破性成果 BERT 模型详细解读 章鱼小丸子 不懂算法的产品经理不是好的程序员 ​关注她 82 人赞了该文章 Goo ...

  5. 解读模拟摇杆原理及实验

    解读模拟摇杆原理及实验 Interpreting Analog Sticks 当游戏支持控制器时,玩家可能会一直使用模拟摇杆.在整个体验过程中,钉住输入处理可能会对质量产生重大影响.让来看一些核心概念 ...

  6. 自监督学习(Self-Supervised Learning)多篇论文解读(下)

    自监督学习(Self-Supervised Learning)多篇论文解读(下) 之前的研究思路主要是设计各种各样的pretext任务,比如patch相对位置预测.旋转预测.灰度图片上色.视频帧排序等 ...

  7. 自监督学习(Self-Supervised Learning)多篇论文解读(上)

    自监督学习(Self-Supervised Learning)多篇论文解读(上) 前言 Supervised deep learning由于需要大量标注信息,同时之前大量的研究已经解决了许多问题.所以 ...

  8. 可视化反投射:坍塌尺寸的概率恢复:ICCV9论文解读

    可视化反投射:坍塌尺寸的概率恢复:ICCV9论文解读 Visual Deprojection: Probabilistic Recovery of Collapsed Dimensions 论文链接: ...

  9. 从单一图像中提取文档图像:ICCV2019论文解读

    从单一图像中提取文档图像:ICCV2019论文解读 DewarpNet: Single-Image Document Unwarping With Stacked 3D and 2D Regressi ...

最新文章

  1. 图灵奖得主Raj Reddy:以历史的视角重新审视“人工智能”
  2. Github标星9k+,超赞的 PyTorch 资源大列表!
  3. 重定向dup2的本质
  4. python 空值_数据库中的空值与NULL的区别以及python中的NaN和None
  5. [iPhone高级] 基于XMPP的IOS聊天客户端程序(IOS端二)
  6. 第二阶段_第三小节_C#基础
  7. docker内存阀值_kubernetes调度之资源耗尽处理配置
  8. canvas游戏篇 - 贪吃蛇
  9. nginx开发(二)配置mp4文件在线播放
  10. 面向机器学习的特征工程 六、降维:用 PCA 压缩数据集
  11. java时间戳龙_Java时间戳与日期格式字符串的互转
  12. winform 图片集合
  13. 均匀分布(uniform distribution)期望的最大似然估计(maximum likelihood estimation)
  14. oracle当前用户创建的表不可见?
  15. 简述x264几种码率控制方式的实现
  16. java web play_玩转Java Web应用开发:Play框架
  17. 帆软报表 异常汇总及方案.
  18. 【csdn学习-Python】CSDN技能树-Python语言学习笔记
  19. IM即时通讯综合消息系统的架构
  20. 游戏大额数值转换“K“, “M“, “B“, “T“, “aa“, “ab“, “ac“, “ad“

热门文章

  1. 修改用户和用户组权限
  2. 高仿真 JDK Proxy手写实现
  3. 加速静态内容访问速度的CDN
  4. 四种方式下创建线程启动的区别
  5. 数据类型转换_注意事项
  6. sqoop导入-hdfs
  7. 部署RocketMQ的管理工具
  8. SpringBoot异常处理-@ControlleAdvice
  9. SpringBoot高级-任务-定时任务
  10. SpringBoot_入门-HelloWorld细节-场景启动器(starter)