关注公众号,发现CV技术之美

本文分享一篇 ICCV2021 论文:『Rethinking Spatial Dimensions of Vision Transformers』

详细信息如下:

  • 论文链接:https://arxiv.org/abs/2103.16302

  • 项目链接:https://github.com/naver-ai/pit

导言:

Vision Transformer (ViT)将Transformer结构的应用范围从自然语言处理扩展到计算机视觉任务,成为了一种替代现有卷积神经网络(CNN)的架构。然而,基于Transformer-based的结构在计算机视觉任务中提出的时间还不久,大多工作都是将Transformer结构直接拿过来用,没有考虑CV和NLP任务和特征的区别,因此对视觉Transformer有效结构设计的研究较少。在本文中,作者从CNN的设计原则出发,研究了空间维度转换在Transformer-based结构中的作用及其有效性。

随着深度的增加,传统CNN会增加通道维度,减少空间维度;但是Transformer并没有这个性质,不同层的通道和空间维度都没有发生变化。作者通过实验表明,其实这种空间降维、通道升维的方法也有利于Transformer结构性能的提升。因此,作者在ViT模型的基础上,提出了一种新的基于池化的Vision Transformer(PiT)。

最后作者通过实验证明,与ViT相比,PiT拥有更好的学习能力和泛化性能。在图像分类、目标检测等多任务上,PiT均优于Baseline模型,证明了池化操作在Vision Transformer结构的有效性。

      01      

Motivation

基于自注意力机制的Transformer结构在自然语言处理领域取得了巨大的成功,由于Self-Attention强大的建模能力,一些研究人员也将其在计算机视觉任务中进行了尝试。Non-local和DETR首先证明了Self-Attention在CV任务中的有效性,最近的一系列ViT工作也证明了Transformer结构能够在一定程度上达到比CNN更好的performance。

ViT与卷积神经网络(CNN)对特征处理的方式有很大的不同。输入图像首先被划分为16×16个patch并输入到Vision Transformer网络中;除了第一层的embedding,ViT中没有卷积操作,不同位置的交互仅通过自注意力层来实现。CNN由于卷积核大小的限制,在全局的信息建模上存在局限性;ViT允许图像中的所有位置通过Self-Attention来进行全局的交互。

虽然ViT是一种创新的架构,并且现有的工作也证明其强大的图像识别能力,但它遵循NLP中的Transformer架构,几乎没有对CV任务做针对性的改进。而CNN的一些基本设计原则,在过去的十年中被证明在计算机视觉领域是有效的,但这些在ViT中都没有得到充分的体现。因此,作者重新回顾了CNN架构的设计原则,并实验了将其应用于ViT架构中的有效性。

CNN从空间维度大、通道维度小的特征开始,在空间维度减小的同时逐渐增大通道维度。池化层这样的操作在另一方面也能够影响卷积的感受野。除此之外,目前也有一些工作表明池化层有助于网络的表达能力和泛化性能的提高 。然而,与CNN不同的是,ViT不使用池化层,而是在所有层中使用相同大小的空间token。

在本文中,作者首先验证了CNN上池化层的优点。通过实验表明,池化层提高了ResNet的建模能力和泛化性能。为了将池化层的优点扩展到ViT,作者又提出了一种基于池化的Vision Transformer(PiT)。此外,作者还研究了PiT相比于ViT的优点,从而得出了池化层也提高了ViT性能的结论。最后,为了分析ViT中池化层的作用,作者计算了ViT的空间交互比(类似于CNN的感受野大小),发现ViT中的池化层具有控制Self-Attention空间相互作用大小的效果(类似于CNN的感受野控制)。

      02      

方法

2.1. Effect of pooling on CNN

如上图所示,大多数卷积神经网络都有池化层,池化层降低了空间维度,增加了通道维度。在ResNet50中,stem layer首先将图像的空间尺寸减小到56 × 56,然后步长为2的卷积层使空间维度减半,通道维度加倍。

为了分析卷积结构中存在或不存在池化层的性能差异,作者进行了一个实验。对于有池化层的网络,作者采用了传统的ResNet50结构;对于没有池化层的网络,作者直接用stem layer将特征缩小为14 × 14,然后将在后续过程中保持空间和通道维度不变。

首先,作者测量了有池化层和没有池化层时FLOPs和training loss之间的关系。

如上图所示,在相同的FLOPs下,具有池化层的ResNet具有了更低的training loss。这意味着池化层增加了模型的学习能力。

然后,作者分析了训练和验证精度之间的关系,这代表了模型的泛化性能。

如上图所示,在相同的训练准确率下,具有池化层的ResNet比没有池化层的ResNet具有更高的验证准确率。因此,池化层也有助于提高ResNet的泛化性能。

总的来说,池化层能够提高CNN模型的学习能力和泛化性能,从而显著提高验证集上的精度(如上图所示)。

2.2. Pooling-based Vision Transformer (PiT)

Vision Transformer(ViT)是基于Self-Attention进行计算的,而不是卷积操作。在Self-Attention机制中,空间信息的交互是基于位置两两之间的相似性。

与CNN的stem层类似,ViT在一个embedding层将图片分成多个patch,然后再embedding到token。ViT结构没有池化层,所以整个网络层都比保持相同数量的空间token(如上图所示)。

虽然Self-Attention操作不受空间距离的限制,但参与Attention的空间区域的大小可能会受到特征空间大小而有所影响。因此,如果想在计算过程中调整特征的空间维度大小,池化层对于ViT来说也是有必要的。

基于上面的思想,作者在ViT上加入池化层,并提出了一个基于池化的Vision Transformer(PiT)。

首先,池化层的设计如上图所示。由于ViT处理的是二维矩阵而不是三维张量,池化层首先将空间上的二维token特征reshape成具有空间结构的三维张量。然后,通过一个depth-wise卷积,降低空间维度的大小,提高通道维度的大小。然后再将三维张量reshape回二维的token特征。对于跟空间特征无关的token(比如说用于分类的[cls] token),在这里就直接用一个FC进行映射,来扩大通道维度。

PiT的结构如上图所示,相比于ViT,PiT多了两层池化层,因此在整个网络中就有三种不同尺度的特征。

为了验证池化层在ViT中的有效性,作者又做了一系列实验。

上图展示了有池化层和没有池化层时ViT的模型学习能力。与池化层在ResNet上的实验结果相似,在ViT中加入池化层也能提高模型的学习能力。

如上图所示,即使训练集的准确性提高了,没有池化的ViT并不会提高验证集的准确性。

而使用池化的ViT验证准确性随着训练准确性的增加而增加,证明了PiT的泛化性能比ViT好。

如上图所示,泛化性能的巨大差异导致了有池化层的ViT与没有池化层的ViT之间的最终性能差异。

随着FLOPs的增加,没有池化层的ViT在验证集上的准确率几乎不变。而ViT这种泛化性的不足,可以通过加入池化层来弥补,因此作者认为池化层对于ViT的泛化性能提高是必要的。

2.3. Spatial interaction

接着,作者又通过分析Vision Transformer中的Self-Attention,来探究ViT中的池化层的影响。在CNN中,池化层的作用是用来调整感受野。在特征的空间维度较大时,卷积的感受野通常比较小;但是在特征的空间维度比较小的时候,相同kernel size的感受野就会比较大。因此,在CNN中,池化层可以调节空间上交互的区域大小。在ViT中,无论特征的空间维度是大还是小,Self-Attention都会进行全局的特征交互。

然而,在token数量比较多的时候,Self-Attention的交互也会受到一定影响(可以理解为token数太多时,Self-Attention关注重点的能力可能就没有那么强了)。作者在ImageNet上用预训练的ViT和PiT进行了交互面积的计算。空间交互的衡量是基于Attention矩阵进行Softmax之后的分数。
作者使用1%和10%作为阈值,计算交互发生在阈值以上的空间token数量,并通过交互token数量除以空间token的总大小计算出空间交互比率。上图展示了注意力超过1%的位置的空间交互比率。

在ViT中,相互作用比率在20%到40%之间,因为没有池化层,数值并不会因深度的不同而发生显著变化。PiT通过池化减少token的数量,在前期各层交互比例较小,而后期各层交互比例接近100%。

为了与ResNet进行比较,作者将阈值改为10%,结果如上图所示。

ViT各层之间的交互比相似,但ResNet和PiT的交互比随着其通过池化层的加入而增加。因此,池化层不仅能让ResNet的交互范围变大,也能让Transformer的交互范围变大。作者认为交互范围的控制与模型的性能提升密切相关。

2.4. Architecture design

基于上面的思想,作者对网络结构进行了实例化,并提出了四种不同参数量和计算量的网络。

      03      

实验

3.1. ImageNet classification

从上表可以看出,PiT比ViT不仅有更高的精度,还有更少的FLOPs和更快的速度。并且CutMix、Distill等在ViT上适用的技术,在PiT上依旧适用。

从上表结果可以看出,相比于ViT和其他CNN结构,在相似的参数量下,PiT依旧具有性能上的优势。

3.2. Object detection

通过将DETR中的backbone从ResNet或者ViT换成PiT,检测器的性能有明显提升,这也表明了PiT在目标检测任务中也能有比较好的性能。

3.3. Robustness benchmarks

作者还比较了ResNet、PiT、ViT在四个ImageNet robustness benchmarks上的表现,可以看出在所有的robustness benchmark上,PiT的performance都是最高的。

      04      

总结

在本文中,作者提出了一种带池化层的Transformer结构,并在多个任务上证明了带池化层Transformer的有效性。个人觉得,其实本文做的工作非常简单,在ViT上改动也非常小;不过相对来说这篇工作做的还是比较完整的。

作者首先在ResNet上进行了研究,发现在空间维度上进行变换可以提高计算效率和泛化能力。然后作者为了探究这种变换在ViT上影响,提出了一个将池化层合并到ViT中的PiT,再通过一系列实验来表明,池化层在ViT中也是非常有用的。最后,在显著提高ViT结构性能的同时,作者还用空间交互比率来证明池化层能够提高ViT中Self-Attention的交互面积。

作者介绍

厦门大学人工智能系20级硕士

研究领域:FightingCV公众号运营者,研究方向为多模态内容理解,专注于解决视觉模态和语言模态相结合的任务,促进Vision-Language模型的实地应用。

知乎:努力努力再努力

公众号:FightingCV

END,入群????备注:TFM

ICCV2021-PiT-池化操作不是CNN的专属,ViT说:“我也可以”;南大提出池化视觉Transformer(PiT)...相关推荐

  1. 卷积神经网络中的各种池化操作

    参考:https://www.cnblogs.com/pprp/p/12456403.html 池化操作(Pooling)是CNN中非常常见的一种操作,Pooling层是模仿人的视觉系统对数据进行降维 ...

  2. 【综述】盘点卷积神经网络中的池化操作

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 池化操作(Pooling)是CNN中非常常见的一种操作,池化操作通 ...

  3. 双线性池化_卷积神经网络中的各种池化操作

    池化操作(Pooling)是CNN中非常常见的一种操作,Pooling层是模仿人的视觉系统对数据进行降维,池化操作通常也叫做子采样(Subsampling)或降采样(Downsampling),在构建 ...

  4. 【Pytorch神经网络理论篇】 13 深层卷积神经网络介绍+池化操作+深层卷积神经网络实战

    1 深层卷积神经网络概述 1.1 深层卷积神经网络模型结构图 1.1.1 深层卷积神经网络的正向结构构成剖析 输入层,将每个像素作为一个特征节点输入网络. 卷积层:由多个滤波器组合而成. 池化层:将卷 ...

  5. 【OpenCV3】阈值化操作——cv::threshold()与cv::adaptiveThreshold()详解

    阈值化操作在图像处理中是一种常用的算法,比如图像的二值化就是一种最常见的一种阈值化操作.opencv2和opencv3中提供了直接阈值化操作cv::threshold()和自适应阈值化操作cv::ad ...

  6. TensorFlow实现卷积、池化操作

    1.调用tf.nn.conv2d()实现卷积 首先是调用卷积函数实现卷积操作: 这里说明一下conv2d的定义及参数含义: 参考 [定义:] tf.nn.conv2d (input, filter, ...

  7. 图神经网络的池化操作

    图神经网络有两个层面的任务:一个是图层面(graph-level),一个是节点(node-level)层面,图层面任务就是对整个图进行分类或者回归(比如分子分类),节点层面就是对图中的节点进行分类回归 ...

  8. nn.AvgPool2d——二维平均池化操作

    PyTorch学习笔记:nn.AvgPool2d--二维平均池化操作 torch.nn.AvgPool2d( kernel_size , stride=None , padding=0 , ceil_ ...

  9. Torch 池化操作大全 MaxPool2d MaxUnpool2d AvgPool2d FractionalMaxPool2d LPPool2d AdaptivePool2d dilation详解

    torch 池化操作 1 池化操作 2 类型 2.1 MaxPool2d() 2.2 MaxUnPool2d() 2.3 AvgPool2d() 2.4 FractionalMaxPool2d() 2 ...

最新文章

  1. plsql查找不到带中文的纪录
  2. 对E—R模型的深入理解
  3. oracle什么是全局锁,深入浅出oracle锁---原理篇
  4. Facebook's New Real-time Messaging System: HBase to Store 135+ Billion Messages a Month
  5. dofilter 无效_“鹅厂”商标注册成功,腾讯异议无效
  6. Redis Hash 哈希 结构
  7. 如何设置XMind思维导图线条
  8. linux升级openssl需要先卸载吗,在Linux系统上升级OpenSSL的方法
  9. linux查看进程中的线程名,linux 怎么样查看一个进程的线程
  10. s2 安恒 漏洞验证工具_Struts2漏洞利用工具下载(更新2017-V1.8版增加S2-045/S2-046)-阿里云开发者社区...
  11. 连接查询(交叉连接,内连接,外连接,自然连接)
  12. 虚拟主机搬迁服务器要重新备案吗,域名更换虚拟主机要重新备案吗
  13. html调用网易云播放器无法自动播放,HTML网页调用 网易云 音乐播放器代码-Go语言中文社区...
  14. 微信小程序——校园服务小程序(四)校园论坛加预约理发服务
  15. php eclipse aptana,eclipse 下如何安装 Aptana插件
  16. 大数据工程师要学的编程_每个数据工程师都应了解的ml编程技巧,第2部分
  17. php 法定节假日接口,通过百度接口获取每一个月的工作和法定假日
  18. MATLAB代码保存为word,MATLAB怎么保存为Word?
  19. Red Panda DEV-C++更新到6.7.5啦
  20. 设计模式中的工厂类图

热门文章

  1. java中集合(List)的嵌套分配值、移除等操作
  2. Zdenek Kalal的TLD Tracker(牛啊,学习!)
  3. 天池-新闻推荐-数据分析
  4. 蒙特卡罗方法—举例说明(C++、python)
  5. 679 - Dropping Balls
  6. Linux学习笔记8
  7. mysql 分表_MySQL如何分库分表
  8. vue监听字符串长度_Vue 的 computed 和 watch 的区别
  9. c# npoi 2.5版本设置字体加粗_Python帮你做Excel——格式设置与画图
  10. php前端路由权限,SaaS-前端权限控制