深度学习之图像分类(二十五)-- S2MLPv2 网络详解
深度学习之图像分类(二十五)S2MLPv2 网络详解
目录
- 深度学习之图像分类(二十五)S2MLPv2 网络详解
- 1. 前言
- 2. S2MLPv2
- 2.1 S2MLPv2 Block
- 2.2 Spatial-shift 与感受野反思
- 3. 总结
- 4. 代码
经过 S2MLP 和 Vision Permutator 的沉淀,为此本节我们便来学习学习 S2MLPv2 的基本思想。
1. 前言
S2MLPv2 依是百度提出的用于视觉的空间位移 MLP 架构,其作者以及顺序与 S2MLP 一模一样,其论文为 S2-MLPv2: Improved Spatial-Shift MLP Architecture for Vision。S2MLPv2 的修改点主要在于三处:金字塔结构(参考 ViP)、分三类情况进行考虑(参考 ViP)、使用 Split Attention(参考 ViP 和 ResNeSt)。总结而言就是把 ViP 中的 Permute-MLP layer 中别人沿着不同方向进行交互替换为了 Spatial-shift 操作。在参数量基本一致的情况下,其性能优于 ViP。
2. S2MLPv2
2.1 S2MLPv2 Block
S2MLPv2 和 S2MLPv1 类似,整体网络结构不做过多赘述,主要讲解一下 S2MLPv2 Block 的细节(建议大家先回顾之前的章节 S2MLP 以及 ViP):
- 首先是特征图输入后,对 Channel 进行一个全连接,这里是对于特定位置信息进行交流,其实也就是 1 × 1 1 \times 1 1×1 卷积,只不过这里将维度变为了原来的 3 倍。然后经过一个 GELU 激活函数。
- 将特征图均分为 3 等分,分别用于后续三个 Spatial-shift 分支的输入。
- 第一个分支进行与 S2MLPv1 一样的 Spatial-shift 操作,即右-左-下-上移动。
- 第二个分支进行与第一个分支反对称的 Spatial-shift 操作,即下-上-右-左移动。
- 第三个分支保持不变
- 之后将三个分支的结果通过 Split Attention 结合起来。这样不同位置的信息就被加到同一个通道上对齐了。
- 再经过一个 MLP 进行不同位置的信息整合,然后经过 LN 激活函数。(看了这么多网络,其实激活函数在前在后都可以)
def spatial_shift1(x):b,w,h,c = x.size()x[:,1:,:,:c/4] = x[:,:w-1,:,:c/4]x[:,:w-1,:,c/4:c/2] = x[:,1:,:,c/4:c/2]x[:,:,1:,c/2:c*3/4] = x[:,:,:h-1,c/2:c*3/4]x[:,:,:h-1,3*c/4:] = x[:,:,1:,3*c/4:]return xdef spatial_shift2(x):b,w,h,c = x.size()x[:,:,1:,:c/4] = x[:,:,:h-1,:c/4]x[:,:,:h-1,c/4:c/2] = x[:,:,1:,c/4:c/2]x[:,1:,:,c/2:c*3/4] = x[:,:w-1,:,c/2:c*3/4]x[:,:w-1,:,3*c/4:] = x[:,1:,:,3*c/4:]return xclass S2-MLPv2(nn.Module):def __init__(self, channels):super().__init__()self.mlp1 = nn.Linear(channels,channels * 3)self.mlp2 = nn.Linear(channels,channels)self.split_attention = SplitAttention()def forward(self, x):b,w,h,c = x.size()x = self.mlp1(x)x1 = spatial_shift1(x[:,:,:,:c/3])x2 = spatial_shift2(x[:,:,:,c/3:c/3*2])x3 = x[:,:,:,c/3*2:]a = self.split_attention(x1,x2,x3)x = self.mlp2(a)return x
- 接下来的就是和 MLP-Mixer 中的 Channel-mixing MLP 一致。
2.2 Spatial-shift 与感受野反思
三组 Spatial-shift (包括恒等)与一组相比有什么进步和问题呢?
- 传统计算机视觉感受野以及近期 ViP 工作等等,都提倡奇数和中心概念,即在某中心卷积核大小是奇数的,一左一右一上一下是对称的。原始的一组 Spatial-shift 其实是一个菱形感受野且不包括中心。现在有恒等之后则是菱形感受野且包括中心了,这是一个进步。
- 但是第二组设计为与第一组反对称的结构,但是这个反没有反彻底。其实这三组 Spatial-shift 也可看作是人精心设计构造的。那么我们仔细看一下,其实没有实现完全的互补。让我们把目光放到 Split Attention 之后,输出的特征图其实也可被看作四个部分,每部分对应着:左上中相加,右下中相加,上左中相加,下右中相加。为了更好的方便大家理解这句话,我们不妨先忽略 Split Attention 给出的权重,并将经过 Spatial-shift 操作前的三部分特征图分别记录为 f , g , h f,g,h f,g,h,输出记录为 z z z。则有如下公式,其中下标表示不同的旋转部分。
- 如果从强迫症的观点看:第一组 Spatial-shift 是 右-左-下-上,则第二组 Spatial-shift 应该是 上-下-右-左 才对。
- 如果从感受野完整性的观点看:第一组 Spatial-shift 是 右-左-下-上,则第二组 Spatial-shift 应该是 左上-左下-右上-右下 才对。
z 1 ( x , y ) = f 1 ( x − 1 , y ) + g 1 ( x , y − 1 ) + h 1 ( x , y ) z 2 ( x , y ) = f 2 ( x + 1 , y ) + g 2 ( x , y + 1 ) + h 2 ( x , y ) z 3 ( x , y ) = f 3 ( x , y − 1 ) + g 3 ( x − 1 , y ) + h 3 ( x , y ) z 4 ( x , y ) = f 4 ( x , y + 1 ) + g 4 ( x + 1 , y ) + h 4 ( x , y ) z_{1}(x,y) = f_{1}(x-1,y) + g_{1}(x,y-1) + h_{1}(x,y) \\ z_{2}(x,y) = f_{2}(x+1,y) + g_{2}(x,y+1) + h_{2}(x,y) \\ z_{3}(x,y) = f_{3}(x,y-1) + g_{3}(x-1,y) + h_{3}(x,y) \\ z_{4}(x,y) = f_{4}(x,y+1) + g_{4}(x+1,y) + h_{4}(x,y) z1(x,y)=f1(x−1,y)+g1(x,y−1)+h1(x,y)z2(x,y)=f2(x+1,y)+g2(x,y+1)+h2(x,y)z3(x,y)=f3(x,y−1)+g3(x−1,y)+h3(x,y)z4(x,y)=f4(x,y+1)+g4(x+1,y)+h4(x,y)
关于 Split 的消融实验,作者分别移除了第二部分和第三部分,发现移除第二部分损失的性能还比第三部分(恒等)的多,但是就差 0.1%,这个消融实验其实很难解释三部分怎么相互作用的,至少从计算机视觉感受野的角度不太说得清楚。或许 MLP 结构就不太适合用感受野来分析吧…
3. 总结
相比于现有的 MLP 的结构,S2-MLP 的一个重要优势是仅仅使用通道方向的全连接( 1 × 1 1 \times 1 1×1 卷积)是可以作为 Backbone 的,期待该团队后续的进展。S2-MLPv2 其实是通过 Spatial-shift 和 Split Attention 代替原有的 N × N N \times N N×N 卷积,本质上并没有延续 MLP-Mixer 架构中长距离依赖的思想。S2-MLPv2 中也并没有长距离依赖的使用。S2-MLPv2 虽然性能提升了,但是还没有开源,本身自己的贡献点其实不太足,这样做的理论性也不足。
延续我一贯的认识,如何在 MLP 架构中如何结合图像局部性和长距离依赖依然是值得探讨的点。
4. 代码
代码并没有开源,非官发复现的代码详见 此处。
import torch
from torch import nn
from einops.layers.torch import Reduce
from .utils import pairclass PreNormResidual(nn.Module):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = nn.LayerNorm(dim)def forward(self, x):return self.fn(self.norm(x)) + xdef spatial_shift1(x):b,w,h,c = x.size()x[:,1:,:,:c//4] = x[:,:w-1,:,:c//4]x[:,:w-1,:,c//4:c//2] = x[:,1:,:,c//4:c//2]x[:,:,1:,c//2:c*3//4] = x[:,:,:h-1,c//2:c*3//4]x[:,:,:h-1,3*c//4:] = x[:,:,1:,3*c//4:]return xdef spatial_shift2(x):b,w,h,c = x.size()x[:,:,1:,:c//4] = x[:,:,:h-1,:c//4]x[:,:,:h-1,c//4:c//2] = x[:,:,1:,c//4:c//2]x[:,1:,:,c//2:c*3//4] = x[:,:w-1,:,c//2:c*3//4]x[:,:w-1,:,3*c//4:] = x[:,1:,:,3*c//4:]return xclass SplitAttention(nn.Module):def __init__(self, channel = 512, k = 3):super().__init__()self.channel = channelself.k = kself.mlp1 = nn.Linear(channel, channel, bias = False)self.gelu = nn.GELU()self.mlp2 = nn.Linear(channel, channel * k, bias = False)self.softmax = nn.Softmax(1)def forward(self,x_all):b, k, h, w, c = x_all.shapex_all = x_all.reshape(b, k, -1, c) #bs,k,n,ca = torch.sum(torch.sum(x_all, 1), 1) #bs,chat_a = self.mlp2(self.gelu(self.mlp1(a))) #bs,kchat_a = hat_a.reshape(b, self.k, c) #bs,k,cbar_a = self.softmax(hat_a) #bs,k,cattention = bar_a.unsqueeze(-2) # #bs,k,1,cout = attention * x_all # #bs,k,n,cout = torch.sum(out, 1).reshape(b, h, w, c)return outclass S2Attention(nn.Module):def __init__(self, channels=512):super().__init__()self.mlp1 = nn.Linear(channels, channels * 3)self.mlp2 = nn.Linear(channels, channels)self.split_attention = SplitAttention(channels)def forward(self, x):b, h, w, c = x.size()x = self.mlp1(x)x1 = spatial_shift1(x[:,:,:,:c])x2 = spatial_shift2(x[:,:,:,c:c*2])x3 = x[:,:,:,c*2:]x_all = torch.stack([x1, x2, x3], 1)a = self.split_attention(x_all)x = self.mlp2(a)return xclass S2Block(nn.Module):def __init__(self, d_model, depth, expansion_factor = 4, dropout = 0.):super().__init__()self.model = nn.Sequential(*[nn.Sequential(PreNormResidual(d_model, S2Attention(d_model)),PreNormResidual(d_model, nn.Sequential(nn.Linear(d_model, d_model * expansion_factor),nn.GELU(),nn.Dropout(dropout),nn.Linear(d_model * expansion_factor, d_model),nn.Dropout(dropout)))) for _ in range(depth)])def forward(self, x):x = x.permute(0, 2, 3, 1)x = self.model(x)x = x.permute(0, 3, 1, 2)return xclass S2MLPv2(nn.Module):def __init__(self,image_size=224,patch_size=[7, 2],in_channels=3,num_classes=1000,d_model=[192, 384],depth=[4, 14],expansion_factor = [3, 3],):image_size = pair(image_size)oldps = [1, 1]for ps in patch_size:ps = pair(ps)assert (image_size[0] % (ps[0] * oldps[0])) == 0, 'image must be divisible by patch size'assert (image_size[1] % (ps[1] * oldps[1])) == 0, 'image must be divisible by patch size'oldps[0] = oldps[0] * ps[0]oldps[1] = oldps[1] * ps[1]assert (len(patch_size) == len(depth) == len(d_model) == len(expansion_factor)), 'patch_size/depth/d_model/expansion_factor must be a list'super().__init__()self.stage = len(patch_size)self.stages = nn.Sequential(*[nn.Sequential(nn.Conv2d(in_channels if i == 0 else d_model[i - 1], d_model[i], kernel_size=patch_size[i], stride=patch_size[i]),S2Block(d_model[i], depth[i], expansion_factor[i], dropout = 0.)) for i in range(self.stage)])self.mlp_head = nn.Sequential(Reduce('b c h w -> b c', 'mean'),nn.Linear(d_model[-1], num_classes))def forward(self, x):embedding = self.stages(x)out = self.mlp_head(embedding)return out
深度学习之图像分类(二十五)-- S2MLPv2 网络详解相关推荐
- Keras深度学习实战(22)——生成对抗网络详解与实现
Keras深度学习实战(22)--生成对抗网络详解与实现 0. 前言 1. 生成对抗网络原理 2. 模型分析 3. 利用生成对抗网络生成手写数字图像 小结 系列链接 0. 前言 生成对抗网络 (Gen ...
- 系统学习NLP(二十六)--BERT详解
转自:https://zhuanlan.zhihu.com/p/48612853 前言 BERT(Bidirectional Encoder Representations from Transfor ...
- 深度学习自学(二十五):目标跟踪
运动目标跟踪主流算法大致分类 主要基于两种思路: a)不依赖于先验知识,直接从图像序列中检测到运动目标,并进行目标识别,最终跟踪感兴趣的运动目标: b)依赖于目标的先验知识,首先为运动目标建模,然后 ...
- 【C语言进阶深度学习记录】二十五 指针与数组的本质分析二
文章目录 1 数组的访问方式 1.1 数组的访问方式代码分析 2 数组和指针不同 3 a 和 &a 的区别 3.1 指针运算的经典代码案例分析 4 数组作为函数的参数 4.1 数组作为函数参数 ...
- 深度学习之图像分类(十二)--MobileNetV3 网络结构
深度学习之图像分类(十二)MobileNetV3 网络结构 目录 深度学习之图像分类(十二)MobileNetV3 网络结构 1. 前言 2. 更新 BlocK (bneck) 3. 重新设计激活函数 ...
- 深度学习入门笔记(十五):深度学习框架(TensorFlow和Pytorch之争)
欢迎关注WX公众号:[程序员管小亮] 专栏--深度学习入门笔记 声明 1)该文章整理自网上的大牛和机器学习专家无私奉献的资料,具体引用的资料请看参考文献. 2)本文仅供学术交流,非商用.所以每一部分具 ...
- 未处理异常和C++异常——Windows核心编程学习手札之二十五
未处理异常和C++异常 --Windows核心编程学习手札之二十五 当一个异常过滤器返回EXCEPTION_CONTINUE_SEARCH标识符时是告诉系统继续上溯调用树,寻找另外的异常过滤器,但当每 ...
- 深度学习之图像分类(十六)-- EfficientNetV2 网络结构
深度学习之图像分类(十六)EfficientNetV2 网络结构 目录 深度学习之图像分类(十六)EfficientNetV2 网络结构 1. 前言 2. 从 EfficientNetV1 到 Eff ...
- 深度学习之图像分类(十九)-- Bottleneck Transformer(BoTNet)网络详解
深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解 目录 深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解 1 ...
最新文章
- 正则语法完全正则表达式手册_语法格式重点
- Xamarin ios 教程 Xamarin跨平台开发 C#苹果应用开发
- SQLServer数据库的备份/恢复的3中策略实例
- iOS系统 越狱系统还原(平刷)
- 重磅发布 | 承载亿级流量的开发框架,闲鱼Flutter技术解析与实战大公开
- 分成收益破5000后,我决定将付费专栏开源了
- 顺序表输入栈元素c语言,C语言数据结构之栈简单操作
- Android中使用shape来定义控件的显示属性
- 在用户空间加载和卸载驱动
- 计算机组成原理课后答案(唐朔飞第三版) 第一章
- 数组中除一个元素外其他所有元素出现二或三次,找到只出现一次的元素
- linux动态库so更新
- 大学英语计算机统考怎么过,2011年9月大学英语B 统考 计算机网考样题
- (1)桌面客制化之单屏幕修改以及wight修改
- 打造一个无广告无弹窗快速的Windows办公环境(软件推荐)
- Kinect for Unity检测身高方法
- 【数据库数据恢复】华为云mysql数据库数据被delete的数据恢复案例
- NLP预训练模型综述
- C语言初学基础篇:编译型语言和解释型语言
- 北师大 外国教育史-1(古希腊教育)
热门文章
- 重庆美食地图(绝对霸道)
- 云效搭建流水线实现自动化构建部署
- [转载]Z-stack 应用程序编程接口(API)-网络层
- 设备管理---要点练习及总结
- 联想ERP项目实施案例分析(9):工作方法总结
- Python numpy.abs和abs函数别再傻傻分不清了
- 关于H5页面在iPhoneX适配(转)
- FPGA系列7——Xilinx复数乘法器(Complex Multiplier v6.0)使用小结
- 【Cadence使用】PCB元器件匹配3D模型
- Step1我学习ros2的一些经历(从ubuntu安装到ros2中的位姿转换)