CMT

  • 概要
  • 介绍
    • transformer存在的问题
    • CMT中块的设计
  • 相关工作
    • CNN
    • Vision Transformer
  • 方法
    • 整体架构
    • CMT Block
      • LPU
      • LMHSA
      • IRFFN
  • 代码解析
    • LPU
    • IRFFN
    • LMHSA

概要

CNN捕获局部信息,Transformer来捕获全局信息

介绍

transformer存在的问题

  • 将图片打成patch,会忽略图片内部潜在的2D结构和空间局部信息。
  • transformer的块输出和输入大小固定,难以显示提取多尺度特征和低分辨率的特征。
  • 计算复杂度太高。自注意力的计算与输入图片的大小成二次复杂度。

CMT中块的设计

  • CMT块中的局部感知单元(LPU)和反向残差前馈网络(IRFFN)可以帮助捕获中间特征内的局部和全局结构信息,并提高网络的表示能力。

相关工作

CNN

网络 特点
LeNet 手写数字识别
AlexNet & VGGNet & GoogleNet & InceptionNet ImageNet大赛
ResNet 泛化性增强
SENet 自适应地重新校准通道特征响应

Vision Transformer

网络 特点
ViT 将NLP中的Transformer引入到CV领域
DeiT 蒸馏的方式,使用teacher model来指导Vision Transformer的训练
T2T-ViT 递归的将相邻的tokens聚合成为一个,transformer对visual tokens进行建模
TNT outer block建模patch embedding 之间的关系,inner block建模pixel embedding之间的关系
PVT 将金字塔结构引入到 ViT 中,可以为各种像素级密集预测任务生成多尺度特征图。
CPVT和CvT cnn引入到transformer之中

方法

整体架构

  • CMT stem(减小图片大小,提取本地信息)
  • Conv Stride(用来减少feature map,增大channel)
  • CMT block(捕获全局和局部关系)

CMT Block

LPU

  • 位置编码会破坏卷积中的平移不变性,忽略了patch之间的局部信息和patch内部的结构信息。
  • LPU来缓解这个问题。

LMHSA

  • 使用深度卷积来减少KV的大小,加入相对位置偏置,构成了轻量级的自注意力计算。

IRFFN

  • 深度卷积增强局部信息的提取,残差结构来促进梯度的传播能力。

代码解析

LPU

class LocalPerceptionUint(nn.Module):def __init__(self, dim, act=False):super(LocalPerceptionUint, self).__init__()self.act = act # 增强本地信息的提取self.conv_3x3_dw = ConvDW3x3(dim)if self.act:self.actation = nn.Sequential(nn.GELU(),nn.BatchNorm2d(dim))def forward(self, x):if self.act:out = self.actation(self.conv_3x3_dw(x))return out else:out = self.conv_3x3_dw(x)return out

IRFFN

class InvertedResidualFeedForward(nn.Module):def __init__(self, dim, dim_ratio=4.):super(InvertedResidualFeedForward, self).__init__()output_dim = int(dim_ratio * dim)self.conv1x1_gelu_bn = ConvGeluBN(in_channel=dim,out_channel=output_dim,kernel_size=1,stride_size=1,padding=0)self.conv3x3_dw = ConvDW3x3(dim=output_dim)  self.act = nn.Sequential(nn.GELU(),nn.BatchNorm2d(output_dim))self.conv1x1_pw = nn.Sequential(nn.Conv2d(output_dim, dim, 1, 1, 0),nn.BatchNorm2d(dim))def forward(self, x):x = self.conv1x1_gelu_bn(x)out = x + self.act(self.conv3x3_dw(x))out = self.conv1x1_pw(out)return out

LMHSA

class LightMutilHeadSelfAttention(nn.Module):"""calculate the self attention with down sample the resolution for k, v, add the relative position bias before softmaxArgs:dim (int) : features map channels or dims num_heads (int) : attention heads numbersrelative_pos_embeeding (bool) : relative position embeeding no_distance_pos_embeeding (bool): no_distance_pos_embeedingfeatures_size (int) : features shapeqkv_bias (bool) : if use the embeeding biasqk_scale (float) : qk scale if None use the default attn_drop (float) : attention dropout rateproj_drop (float) : project linear dropout ratesr_ratio (float)  : k, v resolution downsample ratioReturns:x : LMSA attention result, the shape is (B, H, W, C) that is the same as inputs."""def __init__(self, dim, num_heads=8, features_size=56, relative_pos_embeeding=False, no_distance_pos_embeeding=False, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1.):super(LightMutilHeadSelfAttention, self).__init__()assert dim % num_heads == 0, f"dim{dim}should be divided by num_heads{num_heads}"self.dim = dim self.num_heads = num_headshead_dim = dim // num_heads   # used for each attention headsself.scale = qk_scale or head_dim ** -0.5self.relative_pos_embeeding = relative_pos_embeedingself.no_distance_pos_embeeding = no_distance_pos_embeedingself.features_size = features_sizeself.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim*2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.softmax = nn.Softmax(dim=-1)self.sr_ratio = sr_ratioif sr_ratio > 1:self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)self.norm = nn.LayerNorm(dim) if self.relative_pos_embeeding:self.relative_indices = generate_relative_distance(self.features_size)self.position_embeeding = nn.Parameter(torch.randn(2 * self.features_size - 1, 2 * self.features_size - 1))elif self.no_distance_pos_embeeding:self.position_embeeding = nn.Parameter(torch.randn(self.features_size ** 2, self.features_size ** 2))else:self.position_embeeding = Noneif self.position_embeeding is not None:trunc_normal_(self.position_embeeding, std=0.2)def forward(self, x):B, C, H, W = x.shape N = H*Wx_q = rearrange(x, 'B C H W -> B (H W) C')  # translate the B,C,H,W to B (H X W) Cq = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)   # B,N,H,DIM -> B,H,N,DIM# conv for down sample the x resoution for the k, vif self.sr_ratio > 1:x_reduce_resolution = self.sr(x)x_kv = rearrange(x_reduce_resolution, 'B C H W -> B (H W) C ')x_kv = self.norm(x_kv)else:x_kv = rearrange(x, 'B C H W -> B (H W) C ')kv_emb = rearrange(self.kv(x_kv), 'B N (dim h l ) -> l B h N dim', h=self.num_heads, l=2)         # 2 B H N DIMk, v = kv_emb[0], kv_emb[1]attn = (q @ k.transpose(-2, -1)) * self.scale    # (B H Nq DIM) @ (B H DIM Nk) -> (B H NQ NK)# TODO: add the relation position bias, because the k_n != q_n, we need to split the position embeeding matrixq_n, k_n = q.shape[1], k.shape[2]if self.relative_pos_embeeding:attn = attn + self.position_embeeding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]][:, :k_n]elif self.no_distance_pos_embeeding:attn = attn + self.position_embeeding[:, :k_n]attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # (B H NQ NK) @ (B H NK dim)  -> (B NQ H*DIM)x = self.proj(x)x = self.proj_drop(x)x = rearrange(x, 'B (H W) C -> B C H W ', H=H, W=W)return x

CMT: Convolutional Neural Networks Meet Vision Transformers相关推荐

  1. 【读点论文】CMT: Convolutional Neural Networks Meet Vision Transformers

    CMT: Convolutional Neural Networks Meet Vision Transformers Abstract 视觉transformer已经成功地应用于图像识别任务,因为它 ...

  2. 【论文精读】CMT: Convolutional Neural Networks MeetVision Transformers

    声明 不定期更新自己精读的论文,通俗易懂,初级小白也可以理解 涉及范围:深度学习方向,包括 CV.NLP.Data fusion.Digital Twin 论文标题: CMT: Convolution ...

  3. 手机CNN网络模型--MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications

    MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications https://arxiv.org ...

  4. MobileNetV1《MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications》

    MobileNetV1<MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications> ...

  5. 轻量化网络(一)MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications

    轻量化网络研究(一)MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 研究过深度学习的 ...

  6. 回顾一些重要的CNN改进模型(你真的了解 Convolutional Neural Networks 么)

    转载自: 干货 | 你真的了解 Convolutional Neural Networks 么 https://mp.weixin.qq.com/s?__biz=MzAwMjM3MTc5OA==&am ...

  7. 吴恩达深度学习课程deeplearning.ai课程作业:Class 4 Week 1 Convolutional Neural Networks: Step by Step

    吴恩达deeplearning.ai课程作业,自己写的答案. 补充说明: 1. 评论中总有人问为什么直接复制这些notebook运行不了?请不要直接复制粘贴,不可能运行通过的,这个只是notebook ...

  8. CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章

    CV:翻译并解读2019<A Survey of the Recent Architectures of Deep Convolutional Neural Networks>第一章~第三 ...

  9. 干货 | 你真的了解 Convolutional Neural Networks 么

    干货 | 你真的了解 Convolutional Neural Networks 么 原创2016-01-11小S程序媛的日常程序媛的日常 首先,先感谢大家支持我们周六推送的第一次线下活动:程序媛们一 ...

最新文章

  1. 机翼翼尖_我所理解的流体力学 | 闲话翼尖涡
  2. 微型计算机原理王,微型计算机原理王1王忠民著.ppt
  3. Django_modelform组件
  4. Windows11升级绕过不支持该处理器
  5. [HNOI2016]序列(莫队,RMQ)
  6. JS Math方法、逻辑
  7. 谈谈研发PLM项目管理
  8. 3dmax 2015破解步骤
  9. pentaho资源库迁移-MySQL
  10. 图书馆管理系统需求规格说明书
  11. 前端知识质量内容网址
  12. [Win10+Excel365]尽管已启用VBA宏,Excel还是无法运行宏
  13. 奖励稀疏_好奇心解决稀疏奖励任务
  14. nrf24l01工作原理
  15. PCB正片和负片有什么区别
  16. 【leetcode】108. 将有序数组转换为二叉搜索树
  17. 展锐android r kernel 快速编译
  18. 中软国际入职考试 质量管理考试 资料整理
  19. Python从网易云音乐、QQ 音乐、酷狗音乐等搜索和下载歌曲
  20. 百度正式推出外链工具beta版本

热门文章

  1. 爆肝一周,完成了一款第一人称3D射击游戏,现在把源代码分享给大家,适合新手跟着学习
  2. 年终奖均值7826,你拖后腿了吗?
  3. Oracle索引梳理系列(十)- 直方图使用技巧及analyze table操作对直方图统计的影响(谨慎使用)...
  4. warning: statement has no effect [-Wunused-value]
  5. Matlab画柱状图和折线图的暗黑技巧
  6. 10%干股、65K高薪!本周新增多项高福利急聘职位
  7. Sheldon Numbers (暴力枚举)
  8. swagger2配置
  9. openssh for android,android安装openssh,通过其他电脑ssh登陆到安卓手机
  10. python里raise是什么意思_python raise有什么用