CMT: Convolutional Neural Networks Meet Vision Transformers
CMT
- 概要
- 介绍
- transformer存在的问题
- CMT中块的设计
- 相关工作
- CNN
- Vision Transformer
- 方法
- 整体架构
- CMT Block
- LPU
- LMHSA
- IRFFN
- 代码解析
- LPU
- IRFFN
- LMHSA
概要
CNN捕获局部信息,Transformer来捕获全局信息
介绍
transformer存在的问题
- 将图片打成patch,会忽略图片内部潜在的2D结构和空间局部信息。
- transformer的块输出和输入大小固定,难以显示提取多尺度特征和低分辨率的特征。
- 计算复杂度太高。自注意力的计算与输入图片的大小成二次复杂度。
CMT中块的设计
- CMT块中的局部感知单元(LPU)和反向残差前馈网络(IRFFN)可以帮助捕获中间特征内的局部和全局结构信息,并提高网络的表示能力。
相关工作
CNN
网络 | 特点 |
---|---|
LeNet | 手写数字识别 |
AlexNet & VGGNet & GoogleNet & InceptionNet | ImageNet大赛 |
ResNet | 泛化性增强 |
SENet | 自适应地重新校准通道特征响应 |
Vision Transformer
网络 | 特点 |
---|---|
ViT | 将NLP中的Transformer引入到CV领域 |
DeiT | 蒸馏的方式,使用teacher model来指导Vision Transformer的训练 |
T2T-ViT | 递归的将相邻的tokens聚合成为一个,transformer对visual tokens进行建模 |
TNT | outer block建模patch embedding 之间的关系,inner block建模pixel embedding之间的关系 |
PVT | 将金字塔结构引入到 ViT 中,可以为各种像素级密集预测任务生成多尺度特征图。 |
CPVT和CvT | cnn引入到transformer之中 |
方法
整体架构
- CMT stem(减小图片大小,提取本地信息)
- Conv Stride(用来减少feature map,增大channel)
- CMT block(捕获全局和局部关系)
CMT Block
LPU
- 位置编码会破坏卷积中的平移不变性,忽略了patch之间的局部信息和patch内部的结构信息。
- LPU来缓解这个问题。
LMHSA
- 使用深度卷积来减少KV的大小,加入相对位置偏置,构成了轻量级的自注意力计算。
IRFFN
- 深度卷积增强局部信息的提取,残差结构来促进梯度的传播能力。
代码解析
LPU
class LocalPerceptionUint(nn.Module):def __init__(self, dim, act=False):super(LocalPerceptionUint, self).__init__()self.act = act # 增强本地信息的提取self.conv_3x3_dw = ConvDW3x3(dim)if self.act:self.actation = nn.Sequential(nn.GELU(),nn.BatchNorm2d(dim))def forward(self, x):if self.act:out = self.actation(self.conv_3x3_dw(x))return out else:out = self.conv_3x3_dw(x)return out
IRFFN
class InvertedResidualFeedForward(nn.Module):def __init__(self, dim, dim_ratio=4.):super(InvertedResidualFeedForward, self).__init__()output_dim = int(dim_ratio * dim)self.conv1x1_gelu_bn = ConvGeluBN(in_channel=dim,out_channel=output_dim,kernel_size=1,stride_size=1,padding=0)self.conv3x3_dw = ConvDW3x3(dim=output_dim) self.act = nn.Sequential(nn.GELU(),nn.BatchNorm2d(output_dim))self.conv1x1_pw = nn.Sequential(nn.Conv2d(output_dim, dim, 1, 1, 0),nn.BatchNorm2d(dim))def forward(self, x):x = self.conv1x1_gelu_bn(x)out = x + self.act(self.conv3x3_dw(x))out = self.conv1x1_pw(out)return out
LMHSA
class LightMutilHeadSelfAttention(nn.Module):"""calculate the self attention with down sample the resolution for k, v, add the relative position bias before softmaxArgs:dim (int) : features map channels or dims num_heads (int) : attention heads numbersrelative_pos_embeeding (bool) : relative position embeeding no_distance_pos_embeeding (bool): no_distance_pos_embeedingfeatures_size (int) : features shapeqkv_bias (bool) : if use the embeeding biasqk_scale (float) : qk scale if None use the default attn_drop (float) : attention dropout rateproj_drop (float) : project linear dropout ratesr_ratio (float) : k, v resolution downsample ratioReturns:x : LMSA attention result, the shape is (B, H, W, C) that is the same as inputs."""def __init__(self, dim, num_heads=8, features_size=56, relative_pos_embeeding=False, no_distance_pos_embeeding=False, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1.):super(LightMutilHeadSelfAttention, self).__init__()assert dim % num_heads == 0, f"dim{dim}should be divided by num_heads{num_heads}"self.dim = dim self.num_heads = num_headshead_dim = dim // num_heads # used for each attention headsself.scale = qk_scale or head_dim ** -0.5self.relative_pos_embeeding = relative_pos_embeedingself.no_distance_pos_embeeding = no_distance_pos_embeedingself.features_size = features_sizeself.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim*2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.softmax = nn.Softmax(dim=-1)self.sr_ratio = sr_ratioif sr_ratio > 1:self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)self.norm = nn.LayerNorm(dim) if self.relative_pos_embeeding:self.relative_indices = generate_relative_distance(self.features_size)self.position_embeeding = nn.Parameter(torch.randn(2 * self.features_size - 1, 2 * self.features_size - 1))elif self.no_distance_pos_embeeding:self.position_embeeding = nn.Parameter(torch.randn(self.features_size ** 2, self.features_size ** 2))else:self.position_embeeding = Noneif self.position_embeeding is not None:trunc_normal_(self.position_embeeding, std=0.2)def forward(self, x):B, C, H, W = x.shape N = H*Wx_q = rearrange(x, 'B C H W -> B (H W) C') # translate the B,C,H,W to B (H X W) Cq = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B,N,H,DIM -> B,H,N,DIM# conv for down sample the x resoution for the k, vif self.sr_ratio > 1:x_reduce_resolution = self.sr(x)x_kv = rearrange(x_reduce_resolution, 'B C H W -> B (H W) C ')x_kv = self.norm(x_kv)else:x_kv = rearrange(x, 'B C H W -> B (H W) C ')kv_emb = rearrange(self.kv(x_kv), 'B N (dim h l ) -> l B h N dim', h=self.num_heads, l=2) # 2 B H N DIMk, v = kv_emb[0], kv_emb[1]attn = (q @ k.transpose(-2, -1)) * self.scale # (B H Nq DIM) @ (B H DIM Nk) -> (B H NQ NK)# TODO: add the relation position bias, because the k_n != q_n, we need to split the position embeeding matrixq_n, k_n = q.shape[1], k.shape[2]if self.relative_pos_embeeding:attn = attn + self.position_embeeding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]][:, :k_n]elif self.no_distance_pos_embeeding:attn = attn + self.position_embeeding[:, :k_n]attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B H NQ NK) @ (B H NK dim) -> (B NQ H*DIM)x = self.proj(x)x = self.proj_drop(x)x = rearrange(x, 'B (H W) C -> B C H W ', H=H, W=W)return x
CMT: Convolutional Neural Networks Meet Vision Transformers相关推荐
- 【读点论文】CMT: Convolutional Neural Networks Meet Vision Transformers
CMT: Convolutional Neural Networks Meet Vision Transformers Abstract 视觉transformer已经成功地应用于图像识别任务,因为它 ...
- 【论文精读】CMT: Convolutional Neural Networks MeetVision Transformers
声明 不定期更新自己精读的论文,通俗易懂,初级小白也可以理解 涉及范围:深度学习方向,包括 CV.NLP.Data fusion.Digital Twin 论文标题: CMT: Convolution ...
- 手机CNN网络模型--MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications https://arxiv.org ...
- MobileNetV1《MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications》
MobileNetV1<MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications> ...
- 轻量化网络(一)MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
轻量化网络研究(一)MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 研究过深度学习的 ...
- 回顾一些重要的CNN改进模型(你真的了解 Convolutional Neural Networks 么)
转载自: 干货 | 你真的了解 Convolutional Neural Networks 么 https://mp.weixin.qq.com/s?__biz=MzAwMjM3MTc5OA==&am ...
- 吴恩达深度学习课程deeplearning.ai课程作业:Class 4 Week 1 Convolutional Neural Networks: Step by Step
吴恩达deeplearning.ai课程作业,自己写的答案. 补充说明: 1. 评论中总有人问为什么直接复制这些notebook运行不了?请不要直接复制粘贴,不可能运行通过的,这个只是notebook ...
- CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章
CV:翻译并解读2019<A Survey of the Recent Architectures of Deep Convolutional Neural Networks>第一章~第三 ...
- 干货 | 你真的了解 Convolutional Neural Networks 么
干货 | 你真的了解 Convolutional Neural Networks 么 原创2016-01-11小S程序媛的日常程序媛的日常 首先,先感谢大家支持我们周六推送的第一次线下活动:程序媛们一 ...
最新文章
- 机翼翼尖_我所理解的流体力学 | 闲话翼尖涡
- 微型计算机原理王,微型计算机原理王1王忠民著.ppt
- Django_modelform组件
- Windows11升级绕过不支持该处理器
- [HNOI2016]序列(莫队,RMQ)
- JS Math方法、逻辑
- 谈谈研发PLM项目管理
- 3dmax 2015破解步骤
- pentaho资源库迁移-MySQL
- 图书馆管理系统需求规格说明书
- 前端知识质量内容网址
- [Win10+Excel365]尽管已启用VBA宏,Excel还是无法运行宏
- 奖励稀疏_好奇心解决稀疏奖励任务
- nrf24l01工作原理
- PCB正片和负片有什么区别
- 【leetcode】108. 将有序数组转换为二叉搜索树
- 展锐android r kernel 快速编译
- 中软国际入职考试 质量管理考试 资料整理
- Python从网易云音乐、QQ 音乐、酷狗音乐等搜索和下载歌曲
- 百度正式推出外链工具beta版本
热门文章
- 爆肝一周,完成了一款第一人称3D射击游戏,现在把源代码分享给大家,适合新手跟着学习
- 年终奖均值7826,你拖后腿了吗?
- Oracle索引梳理系列(十)- 直方图使用技巧及analyze table操作对直方图统计的影响(谨慎使用)...
- warning: statement has no effect [-Wunused-value]
- Matlab画柱状图和折线图的暗黑技巧
- 10%干股、65K高薪!本周新增多项高福利急聘职位
- Sheldon Numbers (暴力枚举)
- swagger2配置
- openssh for android,android安装openssh,通过其他电脑ssh登陆到安卓手机
- python里raise是什么意思_python raise有什么用