点击上方“计算机视觉工坊”,选择“星标”

干货第一时间送达

作者丨陀飞轮@知乎(已授权)

来源丨Smarter

编辑丨极市平台

导读

本文主要解析了CNN based和Transformer based的网络架构设计,其中CNN based涉及ResNet和BoTNet,Transformer based涉及ViT和T2T-ViT。

从DETR到ViT等工作都验证了Transformer在计算机视觉领域的潜力,那么很自然的就需要考虑一个新的问题,图像的特征提取,究竟是CNN好还是Transformer好?

其中CNN的优势在于参数共享,关注local信息的聚合,而Transformer的优势在于全局感受野,关注global信息的聚合。直觉上来讲global和local的信息聚合都是有用的,将global信息聚合和local信息聚合有效的结合在一起可能是设计最佳网络架构的正确方向。

如何有效的结合global和local信息,最近的几篇文章主要分成了两个方向:CNN based和Transformer based。以下主要解析一下CNN based和Transformer based的网络架构设计,其中CNN based涉及ResNet和BoTNetTransformer based涉及ViT和T2T-ViT

网络架构设计的相互关系

BoTNet在ResNet的基础上将Bottlenneck的3x3卷积替换成MHSA,增加CNN based的网络架构的global信息聚合能力。T2T-ViT在ViT的基础上将patch的linear projection替换成T2T,增加Transformer based的网络架构的local信息聚合能力。

ResNet&BoTNet

ResNet的结构设计,ResNet主要由Bottleneck结构堆叠而成,一层Bottlenneck由1x1conv、3x3conv和1x1conv堆叠构成残差分支,然后和skip connect分支相加。BoTNet在Bottlenneck结构的基础上将中间的3x3conv替换成MHSA结构,跟之间的Non-local等工作非常相似,本质上在CNN中引入global信息聚合。

MHSA结构如上图所示,代码如下。

class MHSA(nn.Module):def __init__(self, n_dims, width=14, height=14):super(MHSA, self).__init__()self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)self.rel_h = nn.Parameter(torch.randn([1, n_dims, 1, height]), requires_grad=True)self.rel_w = nn.Parameter(torch.randn([1, n_dims, width, 1]), requires_grad=True)self.softmax = nn.Softmax(dim=-1)def forward(self, x):n_batch, C, width, height = x.size()q = self.query(x).view(n_batch, C, -1)k = self.key(x).view(n_batch, C, -1)v = self.value(x).view(n_batch, C, -1)content_content = torch.bmm(q.permute(0, 2, 1), k)content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)content_position = torch.matmul(content_position, q)energy = content_content + content_positionattention = self.softmax(energy)out = torch.bmm(v, attention.permute(0, 2, 1))out = out.view(n_batch, C, width, height)return out

跟Transformer中的multi-head self-attention非常相似,区别在于MSHA将position encoding当成了spatial attention来处理,嵌入两个可学习的向量看成是横纵两个维度的空间注意力,然后将相加融合后的空间向量于q相乘得到contect-position(相当于是引入了空间先验),将content-position和content-content相乘得到空间敏感的相似性feature,让MHSA关注合适区域,更容易收敛。另外一个不同之处是MHSA只在蓝色块部分引入multi-head。

ViT

ViT是第一篇纯粹的将Transformer用于图像特征抽取的文章。

Vision Transformer(ViT)将输入图片拆分成16x16个patches,每个patch做一次线性变换降维同时嵌入位置信息,然后送入Transformer。类似BERT[class]标记位的设置,ViT在Transformer输入序列前增加了一个额外可学习的[class]标记位,并且该位置的Transformer Encoder输出作为图像特征。

假设输入图片大小是256x256,打算分成64个patch,每个patch是32x32像素。

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
# 将3072变成dim,假设是1024
self.patch_to_embedding = nn.Linear(patch_dim, dim)
x = self.patch_to_embedding(x)

这个写法是采用了爱因斯坦表达式,具体是采用了einops库实现,内部集成了各种算子,rearrange就是其中一个,非常高效。p就是patch大小,假设输入是b,3,256,256,则rearrange操作是先变成(b,3,8x32,8x32),最后变成(b,8x8,32x32x3)即(b,64,3072),将每张图片切分成64个小块,每个小块长度是32x32x3=3072,也就是说输入长度为64的图像序列,每个元素采用3072长度进行编码。考虑到3072有点大,ViT使用linear projection对图像序列编码进行降维。

T2T-ViT

ViT虽然验证了Transformer在图像分类网络架构设计的潜力,但是需要额外的大规模数据来进行pre-train,而在中等规模数据集如imagenet上效果却不理想。T2T-ViT引入了local的信息聚合来增强ViT局部结构建模的能力,使得T2T-ViT在中等规模imagenet上训练能达到更高的精度。

在T2T模块中,先将输入图像软分割为小块,然后将其展开成一个tokens T0序列。然后tokens的长度在T2T模块中逐步减少(文章中使用两次迭代然后输出Tf)。后续跟ViT基本上一致。

一次迭代T2T结构由re-structurization和soft split构成,re-structurization将一维序列reshape成二维图像, soft split对二维图像进行滑窗操作,拆分成重叠块。

以token transformer为例,先将输入图像拆分成7x7的重叠块,然后通过token transformer,进行块内的global信息聚合,然后通过re-structurization和soft split进行token重组和拆分成3x3的重叠块,得到长度更短的token序列,重复迭代两次,最后linear projection进一步降低token序列长度。

class T2T_module(nn.Module):"""Tokens-to-Token encoding module"""def __init__(self, img_size=224, in_chans=3, embed_dim=768, token_dim=64):super().__init__()self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)self.project = nn.Linear(token_dim * 3 * 3, embed_dim)self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 soft split, stride are 4,2,2 seperatelydef forward(self, x):# step0: soft splitx = self.soft_split0(x).transpose(1, 2)# iteration1: restricturization/reconstructionx = self.attention1(x)B, new_HW, C = x.shapex = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))# iteration1: soft splitx = self.soft_split1(x).transpose(1, 2)# iteration2: restricturization/reconstructionx = self.attention2(x)B, new_HW, C = x.shapex = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))# iteration2: soft splitx = self.soft_split2(x).transpose(1, 2)# final tokensx = self.project(x)return x

总结

1.global和local信息聚合的关系

global和local应该相互补充来同时balance 速度和精度,同时提升速度和精度的上限

2.CNN based和Transformer based的关系,CNN based 和 Transformer based哪个好

本质上是网络架构设计是以CNN为主好还是Transformer为主好的问题,CNN为主还是将输入当成二维的图像信号来处理,Transformer为主则将输入当成一维的序列信号来处理,所以想要研究清楚CNN为主好还是Transformer为主好的问题,需要去探索哪种输入信号更加具有优势,之前不少研究都表明CNN的padding可能透露了位置信息,而Transformer因为没有归纳偏见,需要增加position encoding来引入位置信息。CNN为主和Transformer为主各有优劣,目前来看暂无定论,且看后续发展。

Reference

[1] Deep Residual Learning for Image Recognition

[2] Bottleneck Transformers for Visual Recognition

[3] An image is worth 16x16 words: Transformers for image recognition at scale

[4] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

本文仅做学术分享,如有侵权,请联系删文。

重磅!计算机视觉工坊-学习交流群已成立

扫码添加小助手微信,可申请加入3D视觉工坊-学术论文写作与投稿 微信交流群,旨在交流顶会、顶刊、SCI、EI等写作与投稿事宜。

同时也可申请加入我们的细分方向交流群,目前主要有ORB-SLAM系列源码学习、3D视觉CV&深度学习SLAM三维重建点云后处理自动驾驶、CV入门、三维测量、VR/AR、3D人脸识别、医疗影像、缺陷检测、行人重识别、目标跟踪、视觉产品落地、视觉竞赛、车牌识别、硬件选型、深度估计、学术交流、求职交流等微信群,请扫描下面微信号加群,备注:”研究方向+学校/公司+昵称“,例如:”3D视觉 + 上海交大 + 静静“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进去相关微信群。原创投稿也请联系。

▲长按加微信群或投稿

▲长按关注公众号

3D视觉从入门到精通知识星球:针对3D视觉领域的视频课程(三维重建系列三维点云系列结构光系列手眼标定相机标定、激光/视觉SLAM、自动驾驶等)、知识点汇总、入门进阶学习路线、最新paper分享、疑问解答五个方面进行深耕,更有各类大厂的算法工程人员进行技术指导。与此同时,星球将联合知名企业发布3D视觉相关算法开发岗位以及项目对接信息,打造成集技术与就业为一体的铁杆粉丝聚集区,近4000星球成员为创造更好的AI世界共同进步,知识星球入口:

学习3D视觉核心技术,扫描查看介绍,3天内无条件退款

圈里有高质量教程资料、可答疑解惑、助你高效解决问题

觉得有用,麻烦给个赞和在看~

网络架构设计:CNN based和Transformer based相关推荐

  1. 【深度学习】网络架构设计:CNN based和Transformer based

    从DETR到ViT等工作都验证了Transformer在计算机视觉领域的潜力,那么很自然的就需要考虑一个新的问题,图像的特征提取,究竟是CNN好还是Transformer好? 其中CNN的优势在于参数 ...

  2. 服务器虚拟化架构设计,服务器虚拟化与数据中心I/O网络架构设计

    这是关于网络架构设计两部分系列中的第一篇.想了解关于在网络架构中整合I/O虚拟化方面的知识,请点击第二篇:如何将I/O虚拟化整合到网络体系架构中. 服务器虚拟化对于数据中心I/O网络架构的需求非常强烈 ...

  3. 三层架构学习的困难_“网工起航计划”3天集训营 带你了解大型企业网络架构设计!...

    点击蓝字 关注我们 网工起航计划 3天集训营 带你了解大型企业网络架构设计 开营时间:8月26日晚8点 DAY1:企业园区网二层架构冗余设计实战  时间:8月26日20:00-21:30 1.网络通信 ...

  4. 从哲学源头思考自动驾驶网络架构设计

    摘要:本篇从哲学的角度阐述自动驾驶网络架构设计的方法. 自动驾驶网络关键在架构创新,创新不是漫无边际,毫无逻辑和实现可能性的瞎想,没有约束和方法论的瞎想是民科干的事情.我们要通过坚实的架构设计方法,铺 ...

  5. IMT-2020(5G)推进组《5G网络架构设计》白皮书

    为了应对各类移动互联网和物联网应用场景的差异化极致性能需求,有效服务于工业.交通.医疗等其他行业应用,需要对5G网络架构进行端到端的统一设计.白皮书从系统设计和组网设计两个角度深入分析,提出了新型5G ...

  6. IMT-2020(5G)推进组的《5G网络架构设计》白皮书

    由我国IMT-2020(5G)推进组联合欧盟5G PPP.韩国5G论坛.日本5GMF和美洲5G Americas主办的第一届全球5G大会于2016年5月31日至6月1日在北京召开,来自中国.欧盟.日本 ...

  7. CVPR 2022 3月3日论文速递(19 篇打包下载)涵盖网络架构设计、姿态估计、三维视觉、动作检测、语义分割等方向

    以下CVPR2022论文打包合集:下载地址 神经网络架构设计 [1] An Image Patch is a Wave: Quantum Inspired Vision MLP(图像补丁是波浪:量子启 ...

  8. 微博技术:千万级规模高性能高并发的网络架构设计

    分享人:卫向军(毕业于北京邮电大学,现任微博平台架构师,先后在微软.金山云.新浪微博从事技术研发工作,专注于系统架构设计.音视频通讯系统.分布式文件系统和数据挖掘等领域.) 架构以及我理解中架构的本质 ...

  9. CESS 机制详解(1):多层网络架构设计

    随着区块链的发展,当前的很多公链项目都采用多层的模块化设计.举例来说,以太坊当前就正在进行 PoS 共识层的开发,未来也将成为拥有共识层和执行层的网络:再例如波卡网络,由中继链和平行链网络组成.类似的 ...

  10. 深度网络架构的设计技巧(三)之ConvNeXt:打破Transformer垄断的纯CNN架构

    单位:FAIR (DenseNet共同一作,曾获CVPR2017 best paper),UC伯克利 ArXiv:https://arxiv.org/abs/2201.03545 Github:htt ...

最新文章

  1. react遇到的各种坑
  2. 不同编程语言在发生stackoverflow之前支持的调用栈最大嵌套层数
  3. retain copy(浅复制) mutablecopy (深复制)
  4. 一个关于HINT中指定索引查询的问题
  5. note deletion case
  6. TensorFlow2快速模型构建及tensorboard初体验
  7. python无法导入numpy_python – Pycharm无法导入numpy
  8. JavaScript MSN 弹出消息框
  9. YOLOv1-YOLOv4
  10. 监听url地址栏变化
  11. 【Linux】 CentOS 7 安装 RabbitMQ
  12. HSImageSidebarView
  13. Redis在CentOS 7上的安装部署
  14. Ubuntu中vim编辑器的常用操作
  15. 学说不能选计算机专硕的课,初试前先选组,选定离手还不能改?北京邮电大学计算机...
  16. 有趣的微分方程传之可分离变量的微分方程
  17. Python基础学习笔记【廖雪峰】
  18. vue实现横向或竖向滑动轮播
  19. 微信小程序期末大作业 记单词小程序 适合初学者学习使用
  20. 商业综合体能耗在线监测管理系统_商场管理平台

热门文章

  1. C中的C文件与h文件辨析(转)
  2. store procedure 翻页
  3. 2. mysql 基本命令
  4. iOS 关于单例那点事
  5. JavaScript之全局函数详解
  6. 使用Git将项目上传到GitHub(Windows+Linux双教程)
  7. 查看IIS进程所对应的应用程序池名称
  8. Javascript、jQuery 操作select控件大全(新增、修改、删除、选中、清空、判断存在等)(转)...
  9. Asp.Net资料网址
  10. 2.React学习笔记----修改模板并使用Ant Design