yolov5的head修改为decouple head

yolox的decoupled head结构


本来想将yolov5的head修改为decoupled head,与yolox的decouple head对齐,但是没注意,该成了如下结构:

感谢少年肩上杨柳依依的指出,如还有问题欢迎指出

1.修改models下的yolo.py文件中的Detect

class Detect(nn.Module):stride = None  # strides computed during buildonnx_dynamic = False  # ONNX export parameterdef __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layersuper().__init__()self.nc = nc  # number of classesself.no = nc + 5  # number of outputs per anchorself.nl = len(anchors)  # number of detection layersself.na = len(anchors[0]) // 2  # number of anchorsself.grid = [torch.zeros(1)] * self.nl  # init gridself.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor gridself.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)# self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output convself.m_box = nn.ModuleList(nn.Conv2d(256, 4 * self.na, 1) for x in ch)  # output convself.m_conf = nn.ModuleList(nn.Conv2d(256, 1 * self.na, 1) for x in ch)  # output convself.m_labels = nn.ModuleList(nn.Conv2d(256, self.nc * self.na, 1) for x in ch)  # output convself.base_conv = nn.ModuleList(BaseConv(in_channels = x, out_channels = 256, ksize = 1, stride = 1) for x in ch)self.cls_convs = nn.ModuleList(BaseConv(in_channels = 256, out_channels = 256, ksize = 3, stride = 1) for x in ch)self.reg_convs = nn.ModuleList(BaseConv(in_channels = 256, out_channels = 256, ksize = 3, stride = 1) for x in ch)# self.m = nn.ModuleList(nn.Conv2d(x, 4 * self.na, 1) for x in ch, nn.Conv2d(x, 1 * self.na, 1) for x in ch,nn.Conv2d(x, self.nc * self.na, 1) for x in ch)self.inplace = inplace  # use in-place ops (e.g. slice assignment)self.ch = chdef forward(self, x):z = []  # inference outputfor i in range(self.nl):# # x[i] = self.m[i](x[i])  # convs# print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&", i)# print(x[i].shape)# print(self.base_conv[i])# print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")x_feature = self.base_conv[i](x[i])# x_feature = x[i]cls_feature = self.cls_convs[i](x_feature)reg_feature = self.reg_convs[i](x_feature)# reg_feature = x_featurem_box = self.m_box[i](reg_feature)m_conf = self.m_conf[i](reg_feature)m_labels = self.m_labels[i](cls_feature)x[i] = torch.cat((m_box,m_conf, m_labels),1)bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whelse:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # why = torch.cat((xy, wh, y[..., 4:]), -1)z.append(y.view(bs, -1, self.no))return x if self.training else (torch.cat(z, 1), x)

2.在yolo.py中添加

def get_activation(name="silu", inplace=True):if name == "silu":module = nn.SiLU(inplace=inplace)elif name == "relu":module = nn.ReLU(inplace=inplace)elif name == "lrelu":module = nn.LeakyReLU(0.1, inplace=inplace)else:raise AttributeError("Unsupported act type: {}".format(name))return moduleclass BaseConv(nn.Module):"""A Conv2d -> Batchnorm -> silu/leaky relu block"""def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):super().__init__()# same paddingpad = (ksize - 1) // 2self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=ksize,stride=stride,padding=pad,groups=groups,bias=bias,)self.bn = nn.BatchNorm2d(out_channels)self.act = get_activation(act, inplace=True)def forward(self, x):# print(self.bn(self.conv(x)).shape)return self.act(self.bn(self.conv(x)))# return self.bn(self.conv(x))def fuseforward(self, x):return self.act(self.conv(x))

decouple head的特点:
由于训练模型时,应该是channels = 256的地方改成了channels = x(失误),所以在decoupled head的部分参数量比yolox要大一些,以下的结果是在channels= x的情况下得出
比yolov5s参数多,计算量大,在我自己的2.5万的数据量下map提升了3%多
1.模型给出的目标cls较高,需要将conf的阈值设置较大(0.5),不然准确率较低

parser.add_argument('--conf-thres', type=float, default=0.5, help='confidence threshold')

2.对于少样本的检测效果较好,召回率的提升比准确率多
3.在conf设置为0.25时,召回率比yolov5s高,但是准确率低;在conf设置为0.5时,召回率与准确率比yolov5s高
4.比yolov5s参数多,计算量大,在2.5万的数据量下map提升了3%多

对于decouple head的改进


改进:
1.将红色框中的conv去掉,缩小参数量和计算量;
2.channels =256 ,512 ,1024是考虑不增加参数,不进行featuremap的信息压缩

class Detect(nn.Module):stride = None  # strides computed during buildonnx_dynamic = False  # ONNX export parameterdef __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layersuper().__init__()self.nc = nc  # number of classesself.no = nc + 5  # number of outputs per anchorself.nl = len(anchors)  # number of detection layersself.na = len(anchors[0]) // 2  # number of anchorsself.grid = [torch.zeros(1)] * self.nl  # init gridself.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor gridself.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output convself.inplace = inplace  # use in-place ops (e.g. slice assignment)def forward(self, x):z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whelse:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # why = torch.cat((xy, wh, y[..., 4:]), -1)z.append(y.view(bs, -1, self.no))return x if self.training else (torch.cat(z, 1), x)

特点
1.模型给出的目标cls较高,需要将conf的阈值设置较大(0.4),不然准确率较低
2.对于少样本的检测效果较好,准确率的提升比召回率多
3. 准确率的提升比召回率多,
该改进不如上面的模型提升多,但是参数量小,计算量小少9Gflop,占用显存少

decoupled head指标提升的原因:由于yolov5s原本的head不能完全的提取featuremap中的信息,decoupled head能够较为充分的提取featuremap的信息;

疑问

为什么decoupled head目标的cls会比较高,没想明白
为什么去掉base_conv,召回率要比准确率提升少

yolov5的head修改为decouple head相关推荐

  1. 一点就分享系列(实践篇3-中篇)— 虽迟但到!全网首发?yolov5之“baseline修改小结“+“CV领域展开-Involutiontransformercnn”

    一点就分享系列(实践篇3-中篇)- yolov5之"修改总结以及baseline算子的分享" 说明 上篇有很多朋友照搬了我的yaml结构,这里抱歉下也有原因是我的工作偏工程,真的研 ...

  2. yolov5改进mark

    关于小目标检测优化: 小目标检测心得_jacke121的专栏-CSDN博客 目标遮挡问题改进: 目标检测遮挡问题及解决方案汇总_jacke121的专栏-CSDN博客 bifpn https://git ...

  3. yolov5 tensorrt

    这个感觉靠谱,还没试: tensorrtx/yolov5 at master · wang-xinyu/tensorrtx · GitHub 这个不错,c++调通的版本: yolov5转tensorr ...

  4. YOLOv5+TensorRT+Win11(Python版)

    快速上手YOLOv5 快速上手YOLOv5 一.YOLOv5算法 1. 算法对比 (1)传统目标检测方法 (2)基于深度学习的目标检测算法 (2-1)Two-Stage(R-CNN/Fast R-CN ...

  5. NCNN+Int8+yolov5部署和量化

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 [引言] 刚开始准备写yolov5+ncnn+int8量化的教程,却在yolov5的量化上遇到了麻烦, ...

  6. 知识蒸馏⚗️ | YOLOv5知识蒸馏实战篇

    知识蒸馏 | YOLOv5知识蒸馏实战篇 | 2/2 文章目录 知识蒸馏 | YOLOv5知识蒸馏实战篇 | 2/2 0. 环境准备 1. 训练学生网络 2. 训练教师网络 3. 知识蒸馏训练 4. ...

  7. DeepStream5.0系列之yolov5使用

    0 背景 本文旨在对 deepstream 中使用 yolov5 的方法做一介绍 测试环境: Ubuntu 18.04, CUDA 10.2, T4, jetpack4.4 软件版本:yolov5:3 ...

  8. 基于YOLOv5的中式快餐店菜品识别系统

    基于YOLOv5的中式快餐店菜品识别系统[金鹰物联智慧食堂项目] 摘要 本文基于YOLOv5v6.1提出了一套适用于中式快餐店的菜品识别自助支付系统,综述了食品识别领域的发展现状,简要介绍了YOLOv ...

  9. 华为Atlas500 yolov5模型部署全流程

    python3.7.5安装(装在usr/local 以后复制到home目录) 检查系统是否安装python依赖以及gcc等软件. 分别使用如下命令检查是否安装gcc,make以及python依赖软件等 ...

  10. Jetson Nano配置YOLOv5并实现FPS=25

    镜像下载.域名解析.时间同步请点击 阿里云开源镜像站 一.版本说明 JetPack 4.6--2021.8 yolov5-v6.0版本 使用的为yolov5的yolov5n.pt,并利用tensorr ...

最新文章

  1. 六条规则让你的ML模型部署的更快
  2. CodeForces - 609E Minimum spanning tree for each edge(最小生成树+树链剖分+线段树/树上倍增)
  3. Java ObjectOutputStream writeInt()方法及示例
  4. java网络编程TCP传输—流操作—拿到源后的写入动作
  5. 数据意识崛起,从企业应用看BI软件的未来发展 1
  6. 获取触发事件的元素的ID
  7. Windows 2016 安装单机版本Oracle ASM 的简单说明
  8. linux下使用hash_map及STL总结
  9. 黑客攻防之SQL注入原理解析入门教程
  10. 基于Matlab App Designer的语音信号分析与处理(一):语音信号的采集,时域、频域分析
  11. Android 百度地图定位显示当前位置
  12. [渝粤教育] 西南科技大学 计算机组成原理 在线考试复习资料
  13. Comparing Mongo DB and Couch DB
  14. awvs安装及问题解决
  15. 论文写作基础之文献研究法与访谈法介绍
  16. Presto Split 详解
  17. SSH远程登录RaspberryPi命令行响应缓慢问题
  18. 分享66个NET源码,总有一款适合您
  19. Java stream流式计算详解
  20. 华工计算机基础知识随堂作业答案,新版华工《计算机应用基础》随堂练习.doc...

热门文章

  1. Linux中E45:'readonly' option is set错误原因的分析
  2. 有关E45: ‘readonly‘ option is set (add!to override)错误的解决方法
  3. 每个星座的出生日期php,php根据出生日期计算 年龄/生肖/星座
  4. 位置不可用无法访问介质受写入保护 chkdsk无法修复. 不知道怎样找到呢?
  5. 修改Google浏览器默认打开是金山毒霸网址
  6. python中pos函数,python pos是什么
  7. yii2中hasMany 和 hasOne 的用法及在页面中增加字段的问题
  8. Bootstrap 对齐方式
  9. 古风排版 python
  10. 虚拟机VMware的下载、安装与卸载