LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference论文解读

  • 前言
  • 0.摘要
  • 1.introduction
  • 2.related work
  • 3.Motivation
  • 4.Model
    • Patch embedding
    • No CLS token
    • Multi-resolution pyramid
  • 4.experiment

前言

  本文的创新点在于提出了transformer金字塔,attention计算中减小Q的大小,让特征图过了几层transformer金字塔后HW大大缩小,C有限增加,宏观上LeViT是CNN金字塔+transformer金字塔,最后实现小数据量的层次性transforrmer结构。另外本文还提出了attention bias用来取代position encoding.
论文地址:论文PDF地址
代码地址:github代码地址

0.摘要

  作者说他们重新审视了CNN结构的优点,想着把CNN的结构引入transformer.特别是CNN金字塔,他们对CNN中分辨率不断降低的activation map(其实就是featrue map)很感兴趣,因为分辨率降低,通道数增加,但总的来说数据量还是降低的,这对能快速推理很重要。所以作者介绍的LeViT模型能在推理速度和准确率取得比较平衡的结果。

1.introduction

  作者首先简要介绍了transformer模块。然后提出他们比较看重在性能和准确率之间的取得比较平衡的结果。所以最后使用了带池化的金字塔transformer块取代传统的transformer块。因为类似于LeNet,所以起名叫LeViT.
  最后看看本文的contribution:
1.能在transformer块里实现下采样的金字塔型结构
2.可学习的attention bias能取代position encoding

2.related work

。。。

3.Motivation

本节作者简要的介绍了将transformer嫁接到resnet的实验。通过逐渐改变resnet阶数与Deit的层数来考察其在imagenet上的准确度与速度。

从实验结果来看两者混合后的实验结果要比单独一种结构更好,这个实验给作者带来了信心,CNN与transformer的混合结构能带来更好的结果。于是作者提出了CNN与transformer结合更加紧密的LeViT,而不是简单的嫁接。

4.Model


上图是LeViTd的整体结构图,也是我们要介绍的主要部分。可以看出LeViT结合了CNN金字塔与transformer金字塔。

Patch embedding

  patch embedding部分就是上图的CNN金字塔部分。输入图像不考虑batchsize的话shape是(3x224x224)即(CxHxW)。通过四次3x3卷积,stride取2.每次输入都会H,W减半,C加倍。C从3->32->64->128->256。H和W是224->112->56->28->14经过卷积后,H,W缩小到1/16,相当于patch_size=16后取的token.注意Conv2d_BN包括一次卷积,一次BN。

#patch embedding
def b16(n, activation, resolution=224):#n是的embed_dimension[0],将作为transformer第一层输入的维度。return torch.nn.Sequential(Conv2d_BN(3, n // 8, 3, 2, 1, resolution=resolution),activation(),Conv2d_BN(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),activation(),Conv2d_BN(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),activation(),Conv2d_BN(n // 2, n, 3, 2, 1, resolution=resolution // 8))

No CLS token

为了更好的保护BCHW的数据格式,最终决定不加入class token,这样token的数量224/16=14.

Multi-resolution pyramid

  上述MLPx2层包括Linear+BN重复两次。本文使用了两种attention块。一种是普通的attention一种是shrink attention。两者交替使用构成了transformer金字塔.

上图是shrink attention模块,下图是普通的attention模块。

下方是两种attention模块的结构图。

  可以看出普通的transformer模块基本上遵循传统的attention计算方式,不同的地方是给QKT加上了attention_bias,以替代postion encoding。另外还增加了Hardswish激活函数。
  比较值得关注的是shrink attention,在进行attention计算前,其中一个输入经过sub-sample后shape从CxHxW->Cx(H/2)x(W/2)
这样取Q的shape就变成(DxHW/4).Q
KT的shape变为(HW/4)xHW.把QKT看成相似性sim。则softmax(sim)V的shape变为((HW/4)x4D)->(4DxH/2xW/2)->(C’xH/2xW/2)这样就完成了数据量的减少。C’是embed_dimension,由作者所固定设置。
接下来说说attention bias.attention abis被设置成可学习变量来表示位置,初始化是torch.zeros来完成的。
attention bias一共的数量是HW个,如果是CHW(256,14,14)进来的,那么attention bias取14
14=196个。
原因在于其编码的方式

我们可以把QKT看成矩阵里面不同的pixel相乘,就是上面公式Q(x,y) * K(x’,y’)
就是Q里面的(x,y)点与K里面的(x’,y’)点相乘,那么作者认为positon应该是两点之间的差的绝对值|x-x’| , |y-y’|,注意这里是相对position,还没有进行编码。也就是说实际上有效的position位置仅仅只有H
W个,例如H,W取14,则相对position只有[(0,0),(0,1),(0,2),…(0,13),(1,0),(1,1),…(1,13),…(13,13)]共14*14个坐标。

       points = list(itertools.product(range(resolution), range(resolution)))###resolution代表k的H,W,例如为14points_ = list(itertools.product(range(resolution_), range(resolution_)))###resolution_代表q的H,W,例如为7##points=[(0,0),(0,1),(0,2),...(0,13),(1,0),(1,1),...(1,13),...(13,13)]共14^2个坐标,囊括了K的所有点,points_也是如此N = len(points)#N=196N_ = len(points_)#N_=49attention_offsets = {}idxs = []for p1 in points_:#Q坐标合集for p2 in points:#K坐标合集size = 1offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),#这里是x-x0abs(p1[1] * stride - p2[1] + (size - 1) / 2))#这里是y-y0if offset not in attention_offsets:attention_offsets[offset] = len(attention_offsets)#attention_offsets是关于(x,y)与bias的字典。例如(1, 3): 17代表|x-x0|=1,|y-y0|=3时,相对position取17idxs.append(attention_offsets[offset])#idxs是HW*HW,包含了Q*KT的所有Pixel的position,不是bias,是positon.self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))#attention_offsets的长度代表了所有有效的position数量。例如H,W都是14,那么(0,0)(0,1)(0,2)...(0,13)(1,0)(1,1)...(1,13)...(13,13)共196个self.register_buffer('attention_bias_idxs',torch.LongTensor(idxs).view(N_, N))
#####################计算attention并加biasattn = (q @ k.transpose(-2, -1)) * self.scale + (self.attention_biases[:, self.attention_bias_idxs#(num_head,Hq*Wq,Hk*Wk)

4.experiment

  实验中我们比较关心的是消融研究。从实验结果可以看到金字塔transformer和attention bias都是有增益的。启示是金字塔型transformer是可行的。我们可以尝试用金字塔型transformer来代替传统的CNN transformer来做文章。

【LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference论文解读】相关推荐

  1. ICCV 2021 | LeViT: a Vision Transformer in ConvNet‘s Clothing for Faster Inference论文阅读笔记

    论文:https://arxiv.org/abs/2104.01136 代码(刚刚开源): https://github.com/facebookresearch/LeViT ABSTRACT 我们设 ...

  2. LeViT: a Vision Transformer in ConvNet‘s Clothing for Faster Inference

    文章目录 前言 1. 模型 1.1 设计原则 1.2 模型组件 patch embedding no classitication token normalization layers and act ...

  3. LeViT: aVision Transformer in ConvNet‘s Clothing for Fast in

    摘要 我们设计了一系列图像分类架构,可以在高速模式下优化精度和效率之间的平衡.我们的工作利用了基于注意力的体系结构的最新发现,这种体系结构在高度并行处理硬件上具有竞争力.我们重温了大量文献中关于卷积神 ...

  4. 【深度学习】搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了

    作者丨科技猛兽 编辑丨极市平台 导读 本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始.Transformer的实现和代码以及Tr ...

  5. 【Timm】搭建Vision Transformer系列实践,终于见面了,Timm库!

    前言:工具用不好,万事都烦恼,原本真的就是很简单的一个思路实现,偏偏绕了一圈又一圈,今天就来认识认识Timm库吧! 目录 1.百度飞桨提供的-从零开始学视觉Transformer 2.资源:视觉Tra ...

  6. 搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了

    ↑ 点击蓝字 关注极市平台 作者丨科技猛兽 编辑丨极市平台 极市导读 本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始.Trans ...

  7. Vision Transformer 论文解读

    原文链接:https://blog.csdn.net/qq_16236875/article/details/108964948 扩展阅读:吸取CNN优点!LeViT:用于快速推理的视觉Transfo ...

  8. 一文细数Vision transformer家族成员

    可以看作是vision transformer的idea集,查漏补缺使用.需要精读的文章前面加了*号,均附有文章链接及代码链接. 下面这个链接基本上有所有的ViT的论文地址:https://githu ...

  9. OUC暑期培训(深度学习)——第六周学习记录:Vision Transformer amp; Swin Transformer

    第六周学习:Vision Transformer & Swin Transformer Part 1 视频学习及论文阅读 1.Vision Transformer 原文链接:https://a ...

最新文章

  1. KineticJS教程(3)
  2. Linux系统编程----15(线程与进程函数之间的对比,线程属性及其函数,线程属性控制流程,线程使用注意事项,线程库)
  3. 工作23:vue---封装request做数据请求
  4. 使用java来进行分词处理
  5. python爬取json数据_Python爬取数据保存为Json格式的代码示例
  6. PMP 考试一定要报培训班吗?(PMP备考资料分享)
  7. 计算机c盘删除的文件怎么找回,两分钟恢复电脑误删除的文件数据
  8. 35岁以上的大龄程序员们,后来都干什么去了?
  9. 7-28 猴子选大王(20 分)
  10. chipsel语言_英语快速记忆法视频
  11. 树莓派外设开发——IIC接口OLED屏幕
  12. flutter打包出错了,有大神帮忙看看吗?
  13. Excel-自网站粘贴
  14. xx.h和xx.c的奥妙
  15. SCAU 8609 哈夫曼树
  16. 安卓配置正式包和测试包不同的名字、图标、同时安装,(极光配置测试和正式)
  17. [随心译]2017.8.7-这些难以置信的地球太空夜景图实际上全是假货
  18. 公文专用计算机,[计算机]常用公文写作方法
  19. DevComponents.DotNetBar之SuperTabControl动态调整TABS页的位置,以动态调整按钮ButtonItem
  20. 【OFDM系列6】MIMO-OFDM系统模型、迫零(ZF)均衡检测和最小均方误差(MMSE)均衡检测原理和公式推导

热门文章

  1. 中国地质调查局:汶川地震原因已有初步的结论
  2. AUC、ROC、ACC区别
  3. 使用教育网邮箱学生验证Microsoft Imagine 微软开发者 获取window server 2016正版密钥教程
  4. 天梯赛的善良 (20 分)
  5. 【解决方案】Command failed due to signal: Segmentation fault: 11
  6. easyX中loadimage()函数共计有5个参数详解
  7. html5 css3 jquery 画板
  8. fpga串口打印计数值作业
  9. python画趋势图_python 绘制走势图
  10. Spring官方文档中文翻译