def compute_loss(p, targets, model):  # predictions, targets, modeldevice = targets.device#创建用来保存三层特征图的损失lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)#build_targets详见https://blog.csdn.net/a1874738854/article/details/112789533#获取gt和对应的anchortcls, tbox, indices, anchors = build_targets(p, targets, model)  # targetsh = model.hyp  # hyperparameters# Define criteria#分类和confidence损失函数BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3#是否对label采取平滑cp, cn = smooth_BCE(eps=0.0)# Focal lossg = h['fl_gamma']  # focal loss gammaif g > 0:BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)# Lossesnt = 0  # number of targets#获取输出特征图的层数no = len(p)  # number of outputsbalance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1]  # P3-5 or P3-6#对每个特征图进行计算损失for i, pi in enumerate(p):  # layer index, layer predictions#获取该层特征图上的gt信息:图像序号,anchor序号,位于特征图上的格网坐标b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx#tobj存储gt中的置信度真值tobj = torch.zeros_like(pi[..., 0], device=device)  # target objn = b.shape[0]  # number of targetsif n:#有gt才计算分类和回归损失,否则只计算置信度损失nt += n  # cumulative targets#获取真值对应的预测值box信息ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets# Regression 对预测值进行预处理pxy = ps[:, :2].sigmoid() * 2. - 0.5pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]pbox = torch.cat((pxy, pwh), 1).to(device)  # predicted box#计算CIOUiou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  # iou(prediction, target)#box坐标回归损失lbox += (1.0 - iou).mean()  # iou loss# Objectness#利用IOU对gt中的置信度进行加权(对应与build_targets中的gt扩充)tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype)  # iou ratio# Classification计算分类损失if model.nc > 1:  # cls loss (only if multiple classes)#label smootht = torch.full_like(ps[:, 5:], cn, device=device)  # targetst[range(n), tcls[i]] = cplcls += BCEcls(ps[:, 5:], t)  # BCE# Append targets to text file# with open('targets.txt', 'a') as file:#     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]#获取置信度损失lobj += BCEobj(pi[..., 4], tobj) * balance[i]  # obj losss = 3 / no  # output count scalinglbox *= h['box'] * slobj *= h['obj'] * s * (1.4 if no == 4 else 1.)lcls *= h['cls'] * sbs = tobj.shape[0]  # batch sizeloss = lbox + lobj + lclsreturn loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()

yolov5代码详解-compute_loss(p, targets, model)相关推荐

  1. yolov5代码详解-build_targets(p, targets, model)

    build_target函数是将box标签和anchor进行匹配,其中FPN特征图为三层,分别为降采样8,16,32,代码展示与解释如下: def build_targets(p, targets, ...

  2. Yolov5代码详解——detect.py

    首先执行扩展包的导入: import argparse import os import platform import sys from pathlib import Path ​ import t ...

  3. yolov5的detect.py代码详解

    目标检测系列之yolov5的detect.py代码详解 前言 哈喽呀!今天又是小白挑战读代码啊!所写的是目标检测系列之yolov5的detect.py代码详解.yolov5代码对应的是官网v6.1版本 ...

  4. YOLOv5算法详解

    目录 1.需求解读 2.YOLOv5算法简介 3.YOLOv5算法详解 3.1 YOLOv5网络架构 3.2 YOLOv5实现细节详解 3.2.1 YOLOv5基础组件 3.2.2 输入端细节详解 3 ...

  5. Pytorch|YOWO原理及代码详解(二)

    Pytorch|YOWO原理及代码详解(二) 本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看. 1.正式训练 if opt.evaluate:logging('evaluating ...

  6. 目标检测Tensorflow:Yolo v3代码详解 (2)

    目标检测Tensorflow:Yolo v3代码详解 (2) 三.解析Dataset()数据预处理部分 四. 模型训练 yolo_train.py 五. 模型冻结 model_freeze.py 六. ...

  7. yolov3代码详解(七)

    Pytorch | yolov3代码详解七 test.py test.py from __future__ import divisionfrom models import * from utils ...

  8. Pytorch | yolov3原理及代码详解(二)

    阅前可看: Pytorch | yolov3原理及代码详解(一) https://blog.csdn.net/qq_24739717/article/details/92399359 分析代码: ht ...

  9. Pytorch | yolov3原理及代码详解(一)

    YOLO相关原理 : https://blog.csdn.net/leviopku/article/details/82660381 https://www.jianshu.com/p/d13ae10 ...

最新文章

  1. 百所学校寒假时长排行,看看你的学校排多少名~
  2. IDE接口驱动程序移植
  3. 数据中心如何建设,数据中心机房维护方法详解!
  4. 用Python爬取Bilibili上二次元妹子的视频
  5. POJ1845-Sumdiv【逆元,等比数列,约数】
  6. Javascript中NaN、null和undefinded的区别
  7. 植物大战僵尸不能保存进度
  8. 自动驾驶芯片_自动驾驶芯片“争夺战”
  9. FPGA图像处理 两路sensor的色调不一致
  10. UE4实时渲染需要注意的点——RTR(Real Time Rendering)
  11. WPS2003排版位置错误一例(转)
  12. 电动车结构及其工作原理
  13. nodejs之koa配置koa-views中间件
  14. 在TPU上运行PyTorch的技巧总结
  15. 2021全球程序员收入报告!字节高级码农年薪274万元排第5
  16. 解决 Chrome 浏览器地址栏字体发虚模糊
  17. SharedPreferences和SQlite数据库
  18. VLC Plugin JS 方法
  19. Python常用轮子下载网站
  20. 2021年东城区文菁计划资金补助政策及申报条件,部分项目补贴100万

热门文章

  1. 小白的努力——此时少年山巅客,凭栏尽收快哉风
  2. 分布式任务调度中心xxl-job
  3. DNS无法区域传送(axfr,ixfr)
  4. 支付宝:APP支付接口2.0(alipay.trade.app.pay)
  5. 30以上java程序员出路,详细说明
  6. 不用找,你想要的游戏3d纹理图片素材都在这里
  7. 【自然语言处理】实验1布置:Word2Vec TransE案例
  8. CentOS 7 安装 Postfix Dovecot
  9. 校园导游咨询系统(数据结构课程设计)
  10. scala读取mysql_转: spark:scala读取mysql的4种方法