YOLOv5、v7改进之三十一:CrissCrossAttention注意力机制
前 言:作为当前先进的深度学习目标检测算法YOLOv7,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法。此后的系列文章,将重点对YOLOv7的如何改进进行详细的介绍,目的是为了给那些搞科研的同学需要创新点或者搞工程项目的朋友需要达到更好的效果提供自己的微薄帮助和参考。由于出到YOLOv7,YOLOv5算法2020年至今已经涌现出大量改进论文,这个不论对于搞科研的同学或者已经工作的朋友来说,研究的价值和新颖度都不太够了,为与时俱进,以后改进算法以YOLOv7为基础,此前YOLOv5改进方法在YOLOv7同样适用,所以继续YOLOv5系列改进的序号。另外改进方法在YOLOv5等其他算法同样可以适用进行改进。希望能够对大家有帮助。
具体改进办法请关注后私信留言!
解决问题:之前改进增加了很多注意力机制的方法,包括比较常规的SE、CBAM等,本文加入CrissCrossAttention注意力机制,该注意力机制为应用在语义分割中的模块,用于可以让网络更加关注待检测目标,提高检测效果
基本原理:
语义分割的Criss-Cross网络(CCNet)的细节。我们首先介绍了CCNet的总体框架。然后,将介绍在水平和垂直方向捕获上下文信息的2D交叉注意力模块。为了获取密集的全局上下文信息,我们建议对交叉注意力模块采用循环操作。为了进一步改进RCCA,我们引入了判别损失函数来驱动RCCA学习类别一致性特征。最后,我们提出了同时利用时间和空间上下文信息的三维交叉注意模块。
添加方法:
第一步:确定添加的位置,作为即插即用的注意力模块,可以添加到YOLOv5网络中的任何地方。
第二步:common.py构建CoordAtt模块。部分代码如下,关注文章末尾,私信后领取。
class CrissCrossAttention(nn.Module):""" Criss-Cross Attention Module"""def __init__(self, in_dim):super(CrissCrossAttention, self).__init__()self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)self.softmax = Softmax(dim=3)self.INF = INFself.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):m_batchsize, _, height, width = x.size()proj_query = self.query_conv(x)proj_query_H = proj_query.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height).permute(0, 2,1)proj_query_W = proj_query.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width).permute(0, 2,1)proj_key = self.key_conv(x)proj_key_H = proj_key.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)proj_key_W = proj_key.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)proj_value = self.value_conv(x)proj_value_H = proj_value.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)proj_value_W = proj_value.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)energy_H = (torch.bmm(proj_query_H, proj_key_H) + self.INF(m_batchsize, height, width)).view(m_batchsize, width,height,height).permute(0,energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize, height, width, width)concate = self.softmax(torch.cat([energy_H, energy_W], 3))att_H = concate[:, :, :, 0:height].permute(0, 2, 1, 3).contiguous().view(m_batchsize * width, height, height)# print(concate)# print(att_H)att_W = concate[:, :, :, height:height + width].contiguous().view(m_batchsize * height, width, width)out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize, width, -1, height).permute(0, 2, 3, 1)out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize, height, -1, width).permute(0, 2, 1, 3)# print(out_H.size(),out_W.size())return self.gamma * (out_H + out_W) + x
第三步:yolo.py中注册 CrissCrossAttention模块
elif m is CrissCrossAttention:c1, c2 = ch[f], args[0]if c2 != no:c2 = make_divisible(c2 * gw, 8)args = [c1, *args[1:]]
第四步:修改yaml文件,本文以修改head(特征融合网络)为例,将原C3模块后加入该模块。
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2[-1, 1, Conv, [128, 3, 2]], # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]], # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]], # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]], # 9[-1, 1, CrissCrossAttention, [1024]],]
第五步:将train.py中改为本文的yaml文件即可,开始训练。
结 果:本人在遥感数据集上进行实验,有涨点效果。需要请关注留言。
预告一下:下一篇内容将继续分享深度学习算法相关改进方法。有兴趣的朋友可以关注一下我,有问题可以留言或者私聊我哦
PS:该方法不仅仅是适用改进YOLOv5,也可以改进其他的YOLO网络以及目标检测网络,比如YOLOv7、v6、v4、v3,Faster rcnn ,ssd等。
最后,希望能互粉一下,做个朋友,一起学习交流。
YOLOv5、v7改进之三十一:CrissCrossAttention注意力机制相关推荐
- 改进YOLOv5系列:13.添加CrissCrossAttention注意力机制
- 目标检测算法——YOLOv5/YOLOv7改进之结合CBAM注意力机制
深度学习Tricks,第一时间送达 论文题目:<CBAM: Convolutional Block Attention Module> 论文地址: https://arxiv.org/p ...
- YOLOv5、v7改进之三十二:引入SKAttention注意力机制
前 言:作为当前先进的深度学习目标检测算法YOLOv7,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法.此后的系列文章,将重点对YOLOv7 ...
- 改进YOLOv5系列:16.添加SKAttention注意力机制
最新创新点改进推荐 -
- 改进YOLOv5系列:21.添加CBAM注意力机制
- 《YOLOv5/v7改进实战专栏》专栏介绍 专栏目录
- 目标检测算法——YOLOv5/v7改进之结合最强视觉识别模块CotNet(Transformer)
- [YOLOv7/YOLOv5系列算法改进NO.33]引入GAMAttention注意力机制
前 言:作为当前先进的深度学习目标检测算法YOLOv7,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法.此后的系列文章,将重点对YOLOv7 ...
- 小目标检测3_注意力机制_Self-Attention
主要参考: (强推)李宏毅2021/2022春机器学习课程 P38.39 李沐老师:64 注意力机制[动手学深度学习v2] 手把手带你Yolov5 (v6.1)添加注意力机制(一)(并附上30多种顶会 ...
最新文章
- 生成allure测试报告时报错的解决方法
- 法国spin高等计算机学校,法国顶尖“大矿”,一起去矿校挖矿吧!
- 解决U盘无法拷贝大文件问题
- 关于端到端通信的讨论(P2P)
- c#如何实现excel导入到sqlserver,如何实现从sqlserver导出到excel中(详细)
- 本科生 计算机图形学试卷,湖南工程学院《计算机图形学》毕业补考试卷及答案...
- linux 变量引用 和 变量的自动类型转换 c++,c++类型转换 - memristor的个人空间 - OSCHINA - 中文开源技术交流社区...
- Android LowMemoryKiller ADJ原理
- centos 7 yum命令安装 Nginx、PHP 7、MySQL 57 、redis
- Java多线程系列--“JUC原子类”
- java 11下载_jdk11版
- Swiper去除点击选项卡时出现的蓝色边框和蓝色背景
- 将不同数据来源的ggplot图绘制到同一张图中,并添加统一的图例
- 一份还热乎的蚂蚁金服面经(已拿Offer)!附答案!!
- 我读Saliency Filters cvpr 2012
- mysql创建表并指定字符集_mysql 创建表 指定字符集
- 马云新零售再下一城:要把国外东西运进来,先将中国物流搬出去
- LaTeX中常见的宏包及其含义
- 关于sv中宏定义`define的增强使用
- java 方法实现数学黑洞
热门文章
- 女人身体8大隐私部位长得越丑健康指数越高_113
- 电视剧中的程序员,是真的敲代码吗?
- html怎么加深字体颜色,我打印网页的字的颜色非常浅,怎样才能加深? – 手机爱问...
- 【movie】整理了一些电影资料,自己留着慢慢看
- 上网行为管理设备的介绍,部署与使用
- 谷歌地图 替代_Google地图的替代品
- 什么是IP地址冲突?如何解决IP地址冲突?
- 高中计算机教师专业,高中计算机教师资格证,要计算机专业证书吗
- IC设计工程师的职业规划
- 第一次打CF的感受(附A-C题题解) -Codeforces Round #764 (Div. 3)