SE、CBAM 以及 ECA 三种注意力机制的结构实现与代码详解如下所示。

代码可参考:https://github.com/XuecWu/External-Attention-pytorch

import torch
import torch.nn as nn
import math#----------------------------#
# SE注意力机制
#----------------------------#
class se_block(nn.Module):def __init__(self, channel, ratio=16):super(se_block, self).__init__()#--------------------------------------------------## 此为自适应的二维平均全局池化操作# 通道数不会发生改变# The output is of size H x W, for any input size.# AdaptiveAvgPool2d(1) = AdaptiveAvgPool2d((1,1))#--------------------------------------------------#self.avg_pool = nn.AdaptiveAvgPool2d(1)#-------------------------------------------------------## 对于全连接层,第一个参数为输入的通道数,第二个参数为输入的通道数# 之后经过ReLU来提升模型的非线性表达能力,以及对特征信息进行编码# sigmoid还是一个激活函数,来提升模型的非线性表达能力# ratio越大其对于特征融合以及信息表达所产生的影响越大,# 压缩降维对于学习通道之间的依赖关系有着不利影响#-------------------------------------------------------#self.fc = nn.Sequential(nn.Linear(channel, channel // ratio, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // ratio, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()#-----------------------------------------------------------------## 先进行自适应二维全局平均池化,然后进行一个reshape的操作# 之后使其通过一个全连接层、一个ReLU、一个全连接层、一个Sigmoid层# 再将其reshape成之前的shape即可# 最后将注意力权重y和输入X按照通道加权相乘,调整模型对输入x不同通道的重视程度#------------------------------------------------------------=----#y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y#--------------------------------------#
# CBAM注意力机制 包含通道注意力以及空间注意力
#--------------------------------------#
class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=8):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)#-----------------------------------------## 利用1x1卷积代替全连接,以减小计算量以及模型参数# 整体前向传播过程为 卷积 ReLU 卷积 进而实现特征# 信息编码以及增强网络的非线性表达能力#-----------------------------------------#self.fc1   = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)self.relu1 = nn.ReLU()self.fc2   = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)#----------------------------------------## 定义sigmoid激活函数以增强网络的非线性表达能力#----------------------------------------#self.sigmoid = nn.Sigmoid()def forward(self, x):#-------------------------------------## 分成两部分,一部分为平均池化,一部分为最大池化# 之后将两部分的结果相加再经过sigmoid作用#-------------------------------------#avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))out     = avg_out + max_outreturn self.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding      = 3 if kernel_size == 7 else 1#-------------------------------------## 这个卷积操作为大核卷积操作,其虽然可以计算# 空间注意力但是仍无法有效建模远距离依赖关系#-------------------------------------#self.conv1   = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):#----------------------------------------------------## 整体前向传播过程即为先分别做平均池化操作,再做最大池化操作# 其中的dim:# 指定为1时,求得是列的平均值# 指定为0时,求得是行的平均值# 之后将两个输出按照列维度进行拼接,此时通道数为2# 拼接之后通过一个大核卷积将双层特征图转为单层特征图,此时通道为1# 最后通过sigmoid来增强模型的非线性表达能力#-----------------------------------------------------#avg_out    = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x          = torch.cat([avg_out, max_out], dim=1)x          = self.conv1(x)return self.sigmoid(x)class cbam_block(nn.Module):def __init__(self, channel, ratio=8, kernel_size=7):super(cbam_block, self).__init__()#----------------------------## 定义好通道注意力以及空间注意力#----------------------------#self.channelattention = ChannelAttention(channel, ratio=ratio)self.spatialattention = SpatialAttention(kernel_size=kernel_size)#----------------------------## 输入x先与通道注意力权重相乘# 之后将输出与空间注意力权重相乘#----------------------------#def forward(self, x):x = x * self.channelattention(x)x = x * self.spatialattention(x)return x#----------------------------#
# ECA注意力机制
#----------------------------#
class eca_block(nn.Module):def __init__(self, channel, b=1, gamma=2):super(eca_block, self).__init__()#----------------------------------## 根据通道数求出卷积核的大小kernel_size#----------------------------------#kernel_size = int(abs((math.log(channel, 2) + b) / gamma))kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv     = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)self.sigmoid  = nn.Sigmoid()def forward(self, x):#------------------------------------------## 显示全局平均池化,再是k*k的卷积,# 最后为Sigmoid激活函数,进而得到每个通道的权重w# 最后进行回承操作,得出最终结果#------------------------------------------#y = self.avg_pool(x)y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)y = self.sigmoid(y)return x * y.expand_as(x)

【视觉注意力机制】SE、CBAM、ECA三种可插拔注意力模块结构实现与详解相关推荐

  1. leetcode84- 柱状图中最大的矩形(三种思路:暴力,单调栈+哨兵(详解),分治)

    leetcode84- 柱状图中最大的矩形(三种思路:暴力,单调栈+哨兵(详解),分治) 介绍 题目 解题思路 解法一:暴力向两边搜索 解法二:单调栈 画图演示 宽度计算: 解法三:单调栈+哨兵 解法 ...

  2. python下载url_三种Python下载url并保存文件的代码详解

    利用程序自己编写下载文件挺有意思的. Python中最流行的方法就是通过Http利用urllib或者urllib2模块. 当然你也可以利用ftplib从ftp站点下载文件.此外Python还提供了另外 ...

  3. springboot整合elasticJob实战(纯代码开发三种任务类型用法)以及分片系统,事件追踪详解...

    一 springboot整合 介绍就不多说了,只有这个框架是当当网开源的,支持分布式调度,分布式系统中非常合适(两个服务同时跑不会重复,并且可灵活配置分开分批处理数据,贼方便)! 这里主要还是用到zo ...

  4. P2P技术详解(三):P2P技术之STUN、TURN、ICE详解

    本文是<P2P理论详解>系列文章中的第2篇,总目录如下: <P2P技术详解(一):NAT详解--详细原理.P2P简介> <P2P技术详解(二):P2P中的NAT穿越(打洞 ...

  5. Framework事件机制—Android事件处理的三种方法

    1.1.背景 Android的事件处理的三种方法: 1.基于监听的事件处理机制 setOnClickListener,setOnLongClickListener.setOnTouchListener ...

  6. Vue学习笔记(三)Vue2三种slot插槽的概念与运用 | ES6 对象的解构赋值 | 基于Vue2使用axios发送请求实现GitHub案例 | 浏览器跨域问题与解决

    文章目录 一.参考资料 二.运行环境 三.Vue2插槽 3.1 默认插槽 3.2 具名插槽 3.3 作用域插槽 ES6解构赋值概念 & 作用域插槽的解构赋值 3.4 动态插槽名 四.GitHu ...

  7. c语言三种循环结构特点,c语言循环结构(c语言循环结构特点)

    1.while循环 while语句的一般形式为:while(表达式)语句.其中表达式是循环条件,语句为循环体.while语句中的表达式一般是关系表达或逻辑表达式,只要表达式的. for语句循环1 fo ...

  8. JAVA SE知识整合(暂时完结 五万七字)后续分点详解

    目录 1.别再问为什么在类里面写个sysout语句爆红了,类里面有且只有五个成分: 2.面向对象三大特征: 封装,继承,多态 (扫盲扫盲,别这个都不知道) 3.讲一下static这个很重要的关键词 4 ...

  9. Linux内存管理(三十九):页面回收简介和 kswapd详解(1)

    源码基于:Linux5.4 0. 前言 在 LRU简介 一文和 LRU 第二次机会法 一文中,提到当内存出现紧张的时候,会将 inactive list 尾部的 page 进行换出,从而将page 释 ...

最新文章

  1. 分布式深度学习DDL解析
  2. JVM运行时栈帧结构
  3. php获取循环,PHP循环获取GET和POST值的代码
  4. python九九乘法表代码知乎_二年级上册表内乘法教学反思_二年级6的乘法口诀教学反思...
  5. jquery选择器连续选择_JQuery中的选择器
  6. PP视频如何设置关闭的时候直接退出程序
  7. 与context的关系_Android-Context
  8. 冯诺依曼计算机的组成
  9. 《objective-c基础教程》学习笔记(四)—— OC面向对象编程初探
  10. Android--android 中自定义菜单
  11. 游戏开发论坛_微信小游戏增速35% 重度游戏最高单款累计流水8亿 | 游戏茶馆
  12. Python考试题库(含答案)
  13. suitecrm 如何backup and restore ,从一个server 转移到另一个 server . 并保证customer package , customer module 不丢...
  14. 如何启用计算机的休眠,win7休眠-win7如何启用休眠,我已经google过了,没用,请大家帮忙我? 爱问知识人...
  15. 今日头条 巨量引擎 marketing api
  16. 第十一届 蓝桥杯 单片机设计与开发项目 决赛
  17. 五款最优秀的java微服务框架
  18. RFID固定资产管理降低人工成本,实现智能化的管理-新导智能
  19. 快速学习COSMIC方法之九:如何识别兴趣对象?
  20. 使用SmartUpload组件上传文件,自己踩过的坑

热门文章

  1. 基于OpenCV的视频场景切割神器
  2. 【Java】9、Java IO 流
  3. liferay6.2.2GA2中CKEditor在IE11与SAFARI中BUG解决方案
  4. 3D导航栏翻转(css)
  5. “大多数”餐馆收银系统被用于盗用信用卡信息的恶意软件感染
  6. [转]中英文停止词表(stopword)
  7. ogre 学习笔记 - Day 7
  8. mysql多条件count_Mysql中使用count加条件统计
  9. 计算机网络 (头歌平台)实验二
  10. J​a​v​a​S​c​r​i​p​t​针​对​D​o​m​相​关​的​优​化​心​得...