导读

本文是苏黎世联邦理工-Luc Van Gool组联合南开大学和电子科技大学在Transformer上的最新工作,这项工作主要提出了一个新的分层级联多头自注意力模块(MHSA),通过分层级联的方式显著降低了计算/空间复杂度。H-MHSA 模块可轻松插入任何CNN架构中,并且可以通过反向传播进行训练。基于此,我们提出了一种新的骨干网络叫做TransCNN,它完美继承了CNN和Transformer的优点,与当前最具竞争力的基于CNN和Transformer的网络架构相比,TransCNN在计算/空间复杂度和准确率上实现了最先进的性能。

  • 作者单位:苏黎世联邦理工(Luc Van Gool大佬组)、南开大学, 电子科大

  • 论文:https://arxiv.org/abs/2106.03180

  • 代码: https://github.com/yun-liu/TransCNN

摘要与介绍

本文主要提出了一种分层级联自注意力模块(MHSA),通过以一种级联计算的方式解决了多头自注意(MHSA)中由于计算/空间复杂度高而导致视觉Transformer效率低下的问题,同时该模块还可以使得Transformer中的自注意力计算更加灵活和高效。具体来说,首先将图像拆分为多个patch,而每一个patch都可视为一个token(标记)来学习小网格内的特征关系。我们不再跨所有patch去计算注意力,而是将patch进一步分组到每个小网格中,并在每个网格中计算自注意力,从而捕获局部特征关系,并产生可区分的局部特征表示。然后,将小网格合并到较大网格中,同时将上一步中的小网格视为一个新的token来重新计算下一个网格中的注意力。通过反复迭代这个过程以逐渐减少token的数量。在整个过程中,我们的H-MHSA模块在不断增加的区域网络大小中逐步计算自我注意力,并自然地以分层的方式对全局特征关系进行建模。由于每个网格在每一步只有少量的标记,所以,我们可以显著降低视觉Transformer的计算/空间复杂度。最后,H-MHSA 模块可轻松插入任何CNN架构中,并且可以通过反向传播进行训练。因此,我们提出了一种新的骨干网络叫做TransCNN,它本质上继承了Transformer和CNN的优点。所以,它具有学习尺度变换以及位移不变特征表示的能力,还可以对输入数据建立长期依赖关系。根据ImageNet和MS-COCO数据集的实验结果表明:与当前最具竞争力的基于CNN和Transformer的网络架构相比,TransCNN在计算复杂度和准确率上实现了最先进的性能。

模型架构

最近关于很多vision Transformer的工作通常旨在建立纯Transformer网络,同时为视觉和自然语言处理任务提供统一的体系结构。然而,这对于视觉任务可能不是最优的选择,因为这些架构不擅长学习那些对视觉数据至关重要的局域表示。受此启发,我们设计了H-MHSA模块,它可以很容易地插入到任何现有的CNN架构中。同时,我们的网络基本上可以继承Transformer和CNN的优点。首先,我们在骨干网络中保留3D特征输入图,并使用全局平均池化层和全连接层来预测图像类别。这点与现有的依赖另一个1D类标记进行预测的Transformer是不同的。之前的Transformer网络通常采用GELU函数进行非线性激活。然而,在网络训练中,GELU函数需要大量内存。根据经验发现,SiLU函数的性能与GELU相同,而且更利于存储。因此,这里我们使用SiLU函数进行非线性激活。TransCNN的总体架构如图1所示。在开始阶段,不同于之前的Transformer通过直接拉直图像patch,我们应用了两个顺序的vanilla






卷积,每个卷积的步长为2,然后将输入图像下采样到原来的1/4尺度大小.然后,我们将H-MHSA和卷积块进行交替叠加使用,主要分为4个阶段,分别以1/4,1/8,1/16,1/32的金字塔特征尺度进行划分。我们采用倒残差瓶颈块(IRB,图1c)和深度可分离卷积。在每个阶段的最后,我们设计了一个简单的二分支降采样块(TDB,图1d)。它包括两个分支:一个分支是步长为2的vanilla






卷积;另一个分支包括一个池化层和一个






卷积。在特征降采样过程中,这两个分支通过按元素求和的方式融合,以保留更多的上下文信息。实验表明,TDB的性能优于直接降采样。TransCNN的详细配置如表1所示。我们提出了两个版本: TransCNN-small和TransCNN-base。具体网络结构配置如下表所示:

模型方法

接下来,我们首先回顾一下视觉Transformer。然后,我们将详细介绍本文提出的H-MHSA模块。

Revisit Vision Transformer

Transformer主要依赖MHSA来建模长期特征依赖关系,这里,我们用













表示输入,









分别代表Token的数量以及每个token的特征维度,然后,我们有the query :












, the key:












, the value:












,并且

在假设输入和输出具有相同维度的情况下,传统的MHSA可以通过下式计算:

其中









表示近似归一化,将Softmax函数应用于矩阵的每一行。注意,为了简单起见,我们在这里省略了多个头的计算。矩阵乘积










具体做法是首先计算每对token之间的相似度。然后,在所有token的组合基础上再派生获取得到每个新的token。在MHSA计算后,可以进一步添加残差连接以方便优化,如:























在这里,


















是特征投影的权重矩阵。最后,采用MLP来对特征进行增强,表示形式为:





























这里,




代表一个Transformer block的输出。MHSA的计算复杂度表示如下:因此,很容易推断出空间复杂度(内存消耗)应当为












级别。因此对于高分辨率的输入特征图,












可能变得非常大,这就大大限制了vision Transformer在视觉任务中的泛化性。基于此,我们的目标是在不降低性能损失的基础上,进一步降低计算/空间复杂度,并保持全局特征关系建模的能力。

Hierarchical Multi-Head Self-Attention

在这里,我们的H-MHSA模块,它可以降低计算/空间复杂度。首先,我们不是针对整个输入中计算注意力,而是以分层的方式计算注意力,这样每个步骤只处理有限数量的token。图1b为H-MHSA的范式。

假设输入特征映射

























的高度为









,宽度为









,我们有


















。然后,我们可以将特征图划分为小网格,每个网格大小为
















,因此,我们对输入的特征图进行重构得到新的











:

并且:对于产生的局部自注意力









。为了简化网络优化,我们也进行如下变换:

添加残差连接后:


















因为









式计算每个小的
















网络快,所以这个计算复杂度得到了显著地减少。对于第




步,我们可以将第






步得到的更小的网络块




















视为一个新的token,它可以简单地通过对注意力特征











进行下采样来实现:

这里,



















































分别表示:最大池化和平均池化。进一步针对







































这里

然后我们同样将
















划分为
















大小的网格,并重新得到:

然后进一步有:

最后,我们得到:

残差连接:






























这个过程将不断迭代,直到足够小。然后我们停止切分网格块。H-MHSA的最终输出为:

其中














表示将注意力特征上采样到原始大小,


















是特征投影的权重矩阵。




为最大迭代步数。通过这种方式,H-MHSA可以建立全局特征依赖关系,相当于传统的MHSA。很容易证明,在所有









都相同的假设下,H-MHSA的计算复杂度为:

所以,我们显著降低了计算复杂度,即从












降低到了
















,并且这里

















小得多。同理,空间复杂度也得到了显著降低。

Experiments

本文我们主要在ImageNet数据集上对提出来的TransCNN网络结构进行图像分类任务。首先,为了更好地理解我们的TransCNN结构,我们设计了一个消融实验。然后,我们将TransCNN与现有的基于CNN和Transformer的主干网络结构进行比较。最后,我们将TransCNN应用于MS-COCO数据集上进行目标检测和实例分割任务,进一步验证了该算法的优越性。

首先,我们对所提出的TransCNN的各种设计方案进行了评估,消融实验结果如下表所示:

在ImageNet数据集上,TransCNN与最先进的CNN和Transformer网络架构的在分类任务识别效果的比较,具体结果如下所示:

在MS-COCO val2017数据集上物体检测效果如下所示:

在MS-COCO val2017数据集上实例分割效果如下所示:

本文创新点总结

1. 本文提出了一个级联的MHSA模块,即H-MHSA,显著地降低了计算/空间复杂度。H-MHSA有两个显著的优势: i)直接对输入特征图进行全局依赖关系建模,ii)能够轻松处理大量输入图像。

2. H-MHSA可以灵活的插入到CNN中,而不用像传统的ViT那样,在注意力计算之后使用MLP进行特征增强。

3. TransCNN完美的继承了Transform和CNN的优点,根据图像分类、目标检测和实例分割的实验结果表明,TransCNN在特征表示学习方面有着较大的潜力和高效性。

4. 未来,神经结构搜索技术也可以应用于TransCNN进行参数配置优化。

重磅!DLer-AI顶会交流群已成立!

大家好,这是DLer-AI顶会交流群!首先非常感谢大家的支持和鼓励,欢迎各位加入DLer-AI顶会交流群!本群旨在学习交流人工智能顶会(CVPR/ICCV/ECCV/NIPS/ICML/ICLR/AAAI等)、顶刊(IJCV/TPAMI/TIP等)写作与投稿事宜。包括第一时间发布论文信息和公开演讲视频,以及各大会议的workshop等等。希望能给大家提供一个更精准的研讨交流平台!!!

添加请备注:AI顶会+学校/公司+昵称(如CVPR+上交+小明)

???? 长按识别添加,即可进群!

分层级联Transformer!苏黎世联邦提出TransCNN: 显著降低了计算/空间复杂度!相关推荐

  1. CVPR 2022 | 提高小数据集利用效率,复旦等提出分层级联ViT网络

    ©作者 | 戈维峰 单位 | 复旦大学 来源 | 机器之心 来自复旦大学.上海市智能信息处理重点实验室和香港大学的研究者提出了一种基于 DINO 知识蒸馏架构的分层级联 Transformer (HC ...

  2. 苏黎世联邦理工学院SML课题组招收统计机器学习全奖博士生

    来源:AI求职 ETH Zurich 苏黎世联邦理工学院(ETH Zurich),由瑞士联邦政府创建于 1854 年,坐落于瑞士联邦第一大城市苏黎世,ETH 专注于工程技术.自然科学与建筑学的教育与研 ...

  3. 苏黎世联邦理工开发的多相机光学触觉传感器,可以实现基于视觉的机器人皮肤

    最近瑞士苏黎世联邦理工学院的一组研究人员开发了一种多相机光学触觉传感器(即基于光学设备的触觉传感器),该传感器收集有关施加到其表面的接触力分布的信息.在arXiv上预发表的一篇论文中介绍的这种传感器可 ...

  4. Talk | 清华大学陈晓宇苏黎世联邦理工黄嘉伟 :基于实际应用的强化学习

    本期为TechBeat人工智能社区第455期线上Talk! 北京时间11月17日(周四)20:00,清华大学交叉信息研究院在读博士生--陈晓宇与苏黎世联邦理工大学计算机科学在读博士生--黄嘉伟的Tal ...

  5. 中科大计算机苏黎世联邦理工,从国内top10到世界top10-苏黎世联邦理工offer到!...

    -滕慧芝武汉前途欧洲部留学顾问 国内985院校英语专业毕业,硕士就读于法国精英高等商学院,在欧洲生活学习工作3年多,回国后一直专注欧洲高端留学,精通欧洲各国留学业务 L同学去年年初找到我的时候还是有些 ...

  6. 苏黎世联邦理工学院计算机系研究生,苏黎世联邦理工学院硕士申请条件都有哪些?...

    针对苏黎世联邦理工学院硕士申请条件都有哪些,下面我们一起来看苏黎世联邦理工学院研究生申请条件: 苏黎世联邦理工学院与美国麻省理工齐名,是瑞士唯一一所国际性的大学.学校质量很好,到现在为止就有21为诺贝 ...

  7. Swin-Transformer:基于移位窗口(Shifted Windows)的分层视觉Transformer

    论文链接:Swin Transformer 论文代码:https://github.com/microsoft/Swin-Transformer 目录 1.摘要和背景介绍 2.整体框架 2.1.基于移 ...

  8. 一张照片就能生成3D模型,GAN和自动编码器碰撞出奇迹,苏黎世联邦理工学院出品...

    萧箫 发自 凹非寺 量子位 | 公众号 QbitAI 2D图片"脑补"3D模型,这次真的只用一张图就行了-- 只需要给AI随便喂一张照片,它就能从不一样的角度给你生成"新 ...

  9. CV初级研究工程师,苏黎世联邦理工学院招聘

    瑞士苏黎世联邦理工学院招聘计算机视觉初级研究工程师,该职位来自 ETH Media Technology Center  部门. The ETH Media Technology Center is ...

最新文章

  1. Eigen/Matlab 使用小结
  2. c++ fork 进程时 共享内存_因为没答好进程间通信,面试挂了...
  3. node-serialport —— Node.js 串口数据读写包
  4. linux mysql安装失败 lib冲突问题_Linux 安装 Mysql 冲突 问题
  5. GitHub 热榜:这款开源神器可帮您将文本转换为手写文字,并下载为 PDF 格式文件!...
  6. 爬虫IP被禁的简单解决方法
  7. 利用11行Python代码,盗取了室友的U盘,内容十分刺激!
  8. 【网络】c++ socket 学习笔记(一)
  9. php基本语法实验总结,PHP总结(一)基本语法内容
  10. 从 FingBugs的错误来看JAVA代码质量
  11. MySQL:Specified key was too long; max key length is 1000 bytes
  12. java 虚拟机--新生代与老年代GC [转]
  13. 近期面试了三个人之感想
  14. 上海市二级c语言软件环境,上海市2019年9月计算机二级考试复习教程:(C语言)上机考试新版题库+全真模拟试卷(2本装)...
  15. CPU缓存侧信道攻击综述-Survey of CPU Cache-Based Side-Channel Attacks
  16. 万字长文:功能安全量产落地的三座大山
  17. 数字逻辑·逻辑代数【运算、函数】
  18. 弘辽科技:拼多多高客单价怎么改低价格提升?
  19. WEB安全——CS中Beacon的使用
  20. web前端数据表格有合并项的一种简单实现方法

热门文章

  1. TTCN手动测试总结
  2. [每周软件]:Cucumber:未完待续的原因
  3. 先为成功的人工作,再与成功的人合作,最后是让成功的人为你工作
  4. 地图样式自定义_干货在线 | ArcGIS中定义图框样式
  5. pandas.apply 有源码github
  6. 自适应激活函数 ACON:统一ReLU和Swish的新范式
  7. 智源沙龙 | 人工智能“3个30年”之后,下个30年将走向何方?
  8. 孟宪会老师推荐的一部C#图解教程
  9. 使用OpenCV的findContours获取轮廓并切割(python)
  10. 深度学习4:使用MNIST数据集(tensorflow)