Pytorch | yolov3代码详解七

  • test.py

test.py

from __future__ import divisionfrom models import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *import os
import sys
import time
import datetime
import argparse
import tqdmimport torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
##########################################################################
#测试
###########################################################################1.定义evaluate评估函数
#2.解析输入的参数
#3.打印当前使用的参数
#4.解析评估数据集的路径和class_names
#5.创建model
#6.加载模型的权重
#7.调用evaluate评估函数得到评估结果
#8.打印每一种class的评估结果ap
#9.打印所有class的平均评估结果mAP#1.定义evaluate评估函数
def evaluate(model, path, iou_thres, conf_thres, nms_thres, img_size, batch_size):#输入模型model,拟评估数据集地址valid_path,iou_thres阀值,conf_thres阀值,nms_thres阀值,img_size,batch_sizemodel.eval()#设置为验证模式# Get dataloaderdataset = ListDataset(path, img_size=img_size, augment=False, multiscale=False)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn)Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensorlabels = []sample_metrics = []  # List of tuples (TP, confs, pred)for batch_i, (_, imgs, targets) in enumerate(tqdm.tqdm(dataloader, desc="Detecting objects")):   #Tqdm 是 Python 进度条库##targets(0,类别,x,y,w,h)即在0到1之间# Extract labelslabels += targets[:, 1].tolist()# Rescale targettargets[:, 2:] = xywh2xyxy(targets[:, 2:])  #转换为左上右下形式targets[:, 2:] *= img_size#调整为原图大小imgs = Variable(imgs.type(Tensor), requires_grad=False)with torch.no_grad():outputs = model(imgs)#通过Darknet的forward() 函数得到检测结果,yolo_outputs   输出形状为【1,507,85】  在特征图13*13上,其中507为3*13*13outputs = non_max_suppression(outputs, conf_thres=conf_thres, nms_thres=nms_thres)#得到 [;,  x,y,x,y,置信度,最大类别值,最大类别索引]#即detections(经过置信度阈值筛选后的数量,7)sample_metrics += get_batch_statistics(outputs, targets, iou_threshold=iou_thres)# Concatenate sample statisticstrue_positives, pred_scores, pred_labels = [np.concatenate(x, 0) for x in list(zip(*sample_metrics))]  #函数是对矩阵元素的扩充precision, recall, AP, f1, ap_class = ap_per_class(true_positives, pred_scores, pred_labels, labels)return precision, recall, AP, f1, ap_classif __name__ == "__main__":#2.解析输入的参数parser = argparse.ArgumentParser()parser.add_argument("--batch_size", type=int, default=8, help="size of each image batch")parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file")parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file")parser.add_argument("--class_path", type=str, default="data/coco.names", help="path to class label file")parser.add_argument("--iou_thres", type=float, default=0.5, help="iou threshold required to qualify as detected")parser.add_argument("--conf_thres", type=float, default=0.001, help="object confidence threshold")parser.add_argument("--nms_thres", type=float, default=0.5, help="iou thresshold for non-maximum suppression")parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")opt = parser.parse_args()#3.打印当前使用的参数print(opt)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#4.解析评估数据集的路径和class_namesdata_config = parse_data_config(opt.data_config)valid_path = data_config["valid"]class_names = load_classes(data_config["names"])#5.创建model# Initiate modelmodel = Darknet(opt.model_def).to(device)#6.加载模型的权重if opt.weights_path.endswith(".weights"):# Load darknet weightsmodel.load_darknet_weights(opt.weights_path)else:# Load checkpoint weightsmodel.load_state_dict(torch.load(opt.weights_path))print("Compute mAP...")#7.调用evaluate评估函数得到评估结果precision, recall, AP, f1, ap_class = evaluate(model,path=valid_path,iou_thres=opt.iou_thres,conf_thres=opt.conf_thres,nms_thres=opt.nms_thres,img_size=opt.img_size,batch_size=8,)#8.打印每一种class的评估结果apprint("Average Precisions:")for i, c in enumerate(ap_class):print(f"+ Class '{c}' ({class_names[c]}) - AP: {AP[i]}")#9.打印平均的评估结果mAPprint(f"mAP: {AP.mean()}")

yolov3代码详解(七)相关推荐

  1. Keras YOLOv3代码详解(三):目标检测的流程图和源代码+中文注释

    Keras YOLOv3源代码下载地址:https://github.com/qqwweee/keras-yolo3 YOLOv3论文地址:https://pjreddie.com/media/fil ...

  2. yoloV3代码详解(注释)

    原文链接:https://www.cnblogs.com/hujinzhou/p/guobao_2020_3_13.html yolo3各部分代码详解(超详细) </h1><div ...

  3. yolov3代码详解_代码资料

    faster RCNN TensorFlow版本: 龙鹏:[技术综述]万字长文详解Faster RCNN源代码(一) buptscdc:tensorflow 版faster rcnn代码理解(1) l ...

  4. YOLOv3 代码详解(2) —— 数据处理 dataset.py解析:输入图片增强、制作模型的每层输出的标签

    前言: yolo系列的论文阅读 论文阅读 || 深度学习之目标检测 重磅出击YOLOv3 论文阅读 || 深度学习之目标检测yolov2 论文阅读 || 深度学习之目标检测yolov1   该篇讲解的 ...

  5. pytorch yolov3 代码详解_PyTorch C++ libtorch的使用方法(1)-nightly 版本的 libtorch

    问题描述: 按照PyTorch中文教程的[ 在 C++ 中加载 PYTORCH 模型 ]一文,尝试调用 PyTorch模型. 1. 例子来源 在 C++ 中加载 PYTORCH 模型 我是使用Qt新建 ...

  6. Pytorch | yolov3原理及代码详解(一)

    YOLO相关原理 : https://blog.csdn.net/leviopku/article/details/82660381 https://www.jianshu.com/p/d13ae10 ...

  7. 深度篇——目标检测史(七) 细说 YOLO-V3目标检测 之 代码详解

    返回主目录 返回 目标检测史 目录 上一章:深度篇--目标检测史(六) 细说 YOLO-V3目标检测 下一章:深度篇--目标检测史(八) 细说 CornerNet-Lite 目标检测 论文地址:< ...

  8. Pytorch | yolov3原理及代码详解(二)

    阅前可看: Pytorch | yolov3原理及代码详解(一) https://blog.csdn.net/qq_24739717/article/details/92399359 分析代码: ht ...

  9. DeepLearning tutorial(4)CNN卷积神经网络原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43225445 DeepLearning tutorial(4)CNN卷积神经网络原理简介 ...

最新文章

  1. FastDFS文件上传和下载流程
  2. 黑马Go语言与区块链学习笔记
  3. python获取坐标颜色,python – 根据一组坐标的数据着色地图
  4. centos 安装java_自己动手基于centos7安装docker及如何发布tomcat镜像
  5. Educational Codeforces Round 41 (Rated for Div. 2)
  6. Python并发编程:多线程-Thread对象的其它属性和方法
  7. 【mobile】安卓图案解锁尝试次数过多导致 要解锁需要GOOGLE账户登录,解决方案...
  8. 抱米花-豆丁文档下载器 20100529
  9. dmg文件如何安装linux,我怎么能打开.dmg文件?
  10. 方向余弦矩阵与四元数
  11. 2014年南京航空航天大学计算机学院推荐研究生公示,南京航空航天大学2013-2014学年研究生评优评奖公示...
  12. 如何修改Tomcat的默认主页
  13. 基于Java实现一个简单的记事本Android App
  14. Java期末考试题(附答案)
  15. 我帮粉丝赚了10w+
  16. 简单的网页设计,以学校官网为例
  17. Leetcode 971 C++代码
  18. 最牛逼android上的图表库MpChart(一) 介绍篇
  19. AI技术领跑、23个国际冠军,2019百度AI如何彰显核心竞争力
  20. js Proxy 从入门到废掉的整个过程

热门文章

  1. 服务器宝塔Error: connect ETIMEDOUT
  2. 计算机操作系统计算题及答案(5),5计算机操作系统练习题及答案.doc
  3. 数据平台建设的痛点,如何进行元数据治理?
  4. 1003 Emergency (25 point(s))
  5. java 多线程分段等待执行完成状况,循环屏障CyclicBarrier | Java工具类
  6. WPF——ViewBox控件
  7. MySQL安装版本Navicat连接报错2509解决方案
  8. 【正点原子Linux连载】第三章 RV1126开发环境搭建 摘自【正点原子】ATK-DLRV1126系统开发手册
  9. 机器学习系列2 BP神经网络+代码实现
  10. 快慢指针 ——链表 | Leetcode 练习