yolov5 + second_classify -- 代码
因内容重要,故做此笔记,也仅做笔记。
detect_correct.py
from yolov5 import YOLOv5import torch
from torchvision import transforms
# import numpy as np
from PIL import Image
import cv2
import osdef detect(image_path, yolov5_model, recog_model):img = cv2.imread(image_path)if img is None:print('None image', image_path)return Falseh, w, c = img.shape # cv2 read formatbboxes, scores = yolov5_model.detect(img, conf_thres=0.1)bboxes = bboxes.numpy()scores = scores.numpy()for bbox, score in zip(bboxes, scores):# print(bbox, score)# bbox = np.maximum(np.array(bbox), 0).tolist()bbox[0] = max(bbox[0], 0)bbox[1] = max(bbox[1], 0)bbox[2] = min(bbox[2], w)bbox[3] = min(bbox[3], h)x1, y1, x2, y2 = [int(_) for _ in bbox[:4]]cropped_img = Image.fromarray(cv2.cvtColor(img[y1:y2, x1:x2, :], cv2.COLOR_BGR2RGB))cropped_img = data_transforms(cropped_img).unsqueeze(0).cuda()outputs = recog_model(cropped_img)confidence, preds = torch.max(outputs.data, 1)# print(preds.item(), confidence.item())print(preds.item())class_name = class_dict[preds.item()]cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 255), thickness=2)if confidence.item() > 0.5:cv2.putText(img, class_name, (x1+5, y1+20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), thickness=2)if len(bboxes) > 0:cv2.imshow('image', img)cv2.waitKey(5000)class_dict = dict()
data_transforms = transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])image_root = '/home/img_test'
yolov5_weight = '/home/weights/best.pt'
efficient_weight = '/home/models/efficientnet-b0.pth'
efficient_model = torch.load(efficient_weight).cuda().eval()yolov5 = YOLOv5(yolov5_weight)
for image_path in os.listdir(image_root):ext = os.path.splitext(image_path)[1]if not ext in ['.jpg', '.png', '.jpeg']:continueprint(image_path)image_path = os.path.join(image_root, image_path)detect(image_path, yolov5, efficient_model)
yolov5.py
import time
import torch
import numpy as np
import cv2
import os
import argparse
from tqdm import tqdmfrom models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import check_img_size, non_max_suppression, scale_coords
from utils.torch_utils import select_device, time_syncparser = argparse.ArgumentParser(description='Retinaface')
parser.add_argument('--image_dir', type=str)
parser.add_argument('--target_dir', type=str)
args = parser.parse_args()class YOLOv5:def __init__(self, weights, imgsz=640):# Initializeprint('Loading YOLO from', weights)self.device = select_device()self.half = self.device.type != 'cpu' # half precision only supported on CUDA# Load modelself.model = attempt_load(weights, map_location=self.device) # load FP32 modelself.imgsz = check_img_size(imgsz, s=self.model.stride.max()) # check img_sizeif self.half:self.model.half() # to FP16def detect(self, orig_img, augment=True, conf_thres=0.25):iou_thres = 0.45# Padded resizeimg = letterbox(orig_img, new_shape=self.imgsz)[0]# Convertimg = img[:,:,::-1].transpose(2, 0, 1) # BGR to RGBimg = np.ascontiguousarray(img)img = torch.from_numpy(img).to(self.device)img = img.half() if self.half else img.float() # uint8 to fp16/32img /= 255.0 # 0 - 255 to 0.0 - 1.0if img.ndimension() == 3:img = img.unsqueeze(0)# Inferencet0 = time.time()pred = self.model(img, augment=augment)[0]# Apply NMSpred = non_max_suppression(pred, conf_thres, iou_thres, classes=0)t2 = time_sync()# Process detectionsfor i, det in enumerate(pred):if len(det):# Rescale boxes from img_size to im0 sizedet[:, :4] = scale_coords(img.shape[2:], det[:, :4], orig_img.shape).round()results = pred[0].cpu().numpy()bboxes = results[np.where(results[:, 5] == 0)]bboxes, scores = bboxes[:, :4], bboxes[:, 4]return torch.from_numpy(bboxes), torch.from_numpy(scores)if __name__ == "__main__": # 单独运行此文件时weigths = '/home/weights/yolov5x.pt'yolo = YOLOv5(weigths)img_root = args.image_dirimg_list_file = Nonesave_dir = args.target_dirif not os.path.exists(save_dir):os.makedirs(save_dir)print(save_dir)img_list = list()if img_list_file is not None:with open(img_list_file, 'r') as f:for line in f.readlines():img_list.append(line.strip().split()[0])else:img_list = os.listdir(img_root)for img_name in tqdm(img_list):# print(img_name)img_path = os.path.join(img_root, img_name)img = cv2.imread(img_path)if img is None:continuebboxes, scores = yolo.detect(img, conf_thres=0.5)if len(bboxes) < 2:continue# print(bboxes.shape)save_file_name = os.path.splitext(img_name)[0]+'.txt'save_path = os.path.join(save_dir, save_file_name)op_f = open(save_path, 'w')op_f.write(img_name+'\n')op_f.write(str(len(bboxes))+'\n')for bbox, score in zip(bboxes, scores):score = score.numpy()x1, y1, x2, y2 = [int(_) for _ in bbox]op_array = [str(int(_)) for _ in bbox] + [str(score)]op_f.write('\t'.join(op_array)+'\n')# cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 255), thickness=2)# cv2.putText(img, str(np.round(score, 2)), (x1+5, y1+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)# cv2.imshow('image', img)# cv2.waitKey(1000)
其他调用文件与yolov5默认一致。
yolov5 + second_classify -- 代码相关推荐
- GitHub上YOLOv5开源代码的训练数据定义
GitHub上YOLOv5开源代码的训练数据定义 代码地址:https://github.com/ultralytics/YOLOv5 训练数据定义地址:https://github.com/ultr ...
- YOLOV5训练代码train.py注释与解析
YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全 github: https://github.com/Laughing-q/yolov5_annotations Y ...
- YOLOV5检测代码detect.py注释与解析
YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全 github: https://github.com/Laughing-q/yolov5_annotations Y ...
- yolov5s 预训练模型_GitHub上YOLOv5开源代码的训练数据定义
GitHub上YOLOv5开源代码的训练数据定义 代码地址:https://github.com/ultralytics/YOLOv5 训练数据定义地址:https://github.com/ultr ...
- GitHub YOLOv5 开源代码项目系列讲解(五)------链接手机摄像头实现目标检测
本专栏将从安装到实例运用全方位系列讲解 GitHub YOLOv5 开源代码. 专栏地址:GitHub YOLOv5 开源代码项目系列讲解 目录 1 手机下载 "IP摄像头" AP ...
- YOLOV5测试代码test.py注释与解析
YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全 github: https://github.com/Laughing-q/yolov5_annotations Y ...
- yolov5核心代码: anchor匹配策略,compute_loss和build_targets理解
yolov5核心代码理解: anchor匹配策略-跨网格预测,compute_loss(p, targets, model)和build_targets(p, targets, model)理解 本文 ...
- 【yolov5检测代码简化】Yolov5 detect.py推理代码简化,输入图片,输出图片和结果
前言 最近的项目里有yolov5的嵌入,需求是只需要推理,模型文件是已有的,输入需要是图片(原yolov5是输入路径),输出结果的图片和标签.这样的话需要对原来的代码进行一些简化和变更. 路径 模型这 ...
- yolov5检测代码解析
yolov5检测部分代码解析 yolov5源代码:https://github.com/ultralytics/yolov5 设置目标检测的配置参数: parser = argparse.Argume ...
最新文章
- 直接插入排序与希尔排序
- 对Python课的看法
- 【Android View绘制之旅】Layout过程
- SQL Server 临时表
- [Go] golang的MPG调度模型
- 使用 ipmitool 实现远程管理Dell 系列服务器
- webservice 出现No service was found
- Android应用程序之间共享文字和图片(一)
- LDMS 8.8 简明使用手册之客户端配置及部署
- 预处理criteo数据集以预测广告的点击率
- Strom完整攻略(一)
- android音频驱动工程师,4.Android音频驱动(底层1)
- uva 10128 队伍
- 3、太阳能电池板参数解析
- 安卓日志:拍照、文件读取的问题
- html显示宇宙星星,回复评论
- java计算机毕业设计体育城场地预定系统前台源码+系统+数据库+lw文档+mybatis+运行部署
- Kubernetes 集群仓库 Harbor Helm3 部署
- linux终端命令大全
- java常用加密算法及MD5的使用