yoloV5模型训练教程

数据标注

数据标注我们要用labelimg

pip install labelimg

百度爬虫爬取图像

import os
import re
import sys
import urllib
import json
import socket
import urllib.request
import urllib.parse
import urllib.error
# 设置超时
from random import randint
import timetimeout = 5
socket.setdefaulttimeout(timeout)class Crawler:# 睡眠时长__time_sleep = 0.1__amount = 0__start_amount = 0__counter = 0headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:23.0) Gecko/20100101 Firefox/23.0'}__per_page = 30# 获取图片url内容等# t 下载图片时间间隔def __init__(self, t=0.1):self.time_sleep = t# 获取后缀名@staticmethoddef get_suffix(name):m = re.search(r'\.[^\.]*$', name)if m.group(0) and len(m.group(0)) <= 5:return m.group(0)else:return '.jpeg'# 保存图片def save_image(self, rsp_data, word):if not os.path.exists("./" + word):os.mkdir("./" + word)# 判断名字是否重复,获取图片长度self.__counter = len(os.listdir('./' + word)) + 1for image_info in rsp_data['data']:try:if 'replaceUrl' not in image_info or len(image_info['replaceUrl']) < 1:continueobj_url = image_info['replaceUrl'][0]['ObjUrl']thumb_url = image_info['thumbURL']url = 'https://image.baidu.com/search/down?tn=download&ipn=dwnl&word=download&ie=utf8&fr=result&url=%s&thumburl=%s' % (urllib.parse.quote(obj_url), urllib.parse.quote(thumb_url))time.sleep(self.time_sleep)suffix = self.get_suffix(obj_url)# 指定UA和referrer,减少403opener = urllib.request.build_opener()opener.addheaders = [('User-agent', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.116 Safari/537.36'),]urllib.request.install_opener(opener)# 保存图片filepath = './{}/PME_{}_A{}'.format(word, randint(1000000, 500000000), str(self.__counter) + str(suffix))for _ in range(5):urllib.request.urlretrieve(url, filepath)if os.path.getsize(filepath) >= 5:breakif os.path.getsize(filepath) < 5:print("下载到了空文件,跳过!")os.unlink(filepath)continueexcept urllib.error.HTTPError as urllib_err:print(urllib_err)continueexcept Exception as err:time.sleep(1)print(err)print("产生未知错误,放弃保存")continueelse:print("图+1,已有" + str(self.__counter) + "张图")self.__counter += 1return# 开始获取def get_images(self, word):search = urllib.parse.quote(word)# pn int 图片数pn = self.__start_amountwhile pn < self.__amount:url = 'https://image.baidu.com/search/acjson?tn=resultjson_com&ipn=rj&ct=201326592&is=&fp=result&queryWord=%s&cl=2&lm=-1&ie=utf-8&oe=utf-8&adpicid=&st=-1&z=&ic=&hd=&latest=&copyright=&word=%s&s=&se=&tab=&width=&height=&face=0&istype=2&qc=&nc=1&fr=&expermode=&force=&pn=%s&rn=%d&gsm=1e&1594447993172=' % (search, search, str(pn), self.__per_page)# 设置header防403try:time.sleep(self.time_sleep)req = urllib.request.Request(url=url, headers=self.headers)page = urllib.request.urlopen(req)rsp = page.read()except UnicodeDecodeError as e:print(e)print('-----UnicodeDecodeErrorurl:', url)except urllib.error.URLError as e:print(e)print("-----urlErrorurl:", url)except socket.timeout as e:print(e)print("-----socket timout:", url)else:# 解析jsontry:rsp_data = json.loads(rsp)self.save_image(rsp_data, word)# 读取下一页print("下载下一页")pn += 60except Exception as e:continuefinally:page.close()print("下载任务结束")returndef start(self, word, total_page=2, start_page=1, per_page=30):"""爬虫入口:param word: 抓取的关键词:param total_page: 需要抓取数据页数 总抓取图片数量为 页数 x per_page:param start_page:起始页码:param per_page: 每页数量:return:"""self.__per_page = per_pageself.__start_amount = (start_page - 1) * self.__per_pageself.__amount = total_page * self.__per_page + self.__start_amountself.get_images(word)if __name__ == '__main__':crawler = Crawler(0.05)  # 抓取延迟为 0.05crawler.start('玩手机')

标注完成后,每张图像会生成对应的xml标注文件

数据预处理

创建convert_data.py文件,内容如下:

# -*- coding: utf-8 -*-import xml.etree.ElementTree as ET
from tqdm import tqdm
import os
from os import getcwddef convert(size, box):dw = 1. / (size[0])dh = 1. / (size[1])x = (box[0] + box[1]) / 2.0 - 1y = (box[2] + box[3]) / 2.0 - 1w = box[1] - box[0]h = box[3] - box[2]x = x * dww = w * dwy = y * dhh = h * dhreturn x, y, w, hdef convert_annotation(image_id):# try:in_file = open('VOCData/images/{}.xml'.format(image_id), encoding='utf-8')out_file = open('VOCData/labels/{}.txt'.format(image_id),'w', encoding='utf-8')tree = ET.parse(in_file)root = tree.getroot()size = root.find('size')w = int(size.find('width').text)h = int(size.find('height').text)for obj in root.iter('object'):difficult = obj.find('difficult').textcls = obj.find('name').textif cls not in classes or int(difficult) == 1:continuecls_id = classes.index(cls)xmlbox = obj.find('bndbox')b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),float(xmlbox.find('ymax').text))b1, b2, b3, b4 = b# 标注越界修正if b2 > w:b2 = wif b4 > h:b4 = hb = (b1, b2, b3, b4)bb = convert((w, h), b)out_file.write(str(cls_id) + " " +" ".join([str(a) for a in bb]) + '\n')# except Exception as e:#     print(e, image_id)if __name__ == '__main__':sets = ['train', 'val']image_ids = [v.split('.')[0]for v in os.listdir('VOCData/images/') if v.endswith('.xml')]split_num = int(0.95 * len(image_ids))classes = ['face', 'normal', 'phone', 'write','smoke', 'eat', 'computer', 'sleep']if not os.path.exists('VOCData/labels/'):os.makedirs('VOCData/labels/')list_file = open('train.txt', 'w')for image_id in tqdm(image_ids[:split_num]):list_file.write('VOCData/images/{}.jpg\n'.format(image_id))convert_annotation(image_id)list_file.close()list_file = open('val.txt', 'w')for image_id in tqdm(image_ids[split_num:]):list_file.write('VOCData/images/{}.jpg\n'.format(image_id))convert_annotation(image_id)list_file.close()

运行结束后,可以看到VOCData/labels下生成了对应的txt文件

在data文件夹下创建myvoc.yaml文件

内容如下:

train: train.txt
val: val.txt# number of classes
nc: 8# class names
names: ["face", "normal", "phone", "write", "smoke", "eat", "computer", "sleep"]

下载预训练模型

我训练yolov5m这个模型,因此将它的预训练模型下载到weights文件夹下:

模型训练

修改models/yolov5m.yaml下的类别数:

python train.py --img 640 --batch 4 --epoch 300 --data ./data/myvoc.yaml --cfg ./models/yolov5m.yaml --weights weights/yolov5m.pt --workers 0

模型推理测试

训练结束后在 run/train/exp/weights 文件夹下会生成训练好的两个模型文件,我们将 last.pt 取出放到根目录下,然后运行:

python detect.py --source data/images --weights last.pt --conf 0.25

模型量化

这时我们注意到,训练好的 last.pt 有172MB,而官方给出的 yolov5m.pt 只有 40MB,这时候我们需要导出半精度模型重新保存,创建slim.py文件

python slim.py --in_weights last.pt --out_weights slim_model.pt --device 0

slim.py

import os
import torchimport torch
import torch.nn as nn
from tqdm import tqdmdef autopad(k, p=None):  # Pad to 'same'if p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-padreturn pclass Conv(nn.Module):# Standard convolution# ch_in, ch_out, kernel, stride, padding, groupsdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):super(Conv, self).__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p),groups=g, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = nn.Hardswish() if act else nn.Identity()def forward(self, x):return self.act(self.bn(self.conv(x)))def fuseforward(self, x):return self.act(self.conv(x))class Ensemble(nn.ModuleList):# Ensemble of modelsdef __init__(self):super(Ensemble, self).__init__()def forward(self, x, augment=False):y = []for module in self:y.append(module(x, augment)[0])# y = torch.stack(y).max(0)[0]  # max ensemble# y = torch.cat(y, 1)  # nms ensembley = torch.stack(y).mean(0)  # mean ensemblereturn y, None  # inference, train outputdef attempt_load(weights, map_location=None):model = Ensemble()for w in weights if isinstance(weights, list) else [weights]:# load FP32 modelmodel.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval())# Compatibility updatesfor m in tqdm(model.modules()):if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:m.inplace = True  # pytorch 1.7.0 compatibilityelif type(m) is Conv:m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibilityif len(model) == 1:return model[-1]  # return modelelse:print('Ensemble created with %s\n' % weights)for k in ['names', 'stride']:setattr(model, k, getattr(model[-1], k))return model  # return ensembledef select_device(device='', batch_size=None):# device = 'cpu' or '0' or '0,1,2,3'cpu_request = device.lower() == 'cpu'if device and not cpu_request:  # if device requested other than 'cpu'os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variableassert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device  # check availablitycuda = False if cpu_request else torch.cuda.is_available()if cuda:c = 1024 ** 2  # bytes to MBng = torch.cuda.device_count()if ng > 1 and batch_size:  # check that batch_size is compatible with device_countassert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)x = [torch.cuda.get_device_properties(i) for i in range(ng)]s = f'Using torch {torch.__version__} 'for i in range(0, ng):if i == 1:s = ' ' * len(s)return torch.device('cuda:0' if cuda else 'cpu')if __name__ == '__main__':import argparseparser = argparse.ArgumentParser()parser.add_argument('--in_weights', type=str,default='last.pt', help='initial weights path')parser.add_argument('--out_weights', type=str,default='slim_model.pt', help='output weights path')parser.add_argument('--device', type=str, default='0', help='device')opt = parser.parse_args()device = select_device(opt.device)model = attempt_load(opt.in_weights, map_location=device)model.to(device).eval()model.half()torch.save(model, opt.out_weights)print('done.')print('-[INFO] before: {} kb, after: {} kb'.format(os.path.getsize(opt.in_weights), os.path.getsize(opt.out_weights)))

yoloV5模型训练教程并进行量化相关推荐

  1. OpenVINO 2021r4.1 - 瞎搞YOLOV5 模型转换,INT8量化及C++推理实现

    今年年初的时候曾经玩了一阵openvino yolov5量化,后来找到了这个github的大神教程完美解决GitHub - Chen-MingChang/pytorch_YOLO_OpenVINO_d ...

  2. 基于YOLOv5模型压缩、模型量化、模型剪枝

    基于YOLOv5模型压缩.模型量化.模型剪枝 代码下载地址:下载地址 Requirements pip install -r requirements.txt Pruning for YOLOs Mo ...

  3. OpenVINO 2022.3实战六:NNCF 实现 YOLOv5 模型 INT8 量化

    OpenVINO 2022.3实战六:NNCF 实现 YOLOv5 模型 INT8 量化 1 将YOLOv5模型转换为OpenVINO IR 使用OpenVINO模型优化器将YOLOv5模型转换为Op ...

  4. Intel N100工控机使用核显加速推理yolov5模型

    Intel N100工控机使用核显加速推理yolov5模型 前言 安装openvino环境 核显加速运行yolov5 进一步加速 再进一步量化压榨 前言 今年3月初开始,某平台开始陆续上货基于英特尔A ...

  5. 基于Caffe格式部署YOLOV5模型

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 [导语]本文为大家介绍了一个caffe部署yolov5 模型的教程,并开源了全部代码.主要是教你如何搭 ...

  6. 使用YOLOv5模型进行目标检测!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:陈信达,华北电力大学,Datawhale成员 目标检测是计算机视觉 ...

  7. yolov5模型训练

    本文将介绍yolov5从环境搭建到模型训练的整个过程.最后训练识别哆啦A梦的模型. 1.anconda环境搭建 2.yolov5下载 3.素材整理 4.模型训练 5.效果预测 - Anconda环境搭 ...

  8. (四)训练用于口罩检测的YOLOv5模型

    目录 介绍 准备训练和验证数据 在Colab Notebook上训练YOLOv5模型 在Google Colab上测试模型 下一步 在这里,我们将训练和测试用于口罩检测的YOLOv5模型. 介绍 在本 ...

  9. 目标检测 YOLOv5 - 模型的样子

    目标检测 YOLOv5 - 模型的样子 flyfish 文章目录 目标检测 YOLOv5 - 模型的样子 开始加载模型文件 模型的层 模型的属性 模块的名称以及模块本身 模型的权重 模型权重的名字和权 ...

最新文章

  1. 用python制作信贷审批监测表
  2. MyGeneration【ui-原】
  3. python time模块详解
  4. boost::units::unscale相关的测试程序
  5. 多重选定怎么撤销_多重网络问题怎么解决?如何取消多重网络?
  6. QTCreator2.8.0+Qt Open source 4.8.5环境配置(Win7x64)
  7. 信息学奥赛一本通 1113:不与最大数相同的数字之和 | OpenJudge NOI 1.9 07
  8. 毕业3年,我换了4份工作:好工作,是这样“熬”出来的
  9. centos7 php多版本切换_CentOS7服务搭建----搭建私有云盘01
  10. Batch, Iteration,Epoch概念理解
  11. 多元线性回归--machine learning
  12. jdk6版本下载地址
  13. lintcode java_Lintcode-java版本
  14. 2022精选最新金融银行面试真题——附带答案
  15. 【Unreal】关于实时编码(live coding)退出UE编辑器C++ Class消失的问题
  16. zcmu-1931 wjw的剪纸
  17. 软件工程与软件开发模型、软件开发方法
  18. Diagram Designer
  19. 使用日期类和计时器模拟商品促销
  20. PHP安装流程(带安装包)

热门文章

  1. 复合类型与with关键字
  2. 基于ZYNQ 的多轴运动控制平台关键技术研发-总体架构设计(一)
  3. VMware虚拟机中安装win10步骤及所遇到的问题
  4. 玫瑰花的折法:如何用钱折玫瑰(组图全说明)
  5. 腾讯音乐季报解读:在线音乐付费用户超8000万 付费率13.3%
  6. 使用 Google Fonts 为网页添加美观字体
  7. VMware Ubuntu 22.04 配置静态IP
  8. shell编程计算1-1000中所有3或5的倍数之和
  9. 1032:Parliament
  10. 2018粤港澳台物流高峰论坛在深圳成功举办