文章目录

  • 1. 简介
  • 2. 运算讲解

1. 简介

SE Block并非一个完整的网络,而且一个子结构,可以嵌在其他分类或检测模型中。SE Block的核心思想是通过网络根据loss去学习特征权重,使得有效的feature map权重更大,无效或效果小的feature map权重更小的方式去训练模型已达到更好的结果。

当然,SE block嵌在原有的一些分类网络中不可避免地增加了一些参数和计算量,但是在效果面前还是可以接受的。

2. 运算讲解

SE block示意图:

Step1: 卷积操作(即图中的 F t r F_{tr} Ftr​操作)
严格来讲这一步是转换操作,并不是SE block的一部分,就是一个标准的卷积操作。输入输出定义如下:
F t r : X → U , X ∈ R W ′ ∗ H ′ ∗ C ′ , U ∈ R W ∗ H ∗ W F_{tr}: X \rightarrow U, X \in R^{W' * H' * C'}, U \in R^{W * H * W} Ftr​:X→U,X∈RW′∗H′∗C′,U∈RW∗H∗W
计算公式就是常规的卷积操作,计算公式如下:
u c = v c ∗ X = ∑ s = 1 C ′ v c s ∗ x s u_c = v_c * X = \sum ^{C'} _{s=1} v_c^{s} * x^s uc​=vc​∗X=s=1∑C′​vcs​∗xs
其中, v c v_c vc​ 表示第c个卷积核, x s x^s xs 表示当前卷积核覆盖下的第s个输入, C ′ C' C′ 表示卷积核个数。

该操作得到了上图中左起第2个矩阵,其维度 = [H ,W, C]

Step2: F s q F_{sq} Fsq​操作(即Squeenze操作)
该操作就是一个:global average pooling操作,公式如下:
z c = F s q ( u c ) = 1 W ∗ H ∑ i = 1 W ∑ j = 1 H u c ( i , j ) z_c = F_{sq}(u_c) = {1 \over W*H} \sum ^{W}_{i=1} \sum ^H_{j=1} u_c(i, j) zc​=Fsq​(uc​)=W∗H1​i=1∑W​j=1∑H​uc​(i,j)

这里使用代码进行一定的解释:
代码如下:

x = torch.ones(size=(1, 2, 2, 3))
x[0][0][0][0] = 7
print("x = ", x)avg_pool = torch.nn.AdaptiveAvgPool2d(1)    # 全局平均池化
x_pool = avg_pool(x)
print("x_pool.shape = ", x_pool.shape)
print("x_pool = ", x_pool)

输出结果:

计算的是每个通道的平均值,输出的shape=[1, 2, 1, 1]

这一步的结果相当于表明该层C个通道的数值分布情况,或者叫全局信息。

Step3: F e x F_{ex} Fex​操作(即Excitation操作)
计算公式如下:
s = s i g m o i d ( W 2 ∗ R e l u ( W 1 z ) ) s = sigmoid(W_2 * Relu(W_1 z)) s=sigmoid(W2​∗Relu(W1​z))
其中的 z z z表示上一步的 z z z, W 1 , W 2 W_1, W_2 W1​,W2​ 表示的是线性层。这里计算出来的 s s s就是该模块的核心,用来表示各个channel的权重, 而且这个权重是通过前面这些全连接层和非线性层学习得到的,因此可以end-to-end训练。这两个全连接层的作用就是融合各通道的feature map信息,因为前面的squeeze都是在某个channel的feature map里面操作。

这里结合代码容易理解(即Pytorch实现SE模块):

class SELayer_2d(nn.Module):def __init__(self, channel, reduction=16):super(SELayer_2d, self).__init__()self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)self.linear1 = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True))self.linear2 = nn.Sequential(nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, X_input):b, c, _, _ = X_input.size()   # shape = [32, 64, 2000, 80]y = self.avg_pool(X_input)        # shape = [32, 64, 1, 1]y = y.view(b, c)              # shape = [32,64]# 第1个线性层(含激活函数),即公式中的W1,其维度是[channel, channer/16], 其中16是默认的y = self.linear1(y)               # shape = [32, 64] * [64, 4] = [32, 4]# 第2个线性层(含激活函数),即公式中的W2,其维度是[channel/16, channer], 其中16是默认的y = self.linear2(y)             # shape = [32, 4] * [4, 64] = [32, 64]y = y.view(b, c, 1, 1)         # shape = [32, 64, 1, 1], 这个就表示上面公式的s, 即每个通道的权重return X_input*y.expand_as(X_input)

测试代码:

    data = torch.ones((32, 64, 2000, 80))se_2d = SELayer_2d(64)data_out = se_2d.forward(data)print("data_out = ", data_out.shape)

Step4: F s c a l e F_{scale} Fscale​操作
计算公式如下:
x ~ = F s c a l e ( u c , s c ) = s c ⋅ u c \widetilde {x} = F_{scale}(u_c, s_c) = s_c ·u_c x =Fscale​(uc​,sc​)=sc​⋅uc​

其中, u c u_c uc​表示 u u u中的一个通道, s c s_c sc​表示通道的权重。因此,相当于把每个通道的值乘以其权重。

代码即上述代码中的最后一行:

# y.expand_as(X_input)表示将y扩张到和X_input一样的维度
X_input*y.expand_as(X_input)        # 每个通道的值,乘以对应的权重

附录:
论文:Squeeze-and-Excitation Networks
论文链接:https://arxiv.org/abs/1709.01507
代码地址:https://github.com/hujie-frank/SENet
PyTorch代码地址:https://github.com/miraclewkf/SENet-PyTorch

引用:
有参考添加链接描述,在此文章的理解上,增加了一些代码注释。

SE(Squeeze and Excitation)模块的理解以及代码实现相关推荐

  1. 【新手小白向-自我感觉只有基础的高数和线代知识】-SE(Squeeze and Excitation)模块的原理理解与解释-以别人的文章为主加上自己的理解

    第1章 SE模块原理解释(照抄为加快理解) ++++通道注意力机制SE(Squeeze and Excitation)模块和动态激活函数引入骨干网络xx,增强特征提取模块对某个主要对象中关键特征的提取 ...

  2. 关于Autosar中的NM模块的理解

    关于Autosar中的NM模块的理解 本篇文章主要介绍AutoSar中关于NM模块的理解. 阅读本篇文章希望达到的目的: 1. NM(网络管理)是用来做什么的: 2. AutoSar中网络管理的原理: ...

  3. nodejs中的模块的理解

    nodejs所谓的模块就是一个文件,或者是匿名函数.(CommonJs) require exports module.exports. 为什么直接console.log(typeof require ...

  4. 关于iic协议和对AT24C02进行读写数据的理解和代码解读

    关于iic协议和对AT24C02进行读写数据的理解和代码解读 认识IIC协议 IIC协议软件模拟方法 管脚初始化 时序 AT24C02 简介 存储大小计算 工作方式 流程(代码) 认识IIC协议 本文 ...

  5. Transformer的理解与代码实现—Autoformer文献阅读

    文章目录 摘要 一. 关于Transformer的相关学习 1.1 手推transformer 1.1.1 Encoder部分 1.1.2 Decoder部分 1.2 Transformer的理解与实 ...

  6. 模块加载过程代码分析1

    一.概述 模块是作为ELF对象文件存放在文件系统中的,并通过执行insmod程序链接到内核中.对于每个模块,系统都要分配一个包含以下数据结构的内存区. 一个module对象,表示模块名的一个以null ...

  7. 深入理解C代码中的注释

    深入理解C代码中的注释 C 语言的注释可以出现在C 语言代码的任何地方?错!我们就看看下面的例子: A) int/*...*/i; B) char* s="abcdefgh //hijklm ...

  8. Spring松耦合的个人理解和代码实例

    Spring松耦合的个人理解和代码实例 理解Spring的松耦合概念,那么我们先来看看一个不使用Sring的实例代码 先看一下整个测试项目案例的结构 正常方式 创建一个接口,这个接口指定车辆的行驶速度 ...

  9. DL之DNN优化技术:神经网络算法简介之GD/SGD算法(BP的梯度下降算法)的简介、理解、代码实现、SGD缺点及改进(Momentum/NAG/Ada系列/RMSProp)之详细攻略

    DL之DNN优化技术:神经网络算法简介之GD/SGD算法(BP的梯度下降算法)的简介.理解.代码实现.SGD缺点及改进(Momentum/NAG/Ada系列/RMSProp)之详细攻略 目录 GD算法 ...

最新文章

  1. 操作系统外壳(shell)
  2. Zabbix全方位告警接入-电话/微信/短信都支持
  3. java线程创建销毁_c++多线程的创建挂起执行与销毁
  4. Docker应用基础
  5. android 5.1 内核版本号,最新的安卓5.1.1 ROOT教程(不需要刷第三方内核)
  6. c语言山东科技大学答案oj,山东科技大学oj部分题目记答案.doc
  7. java技术体系基础
  8. MNIST 数据集下载及图片可视化
  9. [IT新应用]无线投影技术
  10. vue 针试打印机实现
  11. 重难点详解-关系代数表达式
  12. 支付宝支付返回resultStatus:4000(系统繁忙,请稍后再试)
  13. 关于RedisPool配置参数
  14. matlab背景色为白色
  15. Android 11.0 PackageManagerService(一)工作原理和启动流程
  16. SE,SA和RD都代表什么
  17. VFW连接视频驱动不成功问题解决
  18. 基于spark的Scala编程—读取properties文件
  19. VxWorks操作系统shell命令与调试方法总结
  20. 在阿里(05)2022.04.19 周年啦

热门文章

  1. 能够出线的学生序号(0~9),每行一个序号。
  2. WINDDOWS 7 下Oracle11g的TNS-12535: TNS操作超时
  3. 浅浅的 C++ 11
  4. 读懂 ECMA 规格
  5. 日志采集方式 SNMP TRAP 和 Syslog 的区别
  6. 山东春考计算机专业计划,2020年山东省春季高考机电一体化专业本科招生计划!...
  7. pigx框架 源码_【Pig源码分析】谈谈Pig的数据模型
  8. 以假乱真的AI美女,有着让人羡慕的好身材
  9. 为什么别人的成长叫蓝图,你的成长始终是流浪!
  10. 网络通信(三): UDP 报头格式说明