这里我们实际推演一下yolov5训练过程中的anchor匹配策略,为了简化数据和便于理解,设定以下训练参数。

  • 输入分辨率(img-size):608x608
  • 分类数(num_classes):2
  • batchsize:1
  • 网络结构如下图所示:

def build_targets(pred, targets, model):"""pred:type(pred) : <class 'list'>"""#p:predict,targets:gt# Build targets for compute_loss(), input targets(image,class,x,y,w,h)det = model.module.model[-1] if is_parallel(model) else model.model[-1]  # Detect() module

输入参数pred为网络的预测输出,它是一个list包含三个检测头的输出tensor。

(Pdb) print(type(pred))
<class 'list'>
(Pdb) print(len(pred))
3
(Pdb) print(pred[0].shape)
torch.Size([1, 3, 76, 76, 7])  #1:batch-size,3:该层anchor的数量,7:位置(4),obj(1),分类(2)
(Pdb) print(pred[1].shape)
torch.Size([1, 3, 38, 38, 7])
(Pdb) print(pred[2].shape)
torch.Size([1, 3, 19, 19, 7])

targets为标签信息(gt),我这里只有一张图片,包含14个gt框,且类别id为0,在我自己的训练集里面类别0表示行人。其中第1列为图片在当前batch的id号,第2列为类别id,后面依次是归一化了的gt框的x,y,w,h坐标。

(Pdb) print(targets.shape)
torch.Size([14, 6])
(Pdb) print(targets)
tensor([[0.00000, 0.00000, 0.56899, 0.42326, 0.46638, 0.60944],
        [0.00000, 0.00000, 0.27361, 0.59615, 0.02720, 0.02479],
        [0.00000, 0.00000, 0.10139, 0.59295, 0.04401, 0.03425],
        [0.00000, 0.00000, 0.03831, 0.59863, 0.06223, 0.02805],
        [0.00000, 0.00000, 0.04395, 0.57031, 0.02176, 0.06153],
        [0.00000, 0.00000, 0.13498, 0.57074, 0.01102, 0.03152],
        [0.00000, 0.00000, 0.25948, 0.59213, 0.01772, 0.03131],
        [0.00000, 0.00000, 0.29733, 0.63080, 0.07516, 0.02536],
        [0.00000, 0.00000, 0.16594, 0.57749, 0.33188, 0.13282],
        [0.00000, 0.00000, 0.79662, 0.89971, 0.40677, 0.20058],
        [0.00000, 0.00000, 0.14473, 0.96773, 0.01969, 0.03341],
        [0.00000, 0.00000, 0.10170, 0.96792, 0.01562, 0.03481],
        [0.00000, 0.00000, 0.27727, 0.95932, 0.03071, 0.07851],
        [0.00000, 0.00000, 0.18102, 0.98325, 0.00749, 0.01072]])

model自然就是表示的模型,det是模型的检测头,从该对象中可以拿到anchor数量(na)以及尺寸,检测头数量(nl)等信息。

    na, nt = det.na, targets.shape[0]  # number of anchors, targetstcls, tbox, indices, anch = [], [], [], []gain = torch.ones(7, device=targets.device)  # normalized to gridspace gainai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) 

这里的骚操作还挺多,pytorch不熟练的话only look once还真看不明白,我稍微拆解一下。

(Pdb) na,nt,gain
(3, 14, tensor([1., 1., 1., 1., 1., 1., 1.]))
(Pdb) torch.arange(na).float().view(na,1)
tensor([[0.],
        [1.],
        [2.]])
(Pdb) torch.arange(na).float().view(na,1).repeat(1,nt) #第二个维度复制nt遍
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]])

(Pdb) targets.shape
torch.Size([14, 6])
(Pdb) targets.repeat(na,1,1).shape #targets原本只有两维,该repeat操作过后会增加一维。
torch.Size([3, 14, 6])

(Pdb) ai[:,:,None].shape #原本两维的ai也会增加一维
torch.Size([3, 14, 1])

(Pdb)  torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2).shape #两个3维的tensort在第2维上concat
torch.Size([3, 14, 7])

(Pdb)  torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)
tensor([[[0.00000, 0.00000, 0.56899, 0.42326, 0.46638, 0.60944, 0.00000],
         [0.00000, 0.00000, 0.27361, 0.59615, 0.02720, 0.02479, 0.00000],
         [0.00000, 0.00000, 0.10139, 0.59295, 0.04401, 0.03425, 0.00000],
         ......],

[[0.00000, 0.00000, 0.56899, 0.42326, 0.46638, 0.60944, 1.00000],
         [0.00000, 0.00000, 0.27361, 0.59615, 0.02720, 0.02479, 1.00000],
        ......],

[[0.00000, 0.00000, 0.56899, 0.42326, 0.46638, 0.60944, 2.00000],
         [0.00000, 0.00000, 0.27361, 0.59615, 0.02720, 0.02479, 2.00000],
        ......]])

    g = 0.5  # biasoff = torch.tensor([[0, 0],[1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m# [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm], device=targets.device).float() * g  # offsets

off是偏置矩阵。

(Pdb) print(off)
tensor([[ 0.00000,  0.00000],
        [ 0.50000,  0.00000],
        [ 0.00000,  0.50000],
        [-0.50000,  0.00000],
        [ 0.00000, -0.50000]])

 for i in range(det.nl): #nl=>3anchors = det.anchors[i] #shape=>[3,3,2]gain[2:6] = torch.tensor(pred[i].shape)[[3, 2, 3, 2]]  # Match targets to anchorst = targets * gain

det.nl为预测层也就是检测头的数量,anchor匹配需要逐层进行。不同的预测层其特征图的尺寸不一样,而targets是相对于输入分辨率的宽和高作了归一化,targets*gain通过将归一化的box乘以特征图尺度从而将box坐标投影到特征图上。

(Pdb) pred[0].shape
torch.Size([1, 3, 76, 76, 7])  #1,3,h,w,7
(Pdb) torch.tensor(pred[0].shape)[[3,2,3,2]]
tensor([76, 76, 76, 76])

          if nt:# Matchesr = t[:, :, 4:6] / anchors[:, None]  # wh ratioj = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t']  # compare# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))t = t[j]  # filter

yolov5抛弃了MaxIOU匹配规则而采用shape匹配规则,计算标签box和当前层的anchors的宽高比,即:wb/wa,hb/ha。如果宽高比大于设定的阈值说明该box没有合适的anchor,在该预测层之间将这些box当背景过滤掉(是个狠人!)。

(Pdb) torch.max(r,1./r).shape
torch.Size([3, 14, 2])
(Pdb) torch.max(r,1./r).max(2) #返回两组值,values和indices
torch.return_types.max(
values=tensor([[28.50301,  1.65375,  2.67556,  3.78370,  2.87777,  1.49309,  1.46451,  4.56943, 20.17829, 24.73137,  1.56263,  1.62791,  3.67186,  2.19651],
        [17.72234,  1.99010,  1.67222,  2.36481,  1.24703,  2.38895,  1.57575,  2.85589, 12.61143, 15.45711,  1.47680,  1.68486,  1.59114,  4.60130],
        [16.11040,  1.99547,  1.23339,  1.34871,  2.49381,  4.92720,  3.06377,  1.49178,  6.11463,  7.49436,  2.75656,  3.47502,  2.07540,  7.24849]]),
indices=tensor([[1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0],
        [0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1],
        [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0]]))
(Pdb) torch.max(r,1./r).max(2)[0] < model.hyp['anchor_t']
tensor([[False,  True,  True,  True,  True,  True,  True, False, False, False,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True, False, False,  True,  True,  True, False],
        [False,  True,  True,  True,  True, False,  True,  True, False, False,  True,  True,  True, False]])
(Pdb) print(j.shape)
torch.Size([3, 14])
(Pdb) print(t.shape)
torch.Size([3, 14, 7])
(Pdb) t[j].shape
torch.Size([29, 7])
(Pdb) t[j]
tensor([[ 0.00000,  0.00000, 20.79421, 45.30740,  2.06718,  1.88433,  0.00000],
        [ 0.00000,  0.00000,  7.70598, 45.06429,  3.34444,  2.60274,  0.00000],
        [ 0.00000,  0.00000,  2.91188, 45.49583,  4.72962,  2.13167,  0.00000],
        [ 0.00000,  0.00000,  3.34012, 43.34355,  1.65410,  4.67637,  0.00000],
        [ 0.00000,  0.00000, 10.25882, 43.37595,  0.83719,  2.39581,  0.00000],
        [ 0.00000,  0.00000, 19.72059, 45.00159,  1.34638,  2.37982,  0.00000],
        [ 0.00000,  0.00000, 10.99985, 73.54744,  1.49643,  2.53927,  0.00000],
        [ 0.00000,  0.00000,  7.72917, 73.56174,  1.18704,  2.64536,  0.00000],
        [ 0.00000,  0.00000, 21.07247, 72.90799,  2.33363,  5.96677,  0.00000],
        [ 0.00000,  0.00000, 13.75753, 74.72697,  0.56908,  0.81499,  0.00000],
        [ 0.00000,  0.00000, 20.79421, 45.30740,  2.06718,  1.88433,  1.00000],
        [ 0.00000,  0.00000,  7.70598, 45.06429,  3.34444,  2.60274,  1.00000],
        [ 0.00000,  0.00000,  2.91188, 45.49583,  4.72962,  2.13167,  1.00000],
        [ 0.00000,  0.00000,  3.34012, 43.34355,  1.65410,  4.67637,  1.00000],
        [ 0.00000,  0.00000, 10.25882, 43.37595,  0.83719,  2.39581,  1.00000],
        [ 0.00000,  0.00000, 19.72059, 45.00159,  1.34638,  2.37982,  1.00000],
        [ 0.00000,  0.00000, 22.59712, 47.94083,  5.71178,  1.92723,  1.00000],
        [ 0.00000,  0.00000, 10.99985, 73.54744,  1.49643,  2.53927,  1.00000],
        [ 0.00000,  0.00000,  7.72917, 73.56174,  1.18704,  2.64536,  1.00000],
        [ 0.00000,  0.00000, 21.07247, 72.90799,  2.33363,  5.96677,  1.00000],
        [ 0.00000,  0.00000, 20.79421, 45.30740,  2.06718,  1.88433,  2.00000],
        [ 0.00000,  0.00000,  7.70598, 45.06429,  3.34444,  2.60274,  2.00000],
        [ 0.00000,  0.00000,  2.91188, 45.49583,  4.72962,  2.13167,  2.00000],
        [ 0.00000,  0.00000,  3.34012, 43.34355,  1.65410,  4.67637,  2.00000],
        [ 0.00000,  0.00000, 19.72059, 45.00159,  1.34638,  2.37982,  2.00000],
        [ 0.00000,  0.00000, 22.59712, 47.94083,  5.71178,  1.92723,  2.00000],
        [ 0.00000,  0.00000, 10.99985, 73.54744,  1.49643,  2.53927,  2.00000],
        [ 0.00000,  0.00000,  7.72917, 73.56174,  1.18704,  2.64536,  2.00000],
        [ 0.00000,  0.00000, 21.07247, 72.90799,  2.33363,  5.96677,  2.00000]])
按照该匹配策略,一个gt box可能同时匹配上多个anchor。

【玩转yolov5】之anchor匹配策略(build_targets)分析(1)相关推荐

  1. yolov5核心代码: anchor匹配策略,compute_loss和build_targets理解

    yolov5核心代码理解: anchor匹配策略-跨网格预测,compute_loss(p, targets, model)和build_targets(p, targets, model)理解 本文 ...

  2. 找不到匹配的key exchange算法_目标检测--匹配策略

    CVPR2020中的文章ATSS揭露到anchor-based和anchor-free的目标检测算法之间的效果差异原因是由于正负样本的选择造成的.而在目标检测算法中正负样本的选择是由gt与anchor ...

  3. 无锡室内设计——紫色匹配策略

    紫色匹配策略 紫色是具有最高纯度和最低亮度的颜色.在可见光谱中,紫色的光波短,眼睛对紫色的感知是最低的.它可以用来表达孤独,贵族,豪华,优雅和神秘的情绪.紫色代表高贵的气质.它根据一定的平衡比例与红色 ...

  4. 快速玩转Yolov5目标检测—没有好的显卡也能玩(二)

    上篇  快速玩转Yolov5目标检测-没有好的显卡也能玩(一) 已经将YoloV5在我的笔记本电脑上快速跑起来了,因为电脑显卡一般,所以运行的CPU版本,从推理结果来看,耗时还是蛮高的,如下图,平均每 ...

  5. mysql部署策略_MySQL延迟问题和数据刷盘策略流程分析

    一.MySQL复制流程 官方文档流程如下: MySQL延迟问题和数据刷盘策略 1.绝对的延时,相对的同步 2.纯写操作,线上标准配置下,从库压力大于主库,最起码从库有relaylog的写入. 二.My ...

  6. python hook pc微信_一起来用python玩一波微信呀 | 防撤回, 好友分析, 聊天机器人~...

    原文链接一起来用python玩一波微信呀 | 防撤回, 好友分析, 聊天机器人~​mp.weixin.qq.com 导语 众所周知,前段时间微信彻底关闭了网页版微信登录入口.于是一大波基于itchat ...

  7. Python量化编程如何判断均线数据是金叉还是死叉?-股市数据均线策略编程分析

    Python量化编程如何判断均线数据是金叉还是死叉?-股市数据均线策略编程分析 以平安银行的股票数据为例进行分析 1.使用tushare获取股市数据,直接上代码: import pandas as p ...

  8. mysql数据刷盘过程详解_MySQL延迟问题和数据刷盘策略流程分析

    一.MySQL复制流程 官方文档流程如下: MySQL延迟问题和数据刷盘策略 1.绝对的延时,相对的同步 2.纯写操作,线上标准配置下,从库压力大于主库,最起码从库有relaylog的写入. 二.My ...

  9. LibreOJ #2006. 「SCOI2015」小凸玩矩阵 二分答案+二分匹配

    #2006. 「SCOI2015」小凸玩矩阵 内存限制:256 MiB时间限制:1000 ms标准输入输出 题目类型:传统评测方式:文本比较 上传者: 匿名 提交提交记录统计讨论测试数据 题目描述 小 ...

最新文章

  1. (Mybatis)Mybatis简介和初步使用
  2. 光伏双反闹剧何时休?
  3. centos 安装配置ftp服务器
  4. 用nagios监控ORACLE服务器
  5. 实实在在做一位教书匠(来自网络)
  6. gpio 树莓派3a+_树莓派4正式发布:35美元起售!真香
  7. 计算机二级 java和web_2016计算机二级web程序设计判断题及答案
  8. mysql 日期与索引问题
  9. 基于AI的恶意软件分类技术(4)
  10. po模型---tpshop项目
  11. STM32 CubeMX 1ms定时中断的实现
  12. JAVA标准异常分两大类_java异常分类
  13. 5G标准协议中的简写和缩略语
  14. MATLAB工具学习:cftool(曲线拟合工具)
  15. 同是IT小小鸟——《我是一只IT小小鸟》读书笔记
  16. 【ceph】ceph osd blacklist cep黑名单|MDS问题分析
  17. TZOJ 7034: 竹取飞翔 ~ Lunatic Princess 并查集+数学。
  18. NSN sprint904 总结回顾
  19. 韩国2018年GDP增速为2.7% 人均GNI或超3.1万美元
  20. R、D、E、U、T、A命令

热门文章

  1. CSS 绝对定位 div 水平居中(两种)
  2. oracle修改分区范围,如何更改现有表以在Oracle中创建范围分区
  3. 劳务人员实名制中的二维码应用
  4. arduino设备驱动程序安装失败
  5. 视频教程-绝对干货-springspringmvc源码深入解读,不是干货可无条件退款-Java
  6. 面向对象设计思想总结及代码
  7. go(基础09)——defer
  8. 视频清晰度优化指南 | 得物技术
  9. 陈果《懂你》读书笔记
  10. 广告联盟 怎么赚钱?