训练样本准备

中国交通标志检测数据集(CCTSDB),百度网盘:https://pan.baidu.com/s/1-se8J8fQ0FgmUalu8873CQ,
提取码:9fov

图片制作

1、按图片的最大边长生成一个正方形的白板
2、将原图粘贴上去
3、将2的图片resize成416x416大小

标签制作

因为YOLOV3网络分别进行了32倍,16倍和8倍下采样,输出了13x13,26x26,52x52三个特征图,因此标签也要分别作三份
16倍与8倍同理类似
代码实现

# cfg配置
IMG_H=416
IMG_W=416
CLS_NUM=6
ANCHORS_GROUP={13:[[116,90],[156,198],[373,326]],26:[[30,61],[62,45],[59,119]],52:[[10,13],[16,30],[33,23]]
}
ANCHORS_GROUP_AREA={13:[x*y for x,y in ANCHORS_GROUP[13]],26:[x*y for x,y in ANCHORS_GROUP[26]],52:[x*y for x,y in ANCHORS_GROUP[52]]
}
from torch.utils.data import Dataset
import torchvision
import numpy as np
import cfg
import os
from PIL import Image
import math
label_path=r"label_position.txt"
img_path=r"YOLO_V3"
def one_hot(cls_num,v):b=np.zeros(cls_num)b[v]=1return b
class mydatasset(Dataset):def __init__(self):self.m=torchvision.transforms.ToTensor()with open(label_path) as f:self.dataset=f.readlines()def __len__(self):return len(self.dataset)def __getitem__(self, index):labels={}line=self.dataset[index] # 按行读取txt文件里的内容strs=line.split()  # 将每行的数字切割成一个列表img_data_=Image.open(os.path.join(img_path,strs[0]))img_data=self.m(img_data_)boxes_=np.array([float(x) for x in strs[1:]])boxes=np.split(boxes_,len(boxes_)//5) # 将每行的没五个元素封装成一个元素for feature_size, anchors in cfg.ANCHORS_GROUP.items():labels[feature_size] = np.zeros(shape=(feature_size, feature_size, 3, 5 + cfg.CLS_NUM))for box in boxes:cls, cx, cy, w, h = boxcx_offset, cx_index = math.modf(cx * feature_size / cfg.IMG_W)   # modf返回小数和整数部分cy_offset, cy_index = math.modf(cy * feature_size / cfg.IMG_W)for i, anchor in enumerate(anchors):anchor_area = cfg.ANCHORS_GROUP_AREA[feature_size][i]  # 得到建议框的面积p_w, p_h = w / anchor[0], h / anchor[1]p_area = w * hiou = min(p_area, anchor_area) / max(p_area, anchor_area)labels[feature_size][int(cy_index), int(cx_index), i] = np.array([iou, cx_offset, cy_offset, np.log(p_w), np.log(p_h), *one_hot(cfg.CLS_NUM, int(cls))])return labels[13], labels[26], labels[52], img_data

网络模型训练

Loss function的设计:均方差损失函数
loss = loss_obj + loss_noobj
(也就是通过重叠度IOU判断出哪些格子是有目标中心点的,哪些是没有的,然后分开做损失,并且没有目标的格子只对其重叠度做损失)
sum_loss=loss_13+loss_26+loss_52
(因为输出了三个特征图,因此对应有三个损失)
代码实现:

import dataset
import DarkNet_53
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
def loss_fn(output, target):output = output.permute(0, 2, 3, 1)output = output.reshape(output.size(0), output.size(1), output.size(2), 3, -1)mask_obj = target[..., 0]>0mask_noobj = target[..., 0]==0target=target.cuda().float()loss_obj = torch.mean((output[mask_obj] - target[mask_obj]) ** 2)loss_noobj = torch.mean((output[mask_noobj][:,0] - target[mask_noobj][:,0]) ** 2)loss =  loss_obj + loss_noobjreturn loss
if __name__ == '__main__':device = torch.device("cuda" if torch.cuda.is_available() else "cpu")myDataset = dataset.mydatasset()train_loader =DataLoader(myDataset, batch_size=1, shuffle=True)save_para_path = r"net_3.pt"net = DarkNet_53.MainNet().to(device)net.train()opt = torch.optim.Adam(net.parameters(),lr=0.001)count=0# net.load_state_dict(torch.load(save_para_path))while True:for i,(target_13, target_26, target_52, img_data) in enumerate(train_loader):target_13=target_13.to(device)target_26 = target_26.to(device)target_52 = target_52.to(device)img_data=img_data.to(device)torch.cuda.empty_cache()output_13, output_26, output_52 = net(img_data)torch.cuda.empty_cache()loss_13 = loss_fn(output_13, target_13)loss_26 = loss_fn(output_26, target_26)loss_52 = loss_fn(output_52, target_52)loss = loss_13 + loss_26 + loss_52opt.zero_grad()loss.backward()opt.step()print(loss.item())if i % 1000 == 0 and i > 0:torch.save(net.state_dict(), save_para_path)count = count + 1print("第{0}轮训练完毕".format(count))torch.save(net.state_dict(),save_para_path)print("保存成功")

网络模型的测试

测试注意要点:
1、首先要将测试图片的长宽变为32的倍数
2、需要设定一个阈值
3、选择重叠度大于阈值的返回框和分类,要注意这里的返回框包括三个不同尺度,三个不同形状的九个建议框各自的返回框
4、最后对每一个分类的返回框做非极大值抑制
(如果又不知道非极大值抑制抑制的可以参考:https://mp.csdn.net/mdeditor/103183879
代码实现:

import DarkNet_53
import torch.nn as nn
import torch
import cfg
import PIL.Image as pimg
from PIL import ImageDraw
from torchvision import transforms
from utils import nms
import numpy as np
import os
path=r"net_3.pt"
def resize(img):w,h=img.sizeimage=img.resize(((w//32)*w,(h//32)*h))return image
class Detector(nn.Module):def __init__(self):super(Detector, self).__init__()self.m=transforms.ToTensor()self.net = DarkNet_53.MainNet().cuda()self.net.load_state_dict(torch.load(path)self.net.eval()def forward(self, input, thresh, anchors,):input=self.m(image).cuda()input.unsqueeze_(0)output_13, output_26, output_52 = self.net(input)output_13=output_13.cpu().detach()output_26 = output_26.cpu().detach()output_52 = output_52.cpu().detach()idxs_13, vecs_13 = self._filter(output_13, thresh)boxes_13,cls_13 = self._parse(idxs_13, vecs_13, 32, anchors[26])idxs_26, vecs_26 = self._filter(output_26, thresh)boxes_26 ,cls_26= self._parse(idxs_26, vecs_26, 16, anchors[52])idxs_52, vecs_52 = self._filter(output_52, thresh)boxes_52,cls_52 = self._parse(idxs_52, vecs_52, 8, anchors[104])boxes=torch.cat([boxes_13, boxes_26, boxes_52], dim=0)cls=torch.cat([cls_13,cls_26,cls_52],dim=0)return boxes,clsdef _filter(self, output, thresh):output = output.permute(0, 2, 3, 1)output = output.reshape(output.size(0), output.size(1), output.size(2), 3, -1)mask = output[..., 0] > threshidxs = mask.nonzero()vecs = output[mask]return idxs, vecsdef _parse(self, idxs, vecs, t, anchors):if vecs.size()[0]==0:return torch.tensor([]),torch.tensor([])else:anchors = torch.tensor(anchors).float()n = idxs[:, 0]  # 所属的图片a = idxs[:, 3]  # 建议框cond = vecs[:, 0].float()cls = vecs[:, 5:15]cls_ = torch.argmax(cls, dim=1).float()cy = (idxs[:, 1].float() + vecs[:, 2]) * t  # 原图的中心点ycx = (idxs[:, 2].float() + vecs[:, 1]) * t  # 原图的中心点xw = anchors[a, 0] * torch.exp(vecs[:, 3])h = anchors[a, 1] * torch.exp(vecs[:, 4])x1 = (cx - w / 2).float()y1 = (cy - h / 2).float()x2 = (cx + w / 2).float()y2 = (cy + h / 2).float()return torch.stack([cond, x1, y1, x2, y2, cls_, n.float()], dim=1), cls_
if __name__ == '__main__':with torch.no_grad():j=0for name in os.listdir("test"):image_path = os.path.join("test",name)image_=pimg.open(image_path)image=resize(image_)detector = Detector()imgdraw = ImageDraw.Draw(image_)color = ["#FF7F24", "#FF0000", "#00FF00"]boxes,cls = detector(image,0.5, cfg.ANCHORS_GROUP)cls=np.array(cls)for i in list(set(list(cls))):boxes_=[]for box in boxes:if box[5] == i:box=box.detach().numpy()box_=np.array(box)boxes_.append(box)frame=nms(torch.tensor(boxes_),0.1)for x in frame:x1,y1,x2,y2=x[1], x[2], x[3], x[4]imgdraw.rectangle((x1, y1, x2, y2), fill=None, outline=color[int(i)],width=2)j=j+1# image_.show()image_.save("测试保存/"+str(j)+".jpg")

效果展示



召回率:84.71% 精确度:82.40%
注:由于数据集的效果不是很好,再加上YOLO对小目标的侦测不强的缺点,导致召回率和精确度不是很高

基于YOLOV3实现交通标志识别(Pytorch实现)相关推荐

  1. matlab交通标志神经网络识别,基于神经网络的交通标志识别方法

    Municipal & Traffic Construction SCIENCE & TECHNOLOGY FOR DEVELOPMENT 149 基于神经网络的交通标志识别方法 赵丹 ...

  2. Python基于YOLOv5的交通标志识别系统[源码]

    1.图片演示: 2.视频演示: [项目分享]Python基于YOLOv5的交通标志识别系统[源码&技术文档&部署视频&数据集]_哔哩哔哩_bilibili 3.标注好的数据集: ...

  3. Python基于YOLOv5的交通标志识别系统[源码&技术文档&部署视频&数据集]

    1.图片演示: 2.视频演示: 3.标注好的数据集: 4.YOLO网络的构建: 网络结构是首先用Focus将计算图长宽变为原先1/4, channel 数量乘4.再用bottlenectCSP 提取特 ...

  4. 基于OpenCV的交通标志识别

    前几天看新闻得知微软为美国执法机关研发了一套基于AI识别,追踪并提取编辑视频中出现的人脸的算法,只要输入一段带人脸信息的视频文件,运行后即可输出一段所有人脸已被提取并且按要求编辑好的视频文件.当然该算 ...

  5. 毕业设计-基于机器视觉的交通标志识别系统

    目录 前言 课题背景和意义 实现技术思路 一.交通标志识别系统 二.交通标志识别整体方案 三.实验分析 四.总结 实现效果图样例 最后 前言

  6. 深度学习100例 | 第3天:交通标志识别 - PyTorch实现

    文章目录 一.导入数据 1. 获取类别名 2. 数据可视化 3. 加载数据文件 4. 划分数据 二.自建模型 三.模型训练 1. 优化器与损失函数 2. 模型的训练 四.结果分析 大家好,我是K同学啊 ...

  7. 【YOLOv5实战2】基于YOLOv5的交通标志识别系统-自定义数据集

    实战博客指引: 实战环境搭建 自定义数据集 模型训练 模型测试与评估 YOLOv5整合PyQt5 项目源代码可联系博主获取. 一.数据准备 1.1 从官网下载YOLOv5 打开官网YOLOv5,使用g ...

  8. 【YOLOv5实战3】基于YOLOv5的交通标志识别系统-模型训练

    实战博客指引: 实战环境搭建 自定义数据集 模型训练 模型测试与评估 YOLOv5整合PyQt5 项目源代码可联系博主获取. 一.参数说明 再经历前两个步骤后,开始进行模型训练与测试.首先进行模型训练 ...

  9. 基于Keare的交通标志识别

    前两天体验了一下腾讯云的在线实验,内容如题,在这里记录一下一些必要知识( 水 实验步骤 这个实验分为训练过程和测试过程两部分. 训练过程流程及实现: 解析脚本输入参数:使用argparse解析,由ar ...

最新文章

  1. 零基础学编程学java还是python-学编程选Python还是Java?就业发展哪个好?
  2. JAVA中JPasswordField实现密码的确认
  3. 邮局--dp经典问题
  4. IE6 透明遮挡falsh解决方案
  5. AngularJS学习笔记(1) - AngularJS入门
  6. 中国知名个人站长排行TOP91
  7. Gamit 数据处理,相关的文件配置
  8. qte5编译dub.json
  9. BIOS实战之Super IO-Smart Fan
  10. c语言程序设计21点扑克牌,C语言程序设计21点扑克牌游戏.doc
  11. 解释外显子,内含子,CDS、cDNA、EST、mRNA、ORF间的区别
  12. 阿里云ECS云服务器实例重置-更换操作系统
  13. STM32下载程序至SRAM——基于正点原子精英STM32F103ZET6开发板
  14. WD蓝盘绿盘黑盘红盘的区别
  15. 区块链单组群多节点部署
  16. sip是什么?Mac电脑如何关闭sip?关闭系统完整性保护SIP的方法教程
  17. 如果你是CEO,你打算给自己开多少工资?
  18. 自动阅读项目又出新情况?一天秒封47个账号
  19. 光学心率传感器工作原理
  20. 合肥工业大学2022大数据技术实验一

热门文章

  1. LeetCode(数据库)- Users That Actively Request Confirmation Messages
  2. 使用 Java 故意消耗 Cpu 和内存的代码
  3. 毕业三年...(转载)
  4. 【转】我的助理辞职了!
  5. 玄虚子:巧记易经64卦,分宫卦象次序表。
  6. 八块腹肌:硅谷程序员的新标配
  7. 聚乙烯亚胺(PEI)超细纤维负载Pd纳米粒子,GA-PEG-PLA 甘草次酸-聚乙二醇-聚乳酸定制合成
  8. emqtt 启动报错 Erlang closed the connection 查看状态报错 Node 'emq@192.168.*.*' not responding to pings.
  9. 云计算的技术架构与实现分析
  10. DevicePolicyManager(三)设备管理器使用案例——实现一键锁屏