YOLO v5 引入解耦头部


文章目录

  • YOLO v5 引入解耦头部
  • 前言
  • 一、解耦头部示意图
  • 二、在YOLO v5 中引入解耦头部
    • 1.修改common.py文件
    • 2.修改yolo.py文件
    • 3.修改模型的yaml文件
  • 总结

前言

在 YOLO x中,使用了解耦头部的方法,从而加快网络收敛速度和提高精度,因此解耦头被广泛应用于目标检测算法任务中。因此也想在YOLO v5的检测头部引入了解耦头部的方法,从而来提高检测精度和加快网络收敛,但这里与 YOLO x 解耦头部使用的检测方法稍微不同,在YOLO v5中引入的解耦头部依旧还是基于 anchor 检测的方法。


一、解耦头部示意图

在YOLO x中,使用了解耦头部的方法,具体论文请参考:https://arxiv.org/pdf/2107.08430.pdf
于是按照论文中的介绍就可以简单的画出解耦头部,在YOLO v5中引入的解耦头部最终还是基于 anchor 检测的方法。

二、在YOLO v5 中引入解耦头部

1.修改common.py文件

在common.py文件中加入以下代码。

class DecoupledHead(nn.Module):def __init__(self, ch=256, nc=80, anchors=()):super().__init__()self.nc = nc  # number of classesself.nl = len(anchors)  # number of detection layersself.na = len(anchors[0]) // 2  # number of anchorsself.merge = Conv(ch, 256, 1, 1)self.cls_convs1 = Conv(256, 256, 3, 1, 1)self.cls_convs2 = Conv(256, 256, 3, 1, 1)self.reg_convs1 = Conv(256, 256, 3, 1, 1)self.reg_convs2 = Conv(256, 256, 3, 1, 1)self.cls_preds = nn.Conv2d(256, self.nc * self.na, 1)self.reg_preds = nn.Conv2d(256, 4 * self.na, 1)self.obj_preds = nn.Conv2d(256, 1 * self.na, 1)def forward(self, x):x = self.merge(x)x1 = self.cls_convs1(x)x1 = self.cls_convs2(x1)x1 = self.cls_preds(x1)x2 = self.reg_convs1(x)x2 = self.reg_convs2(x2)x21 = self.reg_preds(x2)x22 = self.obj_preds(x2)out = torch.cat([x21, x22, x1], 1)return out

2.修改yolo.py文件

修改后common.py文件后,需要修改yolo.py文件,主要修改两个部分:
1.在model函数,只需修改一句代码,修改后如下:

if isinstance(m, Detect) or isinstance(m, Decoupled_Detect):

2.在parse_model函数中,修改后代码如下:

3.在yolo.py增加Decoupled_Detect代码

class Decoupled_Detect(nn.Module):stride = None  # strides computed during buildonnx_dynamic = False  # ONNX export parameterexport = False  # export modedef __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layersuper().__init__()self.nc = nc  # number of classesself.no = nc + 5  # number of outputs per anchorself.nl = len(anchors)  # number of detection layersself.na = len(anchors[0]) // 2  # number of anchorsself.grid = [torch.zeros(1)] * self.nl  # init gridself.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor gridself.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)self.m = nn.ModuleList(DecoupledHead(x, nc, anchors) for x in ch)self.inplace = inplace  # use in-place ops (e.g. slice assignment)def forward(self, x):z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whelse:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)  # torch 1.8.0xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xywh = (wh * 2) ** 2 * self.anchor_grid[i]  # why = torch.cat((xy, wh, conf), 4)z.append(y.view(bs, -1, self.no))return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)def _make_grid(self, nx=20, ny=20, i=0):d = self.anchors[i].devicet = self.anchors[i].dtypeshape = 1, self.na, ny, nx, 2  # grid shapey, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)if check_version(torch.__version__, '1.10.0'):  # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibilityyv, xv = torch.meshgrid(y, x, indexing='ij')else:yv, xv = torch.meshgrid(y, x)grid = torch.stack((xv, yv), 2).expand(shape) - 0.5  # add grid offset, i.e. y = 2.0 * x - 0.5anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)return grid, anchor_grid

3.在model函数中,修改Build strides, anchors部分代码,修改后代码如下:

# Build strides, anchorsm = self.model[-1]  # Detect()if isinstance(m, Detect) or isinstance(m, Decoupled_Detect):s = 256  # 2x min stridem.inplace = self.inplacem.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forwardcheck_anchor_order(m)  # must be in pixel-space (not grid-space)m.anchors /= m.stride.view(-1, 1, 1)self.stride = m.stride# self._initialize_biases()  # only run oncetry :self._initialize_biases()  # only run onceLOGGER.info('initialize_biases done')except :LOGGER.info('decoupled no biase ')initialize_weights(self)self.info()LOGGER.info('')

3.修改模型的yaml文件

在模型的yaml文件中,修改最后一层检测的头的结构,我修改yolo v5s模型的最后一层检测结构如下:

 [[17, 20, 23], 1, Decoupled_Detect, [nc, anchors]],         # Detect(P3, P4, P5)

总结

至于单独的增加解耦头部,我还没有对自己的数据集进行单独的训练,一般都是解耦头部和其他模型结合在一起进行训练,如果后期在训练的时候map有提升的话,我在把实验结果放在上面,最近也在跑实验结果对比。

YOLO v5 引入解耦头部相关推荐

  1. 如何使用 Yolo v5 对象检测算法进行自定义对象检测

    介绍 在本文中,将向你解释如何使用 Yolo v5 算法检测和分类60+个不同类型的道路交通标志. 我们将从非常基础的开始,涵盖每个步骤,如准备数据集.训练和测试等.在本文中,我们将使用 Window ...

  2. 【深度学习】用 YOLO v5+DeepSORT,打造实时多目标跟踪模型

    内容概要:目标跟踪作为一个非常有前景的研究方向,常常因为场景复杂导致目标跟丢的情况发生.本文按照跟踪目标数量的差异,分别介绍了单目标跟踪及多目标跟踪. 关键词:目标跟踪   计算机视觉    教程 目 ...

  3. 简述yolo1-yolo3_YOLO v4或YOLO v5或PP-YOLO?

    简述yolo1-yolo3 Object detection is a computer vision task that involves predicting the presence of on ...

  4. 二十. 在ROS系统上实现基于PyTorch YOLO v5的实时物体检测

    一. 背景介绍 在我前面的博文 十八.在JetsonNano上为基于PyTorch的物体检测网络测速和选型 中,我介绍过在基于Jetson Nano硬件平台和Ubuntu 18.04 with Jet ...

  5. 用 YOLO v5+DeepSORT,打造实时多目标跟踪模型

    内容概要:目标跟踪作为一个非常有前景的研究方向,常常因为场景复杂导致目标跟丢的情况发生.本文按照跟踪目标数量的差异,分别介绍了单目标跟踪及多目标跟踪. 目标跟踪 (Object Tracking) 是 ...

  6. YOLO v5算法详解

    1.YOLO v5网络结构 2.输入端 3.Backone网络 4.Neck网络 5.Head网络 1.YOLO v5网络结构 图 1 YOLO v5网络结构图 由上图可知,YOLO v5主要由输入端 ...

  7. YOLO v5训练时报fitness错误,求解

    在跑yolo v5训练时出现报错 Traceback (most recent call last):   File "D:/3.7/yolov5-master/train.py" ...

  8. 【支线】输电杆塔识别-YOLO v5在Aidlux的部署

    目录 0.前言 1.模型训练 1.1 任务描述 1.2 输电杆塔数据集采集 1.3 输电杆塔数据集标注 1.4 数据增强 1.5 折腾 1.6 训练 1.7 测试 2.NX部署 2.1 软硬件 2.2 ...

  9. YOLO v5 实现目标检测(参考数据集自制数据集)

    YOLO v5 实现目标检测(参考数据集&自制数据集) Author: Labyrinthine Leo   Init_time: 2020.10.26 GitHub: https://git ...

最新文章

  1. 07-11 Linux命令操作
  2. matlab fftshift_MATLAB信号频谱分析FFT详解
  3. IDEA卡顿问题解决-加大内存
  4. 大数据分析目前面临哪些问题
  5. Activiti5.22:删除工作流引擎自动创建的外键约束
  6. 查看SQL执行计划的方法
  7. think-cell中类别或系列标签在多个图表中未对齐
  8. matlab曲面的最小值,MATLAB最小二乘法拟合曲面
  9. 面向对象之魔术方法_call
  10. 影响ae渲染时间的计算机配置,分享两套影视后期电脑配置2019 能流畅使用ae和pr的电脑主机推荐...
  11. python 异常学习1
  12. php逐个汉字遍历字符串
  13. shell脚本系列-grep用法总结
  14. 洛谷P1830 轰炸
  15. android打开hex文件怎么打开,hex文件怎么打开
  16. OSChina 周六乱弹 ——程序员的女朋友注意了,当你男友说:
  17. 4-Arm PEG-DSPE 四臂-聚乙二醇-磷脂 可用于修饰小分子材料
  18. Android基于NSD实现网络服务发现功能
  19. Redis安装教程(Windows版)
  20. 逐浪字库打造最全的书法字体,最全的合集(字体书法欣赏下载)

热门文章

  1. 速度收藏 | 100+大数据开源处理工具汇总
  2. zkSnarks:QAP上构造零知识证明
  3. php导出excel无边框线,phpexcel设置边框不全或者只有竖线问题解决方法
  4. 【Unity3D 灵巧小知识点】☀️ | Unity 中 怎样切换 天空盒 背景
  5. 最简洁的呼吸灯实验verilog
  6. 编译原理(一)编译程序、解释程序、程序设计语言范型
  7. python 降低图片分辨率的两种方法
  8. 手把手第一篇:写出第一行 Hello World
  9. java 导出复杂格式的 Excel 留着自己备用
  10. java 二进制最大值_java int型最大值/最小值,最大值+1,最小值-1