文章目录

  • 前言
  • 初始化
  • 加载数据
  • 预测
    • 预测数据格式
  • 后置处理
  • 完整注释后代码

前言

今天放松一下,随便看看这个YOLOV5 的识别部分的代码是怎么做的,先前的话我们自己手动实现了一个非常简易的分类框架,HuClassfiy(已经上传Gitee,方便各位访问),那么这里的话想要使用YOLOV5做点好玩的,也必须要对整个的代码流程进行梳理。原理就不用说了,老复杂了,所以先从简单的来探索。

我们原来的实现这个detect的代码非常简单,后面会贴出,我注释后的detect代码

import argparse
from PIL import Image
from utils.DataSet.MyDataSet import MyDataSet
from utils.DataSet.TransformAtions import TransFormAtions"""
这里不想写那么多东西,就是简单地去做一个测试就ok了。
其实做法就是在那个train里面的训练
"""import argparse
import torch
from torch.utils.data import DataLoader
from models.LeNet import LeNet
from data.ModelConfig import *
import outProcess
def detect():ways = opt.valid_imgstransformations = TransFormAtions()net = LeNet(classes=Classes)state_dict_load = torch.load(opt.path_state_dict)net.load_state_dict(state_dict_load)if(ways):test_data = MyDataSet(data_dir=opt.valid_dir, transform=transformations.valid_transform)valid_loader = DataLoader(dataset=test_data, batch_size=1)net.eval()with torch.no_grad():for i, data in enumerate(valid_loader):# forwardinputs, labels = dataoutputs = net(inputs)_, predicted = torch.max(outputs.data, 1)# 输出处理器outProcess.Function(predicted.numpy()[0])else:#指定的是单张图片,少给我来奇奇怪怪的输入,这个版本容错很差滴!!!path_img = opt.valid_dirif(".jpg" not in path_img):raise Exception("小爷打不开这图片")image = Image.open(path_img)image = transformations.valid_transform(image)image = torch.reshape(image, (1, 3, 32, 32))net.eval()with torch.no_grad():out = net(image)outProcess.Function(out.argmax(1).item())if __name__ == '__main__':parser = argparse.ArgumentParser()# False表示识别单张图片,True表示多张图片,此时指定路径即可。parser.add_argument('--valid_imgs',type=bool,default=False)parser.add_argument('--valid_dir', type=str, default=r'F:\projects\PythonProject\MyClassfication\mydata\train\100\1.jpg')parser.add_argument('--path_state_dict', type=str, default='runs/train/epx2/weights/best.pth')opt = parser.parse_args()detect()

在YOLO V 5 里面也不复杂,也就比我多了100多行代码。

我们这边大致的流程就三个。

然后每一个环节都可以有很多细节优化啥的,由于俺们那个是很简陋的,所以没有哈。

好了,我们开始正式进入这个YOLOV5的实际环节。

初始化

我们先来这看到这个环节,这里一共是做了两件事情嘛,读取超参数,加载模型权重文件,加载驱动

这里可以注意到这个函数

这个的话不用想的那么复杂,就是这个玩意

目的就是返回一个 可以正常使用的驱动,要是我写的话,我压根不会管那么多,不行就玩命报错,然后输出日志文件。

加载数据

然后第二步是加载数据,这个说实话,没什么好说的,分两个,一个是读取网络摄像头,一个是读取一张图片,或者视频,本地摄像头。这些逻辑处理细节不一样,但是结果都是一样的。

就是把数据给我封装的dataset里面,然后读取。

预测


我的注释写还是挺不错的。

预测数据格式

这里我们说说那个预测的格式。
我这里还是拿上次的一张图片做演示

这里有两个目标框,所以拿到的数据是这样的

我们发现pred 是一个长度为1,里面有两个list的玩意
之后我们发现最后一个直接是0
这个的话,是这样的

后置处理

之后就是拿到东西之后处理。在yolo里面默认是实现了一个自己绘图的玩意。
当然有时候,我们不仅仅要这玩意,我们想要实现AI压枪的话还需要那啥。

完整注释后代码

import argparse
import time
from pathlib import Pathimport cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import randomfrom models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronizeddef detect(save_img=False):# 读取初始化参数source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_sizesave_img = not opt.nosave and not source.endswith('.txt')  # save inference imageswebcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))# Directoriessave_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))  # increment run(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir# Initializeset_logging()device = select_device(opt.device)half = device.type != 'cpu'  # half precision only supported on CUDA# Load model#加载模型,这一块,weights是我们传入的参数,是我们权重文件地址#注意到这个model就是我们Huclassfiy的netmodel = attempt_load(weights, map_location=device)  # load FP32 modelstride = int(model.stride.max())  # model stride,维度的变换步长,这个和YOLO的网络结构有关,先忽略#imgsz是我们图片资源,对图片尺寸进行检查imgsz = check_img_size(imgsz, s=stride)  # check img_size#Pytorch 模型加速,这个需要GPU加速,需要先加载模型权重的!!!if half:model.half()  # to FP16# Second-stage classifierclassify = Falseif classify:modelc = load_classifier(name='resnet101', n=2)  # initializemodelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()# Set Dataloadervid_path, vid_writer = None, None#如果是网络摄像头的数据这样处理if webcam:view_img = check_imshow()cudnn.benchmark = True  # set True to speed up constant image size inferencedataset = LoadStreams(source, img_size=imgsz, stride=stride)else:#这部分是加载dataset 和我们那个也是类似的,只不过对于单张图片,我们直接转化为了一个tensor在HuClassFiydataset = LoadImages(source, img_size=imgsz, stride=stride)# Get names and colorsnames = model.module.names if hasattr(model, 'module') else model.namescolors = [[random.randint(0, 255) for _ in range(3)] for _ in names]# Run inferenceif device.type != 'cpu':model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run oncet0 = time.time()for path, img, im0s, vid_cap in dataset:#这里就和我们的那个进入验证是类似的了#path 是你的图片路径# img 自然是image转化为了tensor#im0s 是做了一个转化img0 = cv2.imread(path)  # BGR#vid_cap 就是说这玩意是不是一个视频,我们读入图片当然不是所以是Noneimg = torch.from_numpy(img).to(device)img = img.half() if half else img.float()  # uint8 to fp16/32img /= 255.0  # 0 - 255 to 0.0 - 1.0 #归一化if img.ndimension() == 3:img = img.unsqueeze(0)# Inferencet1 = time_synchronized()pred = model(img, augment=opt.augment)[0]#这个是预测的结果,但是按照那个网络的工作原理,还需要进行NMS非极大值抑制筛选目标框框# Apply NMSpred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)t2 = time_synchronized()print("预测结果是",pred)#按照我们在YOLV1论文里面的推出,应该是有5个参数#x,y,w,h,k可信度,但是这里要显示所以还有一个对应的条件概率#所以应该有6个参数,但是对应参数k,我们的概率计算是需要k的,结合参数opt.iou_thres#所以此时那个参数k应该是iou,之后对应的概率,这里最后通过debug我发现那个完整的参数是这样的#左上角,右下角,然后可信度,然后所属类别,注意那里显示的是按照屏幕100%来的,我的笔记本是125%#得到的坐标是需要除以1.25的# Apply Classifierif classify:pred = apply_classifier(pred, modelc, img, im0s)# Process detections#这部分,就是我们的后置处理了。说实话,应该把这玩意拆开的,这个部分是给Opencv画图用的for i, det in enumerate(pred):  # detections per imageif webcam:  # batch_size >= 1p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.countelse:p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)p = Path(p)  # to Pathsave_path = str(save_dir / p.name)  # img.jpgtxt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txts += '%gx%g ' % img.shape[2:]  # print stringgn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwhif len(det):# Rescale boxes from img_size to im0 sizedet[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()# Print resultsfor c in det[:, -1].unique():n = (det[:, -1] == c).sum()  # detections per classs += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string# Write resultsfor *xyxy, conf, cls in reversed(det):if save_txt:  # Write to filexywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywhline = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh)  # label formatwith open(txt_path + '.txt', 'a') as f:f.write(('%g ' * len(line)).rstrip() % line + '\n')if save_img or view_img:  # Add bbox to imagelabel = f'{names[int(cls)]} {conf:.2f}'# plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)im0 = plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)# Print time (inference + NMS)print(f'{s}Done. ({t2 - t1:.3f}s)')# Stream resultsif view_img:cv2.imshow(str(p), im0)cv2.waitKey(1)  # 1 millisecond# Save results (image with detections)if save_img:if dataset.mode == 'image':cv2.imwrite(save_path, im0)else:  # 'video' or 'stream'if vid_path != save_path:  # new videovid_path = save_pathif isinstance(vid_writer, cv2.VideoWriter):vid_writer.release()  # release previous video writerif vid_cap:  # videofps = vid_cap.get(cv2.CAP_PROP_FPS)w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))else:  # streamfps, w, h = 30, im0.shape[1], im0.shape[0]save_path += '.mp4'vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))vid_writer.write(im0)if save_txt or save_img:s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''print(f"Results saved to {save_dir}{s}")print(f'Done. ({time.time() - t0:.3f}s)')if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--weights', nargs='+', type=str, default='runs/train/exp2/weights/best.pt', help='model.pt path(s)')# http://admin:admin@192.168.101.21:8081parser.add_argument('--source', type=str, default=r'F:\projects\PythonProject\yolov5-5.0\mydata\images\003.jpg', help='source')  # file/folder, 0 for webcamparser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')parser.add_argument('--view-img', action='store_true', help='display results')parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')parser.add_argument('--nosave', action='store_true', help='do not save images/videos')parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')parser.add_argument('--augment', action='store_true', help='augmented inference')parser.add_argument('--update', action='store_true', help='update all models')parser.add_argument('--project', default='runs/detect', help='save results to project/name')parser.add_argument('--name', default='exp', help='save results to project/name')parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')opt = parser.parse_args()print(opt)check_requirements(exclude=('pycocotools', 'thop'))with torch.no_grad():if opt.update:  # update all models (to fix SourceChangeWarning)for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:detect()strip_optimizer(opt.weights)else:detect()

那么接下来我们要做的就是提取detect,把这个玩意套在我们自己的项目里面。为了后面便于使用这个yolo,我决定后面对这个玩意进行工程化规范,便于直接进行二次使用,开发。毕竟核心的话其实就和HuClassfiy一样,就那几个块。还是那句话,yolo的难点不在工程上,在原理实现上面…

YOLOV5 Detetct.py 流程分析相关推荐

  1. python开源聊天机器人ChatterBot——聊天机器人搭建、流程分析、源码分析

    开源聊天机器人ChatterBot 3.1  ChatterBot简介 ChatterBot是一个Python库,可以轻松生成对用户输入的自动响应.ChatterBot使用一系列机器学习算法来产生不同 ...

  2. 开源项目Hopsan代码梳理、流程分析

    Hosan开源项目是液压.电力等行业的仿真开源软件,由瑞典林平大学开发,可以仿真电力.液压等.更具体的了解,请参考:<Hopsan -- 液压.电力等行业的仿真开源软件>. Hosan开源 ...

  3. yocto 编译流程分析

    yocto 编译流程分析 2015年04月15日 10:55:13 日月星辰007 阅读数:4955 git clone 一份poky 的工程到本地. source poky/oe-init-buil ...

  4. 高通Android智能平台环境搭建_编译流程分析

    高通Android智能平台环境搭建_编译流程分析 高通平台环境搭建,编译,系统引导流程分析 TOC \o \h \z \u 1. 高通平台android开发总结. 7 1.1 搭建高通平台环境开发环境 ...

  5. 高通平台环境搭建,编译,系统引导流程分析 .

    1.高通平台android开发总结 1.1 搭建高通平台环境开发环境 在高通开发板上烧录文件系统 建立高通平台开发环境 高通平台,android和 modem 编译流程分析 高通平台 7620 启动流 ...

  6. 浅显易懂的Django架构流程分析

    Django的运行方式 运行Django项目的方法很多,一种是在开发和调试中经常用到的runserver方法,使用Django自己的Web Server.另外一种就是使用fastcgi, uWSGI等 ...

  7. 开源项目CookiesPool流程分析

    开源项目CookiesPool流程分析 转载文章请注明出处,邮箱:qiled@qq.com 项目地址:https://github.com/Python3WebSpider/CookiesPool 1 ...

  8. yolov5——train.py代码【注释、详解、使用教程】

    yolov5--train.py代码[注释.详解.使用教程] yolov5--train.py代码[注释.详解.使用教程] yolov5--train.py代码[注释.详解.使用教程] 前言 1. p ...

  9. 高通android智能平台环境搭建_编译流程分析,高通平台环境搭建,编译,系统引导流程分析参考...

    高通有两个cpu,他们分别跑不同的系统,应用程序(ap)端是android系统,modem 端是高通自己的系统. 要编译出可供烧写使用的镜像文件需要三部分代码: 1) 获取经过高通打补丁的 andro ...

  10. VLC架构及流程分析

    0x00 前置信息 VLC是一个非常庞大的工程,我从它的架构及流程入手进行分析,涉及到一些很细的概念先搁置一边,日后详细分析. 0x01 源码结构(Android Java相关的暂未分析) # bui ...

最新文章

  1. win7右键新建文件夹不见了
  2. Working with Symbols (在Balsamiq Mockups中复用自定义控件和页面模板)
  3. mysql和hbase优缺点_hbase优缺点
  4. pythonwith作用_老生常谈Python startswith()函数与endswith函数
  5. 包python_Python 包:
  6. MCU中printf重定向实现
  7. 服务器硬盘 主板,服务器主板和普通主板有什么区别?
  8. FlashBuilder找不到所需要的AdobeFlashPlayer调试器版本的解决方案
  9. BFS广度优先算法, DFS深度优先算法,Python,队列实现,栈实现
  10. GCC,GDB,Makefile
  11. python中goto如何使用,基于python goto的正确用法说明
  12. 固态函数不正确_固态硬盘可靠吗?
  13. 使用notebook 笔记(1)
  14. tengine2.2.3报错502的The proxy server received an invalid response from an upstream server问题处理...
  15. 遗传算法最简单的例子
  16. 史上最全安装Maven教程
  17. 打印机可以打印不能扫描怎么弄_惠普打印机可以复印不能扫描怎么操作
  18. 建筑工地人脸识别门禁通道闸机如何安装
  19. 一个超好看的音乐网站设计与实现(HTML+CSS)
  20. 09.利用U盘PE系统破解Windows7、XP密码

热门文章

  1. 关于搞国外广告联盟的一些思路
  2. 最新Centos7.6 部署ELK日志分析系统
  3. 帅爆! 赛博朋克特效实现
  4. 直播APP软件开发,直播系统开发的技术架构揭秘
  5. U8字符串(u8前缀)的作用
  6. LeetCode刷题复盘笔记—一文搞懂509. 斐波那契数70. 爬楼梯以及递归时间复杂度计算方法(动态规划系列第一篇)
  7. Low-Light Enhancement 数据集 和 论文代码
  8. Mixpanel获Andreessen Horowitz投资 为Viddy及Path提供分析服务
  9. 中国移动开放平台(dev.cmccopen.cn)请求头Header:Authorization验证失败的原因(我遇到的)
  10. [Irving]SqlServer 拆分函数用法