文章目录

  • 【图像分类】2018-CBAM ECCV
    • 1. 简介
      • 1.1 简介
    • 2. 网络
      • 2.1 通道注意力(CA)
      • 2.2 空间注意力(SA)
    • 3. 代码
      • 3.1 模块
      • 3.2 改装后的Unet

【图像分类】2018-CBAM ECCV

卷积注意力模块(CBAM)

论文题目:CBAM: Convolutional Block Attention Module

论文地址:https://arxiv.org/abs/1807.06521

代码地址: https://github.com/Jongchan/attention-module

发表时间:2018年7月

引用:Woo S, Park J, Lee J Y, et al. Cbam: Convolutional block attention module[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 3-19.

引用数:6096

1. 简介

1.1 简介

CBAM 是对标于SENet所提出的一种结合了通道注意力和空间注意力的轻量级模块,它和SENet一样,几乎可以嵌入任何CNN网络中,在带来小幅计算量和参数量的情况下,大幅提升模型性能。

SENet(Sequeeze and Excitation Net)是2017届ImageNet分类比赛的冠军网络,本质上是一个基于通道的Attention模型,它通过建模各个特征通道的重要程度,然后针对不同的任务增强或者抑制不同的通道

CBAM 是对标于SENet所提出的一种结合了通道注意力和空间注意力的轻量级模块,它和SENet一样,几乎可以嵌入任何CNN网络中,在带来小幅计算量和参数量的情况下,大幅提升模型性能。

卷积神经网络在很大程度上推动了计算机视觉任务的发展,最近的众多研究主要关注了网络的三个重要因素:深度、宽度、基数(cardinality)

深度的代表:VGG、ResNet

宽度的代表:GooLeNet

基数的代表:Xception、ResNeXt

而本文作者承接SENet的思想,从attention(注意力)这个维度出发,研究提升网络性能的方法。

2. 网络

人类视觉系统的一个重要特性是,人们不会试图同时处理看到的整个场景。取而代之的是,为了更好地捕捉视觉结构,人类利用一系列的局部瞥见,有选择性地聚集于显著部分。近年来,有人尝试将注意力机制引入到卷积神经网络中,以提高其在大规模分类任务中的性能。

本文作者为了强调空间通道这两个维度上的有意义特征,依次应用通道和空间注意力模块,来分别在通道和空间维度上学习关注什么、在那里关注。CBAM如下图1 所示。

给定一个中间特征图,我们沿着通道和空间两个维度依次推断出注意力权重,然后与原特征图相乘来对特征进行自适应调整。 由于 CBAM 是一个轻量级的通用模块,它可以无缝地集成到任何 CNN 架构中,额外开销忽略不计,并且可以与基本 CNN 一起进行端到端的训练。 在不同的分类和检测数据集上,将 CBAM 集成到不同的模型中后,模型的表现都有了一致的提升,展示了其广泛的可应用性。

输入特征依次通过通道注意力模块空间注意力模块的筛选,最后获得经过了重标定的特征,即强调重要特征,压缩不重要特征

2.1 通道注意力(CA)

Channel Attention

通道注意力有SE-Net,ECA-Net机制,可以理解为让网络在看什么。

特征的每一个通道都代表着一个专门的检测器,因此,通道注意力是关注什么样的特征是有意义的。为了汇总空间特征,作者采用了全局平均池化和最大池化两种方式来分别利用不同的信息。

简而言之:注意力机制可对特征进行校正,校正后的特征可保留有价值的特征,剔除没价值的特征。

步骤

  • 挤压(Squeeze)输入图像

    对输入特征图的空间维度进行压缩,这一步可以通过全局平均池化(GAP)和全局最大池化(GMP)(全局平均池化效果相对来说会比最大池化要好),通过这一步。 H × W × C H\times W\times C H×W×C的输入图像被压缩成为 1 × 1 × C 1\times 1\times C 1×1×C的通道描述符。下方公式输入为 S × S × B S\times S\times B S×S×B的 f e a t u r e m a p feature map featuremap:
    s b l + 1 = 1 S × S ∑ i = 1 S ∑ j = 1 S u i , j , b ( l + 1 ) s_b^{l+1}=\frac{1}{S\times S}\sum_{i=1}^S \sum_{j=1}^S u_{i,j,b}^{(l+1)} sbl+1​=S×S1​i=1∑S​j=1∑S​ui,j,b(l+1)​
    将全局空间信息压缩到通道描述符,既降低了网络参数,也能达到防止过拟合的作用。

  • excitation通道描述符

    这一步主要是将上一步得到的通道描述符送到两个全连接网络中,得到注意力权重矩阵,再与原图做乘法运算得到校准之后的注意力特征图。

class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),nn.ReLU(),nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)

2.2 空间注意力(SA)

来源于 空间域注意力(spatial transformer network, STN)

空间域注意力机制的论文:Spatial Transformer Networks,

pytorch实现:https://github.com/fxia22/stn.pytorch。

使用通道注意力的目的:找到关键信息在map上哪个位置上最多,是对通道注意力的补充,简单来说,通道注意力是为了找到哪个通道上有重要信息,而空间注意力则是在这个基础上,基于通道的方向,找到哪一块位置信息聚集的最多。

空间注意力步骤:

  • 沿着通道轴应用平均池化和最大池操作,然后将它们连接起来生成一个有效的特征描述符。

    注意:池化操作是沿着通道轴进行的,即每次池化时对比的是不同通道之间的数值,而非同一个通道不同区域的数值。

  • 将特征描述符送入一个卷积网络进行卷积,将得到的特征图通过激活函数得到最终的空间注意特征图。

M S ( F ) = σ ( f 7 × 7 ( [ A v g P o o l ( F ) ; M a x P o o l ( F ) ] ) ) = σ ( f 7 × 7 ( F a v g s ; F m a x s ) ) M_S(F)=\sigma(f^{7\times 7}([AvgPool(F);MaxPool(F)])) \\ =\sigma(f^{7\times 7}(F^s_{avg};F^s_{max})) MS​(F)=σ(f7×7([AvgPool(F);MaxPool(F)]))=σ(f7×7(Favgs​;Fmaxs​))

具体来说,使用两个pooling操作聚合成一个feature map的通道信息,生成两个2D图: Fsavg大小为 1 × H × W 1×H×W 1×H×W,Fsmax大小为 1 × H × W 1×H×W 1×H×W。 σ σ σ表示sigmoid函数, f 7 × 7 f^{7×7} f7×7表示一个滤波器大小为 7 × 7 7×7 7×7的卷积运算。

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)return self.sigmoid(x)

3. 代码

3.1 模块

import torch
import torch.nn as nnclass ChannelAttentionModule(nn.Module):def __init__(self, channel, ratio=16):super(ChannelAttentionModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.shared_MLP = nn.Sequential(nn.Conv2d(channel, channel // ratio, 1, bias=False),nn.ReLU(),nn.Conv2d(channel // ratio, channel, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avgout = self.shared_MLP(self.avg_pool(x))maxout = self.shared_MLP(self.max_pool(x))return self.sigmoid(avgout + maxout)class SpatialAttentionModule(nn.Module):def __init__(self):super(SpatialAttentionModule, self).__init__()self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.sigmoid = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avgout, maxout], dim=1)out = self.sigmoid(self.conv2d(out))return outclass CBAM(nn.Module):def __init__(self, channel):super(CBAM, self).__init__()self.channel_attention = ChannelAttentionModule(channel)self.spatial_attention = SpatialAttentionModule()def forward(self, x):out = self.channel_attention(x) * xprint(self.spatial_attention(out).shape)out = self.spatial_attention(out) * outreturn outif __name__ == '__main__':x=torch.randn(1,32,64,64)model=CBAM(32)y=model(x)print(y.shape)

3.2 改装后的Unet

class conv_block(nn.Module):def __init__(self,ch_in,ch_out):super(conv_block,self).__init__()self.conv = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True),nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self,x):x = self.conv(x)return xclass up_conv(nn.Module):def __init__(self,ch_in,ch_out):super(up_conv,self).__init__()self.up = nn.Sequential(nn.Upsample(scale_factor=2),nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self,x):x = self.up(x)return xclass U_Net_v1(nn.Module):   #添加了空间注意力和通道注意力def __init__(self,img_ch=3,output_ch=2):super(U_Net_v1,self).__init__()self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) #64self.Conv2 = conv_block(ch_in=64,ch_out=128)  #64 128self.Conv3 = conv_block(ch_in=128,ch_out=256) #128 256self.Conv4 = conv_block(ch_in=256,ch_out=512) #256 512self.Conv5 = conv_block(ch_in=512,ch_out=1024) #512 1024self.cbam1 = CBAM(channel=64)self.cbam2 = CBAM(channel=128)self.cbam3 = CBAM(channel=256)self.cbam4 = CBAM(channel=512)self.Up5 = up_conv(ch_in=1024,ch_out=512)  #1024 512self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)  self.Up4 = up_conv(ch_in=512,ch_out=256)  #512 256self.Up_conv4 = conv_block(ch_in=512, ch_out=256)  self.Up3 = up_conv(ch_in=256,ch_out=128)  #256 128self.Up_conv3 = conv_block(ch_in=256, ch_out=128) self.Up2 = up_conv(ch_in=128,ch_out=64) #128 64self.Up_conv2 = conv_block(ch_in=128, ch_out=64)  self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)  #64def forward(self,x):# encoding pathx1 = self.Conv1(x)x1 = self.cbam1(x1) + x1x2 = self.Maxpool(x1)x2 = self.Conv2(x2)x2 = self.cbam2(x2) + x2x3 = self.Maxpool(x2)x3 = self.Conv3(x3)x3 = self.cbam3(x3) + x3x4 = self.Maxpool(x3)x4 = self.Conv4(x4)x4 = self.cbam4(x4) + x4x5 = self.Maxpool(x4)x5 = self.Conv5(x5)# decoding + concat pathd5 = self.Up5(x5)d5 = torch.cat((x4,d5),dim=1)d5 = self.Up_conv5(d5)d4 = self.Up4(d5)d4 = torch.cat((x3,d4),dim=1)d4 = self.Up_conv4(d4)d3 = self.Up3(d4)d3 = torch.cat((x2,d3),dim=1)d3 = self.Up_conv3(d3)d2 = self.Up2(d3)d2 = torch.cat((x1,d2),dim=1)d2 = self.Up_conv2(d2)d1 = self.Conv_1x1(d2)return d1

参考文章

(4条消息) 通道注意力与空间注意力模块_aMythhhhh的博客-CSDN博客_通道注意力模型

(4条消息) 一张手绘图带你搞懂空间注意力、通道注意力、local注意力及生成过程(附代码注释)_Mr DaYang的博客-CSDN博客_通道注意力

https://blog.csdn.net/qq_43205656/article/details/121191937

【图像分类】2018-CBAM ECCV相关推荐

  1. 【读文献】License Plate Detection and Recognition in Unconstrained Scenarios(2018年ECCV)

    [读文献]License Plate Detection and Recognition in Unconstrained Scenarios(2018年ECCV) 参考文章链接:https://bl ...

  2. 计算机视觉注意力网络(三)——CBAM [ECCV 2018]

    CBAM: Convolutional Block Attention Module 论文地址:https://arxiv.org/abs/1807.06521 PyTorch代码:https://g ...

  3. 显著性检测2018(ECCV, CVPR)【part-1】

    1.<Salient Object Detection Driven by Fixation Prediction> ASNet 网络结构图如下: 其中,具体的模块连接实现如下图: (1) ...

  4. 显著性检测2018(ECCV, CVPR)【part-2】

    1.<Salient Objects in Clutter: Bringing Salient Object Detection to the Foreground> 提出新的显著性检测数 ...

  5. ECCV 2018|商汤37篇论文入选,为你解读精选论文(附链接+开源资源)

    整理 | Jane 出品| AI科技大本营 [导读]9 月 8 日-14 日,每两年举办一次的 2018 欧洲计算机视觉大会(ECCV 2018)在德国慕尼黑召开,本次会议总共收到了 2439 篇有效 ...

  6. 改进版ASPP(2):ASPP模块中加入CBAM(卷积注意力模块),即CBAM_ASPP

    1.ASPP模型结构 空洞空间卷积池化金字塔(atrous spatial pyramid pooling (ASPP))通过对于输入的特征以不同的采样率进行采样,即从不同尺度提取输入特征,然后将所获 ...

  7. 论文翻译:2022_PACDNN: A phase-aware composite deep neural network for speech enhancement

    论文地址:PACDNN:一种用于语音增强的相位感知复合深度神经网络 相似代码:https://github.com/phpstorm1/SE-FCN 引用格式:Hasannezhad M,Yu H,Z ...

  8. 基于深度学习的细粒度分类研究及应用

    本文主要介绍深度学习图像分类的经典网络结构及发展历程,就细粒度图像分类中的注意力机制进行了综述,最后给出了汽车之家团队参加CVPR2022细粒度分类竞赛所使用的模型及相关算法.参赛经验等,同时介绍了该 ...

  9. CV中的Attention机制总结

    CV中的Attention机制 注意力机制 CV中的注意力机制 卷积神经网络中常用的Attention 视觉注意力机制在分类网络中的应用 SE-Net(CVPR 2017) ECA-Net(CVPR ...

最新文章

  1. 使用IntelliJ IDEA 13搭建Android集成开发环境(图文教程)
  2. mysql innodb 设置详解_【mysql】mysql innodb 配置详解
  3. Java高级编程细节-动态代理-进阶高级开发必学技能
  4. 设计模式——装饰者(Decorator)模式DEMO——成绩汇报的装饰者模式实现
  5. 在java中补零的作用是什么_浅谈Java中的补零扩展和补符号位扩展
  6. hdu5692 Snacks dfs序+线段树
  7. CentOS 7主机名修改与查看命令详述
  8. Linux 性能测试工具 sysbench 的安装与简单使用
  9. 怎么样把设备管理器弄到计算机处,电脑设备管理器要连接上蓝牙的方法
  10. 【转】腾讯云PCDN:从P2P到万物互联服务框架
  11. 粗糙集(Rough Sets)
  12. OpenCV如何进行图像的平滑和锐化处理?
  13. 0xc0000005 系统应用日志_0xc0000005,小编教你怎么解决应用程序正常初始化0xc0000005失败...
  14. 桌面管理landesk太古案例
  15. 操作系统LAB1实验报告
  16. 字典树c语言,字典树的应用 单词意义查找-C语言实现
  17. Owl Carousel轮播插件介绍
  18. 获取MAC地址的四种方法(转)
  19. navicat连接远程mysql数据库
  20. 基于STM32+ESP8266的HLW8032智能电表超额报警设计

热门文章

  1. 百度地图API获取当前位置
  2. KKS编码的基本知识
  3. opengl模拟太阳效果
  4. matlab仿真瑞利分布与高斯分布
  5. MySQL——单行函数的介绍和使用
  6. cocos2d-x横版格斗游戏教程4
  7. 深度学习在推荐系统中的应用
  8. GPIO内部结构和各种模式
  9. linux修改密码策略
  10. andorid 腾讯IM即时通信集成 (一)