Channel Attention网络结构、源码解读系列一

SE-Net、SK-Net与CBAM

1 SENet

原文链接:SENet原文
源码链接:SENet源码

Squeeze-and-Excitation Networks(SENet)是由自动驾驶公司Momenta在2017年公布的一种全新的图像识别结构,它通过对特征通道间的相关性进行建模,把重要的特征进行强化来提升准确率。这个结构是2017 ILSVR竞赛的冠军,作者在原文中提到,SENet将top5的错误率达到了2.251%,比2016年的第一名还要低25%,在当年也是很有成就的一件事。

1.1 Squeeze-and-Excitation Blocks


SE Block模块主要由Squeeze操作和Excitation操作组成:Squeeze操作负责将spatial维度进行全局池化(比如7 x 7 -->1 x 1);Excitation操作则学习池化后的通道依赖关系,并进行通道权重的赋权。上图网络结构其实很好地概括了SENet的主题思想,下面我将会从Squeeze和Excitation两个方面具体讲解。

1.1.1 Squeeze: Global Information Embedding

网络结构最开始的部分Ftr:X->U是以往的经典卷积结构,U之后的部分才是SENet的创新部分:使用全局平均池化在H和W两个维度对U进行Squeeze,将一个channel上整个空间特征编码为一个全局特征,得到1x1xC的中间输出。说得通俗点,这里其实就是使用一个二维的池化核对特征图进行降维,由原来的H、W、C上的3个维度降到了C这1个维度上,使得后续的通道赋权操作可行,其公式如下图所示:

1.1.2 Excitation: Adaptive Recalibration

为了更好的学习到Squeeze操作得到的特征信息,作者使用Excitation操作获取通道之间的依赖关系。为了实现这一目标,作者分析到该函数必须满足两个标准:(1)它必须是灵活的(特别是能够学习通道之间的非线性交互作用);(2)它必须能够学习一种非互斥的关系(因为我们希望确保允许强调多个通道)。所以作者使用两个全连接层FC学习通道之间的依赖关系,最后再通过sigmoid函数对权重进行归一(将各通道的权重值限制在0-1,权重和限定为1),其公式如下:

1.1.3 举个栗子:SE-ResNet Module


上图是SE-ResNet的网络结构。对于Residual阶段,SE-Block会通过一次全局池化进行降维(说降维可能不规范)得到通道C这一维度的特征,而后经过两层FC。第一层FC会继续降低C的维度,主要通过超参数r来实现(r是指压缩的比例,作者尝试了r在各种取值下的性能 ,最后得出结论r=16时整体性能和计算量最平衡);经过激活后,第二层FC则将压缩后的通道映射回原来的维度,最后利用Sigmoid函数对每个通道赋予不同的权重。
Scale代表将权重与待加权的特征相乘的操作,经过Scale操作后,channel维度上权重就完美地添加到特征中了。

1.2 代码实现

1.2.1 SE module

SE的实现如下代码所示,具体每一步我都做了详细的注释。如果前面的公式看不明白,对应这里的函数操作可能会帮助理解公式。

class SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)# Squeeze操作的定义self.fc = nn.Sequential(# Excitation操作的定义nn.Linear(channel, channel // reduction, bias=False),# 压缩nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),# 恢复nn.Sigmoid()# 定义归一化操作)def forward(self, x):b, c, _, _ = x.size()# 得到H和W的维度,在这两个维度上进行全局池化y = self.avg_pool(x).view(b, c)# Squeeze操作的实现y = self.fc(y).view(b, c, 1, 1)# Excitation操作的实现# 将y扩展到x相同大小的维度后进行赋权return x * y.expand_as(x)

1.2.2 SE-ResNet

下列代码展示了将SENet中加入到Resnet中残差链接前的操作,其实理论上来说SENet可以在浅层Block中添加(如添加在conv1前),也可以在深层中添加(bn2后),具体的添加位置要根据自身任务确定。 如果你的网络更关注浅层特征,如纹理特征,那么就可以加在浅层;相反,如果你的网络更关注深层特征,如轮廓特征、结构特征,那就应该加在深层, 具体问题具体分析。

class SEBasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None,*, reduction=16):super(SEBasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes, 1)self.bn2 = nn.BatchNorm2d(planes)self.se = SELayer(planes, reduction)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.se(out)# 加入通道注意力机制if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return out

2 SKNet

原文链接:SKNet原文
源码链接:SKNet源码

CVPR2019的文章Selective Kernel Networks,这篇文章也是致敬了SENet的思想。 SENet提出了Sequeeze and Excitation block,而SKNet提出了Selective Kernel Convolution. 二者都可以很方便的嵌入到现在的网络结构,比如ResNet、Inception、ShuffleNet,实现精度的提升。

2.1 Selective Kernel Convolution

文章关注点主要是不同大小的感受野对于不同尺度的目标有不同的效果,而我们又应该采取什么方法使得网络可以自动地利用对分类有效的感受野呢?为了解决这个问题,作者在文章中提出了一种对卷积核的动态选择机制,该机制允许每个神经元根据输入信息的多尺度自适应地调整其感受野(卷积核)的大小。

上图就是select kernel convolution模块,网络中主要包括Split、Fuse、Select三个操作。Split通过多条不同大小的kernel产生不同特征图,上图中的模型只设计了两个不同大小的卷积核,实际上可以设计多个分支的多个卷积核;Fuse运算结合并聚合来自多个路径的信息,以获得用于选择权重的全局和综合表示;select操作根据选择权重聚合不同大小内核的特征图。

2.1.1 Split

对输入X使用不同的卷积核生成不同的特征输出,上图所示的是使用3x3和5x5的卷积核进行的卷积操作,为了提高运算效率,5x5的卷积操作是用空洞率为2、卷积核为3x3的空洞卷积实现的,并且使用了分组卷积、深度可分离卷积、BatchNorm和ReLU。

2.1.2 Fuze

将得到的多个特征输出进行信息融合,即pytorch中的sum操作,得到新的特征图U,即下图中的公式(1);然后利用Squeeze相同的操作生成通道这个维度的信息,即下图中的公式(2);最后利用1层全连接层FC学习通道之间的依赖关系,最后使用ReLU和BatchNorm进行归一,即下图中的公式(3)。相关公式如下:

2.1.3 Select

在通道这个维度对多个分支得到的最终特征图进行赋权,使用sigmoid函数。最后将所有分支加权后的特征图相加,得到最终的输出。

2.2 代码实现

结合上述的讲解,代码其实很明了,具体的定义及操作我都作了注释,可以参看注释进行理解。

class SKConv(nn.Module):def __init__(self, features, WH, M, G, r, stride=1 ,L=32):super(SKConv, self).__init__()d = max(int(features/r), L)self.M = Mself.features = featuresself.convs = nn.ModuleList([])# 生成M个分支,将其添加到convs中,每个分支采用不同的卷积核和不同规模的padding,保证最终得到的特征图大小一致for i in range(M):self.convs.append(nn.Sequential(nn.Conv2d(features, features, kernel_size=3+i*2, stride=stride, padding=1+i, groups=G),nn.BatchNorm2d(features),nn.ReLU(inplace=False)))# 学习通道间依赖的全连接层self.fc = nn.Linear(features, d)self.fcs = nn.ModuleList([])for i in range(M):self.fcs.append(nn.Linear(d, features))self.softmax = nn.Softmax(dim=1)def forward(self, x):for i, conv in enumerate(self.convs):fea = conv(x).unsqueeze_(dim=1)if i == 0:feas = feaelse:feas = torch.cat([feas, fea], dim=1)fea_U = torch.sum(feas, dim=1)# 将多个分支得到的特征图进行融合fea_s = fea_U.mean(-1).mean(-1)# 在channel这个维度进行特征抽取fea_z = self.fc(fea_s)# 学习通道间的依赖关系# 赋权操作,由于是对多维数组赋权,所以看起来比SENet麻烦一些for i, fc in enumerate(self.fcs):vector = fc(fea_z).unsqueeze_(dim=1)if i == 0:attention_vectors = vectorelse:attention_vectors = torch.cat([attention_vectors, vector], dim=1)attention_vectors = self.softmax(attention_vectors)attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)fea_v = (feas * attention_vectors).sum(dim=1)return fea_v

3 CBAM

原文链接:CBAM原文
源码链接:CBAM源码

CBAM( Convolutional Block Attention Module )是一种轻量化的通道注意力机制,也是目前应用比较广泛的一种视觉注意力机制,在2018年的ECCV中提出。文章同时使用了Channel Attention和Spatial Attention,发现将两种attention串联在一起效果较好。

3.1 Convolutional Block Attention Module

下图是CBAM的网络结构图。

可以看到CBAM包含2个独立的子模块, 通道注意力模块(Channel Attention Module,CAM) 和空间注意力模块(Spartial Attention Module,SAM) ,分别进行通道与空间上的赋权。这样不只能够节约参数和计算力,并且保证了其能够做为即插即用的模块集成到现有的网络架构中去。

3.1.1 Channel attention module


通道注意力机制的基本思想与SENet相同,但是具体操作与SENet略有不同,不同部分我用红色进行了标记。
首先,将输入的特征图F(H×W×C)分别经过基于H和W两个维度的 全局最大池化(MaxPool)和全局平均池化(AvgPool),得到两个1×1×C的特征图; 然后,将两个特征图送入一个 共享权值的双层神经网络(MLP)进行通道间依赖关系的学习,两层神经层之间通过压缩比r实现降维。 最后,将MLP输出的特征进行基于element-wise的加和操作,再经过sigmoid激活操作,生成最终的通道加权,即M_c。其公式如下图:

3.1.2 Spatial attention module

本来本篇专题主要对通道注意力机制进行讨论,想着下一篇空间注意力机制的时候再说CBAM的后续,但按照我懿姐的说法就是,算法都送到嘴边了,那我干脆一块解决了。


空间注意力机制将通道注意力模块输出的特征图F‘作为本模块的输入特征图。
首先,基于channel这个维度进行最大池化(MaxPool)和平均池化(AvgPool)操作,得到两个H×W×1 的特征图; 然后,将两个特征图基于通道维度进行拼接,即concat操作; 再然后,使用7×7卷积核(作者通过实验验证了7x7效果好于其他维度卷积核)进行通道降维,降维为单通道的特征图,即H×W×1; 最后,经过sigmoid学习空间元素之间的依赖关系,生成空间维度的权重,即M_s。其公式如下:

3.2 代码实现

3.2.1 CA&SA

具体的网络定义及操作实现参见我的代码注释。

class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)# 定义全局平均池化self.max_pool = nn.AdaptiveMaxPool2d(1)# 定义全局最大池化# 定义CBAM中的通道依赖关系学习层,注意这里是使用1x1的卷积实现的,而不是全连接层self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),nn.ReLU(),nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))# 实现全局平均池化max_out = self.fc(self.max_pool(x))# 实现全局最大池化out = avg_out + max_out# 两种信息融合# 最后利用sigmoid进行赋权return self.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()# 定义7*7的空间依赖关系学习层self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)# 实现channel维度的平均池化max_out, _ = torch.max(x, dim=1, keepdim=True)# 实现channel维度的最大池化x = torch.cat([avg_out, max_out], dim=1)# 拼接上述两种操作的到的两个特征图x = self.conv1(x)# 学习空间上的依赖关系# 对空间元素进行赋权return self.sigmoid(x)

3.2.2 CBAM_ResNet

篇幅限制,这里仅展示BasicBlock的CA&SA的添加

class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = nn.BatchNorm2d(planes)# 定义ca和sa,注意CA与channel num有关,需要指定这个超参!!!self.ca = ChannelAttention(planes)self.sa = SpatialAttention()self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.ca(out) * out# 对channel赋权out = self.sa(out) * out# 对spatial赋权if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return out
系列一到这里就结束了,点赞越多更新越快哦!

【注意力机制集锦】Channel Attention通道注意力网络结构、源码解读系列一相关推荐

  1. Android事件分发机制完全解析,带你从源码的角度彻底理解(上)

    <div id="container">         <div id="header">     <div class=&qu ...

  2. 计算机视觉中的注意力机制(Visual Attention)

    ,欢迎关注公众号:论文收割机(paper_reader) 原文链接:计算机视觉中的注意力机制(Visual Attention) 本文将会介绍计算机视觉中的注意力(visual attention)机 ...

  3. 一文读懂——全局注意力机制(global attention)详解与代码实现

    废话不多说,直接先上全局注意力机制的模型结构图. 如何通过Global Attention获得每个单词的上下文向量,从而获得子句向量呢?如下几步: 代码如下所示: x = Embedding(inpu ...

  4. 注意力机制(一):注意力提示、注意力汇聚、Nadaraya-Watson 核回归

    专栏:神经网络复现目录 注意力机制 注意力机制(Attention Mechanism)是一种人工智能技术,它可以让神经网络在处理序列数据时,专注于关键信息的部分,同时忽略不重要的部分.在自然语言处理 ...

  5. 【动手深度学习-笔记】注意力机制(四)自注意力、交叉注意力和位置编码

    文章目录 自注意力(Self-Attention) 例子 Self-Attention vs Convolution Self-Attention vs RNN 交叉注意力(Cross Attenti ...

  6. netty源码分析系列——Channel

    2019独角兽企业重金招聘Python工程师标准>>> 前言 Channel是netty中作为核心的一个概念,我们从启动器(Bootstrap)中了解到最终启动器的两个关键操作con ...

  7. 【转】Android事件分发机制完全解析,带你从源码的角度彻底理解(下)

    转载请注明出处:http://blog.csdn.net/guolin_blog/article/details/9153761 记得在前面的文章中,我带大家一起从源码的角度分析了Android中Vi ...

  8. dubbo源码分析系列(1)扩展机制的实现

    1 系列目录 dubbo源码分析系列(1)扩展机制的实现 dubbo源码分析系列(2)服务的发布 dubbo源码分析系列(3)服务的引用 dubbo源码分析系列(4)dubbo通信设计 2 SPI扩展 ...

  9. Android6.0源码解读之ViewGroup点击事件分发机制

    本篇博文是Android点击事件分发机制系列博文的第三篇,主要是从解读ViewGroup类的源码入手,根据源码理清ViewGroup点击事件分发原理,明白ViewGroup和View点击事件分发的关系 ...

最新文章

  1. 动手自己写一个 xcode 插件(Xcode Source Editor Extensions)附源码
  2. Ubuntu caffe 测试matlab接口
  3. 需求获取的三阶段:需求背景、需求调研、需求分析 (3)
  4. 如何使用Alert 组件
  5. java导出服务器已经配置好的excel模板
  6. 使用源码安装 PostgreSQL 12.5 主从集群
  7. 6.6 AdaBoost实战
  8. armv6、armv7、armv7s及arm64
  9. Python网络爬虫经典书籍推荐
  10. 路由器装mentohust插件破解锐捷认证(Pandorabox固件)
  11. 2020 GKCTF
  12. 如何在香港主机上尽可能多的建站
  13. 基于 Windows系统的 KingbaseES 数据库软件安装指南(3. 安装前准备工作)
  14. S3C2440之IIC
  15. java s3_Amazon S3 功能介绍
  16. 网络水军第一课:手写自动弹幕
  17. 【软件测试】思维开拓—用软件测试的思维测试QQ好友是在线或者离线
  18. STC89C52 小车-舵机转向/蓝牙控制/寻迹,有PCB有讲解,更新
  19. iOS播放器、Flutter高仿书旗小说、卡片动画、二维码扫码、菜单弹窗效果等源码...
  20. 三星android8 日期,三星披露升级Android 8.0时间 明年年初

热门文章

  1. K_A12_002 基于STM32等单片机采集光敏电阻传感器参数串口与OLED0.96双显示
  2. DO447Ansible Tower的维护和常规管理--备份和修复
  3. JDBC简介(Statement接口)
  4. 算法设计与分析——活动安排问题(Java)
  5. Java单个文件下载
  6. DevOps实战:版本管理实践指南
  7. 高等数学(第七版)同济大学 习题9-10 个人解答
  8. 【LeetCode33】搜索旋转排序数组
  9. NetTool v2.0 IP配置工具
  10. NC UAP STUDIO授权