深度学习之图像分类(二十五)S2MLPv2 网络详解

目录

  • 深度学习之图像分类(二十五)S2MLPv2 网络详解
    • 1. 前言
    • 2. S2MLPv2
      • 2.1 S2MLPv2 Block
      • 2.2 Spatial-shift 与感受野反思
    • 3. 总结
    • 4. 代码

经过 S2MLP 和 Vision Permutator 的沉淀,为此本节我们便来学习学习 S2MLPv2 的基本思想。

1. 前言

S2MLPv2 依是百度提出的用于视觉的空间位移 MLP 架构,其作者以及顺序与 S2MLP 一模一样,其论文为 S2-MLPv2: Improved Spatial-Shift MLP Architecture for Vision。S2MLPv2 的修改点主要在于三处:金字塔结构(参考 ViP)、分三类情况进行考虑(参考 ViP)、使用 Split Attention(参考 ViP 和 ResNeSt)。总结而言就是把 ViP 中的 Permute-MLP layer 中别人沿着不同方向进行交互替换为了 Spatial-shift 操作。在参数量基本一致的情况下,其性能优于 ViP。

2. S2MLPv2

2.1 S2MLPv2 Block

S2MLPv2 和 S2MLPv1 类似,整体网络结构不做过多赘述,主要讲解一下 S2MLPv2 Block 的细节(建议大家先回顾之前的章节 S2MLP 以及 ViP):

  • 首先是特征图输入后,对 Channel 进行一个全连接,这里是对于特定位置信息进行交流,其实也就是 1 × 1 1 \times 1 1×1 卷积,只不过这里将维度变为了原来的 3 倍。然后经过一个 GELU 激活函数。
  • 将特征图均分为 3 等分,分别用于后续三个 Spatial-shift 分支的输入。
    • 第一个分支进行与 S2MLPv1 一样的 Spatial-shift 操作,即右-左-下-上移动。
    • 第二个分支进行与第一个分支反对称的 Spatial-shift 操作,即下-上-右-左移动。
    • 第三个分支保持不变
  • 之后将三个分支的结果通过 Split Attention 结合起来。这样不同位置的信息就被加到同一个通道上对齐了。
  • 再经过一个 MLP 进行不同位置的信息整合,然后经过 LN 激活函数。(看了这么多网络,其实激活函数在前在后都可以)

def spatial_shift1(x):b,w,h,c = x.size()x[:,1:,:,:c/4] = x[:,:w-1,:,:c/4]x[:,:w-1,:,c/4:c/2] = x[:,1:,:,c/4:c/2]x[:,:,1:,c/2:c*3/4] = x[:,:,:h-1,c/2:c*3/4]x[:,:,:h-1,3*c/4:] = x[:,:,1:,3*c/4:]return xdef spatial_shift2(x):b,w,h,c = x.size()x[:,:,1:,:c/4] = x[:,:,:h-1,:c/4]x[:,:,:h-1,c/4:c/2] = x[:,:,1:,c/4:c/2]x[:,1:,:,c/2:c*3/4] = x[:,:w-1,:,c/2:c*3/4]x[:,:w-1,:,3*c/4:] = x[:,1:,:,3*c/4:]return xclass S2-MLPv2(nn.Module):def __init__(self, channels):super().__init__()self.mlp1 = nn.Linear(channels,channels * 3)self.mlp2 = nn.Linear(channels,channels)self.split_attention = SplitAttention()def forward(self, x):b,w,h,c = x.size()x = self.mlp1(x)x1 = spatial_shift1(x[:,:,:,:c/3])x2 = spatial_shift2(x[:,:,:,c/3:c/3*2])x3 = x[:,:,:,c/3*2:]a = self.split_attention(x1,x2,x3)x = self.mlp2(a)return x
  • 接下来的就是和 MLP-Mixer 中的 Channel-mixing MLP 一致。

2.2 Spatial-shift 与感受野反思

三组 Spatial-shift (包括恒等)与一组相比有什么进步和问题呢

  • 传统计算机视觉感受野以及近期 ViP 工作等等,都提倡奇数和中心概念,即在某中心卷积核大小是奇数的,一左一右一上一下是对称的。原始的一组 Spatial-shift 其实是一个菱形感受野且不包括中心。现在有恒等之后则是菱形感受野且包括中心了,这是一个进步。
  • 但是第二组设计为与第一组反对称的结构,但是这个没有反彻底。其实这三组 Spatial-shift 也可看作是人精心设计构造的。那么我们仔细看一下,其实没有实现完全的互补。让我们把目光放到 Split Attention 之后,输出的特征图其实也可被看作四个部分,每部分对应着:左上中相加,右下中相加,上左中相加,下右中相加。为了更好的方便大家理解这句话,我们不妨先忽略 Split Attention 给出的权重,并将经过 Spatial-shift 操作前的三部分特征图分别记录为 f , g , h f,g,h f,g,h,输出记录为 z z z。则有如下公式,其中下标表示不同的旋转部分。
    • 如果从强迫症的观点看:第一组 Spatial-shift 是 右-左-下-上,则第二组 Spatial-shift 应该是 上-下-右-左 才对。
    • 如果从感受野完整性的观点看:第一组 Spatial-shift 是 右-左-下-上,则第二组 Spatial-shift 应该是 左上-左下-右上-右下 才对。

z 1 ( x , y ) = f 1 ( x − 1 , y ) + g 1 ( x , y − 1 ) + h 1 ( x , y ) z 2 ( x , y ) = f 2 ( x + 1 , y ) + g 2 ( x , y + 1 ) + h 2 ( x , y ) z 3 ( x , y ) = f 3 ( x , y − 1 ) + g 3 ( x − 1 , y ) + h 3 ( x , y ) z 4 ( x , y ) = f 4 ( x , y + 1 ) + g 4 ( x + 1 , y ) + h 4 ( x , y ) z_{1}(x,y) = f_{1}(x-1,y) + g_{1}(x,y-1) + h_{1}(x,y) \\ z_{2}(x,y) = f_{2}(x+1,y) + g_{2}(x,y+1) + h_{2}(x,y) \\ z_{3}(x,y) = f_{3}(x,y-1) + g_{3}(x-1,y) + h_{3}(x,y) \\ z_{4}(x,y) = f_{4}(x,y+1) + g_{4}(x+1,y) + h_{4}(x,y) z1(x,y)=f1(x1,y)+g1(x,y1)+h1(x,y)z2(x,y)=f2(x+1,y)+g2(x,y+1)+h2(x,y)z3(x,y)=f3(x,y1)+g3(x1,y)+h3(x,y)z4(x,y)=f4(x,y+1)+g4(x+1,y)+h4(x,y)

关于 Split 的消融实验,作者分别移除了第二部分和第三部分,发现移除第二部分损失的性能还比第三部分(恒等)的多,但是就差 0.1%,这个消融实验其实很难解释三部分怎么相互作用的,至少从计算机视觉感受野的角度不太说得清楚。或许 MLP 结构就不太适合用感受野来分析吧…

3. 总结

相比于现有的 MLP 的结构,S2-MLP 的一个重要优势是仅仅使用通道方向的全连接( 1 × 1 1 \times 1 1×1 卷积)是可以作为 Backbone 的,期待该团队后续的进展。S2-MLPv2 其实是通过 Spatial-shift 和 Split Attention 代替原有的 N × N N \times N N×N 卷积,本质上并没有延续 MLP-Mixer 架构中长距离依赖的思想。S2-MLPv2 中也并没有长距离依赖的使用。S2-MLPv2 虽然性能提升了,但是还没有开源,本身自己的贡献点其实不太足,这样做的理论性也不足。

延续我一贯的认识,如何在 MLP 架构中如何结合图像局部性和长距离依赖依然是值得探讨的点。

4. 代码

代码并没有开源,非官发复现的代码详见 此处。

import torch
from torch import nn
from einops.layers.torch import Reduce
from .utils import pairclass PreNormResidual(nn.Module):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = nn.LayerNorm(dim)def forward(self, x):return self.fn(self.norm(x)) + xdef spatial_shift1(x):b,w,h,c = x.size()x[:,1:,:,:c//4] = x[:,:w-1,:,:c//4]x[:,:w-1,:,c//4:c//2] = x[:,1:,:,c//4:c//2]x[:,:,1:,c//2:c*3//4] = x[:,:,:h-1,c//2:c*3//4]x[:,:,:h-1,3*c//4:] = x[:,:,1:,3*c//4:]return xdef spatial_shift2(x):b,w,h,c = x.size()x[:,:,1:,:c//4] = x[:,:,:h-1,:c//4]x[:,:,:h-1,c//4:c//2] = x[:,:,1:,c//4:c//2]x[:,1:,:,c//2:c*3//4] = x[:,:w-1,:,c//2:c*3//4]x[:,:w-1,:,3*c//4:] = x[:,1:,:,3*c//4:]return xclass SplitAttention(nn.Module):def __init__(self, channel = 512, k = 3):super().__init__()self.channel = channelself.k = kself.mlp1 = nn.Linear(channel, channel, bias = False)self.gelu = nn.GELU()self.mlp2 = nn.Linear(channel, channel * k, bias = False)self.softmax = nn.Softmax(1)def forward(self,x_all):b, k, h, w, c = x_all.shapex_all = x_all.reshape(b, k, -1, c)          #bs,k,n,ca = torch.sum(torch.sum(x_all, 1), 1)       #bs,chat_a = self.mlp2(self.gelu(self.mlp1(a)))  #bs,kchat_a = hat_a.reshape(b, self.k, c)         #bs,k,cbar_a = self.softmax(hat_a)                 #bs,k,cattention = bar_a.unsqueeze(-2)             # #bs,k,1,cout = attention * x_all                     # #bs,k,n,cout = torch.sum(out, 1).reshape(b, h, w, c)return outclass S2Attention(nn.Module):def __init__(self, channels=512):super().__init__()self.mlp1 = nn.Linear(channels, channels * 3)self.mlp2 = nn.Linear(channels, channels)self.split_attention = SplitAttention(channels)def forward(self, x):b, h, w, c = x.size()x = self.mlp1(x)x1 = spatial_shift1(x[:,:,:,:c])x2 = spatial_shift2(x[:,:,:,c:c*2])x3 = x[:,:,:,c*2:]x_all = torch.stack([x1, x2, x3], 1)a = self.split_attention(x_all)x = self.mlp2(a)return xclass S2Block(nn.Module):def __init__(self, d_model, depth, expansion_factor = 4, dropout = 0.):super().__init__()self.model = nn.Sequential(*[nn.Sequential(PreNormResidual(d_model, S2Attention(d_model)),PreNormResidual(d_model, nn.Sequential(nn.Linear(d_model, d_model * expansion_factor),nn.GELU(),nn.Dropout(dropout),nn.Linear(d_model * expansion_factor, d_model),nn.Dropout(dropout)))) for _ in range(depth)])def forward(self, x):x = x.permute(0, 2, 3, 1)x = self.model(x)x = x.permute(0, 3, 1, 2)return xclass S2MLPv2(nn.Module):def __init__(self,image_size=224,patch_size=[7, 2],in_channels=3,num_classes=1000,d_model=[192, 384],depth=[4, 14],expansion_factor = [3, 3],):image_size = pair(image_size)oldps = [1, 1]for ps in patch_size:ps = pair(ps)assert (image_size[0] % (ps[0] * oldps[0])) == 0, 'image must be divisible by patch size'assert (image_size[1] % (ps[1] * oldps[1])) == 0, 'image must be divisible by patch size'oldps[0] = oldps[0] * ps[0]oldps[1] = oldps[1] * ps[1]assert (len(patch_size) == len(depth) == len(d_model) == len(expansion_factor)), 'patch_size/depth/d_model/expansion_factor must be a list'super().__init__()self.stage = len(patch_size)self.stages = nn.Sequential(*[nn.Sequential(nn.Conv2d(in_channels if i == 0 else d_model[i - 1], d_model[i], kernel_size=patch_size[i], stride=patch_size[i]),S2Block(d_model[i], depth[i], expansion_factor[i], dropout = 0.)) for i in range(self.stage)])self.mlp_head = nn.Sequential(Reduce('b c h w -> b c', 'mean'),nn.Linear(d_model[-1], num_classes))def forward(self, x):embedding = self.stages(x)out = self.mlp_head(embedding)return out

深度学习之图像分类(二十五)-- S2MLPv2 网络详解相关推荐

  1. Keras深度学习实战(22)——生成对抗网络详解与实现

    Keras深度学习实战(22)--生成对抗网络详解与实现 0. 前言 1. 生成对抗网络原理 2. 模型分析 3. 利用生成对抗网络生成手写数字图像 小结 系列链接 0. 前言 生成对抗网络 (Gen ...

  2. 系统学习NLP(二十六)--BERT详解

    转自:https://zhuanlan.zhihu.com/p/48612853 前言 BERT(Bidirectional Encoder Representations from Transfor ...

  3. 深度学习自学(二十五):目标跟踪

    运动目标跟踪主流算法大致分类 主要基于两种思路: a)不依赖于先验知识,直接从图像序列中检测到运动目标,并进行目标识别,最终跟踪感兴趣的运动目标:  b)依赖于目标的先验知识,首先为运动目标建模,然后 ...

  4. 【C语言进阶深度学习记录】二十五 指针与数组的本质分析二

    文章目录 1 数组的访问方式 1.1 数组的访问方式代码分析 2 数组和指针不同 3 a 和 &a 的区别 3.1 指针运算的经典代码案例分析 4 数组作为函数的参数 4.1 数组作为函数参数 ...

  5. 深度学习之图像分类(十二)--MobileNetV3 网络结构

    深度学习之图像分类(十二)MobileNetV3 网络结构 目录 深度学习之图像分类(十二)MobileNetV3 网络结构 1. 前言 2. 更新 BlocK (bneck) 3. 重新设计激活函数 ...

  6. 深度学习入门笔记(十五):深度学习框架(TensorFlow和Pytorch之争)

    欢迎关注WX公众号:[程序员管小亮] 专栏--深度学习入门笔记 声明 1)该文章整理自网上的大牛和机器学习专家无私奉献的资料,具体引用的资料请看参考文献. 2)本文仅供学术交流,非商用.所以每一部分具 ...

  7. 未处理异常和C++异常——Windows核心编程学习手札之二十五

    未处理异常和C++异常 --Windows核心编程学习手札之二十五 当一个异常过滤器返回EXCEPTION_CONTINUE_SEARCH标识符时是告诉系统继续上溯调用树,寻找另外的异常过滤器,但当每 ...

  8. 深度学习之图像分类(十六)-- EfficientNetV2 网络结构

    深度学习之图像分类(十六)EfficientNetV2 网络结构 目录 深度学习之图像分类(十六)EfficientNetV2 网络结构 1. 前言 2. 从 EfficientNetV1 到 Eff ...

  9. 深度学习之图像分类(十九)-- Bottleneck Transformer(BoTNet)网络详解

    深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解 目录 深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解 1 ...

最新文章

  1. 正则语法完全正则表达式手册_语法格式重点
  2. Xamarin ios 教程 Xamarin跨平台开发 C#苹果应用开发
  3. SQLServer数据库的备份/恢复的3中策略实例
  4. iOS系统 越狱系统还原(平刷)
  5. 重磅发布 | 承载亿级流量的开发框架,闲鱼Flutter技术解析与实战大公开
  6. 分成收益破5000后,我决定将付费专栏开源了
  7. 顺序表输入栈元素c语言,C语言数据结构之栈简单操作
  8. Android中使用shape来定义控件的显示属性
  9. 在用户空间加载和卸载驱动
  10. 计算机组成原理课后答案(唐朔飞第三版) 第一章
  11. 数组中除一个元素外其他所有元素出现二或三次,找到只出现一次的元素
  12. linux动态库so更新
  13. 大学英语计算机统考怎么过,2011年9月大学英语B 统考 计算机网考样题
  14. (1)桌面客制化之单屏幕修改以及wight修改
  15. 打造一个无广告无弹窗快速的Windows办公环境(软件推荐)
  16. Kinect for Unity检测身高方法
  17. 【数据库数据恢复】华为云mysql数据库数据被delete的数据恢复案例
  18. NLP预训练模型综述
  19. C语言初学基础篇:编译型语言和解释型语言
  20. 北师大 外国教育史-1(古希腊教育)

热门文章

  1. 重庆美食地图(绝对霸道)
  2. 云效搭建流水线实现自动化构建部署
  3. [转载]Z-stack 应用程序编程接口(API)-网络层
  4. 设备管理---要点练习及总结
  5. 联想ERP项目实施案例分析(9):工作方法总结
  6. Python numpy.abs和abs函数别再傻傻分不清了
  7. 关于H5页面在iPhoneX适配(转)
  8. FPGA系列7——Xilinx复数乘法器(Complex Multiplier v6.0)使用小结
  9. 【Cadence使用】PCB元器件匹配3D模型
  10. Step1我学习ros2的一些经历(从ubuntu安装到ros2中的位姿转换)