【深度学习基础】【深度学习基础】SENet——PyTorch实现CNN的SE结构改造

  • 1 论文关键信息
    • 1.1 SE block
      • 1.1.1 squeeze
      • 1.1.2 Exitation
  • 2 pytorch 实现
    • 2.1 SE block代码实现
    • 2.2 CNN加入SE进行改造
      • 2.2.1 网络中添加SE block的位置
      • 2.2.2 SE-ResNet-50, SE-ResNeXt-50(32x4d)的PyTorch实现

代码已同步至Github:https://github.com/EasonCai-Dev/torch_backbones

1 论文关键信息

每个卷积操作实际上是在输入的空间维度(宽高,HxW)和通道维度(channels)上面进行的乘加操作。论文的主要思想是:之前的网络一般在空间维度上进行优化,比如Residual block,Inception block和Dense block等等;这篇论文主要从特征的Channel维度上面着手考虑,提出了一种Squeeze-and-Exitation结构(SE block)。SE block可以嵌入到之前所有经典的网络结构中,实现模型的改造。

1.1 SE block

SE block的结构如下图所示,论文给出了公式介绍。

假设一层的卷积操作为 F t r F_{tr} Ftr​,其输入为 X ∈ R H × W × C X\in\mathbb{R}^{H\times W\times C} X∈RH×W×C,输出为 X ∈ U H × W × C X\in\mathbb{U}^{H\times W\times C} X∈UH×W×C,则该卷积可以表示为 V = [ v 1 , v 2 , . . . , v C ] V=[v_1, v_2, ..., v_C] V=[v1​,v2​,...,vC​],输出可以表示为 U = [ u 1 , u 2 , . . . , u C ] U=[u_1, u_2, ..., u_C] U=[u1​,u2​,...,uC​],卷积过程可以表示为:
u c = v c ∗ X = ∑ s = 1 C ′ v c s ∗ x s u_c = v_c*X = \sum_{s=1}^{C'}v_c^s*x^s uc​=vc​∗X=s=1∑C′​vcs​∗xs

1.1.1 squeeze

SE block的squeeze部分通过一个全局平均滤波实现全局信息的获取,如公式:
z c = F s q ( u c ) = 1 H × W ∑ i = 1 H ∑ i = 1 W u c ( i , j ) z_c = F_sq(u_c) = \frac{1}{H\times W}\sum_{i=1}^H\sum_{i=1}^Wu_c(i,j) zc​=Fs​q(uc​)=H×W1​i=1∑H​i=1∑W​uc​(i,j)

1.1.2 Exitation

这个部分首先通过一个全连接层Linear(C, C/r)将特征压缩到C/r通道,然后使用ReLU层进行非线性操作,接着使用全连接层Linear(C/r, C)将特征还原至C通道,最后使用一个Simoid函数激活 ,如下式
s = F e x ( z , W ) = σ ( W 2 δ ( W 1 z ) ) s = F_ex(z, W) = \sigma(W_2\delta(W_1z)) s=Fe​x(z,W)=σ(W2​δ(W1​z))
w h e r e , W 1 ∈ R C r × C , W 2 = ∈ R C × C r where, W_1\in\mathbb{R}^{{\frac{C}{r}}\times C}, W_2 = \in\mathbb{R}^{C\times {\frac{C}{r}}} where,W1​∈RrC​×C,W2​=∈RC×rC​
其中, W 1 W_1 W1​表示负责压缩的全连接层参数, W 2 W_2 W2​表示负责还原维度的全连接层参数, σ \sigma σ表示Sigmoid函数, δ \delta δ表示ReLU函数。变量 r r r是一个压缩参数。

在这一步之后,我们得到一个 s ∈ R 1 × 1 × C s\in\mathbb{R}^{1\times 1\times C} s∈R1×1×C,然后将 s s s与上一层卷积特征进行逐空间位置相乘,得到SE block的输出,如下式:
X c ~ = F s c a l e ( u c , s c ) = s c u c \tilde{X_c} = F_{scale}(u_c, s_c) = s_cu_c Xc​~​=Fscale​(uc​,sc​)=sc​uc​

2 pytorch 实现

2.1 SE block代码实现

考虑到,对1x1xC的输入进行1x1的卷积,其效果等同于全连接层,我们可以省去将张量进行降维和升维的过程,于是SE block的实现就变得简单了:

import torch.nn as nn
import torch.nn.functional as Fclass SE(nn.Module):def __init__(self, in_chnls, ratio):super(SE, self).__init__()self.squeeze = nn.AdaptiveAvgPool2d((1, 1))self.compress = nn.Conv2d(in_chnls, in_chnls//ratio, 1, 1, 0)self.excitation = nn.Conv2d(in_chnls//ratio, in_chnls, 1, 1, 0)def forward(self, x):out = self.squeeze(x)out = self.compress(out)out = F.relu(out)out = self.excitation(out)return F.sigmoid(out)

2.2 CNN加入SE进行改造

论文主要用了ResNet-50,ResNeXt-50(32x4d)进行实验,我也是改造了这两个网络。按照论文的意思,很多经典的CNN都可以加入SE进行改造,比如Inception,DenseNet,DPN,Res2Net,et al.

2.2.1 网络中添加SE block的位置

如下图,论文中给出,在Inception网络中,SE block应该加在Inception block之后,ResNet网络应该加在shortcut之前,bottleneck之后。同理,ResNeXt也是加在shortcut之前。

2.2.2 SE-ResNet-50, SE-ResNeXt-50(32x4d)的PyTorch实现

我们只需要在合适的位置加入SE block就可以完成模型的改造,我这里在之前博客代码的基础上实现。关于压缩参数 r r r的选取,论文做了一些实验,发现在ResNet-50,和ResNeXt-50(32x4d)中使用 r = 16 r=16 r=16能平衡准确率与模型参数

改造后代码如下:
(1)ResNet

class BasicBlock(nn.Module):"""basic building block for ResNet-18, ResNet-34"""message = "basic"def __init__(self, in_channels, out_channels, strides, is_se=False):super(BasicBlock, self).__init__()self.is_se = is_seself.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=strides, padding=1, bias=False)  # same paddingself.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_channels)if self.is_se:self.se = SE(out_channels, 16)# fit input with residual outputself.short_cut = nn.Sequential()if strides is not 1:self.short_cut = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, stride=strides, padding=0, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):out = self.conv1(x)out = self.bn(out)out = F.relu(out)out = self.conv2(out)out = self.bn(out)if self.is_se:coefficient = self.se(out)out *= coefficientout += self.short_cut(x)return F.relu(out)class BottleNeck(nn.Module):"""BottleNeck block for RestNet-50, ResNet-101, ResNet-152"""message = "bottleneck"def __init__(self, in_channels, out_channels, strides, is_se=False):super(BottleNeck, self).__init__()self.is_se = is_seself.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)  # same paddingself.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=strides, padding=1, bias=False)self.conv3 = nn.Conv2d(out_channels, out_channels * 4, 1, stride=1, padding=0, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.bn2 = nn.BatchNorm2d(out_channels * 4)if self.is_se:self.se = SE(out_channels * 4, 16)# fit input with residual outputself.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * 4, 1, stride=strides, padding=0, bias=False),nn.BatchNorm2d(out_channels*4))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = F.relu(out)out = self.conv2(out)out = self.bn1(out)out = F.relu(out)out = self.conv3(out)out = self.bn2(out)if self.is_se:coefficient = self.se(out)out *= coefficientout += self.shortcut(x)return F.relu(out)class ResNet(nn.Module):"""building ResNet_34"""def __init__(self, block: object, groups: object, num_classes, is_se=False) -> object:super(ResNet, self).__init__()self.channels = 64  # out channels from the first convolutional layerself.block = blockself.is_se = is_seself.conv1 = nn.Conv2d(3, self.channels, 7, stride=2, padding=3, bias=False)self.bn = nn.BatchNorm2d(self.channels)self.pool1 = nn.MaxPool2d(3, 2, 1)self.conv2_x = self._make_conv_x(channels=64, blocks=groups[0], strides=1, index=2)self.conv3_x = self._make_conv_x(channels=128, blocks=groups[1], strides=2, index=3)self.conv4_x = self._make_conv_x(channels=256, blocks=groups[2], strides=2, index=4)self.conv5_x = self._make_conv_x(channels=512, blocks=groups[3], strides=2, index=5)self.pool2 = nn.AvgPool2d(7)patches = 512 if self.block.message == "basic" else 512 * 4self.fc = nn.Linear(patches, num_classes)  # for 224 * 224 input sizedef _make_conv_x(self, channels, blocks, strides, index):"""making convolutional group:param channels: output channels of the conv-group:param blocks: number of blocks in the conv-group:param strides: strides:return: conv-group"""list_strides = [strides] + [1] * (blocks - 1)  # In conv_x groups, the first strides is 2, the others are ones.conv_x = nn.Sequential()for i in range(len(list_strides)):layer_name = str("block_%d_%d" % (index, i))  # when use add_module, the name should be difference.conv_x.add_module(layer_name, self.block(self.channels, channels, list_strides[i], self.is_se))self.channels = channels if self.block.message == "basic" else channels * 4return conv_xdef forward(self, x):out = self.conv1(x)out = F.relu(self.bn(out))out = self.pool1(out)out = self.conv2_x(out)out = self.conv3_x(out)out = self.conv4_x(out)out = self.conv5_x(out)out = self.pool2(out)out = out.view(out.size(0), -1)out = F.softmax(self.fc(out))return outdef ResNet_50_SE(num_classes=1000):return ResNet(block=BottleNeck, groups=[3, 4, 6, 3], num_classes=num_classes, is_se=True)

(2)ResNeXt

class ResNeXt_Block(nn.Module):"""ResNeXt block with group convolutions"""def __init__(self, in_chnls, cardinality, group_depth, stride, is_se=False):super(ResNeXt_Block, self).__init__()self.is_se = is_seself.group_chnls = cardinality * group_depthself.conv1 = BN_Conv2d(in_chnls, self.group_chnls, 1, stride=1, padding=0)self.conv2 = BN_Conv2d(self.group_chnls, self.group_chnls, 3, stride=stride, padding=1, groups=cardinality)self.conv3 = nn.Conv2d(self.group_chnls, self.group_chnls*2, 1, stride=1, padding=0)self.bn = nn.BatchNorm2d(self.group_chnls*2)if self.is_se:self.se = SE(self.group_chnls*2, 16)self.short_cut = nn.Sequential(nn.Conv2d(in_chnls, self.group_chnls*2, 1, stride, 0, bias=False),nn.BatchNorm2d(self.group_chnls*2))def forward(self, x):out = self.conv1(x)out = self.conv2(out)out = self.bn(self.conv3(out))if self.is_se:coefficient = self.se(out)out *= coefficientout += self.short_cut(x)return F.relu(out)class ResNeXt(nn.Module):"""ResNeXt builder"""def __init__(self, layers: object, cardinality, group_depth, num_classes, is_se=False) -> object:super(ResNeXt, self).__init__()self.is_se = is_seself.cardinality = cardinalityself.channels = 64self.conv1 = BN_Conv2d(3, self.channels, 7, stride=2, padding=3)d1 = group_depthself.conv2 = self.___make_layers(d1, layers[0], stride=1)d2 = d1 * 2self.conv3 = self.___make_layers(d2, layers[1], stride=2)d3 = d2 * 2self.conv4 = self.___make_layers(d3, layers[2], stride=2)d4 = d3 * 2self.conv5 = self.___make_layers(d4, layers[3], stride=2)self.fc = nn.Linear(self.channels, num_classes)   # 224x224 input sizedef ___make_layers(self, d, blocks, stride):strides = [stride] + [1] * (blocks-1)layers = []for stride in strides:layers.append(ResNeXt_Block(self.channels, self.cardinality, d, stride, self.is_se))self.channels = self.cardinality*d*2return nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = F.max_pool2d(out, 3, 2, 1)out = self.conv2(out)out = self.conv3(out)out = self.conv4(out)out = self.conv5(out)out = F.avg_pool2d(out, 7)out = out.view(out.size(0), -1)out = F.softmax(self.fc(out))return outdef resNeXt50_32x4d_SE(num_classes=1000):return ResNeXt([3, 4, 6, 3], 32, 4, num_classes, is_se=True)

有相关实验需求或者更进一步兴趣的同学,也可以把其他网络也加入SE试试看,嘿嘿:)

【深度学习基础】SENet——PyTorch实现CNN的SE结构改造相关推荐

  1. (十四)从零开始学人工智能-深度学习基础及CNN

    文章目录 一.深度学习基础 1.1 深度学习及其发展历史 1.1.1 什么是学习? 1.1.2 什么是机器学习? 1.1.3 什么是深度学习? 1.1.4 深度学习发展历史 1.1.5 小结 1.2 ...

  2. 1.0 深度学习回顾与PyTorch简介 - PyTorch学习笔记

    P1 深度学习回顾与PyTorch简介 视频课程地址:点我 fly~~~ 本节课主要偏向于NLP,因为作者本人是做NLP方向NLP 预训练三种模型: BERT OpenAI GPT ELMo [NLP ...

  3. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  4. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

  5. 五万字总结,深度学习基础。

    文章目录 1 基本概念 1.1 神经网络组成? 1.2 神经网络有哪些常用模型结构? 1.3 如何选择深度学习开发平台? 1.4 为什么深层神经网络难以训练? 1.5 深度学习和机器学习的异同? 2 ...

  6. 从零开始的深度学习(一) 经典CNN网络 LeNet-5

    从零开始的深度学习(一) 经典CNN网络 LeNet-5 之前的四篇博客围绕着一个大作业项目来进行的入门,由于小白初涉,因此行文中有时侧重于某些并不重要的东西,同时也忽略了许多其实蛮重要的东西,再加上 ...

  7. 卷积、池化、激活函数、初始化、归一化、正则化、学习率——深度学习基础总结

    有幸拜读大佬言有三的书<深度学习之模型设计>,以下是我的读书笔记,仅供参考,详细的内容还得拜读原著,错误之处还望指正.下面的三张图片来自知乎. <深度学习之模型设计>读书笔记- ...

  8. 资源 | Intel发布AI免费系列课程3部曲:机器学习基础、深度学习基础以及TensorFlow基础

    翻译 | AI科技大本营(公众号ID:rgznai100) 校对 | 成龙 编辑 | 明明 Intel于近期发布了三门AI系列的免费课程,分别是关于机器学习基础.深度学习基础.TensorFlow基础 ...

  9. 干货|《深度学习入门之Pytorch》资料下载

    深度学习如今已经成为了科技领域中炙手可热的技术,而很多机器学习框架也成为了研究者和业界开发者的新宠,从早期的学术框架Caffe.Theano到如今的Pytorch.TensorFlow,但是当时间线来 ...

最新文章

  1. 我为什么要使用IDE? [关闭]
  2. html 超链接 ppt,HTML超链接要点.ppt
  3. 吴恩达深度学习笔记6-Course2-Week2【优化算法】
  4. dwr配置文件dwr.xml详解
  5. scipy.interpolate: 插值和平滑处理
  6. jquery页面滚动显示浮动菜单栏锚点定位效果
  7. 98.验证二叉搜索树
  8. Android碎碎念 -- 广播LocalBroadcastManager的实现
  9. 用户事件的存储与分析
  10. 计算机毕业设计Java-超市会员积分管理系统
  11. Spring boot出现java.awt.HeadlessException【已解决】
  12. git cherry-pick 教程
  13. 华为nova6se怎么升级鸿蒙,华为EMUI11支持哪些手机
  14. 伦敦 quant_伦敦统一用户组(LUUG)见面v1.0
  15. python 写文本文件出现乱码
  16. Linux常用命令——who命令
  17. 点评Hack易支付 - 免签约支付平台 -彩虹易支付,1分钟快速接入支付功能
  18. The 2021 ICPC Asia Shanghai Regional Programming Contest 2021ICPC上海站VP
  19. 1、IOS开发--iPad之仿制QQ空间(登录界面搭建+登录逻辑实现)
  20. Golang学习篇——UTC时间互换标准时间

热门文章

  1. Android addr2line和 c++filt使用(三十六)
  2. 齐家网战略签约友邦吊顶 整合上游资源赋能装企
  3. 以前的老教程不要看了!2019年最新的WEB前端自学教程,新才是王道
  4. Redis消息队列发布微博
  5. [DataAnalysis]为什么说熵是不确定性的度量
  6. 桌面开发:Electron 代码打包 asar
  7. 运动耳机哪种型号好、口碑最好的运动蓝牙耳机排行榜
  8. 华为mate40和p40哪个好有什么区别 华为mate40和p40参数对比
  9. SQL常见的一些面试题(太有用啦)
  10. opencv python搞个写轮眼