【LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference论文解读】
LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference论文解读
- 前言
- 0.摘要
- 1.introduction
- 2.related work
- 3.Motivation
- 4.Model
- Patch embedding
- No CLS token
- Multi-resolution pyramid
- 4.experiment
前言
本文的创新点在于提出了transformer金字塔,attention计算中减小Q的大小,让特征图过了几层transformer金字塔后HW大大缩小,C有限增加,宏观上LeViT是CNN金字塔+transformer金字塔,最后实现小数据量的层次性transforrmer结构。另外本文还提出了attention bias用来取代position encoding.
论文地址:论文PDF地址
代码地址:github代码地址
0.摘要
作者说他们重新审视了CNN结构的优点,想着把CNN的结构引入transformer.特别是CNN金字塔,他们对CNN中分辨率不断降低的activation map(其实就是featrue map)很感兴趣,因为分辨率降低,通道数增加,但总的来说数据量还是降低的,这对能快速推理很重要。所以作者介绍的LeViT模型能在推理速度和准确率取得比较平衡的结果。
1.introduction
作者首先简要介绍了transformer模块。然后提出他们比较看重在性能和准确率之间的取得比较平衡的结果。所以最后使用了带池化的金字塔transformer块取代传统的transformer块。因为类似于LeNet,所以起名叫LeViT.
最后看看本文的contribution:
1.能在transformer块里实现下采样的金字塔型结构
2.可学习的attention bias能取代position encoding
2.related work
。。。
3.Motivation
本节作者简要的介绍了将transformer嫁接到resnet的实验。通过逐渐改变resnet阶数与Deit的层数来考察其在imagenet上的准确度与速度。
从实验结果来看两者混合后的实验结果要比单独一种结构更好,这个实验给作者带来了信心,CNN与transformer的混合结构能带来更好的结果。于是作者提出了CNN与transformer结合更加紧密的LeViT,而不是简单的嫁接。
4.Model
上图是LeViTd的整体结构图,也是我们要介绍的主要部分。可以看出LeViT结合了CNN金字塔与transformer金字塔。
Patch embedding
patch embedding部分就是上图的CNN金字塔部分。输入图像不考虑batchsize的话shape是(3x224x224)即(CxHxW)。通过四次3x3卷积,stride取2.每次输入都会H,W减半,C加倍。C从3->32->64->128->256。H和W是224->112->56->28->14经过卷积后,H,W缩小到1/16,相当于patch_size=16后取的token.注意Conv2d_BN包括一次卷积,一次BN。
#patch embedding
def b16(n, activation, resolution=224):#n是的embed_dimension[0],将作为transformer第一层输入的维度。return torch.nn.Sequential(Conv2d_BN(3, n // 8, 3, 2, 1, resolution=resolution),activation(),Conv2d_BN(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),activation(),Conv2d_BN(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),activation(),Conv2d_BN(n // 2, n, 3, 2, 1, resolution=resolution // 8))
No CLS token
为了更好的保护BCHW的数据格式,最终决定不加入class token,这样token的数量224/16=14.
Multi-resolution pyramid
上述MLPx2层包括Linear+BN重复两次。本文使用了两种attention块。一种是普通的attention一种是shrink attention。两者交替使用构成了transformer金字塔.
上图是shrink attention模块,下图是普通的attention模块。
下方是两种attention模块的结构图。
可以看出普通的transformer模块基本上遵循传统的attention计算方式,不同的地方是给QKT加上了attention_bias,以替代postion encoding。另外还增加了Hardswish激活函数。
比较值得关注的是shrink attention,在进行attention计算前,其中一个输入经过sub-sample后shape从CxHxW->Cx(H/2)x(W/2)
这样取Q的shape就变成(DxHW/4).QKT的shape变为(HW/4)xHW.把QKT看成相似性sim。则softmax(sim)V的shape变为((HW/4)x4D)->(4DxH/2xW/2)->(C’xH/2xW/2)这样就完成了数据量的减少。C’是embed_dimension,由作者所固定设置。
接下来说说attention bias.attention abis被设置成可学习变量来表示位置,初始化是torch.zeros来完成的。
attention bias一共的数量是HW个,如果是CHW(256,14,14)进来的,那么attention bias取1414=196个。
原因在于其编码的方式
我们可以把QKT看成矩阵里面不同的pixel相乘,就是上面公式Q(x,y) * K(x’,y’)
就是Q里面的(x,y)点与K里面的(x’,y’)点相乘,那么作者认为positon应该是两点之间的差的绝对值|x-x’| , |y-y’|,注意这里是相对position,还没有进行编码。也就是说实际上有效的position位置仅仅只有HW个,例如H,W取14,则相对position只有[(0,0),(0,1),(0,2),…(0,13),(1,0),(1,1),…(1,13),…(13,13)]共14*14个坐标。
points = list(itertools.product(range(resolution), range(resolution)))###resolution代表k的H,W,例如为14points_ = list(itertools.product(range(resolution_), range(resolution_)))###resolution_代表q的H,W,例如为7##points=[(0,0),(0,1),(0,2),...(0,13),(1,0),(1,1),...(1,13),...(13,13)]共14^2个坐标,囊括了K的所有点,points_也是如此N = len(points)#N=196N_ = len(points_)#N_=49attention_offsets = {}idxs = []for p1 in points_:#Q坐标合集for p2 in points:#K坐标合集size = 1offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),#这里是x-x0abs(p1[1] * stride - p2[1] + (size - 1) / 2))#这里是y-y0if offset not in attention_offsets:attention_offsets[offset] = len(attention_offsets)#attention_offsets是关于(x,y)与bias的字典。例如(1, 3): 17代表|x-x0|=1,|y-y0|=3时,相对position取17idxs.append(attention_offsets[offset])#idxs是HW*HW,包含了Q*KT的所有Pixel的position,不是bias,是positon.self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))#attention_offsets的长度代表了所有有效的position数量。例如H,W都是14,那么(0,0)(0,1)(0,2)...(0,13)(1,0)(1,1)...(1,13)...(13,13)共196个self.register_buffer('attention_bias_idxs',torch.LongTensor(idxs).view(N_, N))
#####################计算attention并加biasattn = (q @ k.transpose(-2, -1)) * self.scale + (self.attention_biases[:, self.attention_bias_idxs#(num_head,Hq*Wq,Hk*Wk)
4.experiment
实验中我们比较关心的是消融研究。从实验结果可以看到金字塔transformer和attention bias都是有增益的。启示是金字塔型transformer是可行的。我们可以尝试用金字塔型transformer来代替传统的CNN transformer来做文章。
【LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference论文解读】相关推荐
- ICCV 2021 | LeViT: a Vision Transformer in ConvNet‘s Clothing for Faster Inference论文阅读笔记
论文:https://arxiv.org/abs/2104.01136 代码(刚刚开源): https://github.com/facebookresearch/LeViT ABSTRACT 我们设 ...
- LeViT: a Vision Transformer in ConvNet‘s Clothing for Faster Inference
文章目录 前言 1. 模型 1.1 设计原则 1.2 模型组件 patch embedding no classitication token normalization layers and act ...
- LeViT: aVision Transformer in ConvNet‘s Clothing for Fast in
摘要 我们设计了一系列图像分类架构,可以在高速模式下优化精度和效率之间的平衡.我们的工作利用了基于注意力的体系结构的最新发现,这种体系结构在高度并行处理硬件上具有竞争力.我们重温了大量文献中关于卷积神 ...
- 【深度学习】搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了
作者丨科技猛兽 编辑丨极市平台 导读 本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始.Transformer的实现和代码以及Tr ...
- 【Timm】搭建Vision Transformer系列实践,终于见面了,Timm库!
前言:工具用不好,万事都烦恼,原本真的就是很简单的一个思路实现,偏偏绕了一圈又一圈,今天就来认识认识Timm库吧! 目录 1.百度飞桨提供的-从零开始学视觉Transformer 2.资源:视觉Tra ...
- 搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了
↑ 点击蓝字 关注极市平台 作者丨科技猛兽 编辑丨极市平台 极市导读 本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始.Trans ...
- Vision Transformer 论文解读
原文链接:https://blog.csdn.net/qq_16236875/article/details/108964948 扩展阅读:吸取CNN优点!LeViT:用于快速推理的视觉Transfo ...
- 一文细数Vision transformer家族成员
可以看作是vision transformer的idea集,查漏补缺使用.需要精读的文章前面加了*号,均附有文章链接及代码链接. 下面这个链接基本上有所有的ViT的论文地址:https://githu ...
- OUC暑期培训(深度学习)——第六周学习记录:Vision Transformer amp; Swin Transformer
第六周学习:Vision Transformer & Swin Transformer Part 1 视频学习及论文阅读 1.Vision Transformer 原文链接:https://a ...
最新文章
- KineticJS教程(3)
- Linux系统编程----15(线程与进程函数之间的对比,线程属性及其函数,线程属性控制流程,线程使用注意事项,线程库)
- 工作23:vue---封装request做数据请求
- 使用java来进行分词处理
- python爬取json数据_Python爬取数据保存为Json格式的代码示例
- PMP 考试一定要报培训班吗?(PMP备考资料分享)
- 计算机c盘删除的文件怎么找回,两分钟恢复电脑误删除的文件数据
- 35岁以上的大龄程序员们,后来都干什么去了?
- 7-28 猴子选大王(20 分)
- chipsel语言_英语快速记忆法视频
- 树莓派外设开发——IIC接口OLED屏幕
- flutter打包出错了,有大神帮忙看看吗?
- Excel-自网站粘贴
- xx.h和xx.c的奥妙
- SCAU 8609 哈夫曼树
- 安卓配置正式包和测试包不同的名字、图标、同时安装,(极光配置测试和正式)
- [随心译]2017.8.7-这些难以置信的地球太空夜景图实际上全是假货
- 公文专用计算机,[计算机]常用公文写作方法
- DevComponents.DotNetBar之SuperTabControl动态调整TABS页的位置,以动态调整按钮ButtonItem
- 【OFDM系列6】MIMO-OFDM系统模型、迫零(ZF)均衡检测和最小均方误差(MMSE)均衡检测原理和公式推导
热门文章
- 中国地质调查局:汶川地震原因已有初步的结论
- AUC、ROC、ACC区别
- 使用教育网邮箱学生验证Microsoft Imagine 微软开发者 获取window server 2016正版密钥教程
- 天梯赛的善良 (20 分)
- 【解决方案】Command failed due to signal: Segmentation fault: 11
- easyX中loadimage()函数共计有5个参数详解
- html5 css3 jquery 画板
- fpga串口打印计数值作业
- python画趋势图_python 绘制走势图
- Spring官方文档中文翻译