yolov5代码详解-compute_loss(p, targets, model)
def compute_loss(p, targets, model): # predictions, targets, modeldevice = targets.device#创建用来保存三层特征图的损失lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)#build_targets详见https://blog.csdn.net/a1874738854/article/details/112789533#获取gt和对应的anchortcls, tbox, indices, anchors = build_targets(p, targets, model) # targetsh = model.hyp # hyperparameters# Define criteria#分类和confidence损失函数BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3#是否对label采取平滑cp, cn = smooth_BCE(eps=0.0)# Focal lossg = h['fl_gamma'] # focal loss gammaif g > 0:BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)# Lossesnt = 0 # number of targets#获取输出特征图的层数no = len(p) # number of outputsbalance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6#对每个特征图进行计算损失for i, pi in enumerate(p): # layer index, layer predictions#获取该层特征图上的gt信息:图像序号,anchor序号,位于特征图上的格网坐标b, a, gj, gi = indices[i] # image, anchor, gridy, gridx#tobj存储gt中的置信度真值tobj = torch.zeros_like(pi[..., 0], device=device) # target objn = b.shape[0] # number of targetsif n:#有gt才计算分类和回归损失,否则只计算置信度损失nt += n # cumulative targets#获取真值对应的预测值box信息ps = pi[b, a, gj, gi] # prediction subset corresponding to targets# Regression 对预测值进行预处理pxy = ps[:, :2].sigmoid() * 2. - 0.5pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box#计算CIOUiou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)#box坐标回归损失lbox += (1.0 - iou).mean() # iou loss# Objectness#利用IOU对gt中的置信度进行加权(对应与build_targets中的gt扩充)tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio# Classification计算分类损失if model.nc > 1: # cls loss (only if multiple classes)#label smootht = torch.full_like(ps[:, 5:], cn, device=device) # targetst[range(n), tcls[i]] = cplcls += BCEcls(ps[:, 5:], t) # BCE# Append targets to text file# with open('targets.txt', 'a') as file:# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]#获取置信度损失lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj losss = 3 / no # output count scalinglbox *= h['box'] * slobj *= h['obj'] * s * (1.4 if no == 4 else 1.)lcls *= h['cls'] * sbs = tobj.shape[0] # batch sizeloss = lbox + lobj + lclsreturn loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
yolov5代码详解-compute_loss(p, targets, model)相关推荐
- yolov5代码详解-build_targets(p, targets, model)
build_target函数是将box标签和anchor进行匹配,其中FPN特征图为三层,分别为降采样8,16,32,代码展示与解释如下: def build_targets(p, targets, ...
- Yolov5代码详解——detect.py
首先执行扩展包的导入: import argparse import os import platform import sys from pathlib import Path import t ...
- yolov5的detect.py代码详解
目标检测系列之yolov5的detect.py代码详解 前言 哈喽呀!今天又是小白挑战读代码啊!所写的是目标检测系列之yolov5的detect.py代码详解.yolov5代码对应的是官网v6.1版本 ...
- YOLOv5算法详解
目录 1.需求解读 2.YOLOv5算法简介 3.YOLOv5算法详解 3.1 YOLOv5网络架构 3.2 YOLOv5实现细节详解 3.2.1 YOLOv5基础组件 3.2.2 输入端细节详解 3 ...
- Pytorch|YOWO原理及代码详解(二)
Pytorch|YOWO原理及代码详解(二) 本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看. 1.正式训练 if opt.evaluate:logging('evaluating ...
- 目标检测Tensorflow:Yolo v3代码详解 (2)
目标检测Tensorflow:Yolo v3代码详解 (2) 三.解析Dataset()数据预处理部分 四. 模型训练 yolo_train.py 五. 模型冻结 model_freeze.py 六. ...
- yolov3代码详解(七)
Pytorch | yolov3代码详解七 test.py test.py from __future__ import divisionfrom models import * from utils ...
- Pytorch | yolov3原理及代码详解(二)
阅前可看: Pytorch | yolov3原理及代码详解(一) https://blog.csdn.net/qq_24739717/article/details/92399359 分析代码: ht ...
- Pytorch | yolov3原理及代码详解(一)
YOLO相关原理 : https://blog.csdn.net/leviopku/article/details/82660381 https://www.jianshu.com/p/d13ae10 ...
最新文章
- 百所学校寒假时长排行,看看你的学校排多少名~
- IDE接口驱动程序移植
- 数据中心如何建设,数据中心机房维护方法详解!
- 用Python爬取Bilibili上二次元妹子的视频
- POJ1845-Sumdiv【逆元,等比数列,约数】
- Javascript中NaN、null和undefinded的区别
- 植物大战僵尸不能保存进度
- 自动驾驶芯片_自动驾驶芯片“争夺战”
- FPGA图像处理 两路sensor的色调不一致
- UE4实时渲染需要注意的点——RTR(Real Time Rendering)
- WPS2003排版位置错误一例(转)
- 电动车结构及其工作原理
- nodejs之koa配置koa-views中间件
- 在TPU上运行PyTorch的技巧总结
- 2021全球程序员收入报告!字节高级码农年薪274万元排第5
- 解决 Chrome 浏览器地址栏字体发虚模糊
- SharedPreferences和SQlite数据库
- VLC Plugin JS 方法
- Python常用轮子下载网站
- 2021年东城区文菁计划资金补助政策及申报条件,部分项目补贴100万
热门文章
- 小白的努力——此时少年山巅客,凭栏尽收快哉风
- 分布式任务调度中心xxl-job
- DNS无法区域传送(axfr,ixfr)
- 支付宝:APP支付接口2.0(alipay.trade.app.pay)
- 30以上java程序员出路,详细说明
- 不用找,你想要的游戏3d纹理图片素材都在这里
- 【自然语言处理】实验1布置:Word2Vec TransE案例
- CentOS 7 安装 Postfix Dovecot
- 校园导游咨询系统(数据结构课程设计)
- scala读取mysql_转: spark:scala读取mysql的4种方法