yolov5损失函数的几点理解
所用代码:https://github.com/ultralytics/yolov5
参考文献:https://www.cnblogs.com/pprp/p/12590801.html
感谢知乎网友:Ancy贝贝

重要的代码块在build_targets内。

def build_targets(p, targets, model):# Build targets for compute_loss(), input targets(image,class,x,y,w,h)det = model.module.model[-1] if is_parallel(model) else model.model[-1]  # Detect() modulena, nt = det.na, targets.shape[0]  # number of anchors, targetstcls, tbox, indices, anch = [], [], [], []gain = torch.ones(7, device=targets.device)  # normalized to gridspace gainai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)#将targets复制3份,每份分配一个anchor编号,如0,1,2. 也就是每个anchor分配一份targets。targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  # append anchor indicesg = 0.5  # bias# 这里off表示了5个偏移,原点不动,往右、往下、往左、往上。# 其中坐标原点在图像的左上角,x轴往右(列),y轴往下(行)。off = torch.tensor([[0, 0],[1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m# [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm], device=targets.device).float() * g  # offsetsfor i in range(det.nl):#det.anchors在导入model的时候就除以了步长,因此此时anchor大小不是相对于原图,而是相对于对应特征层的尺寸anchors = det.anchors[i]gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]  # xyxy gain# Match targets to anchors#这里主要是将gt的cx,cy,w,h换算到当前特征层对应的尺寸,以便和该层的anchor大小相对应t = targets * gainif nt:# Matches#这个部分是计算gt和anchor的匹配程度#即w_gt/w_anchor  h_gt/h_anchorr = t[:, :, 4:6] / anchors[:, None]  # wh ratio#这里判断了r和1/r与model.hyp['anchor_t']的大小关系,即只有不大于这个数,也就是说gt与anchor的宽高差距不过大的时候,才认为匹配。代码中 model.hyp['anchor_t']=4j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t']  # compare       # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))#将满足条件的targets筛选出来。          t = t[j]  # filter# Offsets#这个部分就是扩充targets的数量,将比较targets附近的4个点,选取最近的2个点作为新targets中心,新targets的w、h使用与原targets一致,只是中心点坐标的不同。gxy = t[:, 2:4]  # grid xygxi = gain[[2, 3]] - gxy  # inversej, k = ((gxy % 1. < g) & (gxy > 1.)).Tl, m = ((gxi % 1. < g) & (gxi > 1.)).Tj = torch.stack((torch.ones_like(j), j, k, l, m))t = t.repeat((5, 1, 1))[j] #筛选后t的数量是原来t的3倍。offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]else:t = targets[0]offsets = 0# Defineb, c = t[:, :2].long().T  # image, classgxy = t[:, 2:4]  # grid xygwh = t[:, 4:6]  # grid whgij = (gxy - offsets) #自己加的代码,方便查看gij的分布。plot_gxy(gxy=gij, scale_i=i, size=gain, flag='gij') #自己编的代码,用于查看gij的分布。gij = (gxy - offsets).long() #将所有targets中心点坐标进行偏移。gi, gj = gij.T  # grid xy indices# Appenda = t[:, 6].long()  # anchor indicesindices.append((b, a, gj, gi))  # image, anchor, grid indicestbox.append(torch.cat((gxy - gij, gwh), 1))  # boxanch.append(anchors[a])  # anchorstcls.append(c)  # classreturn tcls, tbox, indices, anch

下图是20x20的特征图上的gij的分布示意图,从图中可以看出每个targets都扩充了2个临近的targets。关于为什么扩充,我还没理解,有知道的网友欢迎留言。另外,知乎网友Ancy贝贝的理解是:之前通过筛选,去掉了一些匹配不上anchor的gt,本来正样本就比负样本少很多,经过筛选,少得更多了,所以每个gt扩充2个出来,增加正样本比例。

# Regression
    pxy = ps[:, :2].sigmoid() * 2. - 0.5
    pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
    pbox = torch.cat((pxy, pwh), 1).to(device)  # predicted box
    giou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  # giou(prediction, target)
    lbox += (1.0 - giou).mean()  # giou loss

代码中的pxy对应bxy,ps[:, :2]对应txy。由此可知bxy的取值范围是[-0.5,1.5]。因此有可能偏移到临近的单元格内,但偏移不多,不知道作者是什么考虑的。

代码中的pwh对应bwh,anchors[i]对应Pwh。因此可知bwh的范围是[0,4]*Pwh。这和前面
j = torch.max(r, 1. / r).max(2)[0] < model.hyp[‘anchor_t’] # model.hyp[‘anchor_t’]=4 是一致的。

Objectness
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
此处 tobj[b, a, gj, gi]用giou(真实的是ciou)取代1,代表该点对应置信度。为什么要用giou来代替,我也没想明白,有知道的网友欢迎留言。

其余的部分比较好理解,在此不再赘述。

附:
plot_gxy的代码:

def plot_gxy(gxy, scale_i, size, flag):
    s = int(size[2].cpu().numpy())
    ax = plt.subplot(111)
    ax.axis([0, s, 0, s])
    lxx = np.arange(0, s + 1, 1)
    lxx = np.repeat(lxx, s + 1, axis=0)
    lxx = lxx.reshape(s + 1, s + 1)
    lyy = np.arange(0, s + 1, 1)
    lyy = np.repeat(lyy, s + 1, axis=0)
    lyy = lyy.reshape(s + 1, s + 1)
    lyy = lyy.T

for i in range(len(lxx)):
        plt.plot(lxx[i], lyy[i], color='k', linewidth=0.05, linestyle='-')
        plt.plot(lyy[i], lxx[i], color='k', linewidth=0.05, linestyle='-')

for i in range(len(gxy)):
        x1, y1 = gxy.cpu().numpy().T
        plt.scatter(x1, y1, s=0.02, color='k')

ax = plt.gca()  # 获取到当前坐标轴信息
    ax.xaxis.set_ticks_position('top')  # 将X坐标轴移到上面
    ax.invert_yaxis()
    plt.savefig("gxy_{}_{}.png".format(scale_i, flag))
    plt.close()
————————————————
版权声明:本文为CSDN博主「tpz789」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/tpz789/article/details/108844004

yolov5损失函数笔记相关推荐

  1. yolov5学习笔记(毕业设计)

    yolov5学习笔记 一,基本准备 1.配置基本环境 2.程序跑起来 3.核心代码解读和自定义训练目标 二,训练yolov5神经网络 (1,本地训练yolov5 1, 首先把进程改为0,不然有的电脑会 ...

  2. Yolov5自学笔记之二--在游戏中实时推理并应用(实例:哈利波特手游跳舞小游戏中自动按圈圈)

    上一篇帖子我已经自学了Yolov5的基本流程,并运用yolov5进行图片.视频.摄像头.网络视频流等多种方式的推理,这些结合到实际工作中就可以有很广泛的应用了.但是还有一类情况,就是在电脑中的某个程序 ...

  3. YOLOv5损失函数(白勇课件)

    YOLOv5损失函数 总损失各项前可加入系数权重 边界框回归:寻找边界框位置.大小,作为计算机视觉最重要组件,之前一直使用L1\L2 IoU损失 ground truth box 在数据上真实标注的框 ...

  4. 【目标检测算法】IOU、GIOU、DIOU、CIOU与YOLOv5损失函数

    1 常见IOU汇总 classification loss 分类损失 localization loss, 定位损失(预测边界框与GT之间的误差) confidence loss 置信度损失(框的目标 ...

  5. (转载)yolov5理论学习笔记

    算法创新分为三种方式 第一种:面目一新的创新,比如Yolov1.Faster-RCNN.Centernet等,开创出新的算法领域,不过这种也是最难的 第二种:守正出奇的创新,比如将图像金字塔改进为特征 ...

  6. 目标检测算法实现(八)——YOLOV5学习笔记

    非常感谢江大白大佬的研究与分享 附链接 深入浅出Yolo系列之Yolov5核心基础知识完整讲解 目录 1.网络结构图+v5性能对比 2.v5的改进和优势 2.1 输入端 2.1.1 Mosaic数据增 ...

  7. yolov5组件笔记

    深度学习模型组件 ------ 深度可分离卷积.瓶颈层Bottleneck.CSP瓶颈层BottleneckCSP.ResNet模块.SPP空间金字塔池化模块 目录 1.标准卷积: Conv + BN ...

  8. 人脸识别损失函数笔记

    识别改进方法: 剪枝, 蒸馏 参考链接:人脸识别0-05:insightFace-损失函数arcface-史上最全_江南才尽江南山,年少无知年少狂!-CSDN博客_insightface损失函数 以下 ...

  9. yolov5学习笔记

    用已有模型预测自己的图片和视频 配置环境略. 在detect.py文件中改一下路径,或者把下载好的图片放入对应文件夹即可. 目标检测指标 IoU 的全称为交并比(Intersection over U ...

最新文章

  1. 浅析ado.net获取数据库元数据信息
  2. PyCharm使用笔记
  3. -bash:/etc/profile Permission Denied
  4. oracle中delete、truncate、drop的区别 (转载)
  5. 树莓派(TCP客户端 )和Wemos(TCP服务端连接红外模块)通讯实现对红外设备的控制
  6. 255.0.0.0子网掩码相应的cidr前缀表示法是?_六十四、前缀,后缀,中缀表达式转化求值问题...
  7. C# 读写锁 ReaderWriteLock
  8. 项目实战-电商(网上书城)
  9. 分布滞后与自回归模型 ADL
  10. 如何在android lolipop中开启google now
  11. 计算机网络为何使用分组交换,而不是电路交换
  12. 微信登录优化方案设计
  13. 基于知识图谱的知识泛化让AI学会“举一反三”
  14. 上传Excel文件进度条原理
  15. JavaScript面向对象
  16. 《Spring系列》第15章 声明式事务(一) 基础使用
  17. 安卓APP按键美化——圆角按键
  18. C#调用斑马打印机打印条码标签(含源码)(支持COM、LPT、USB、TCP连接方式和ZPL、EPL、CPCL指令)
  19. 用轻量服务器搭建在线协作绘图白板
  20. 如何定制博客园的个人空间

热门文章

  1. linux下 x86、i386、i486、i586、i686、x86_64区别
  2. wireshark过虑规则
  3. Android--SlidingDrawer的使用介绍
  4. TCP/IP详解--学习笔记(13)-TCP坚持定时器,TCP保活定时器
  5. linux 组调度浅析
  6. Linux Systemcall Int0x80方式、Sysenter/Sysexit Difference Comparation
  7. append 降低数组位数_腿粗有理!研究发现腿部脂肪多,能大幅降低患高血压的风险!...
  8. 新浪博客服务器维护,服务器安全维护
  9. Java集合查找Map,java:使用hashmap或其他一些java集合创建查找...
  10. java hasfocus_Android·Focus机制解析和常见问题