点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

开发环境

软件版本信息:

Windows10 64位
Tensorflow1.15
Tensorflow object detection API 1.x
Python3.6.5
VS2015 VC++
CUDA10.0

硬件:

CPUi7
GPU 1050ti

如何安装tensorflow object detection API框架,看这里:

Tensorflow Object Detection API 终于支持tensorflow1.x与tensorflow2.x了

数据集处理与生成

首先需要下载数据集,下载地址为:

https://pan.baidu.com/s/1UbFkGm4EppdAU660Vu7SdQ

总计7581张图像,基于Pascal VOC2012完成标注。分为两个类别,分别是安全帽与人(hat与person),json格式如下:

item {id: 1name: 'hat'
}item {id: 2name: 'person'
}

数据集下载之后,并不能被tensorflow object detection API框架中的脚本转换为tfrecord,主要是有几个XML跟JPEG图像格式错误,本人经过一番磨难之后把它们全部修正了。修正之后的数据运行下面两个脚本即可生成训练集与验证集的tfrecord数据,命令行如下:

这里需要注意的是create_pascal_tf_record.py 脚本的165行把

'aeroplane_' + FLAGS.set + '.txt')

修改为:

FLAGS.set + '.txt')

原因是这里的数据集没有做分类train/val。所以需要修改一下,修改完成之后保存。运行上述的命令行,就可以正确生成tfrecord,否则会遇到错误。

模型训练

基于faster_rcnn_inception_v2_coco对象检测模型实现迁移学习,首先需要配置迁移学习的config文件,对应的配置文件可以从:

research\object_detection\samples\configs

中发现,发现文件:

faster_rcnn_inception_v2_coco.config

之后,修改配置文件的中相关部分,关于如何修改,修改什么,可以看这里:

修完完成之后,在D盘下新建好几个目录之后,执行下面的命令行参数:

就会开始训练,总计训练40000 step。训练过程中可以通过tensorboard查看训练结果:

模型导出

完成了40000 step训练之后,就可以看到对应的检查点文件,借助tensorflow object detection API框架提供的模型导出脚本,可以把检查点文件导出为冻结图格式的PB文件。相关的命令行参数如下:

得到pb文件之后,使用OpenCV4.x中的tf_text_graph_faster_rcnn.py脚本,转换生成graph.pbtxt配置文件。最终得到:

- frozen_inference_graph.pb
- frozen_inference_graph.pbtxt

如何导出PB模型到OpenCV DNN支持看这里:

干货 | tensorflow模型导出与OpenCV DNN中使用

使用OpenCV DNN调用模型

在OpenCV DNN中直接调用训练出来的模型完成自定义对象检测,这里需要特别说明一下的,因为在训练阶段我们选择了模型支持600~1024保持比率的图像输入。所以在推理预测阶段,我们可以直接使用输入图像的真实大小,模型的输出格式依然是1x1xNx7,按照格式解析即可得到预测框与对应的类别。最终的代码实现如下:

1import cv2 as cv23labels = ['hat', 'person']4model = "D:/safehat_train/models/train/frozen_inference_graph.pb"5config = "D:/safehat_train/models/train/frozen_inference_graph.pbtxt"67# 读取测试图像8image = cv.imread("D:/123.jpg")9h, w = image.shape[:2]
10cv.imshow("input", image)
11
12# 加载模型,执行推理
13net = cv.dnn.readNetFromTensorflow(model, config)
14blob = cv.dnn.blobFromImage(cv.resize(image, (w, h)), swapRB=True, crop=False)
15net.setInput(blob)
16detectOut = net.forward()
17
18# 解析输出
19classIds = []
20confidences = []
21boxes = []
22for detection in detectOut[0,0,:,:]:
23    score = detection[2]
24    if score > 0.4:
25        left = detection[3]*w
26        top = detection[4]*h
27        right = detection[5]*w
28        bottom = detection[6]*h
29        classId = int(detection[1]) + 1
30        classIds.append(classId)
31        boxes.append([int(left), int(top), int(right), int(bottom)])
32        confidences.append(float(score))
33
34# 非最大抑制
35nms_indices = cv.dnn.NMSBoxes(boxes, confidences, 0.4, 0.4)
36for i in range(len(nms_indices)):
37    index = nms_indices[i][0]
38    box = boxes[index]
39    cid = classIds[index]
40    if cid == 1:
41        cv.rectangle(image, (box[0], box[1]), (box[2], box[3]), (140, 199, 0), 4, 8, 0)
42    else:
43        cv.rectangle(image, (box[0], box[1]), (box[2], box[3]), (255, 0, 255), 4, 8, 0)
44    cv.putText(image, labels[cid-1], (box[0], box[1]), cv.FONT_HERSHEY_SIMPLEX, 0.75, (255, 0, 0), 2)
45
46# 显示输出
47cv.imshow("safetyhat-detection-demo", image)
48cv.imwrite("D:/result123.png", image)
49cv.waitKey(0)
50cv.destroyAllWindows()

一些测试图像的运行结果如下:

可以看到第二张途中有误识别情况发生!可见模型还可以继续训练!

避坑指南:

1. 下载的公开数据集,记得用opencv重新读取一遍,然后resave为jpg格式,这个会避免在生成tfrecord时候的图像格式数据错误。

ValueError: Image format not JPEG

2. 公开数据集中xml文件的filename有跟真实图像文件名称不一致的情况,要程序处理一下。不然会遇到

Windows fatal exception: access violation error 

3. 使用非最大抑制之后,

SystemError: <built-in function NMSBoxes> returned NULL without setting an error, 解决:boxes 必须是int类型,confidences必须是浮点数类型

参考资料:

使用OpenCV 4.1.2的DNN模块部署深度学习模型

https://github.com/njvisionpower/Safety-Helmet-Wearing-Dataset

https://github.com/opencv/opencv/wiki/Deep-Learning-in-OpenCV

https://github.com/tensorflow/models/tree/master/research/object_detection

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

Tensorflow + OpenCV4 安全帽检测模型训练与推理相关推荐

  1. YOLOV3目标检测模型训练实例

    YOLOV3目标检测 从零开始学习使用keras-yolov3进行图片的目标检测,比较详细地记录了准备以及训练过程,提供一个信号灯的目标检测模型训练实例,并提供相关代码与训练集. DEMO测试 YOL ...

  2. 车牌检测模型训练(含源码和数据集)

    车牌检测模型训练(含源码和数据集) 本教程利用NVIDIA TAO进行车牌检测模型的训练: 模型框架:SSD 数据集: CRPD, 连接:https://github.com/yxgong0/CRPD ...

  3. 分析《 yolov7人脸+手机检测模型训练》

    现在是大三的下学期,已经到了四月份,最近在搞一个华为云的"揭榜挂帅"挑战杯的项目. 项目的初赛阶段是用model arts(华为云的一个ai开发平台)去做一个云端的疲劳/分神检测的 ...

  4. float32精度_混合精度对模型训练和推理的影响

    单精度/双精度/半精度/混合精度 计算机使用0/1来标识信息,每个0或每个1代表一个bit.信息一般会以下面的三种形式表示: 1 字符串 字符串的最小单元是char,每个char占8个bit,也就是1 ...

  5. tensorflow 1.14 ssd_mobilenet_v1 模型训练

    tensorflow 1.14 ssd_mobilenet_v1 模型训练 1 工具版本 序号 软件名称 版本 安装命令 1 操作系统 ubuntu 18.04 2 python 3.6.9 3 te ...

  6. 详谈大模型训练和推理优化技术

    详谈大模型训练和推理优化技术 作者:王嘉宁,转载请注明出处:https://wjn1996.blog.csdn.net/article/details/130764843 ChatGPT于2022年1 ...

  7. PTMs:QLoRA技巧之源码解读(qlora.py文件)—解析命令与加载参数→数据预处理→模型训练+评估+推理

    PTMs:QLoRA技巧之源码解读(qlora.py文件)-解析命令与加载参数→数据预处理→模型训练+评估+推理 目录 QLoRA技巧之源码解读(qlora.py文件)-解析命令与加载参数→数据预处理 ...

  8. Metis异常检测模型训练源码深入刨析

    Metis异常检测模型训练源码深入刨析 模型训练 数据集说明 process_train 方法(detect_service.py) __generate_model方法(detect_service ...

  9. 基于Yolov5的烟火检测——模型训练与C++实现部署

    前言 1.系统环境是win10,显卡RTX3080;cuda10.2,cudnn7.1;OpenCV4.5;yolov5用的是5s的模型,2020年8月13日的发布v3.0这个版本; ncnn版本是2 ...

最新文章

  1. 点云数据的类型主要分为_点云学习在自动驾驶中的研究概述
  2. [C# 基础知识梳理系列]专题二:委托的本质论
  3. 【编程】位(bit)、字节(byte)和字(word)的区别
  4. luogu1355 神秘大三角
  5. 飞鸽传书2011怎么用之启动常见问题(二)
  6. 实现远程连接MySQL
  7. VCSA 6.5 HA配置 之四 开启vCenter HA
  8. flutter实战1:完成一个有侧边栏的主界面
  9. OpenCV对矩形填充透明颜色
  10. 【原创】基于SSM框架的小说网站开发与设计
  11. 《WiscKey: Separating Keys from Values in SSD-conscious Storage》阅读笔记
  12. 计算机应该玩什么游戏,电脑玩游戏主要靠什么配置
  13. !和!!的区别和用法
  14. thinkpad T430改装WiFi6网卡ax200,加装4G模块
  15. 如何查美国公司的年报
  16. cs224w(图机器学习)2021冬季课程学习笔记12 Knowledge Graph Embeddings
  17. java判断车牌号正确性
  18. 苹果Airplay2学习
  19. 即将2023年了,我好想念那些2022年离职的兄弟
  20. linux hd4000显卡驱动,如何在Linux上为Intel HD 4000获取OpenGL 4

热门文章

  1. “Jupyter的杀手”:Netflix发布新开发工具Polynote
  2. NLP学习思维导图,非常的全面和清晰
  3. 领跑交互新时代 蓦然认知助力传统产业智能化升级
  4. 一分钟AI | 腾讯市值超5000亿美元创亚洲最高!CV杀红脸了:AI国家队云从科技完成25亿B轮融资
  5. Redlock——Redis集群分布式锁
  6. 一行代码:你的纯文本秒变Markdown
  7. Redis的一些“锁”事
  8. 强化学习,路在何方?
  9. 这三所985,博士生毕业,可能不再要求发表论文!
  10. 从清华退学,他赴美读博又两次退学!离开谷歌后,如今他怎样了?