自从yolov5-5.0加入se、cbam、eca、ca发布后,反响不错,也经常会有同学跑过来私信我能不能出一期6.0版本加入注意力的博客。个人认为是没有必要专门写一篇来讲,因为步骤几乎一样,但是问的人也慢慢多了,正好上一篇加入注意力的文章写的略有瑕疵,那就再重新写一篇。

yolo加入注意力三部曲

1.common.py中加入注意力模块

2.yolo.py中增加判断条件

3.yaml文件中添加相应模块

所有版本都是一致的,加入注意力机制能否使模型有效的关键在于添加的位置,这一步需要视数据集中目标大小的数量决定。

第一部曲:common.py加入注意力模块

class h_sigmoid(nn.Module):def __init__(self, inplace=True):super(h_sigmoid, self).__init__()self.relu = nn.ReLU6(inplace=inplace)def forward(self, x):return self.relu(x + 3) / 6class h_swish(nn.Module):def __init__(self, inplace=True):super(h_swish, self).__init__()self.sigmoid = h_sigmoid(inplace=inplace)def forward(self, x):return x * self.sigmoid(x)class CoordAtt(nn.Module):def __init__(self, inp, oup, reduction=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = h_swish()self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn, c, h, w = x.size()x_h = self.pool_h(x)x_w = self.pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)y = self.conv1(y)y = self.bn1(y)y = self.act(y)x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_w * a_hreturn outclass ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)self.relu = nn.ReLU()self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))max_out = self.f2(self.relu(self.f1(self.max_pool(x))))out = self.sigmoid(avg_out + max_out)return outclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv(x)return self.sigmoid(x)class CBAM(nn.Module):# CSP Bottleneck with 3 convolutionsdef __init__(self, c1, c2, ratio=16, kernel_size=7):  # ch_in, ch_out, number, shortcut, groups, expansionsuper(CBAM, self).__init__()self.channel_attention = ChannelAttention(c1, ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):out = self.channel_attention(x) * xout = self.spatial_attention(out) * outreturn outclass SE(nn.Module):def __init__(self, c1, c2, r=16):super(SE, self).__init__()self.avgpool = nn.AdaptiveAvgPool2d(1)self.l1 = nn.Linear(c1, c1 // r, bias=False)self.relu = nn.ReLU(inplace=True)self.l2 = nn.Linear(c1 // r, c1, bias=False)self.sig = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.size()y = self.avgpool(x).view(b, c)y = self.l1(y)y = self.relu(y)y = self.l2(y)y = self.sig(y)y = y.view(b, c, 1, 1)return x * y.expand_as(x)

第二部曲:yolo.py中增加判断条件

      if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost,CBAM,CoordAtt,SE]:c1, c2 = ch[f], args[0]if c2 != no:  # if not outputc2 = make_divisible(c2 * gw, 8)

第三部曲:yaml文件中添加注意力

  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, CoordAtt,[1024]],[-1, 1, SPPF, [1024, 5]],  # 9

这是6.0版本的yolov5的骨干层,CoordAtt的位置可以换成以上任意一个注意力,其他参数不需要调整,傻瓜式复制粘贴,即可跑通。

以上就是具体将注意力添加至yolov5-6.0版本中的步骤。注意力模块并没有刻板规定一定要加在什么地方,使用者可随意调整。

接下来是理论部分。关于注意力的理论部分,各位博主大佬已经讲的非常细致了(这个地方想了半天成语想不出来),但是为了证明本人出色的复制粘贴能力,决定再写一下。

SE(挤压-激励注意力)

se注意力增强模型关注对象的方法分为两步:

挤压

对输入的特征图进行通道信息的全局平均池化,即挤压

激励

将挤压之后的信息通过两个全连接层、激活函数再归一化后乘到输入特征图上

CA(位置注意力)

位置注意力将通道注意力分解为两个1维特征编码过程,分别沿2个空间方向聚合特征。这样,可以沿一个空间方向捕获远程依赖关系,同时可以沿另一空间方向保留精确的位置信息。然后将生成的特征图分别编码为一对方向感知和位置敏感的attention map,可以将其互补地应用于输入特征图,以增强关注对象的表示。

coordinate信息嵌入

coordinate attention生成

(注:理论部分非本人原创,如有侵权,请联系我删除。)

总结

各个注意力机制使用后的感受,yolov5中个人感觉CA效果最好。

yolov5-6.0/6.1加入SE、CBAM、CA注意力机制(理论及代码)相关推荐

  1. YOLOV5 6.0加入CA注意力机制(看了包会)

    YOLOV5 6.0手把手教你加入CA注意力机制 文章目录 YOLOV5 6.0手把手教你加入CA注意力机制 yolov5加入注意力机制步骤 一.common.py 二.yolo.py 三.创建自定义 ...

  2. Yolov5添加注意力机制

    一.在backbone后面引入注意力机制 1.先把注意力结构代码放到common.py文件中,以SE举例,将这段代码粘贴到common.py文件中 2.找到yolo.py文件里的parse_model ...

  3. YOLOv5添加注意力机制的具体步骤

    本文以CBAM和SE注意力机制的添加过程为例,主要介绍了向YOLOv5中添加注意力机制的具体步骤 本文在此篇博客的基础上向YOLOv5-5.0版本代码中添加注意力机制 yolov5模型训练---使用y ...

  4. 在yolov5的网络结构中添加注意力机制模块

    知足知不足,有为有不为 目录 前言 一.模块添加步骤 二.相应注意力机制介绍及其代码 1.SE注意力 2.CBAM注意力

  5. Yolov5 网络改进之增加SE、CBAM、CA、ECA等注意力机制

    本文以Yolov5 6.0版本为例,讲解如何添加SE.CA.ECA.CBAM等即插即用的小模块,可同时适配其他网络结构.在这之前需要明白yolov5文件夹的三个小点: models\common.py ...

  6. yolov5-5.0加入CBAM,SE,CA,ECA注意力机制

    CBAM注意力 yolo.py和yaml文件中相应的CBAMC3也要换成CBAM,下面的SE同理 class ChannelAttention(nn.Module):def __init__(self ...

  7. 注意力机制(SE、Coordinate Attention、CBAM、ECA)、即插即用的模块整理

    总结曾经使用过的一些即插即用的模块以及一些注意力机制 ** 注意力模块:SE ** 代码源自这位大佬的仓库:https://github.com/moskomule/senet.pytorch cla ...

  8. YOLOv5改进--添加CBAM注意力机制

    注意力机制包括CBAM.CA.ECA.SE.S2A.SimAM等,接下来介绍具体添加方式. CBAM代码,在common文件中添加以下模块: class CBAMC3(nn.Module):# CSP ...

  9. 目标检测算法——YOLOv5/YOLOv7改进之结合CBAM注意力机制

    深度学习Tricks,第一时间送达 论文题目:<CBAM: Convolutional Block Attention Module> 论文地址:  https://arxiv.org/p ...

最新文章

  1. windows 下使用 Filezilla server 搭建 ftp 服务器
  2. 蓝牙mesh网络基础
  3. 工作51:后端vue学习地址
  4. 微信公众平台中的openid是什么
  5. leetcode题解62-不同路径
  6. 浏览器登录_谷歌浏览器在Android 7.0及以上版本支持使用指纹进行无密码登录
  7. 字符串转数字函数 atol、atoll和strtol、strtoll、strtoul、strtoull 分析
  8. POI读取word文档后插入内容以及设置标题样式
  9. 【HTML 教程系列第 9 篇】什么是 HTML 中的换行标签 br
  10. 几款软件界面模型设计工具
  11. 白嫖阿里-----搭建个人服务
  12. 【vue系列-03】vue的计算属性,列表,监视属性及原理
  13. 中风(脑卒中)研究意义和背景
  14. React全家桶(技术栈) redux 代码
  15. Apollo Planning决策规划算法代码详细解析 (1):Scenario选择
  16. SkiaSharp 之 WPF 自绘 拖曳小球(案例版)
  17. oracle系统中poord是什么,______A.tiredB.weakC.poorD.slow
  18. 网络经济与企业管理【十一】之企业文化管理
  19. Image-Level 弱监督图像语义分割汇总简析
  20. 如何面试 iOS 工程师

热门文章

  1. UART串行通信模式
  2. nacos注册服务的时候报错server is DOWN now, please try again later!
  3. iojs 版本管理ivm
  4. 绝对不变性原理、内模原理
  5. 数据结构学习笔记(7.查找 8.排序)
  6. CLOCs:一种相机-激光雷达3D目标检测后融合方法
  7. c226打印机驱动安装_打印机驱动怎么装?网络打印机驱动的安装方法
  8. 用js给自己照相并修图
  9. python opencv 通过hsv阈值法扣取药盒 并矫正
  10. 高中教师计算机面试什么时候,高中信息技术教师资格证备考经验分享(面试篇)...