如何用tensorflow使用自定义数据来训练,做物体检测
本人通过将近两个月的研究,通过收集众多资料,从一个小白来实现利用tensorflow实现物体检测的过程记录如下:
本人此次项目的前端代码:https://github.com/MRchenao/tfjs-customer-objecte-detection下载本项目的前端代码后运行需要根据下文的说明,修改对应的源码方可正常使用。
此次项目的后端代码,这里是简化版的后端代码,由于tensorflow源代码太大无法上传,请大家仅做一些参考,比如某些文件或目录的对比参考加深对博文的理解:https://github.com/MRchenao/python-customer-objecte-detection
首先先说一下应用场景,做物体检测的模型有很多,我最初尝试过yolov3这种模型大约200多M,网上也有很多python的实现,其中最好的是用imageAI实现的。有符合这方面需求的,可以用这个开源库。https://github.com/OlafenwaMoses/ImageAI/blob/master/imageai/Detection/Custom/CUSTOMDETECTION.md
但是我们的应用场景是在浏览器中做物体检测,所以模型的大小必须要很小,所以这个就有限制我们没办法使用上面的模型,摸索了好久,起初一个模型一个模型的尝试,最终在官方这边发现了模型列表,这里列出了所有的模型。https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
针对模型必须很小,我尝试了ssd_mobilenet_v1_coco(22M),ssd_mobilenet_v2_coco(66M),ssd_mobilenet_v3_small_coco(4M),ssdlite_mobilenet_v2_coco(11M)这些模型后。最终发现ssd_mobilenet_v3_small_coco模型最小只有(4M不到),适合浏览器,但是前端tfjs没有实现这个模型的方法报:Unsupported Ops: NonMaxSuppresionV4 when converting model的错误。
后面退而求次,尝试了ssdlite_mobilenet_v2_coco和ssd_mobilenet_v1_coco两个模型,前端又报:Uncaught (in promise) Error: Tensor must have a shape comprised of positive integers but got shape [100,].输入的形状问题的报错。最后发现是我们的模型与前端js的检测模型在转换或者是使用过程中数据对应方面有些出入,到时候我们会修改前端js的源码以适应这种改变。
废话说完,现在进入正题:先说下大致的步骤流程:
一、安装配置训练环境
二、准备数据
三、开始训练
四、训练结果转成能用的pb模型
五、测试pb模型的检测结果
六、将转的pb模型转成前端能用的json模型
一、安装配置训练环境
此安装配置过程可参考官方的方法:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
只有成功执行上面方法的最后一个测试一行无报错,才仅仅只是说明你的环境安装准备成功。
二、准备数据
1、数据准备您需要使用:labelimg将您的图片框出来,最后生成与图片对应的xml文件。目录如下
其中train是训练集,validation是验证集。不想准备数据集想做测试的同学也可直接联系我找我要。说实话准备数据集就是一个体力活。
2、有了上面的数据之后,我们需要将数据转换成tensorflow能用的record数据,这里有两步,一个是将上面的xml转换成csv文件,另一个是就是生成所需的record文件了。下面先给出xml转成csv的代码xml_2_csv.py
import os
import glob
import pandas as pd
import xml.etree.ElementTree as ETdef xml_to_csv(path):xml_list = []for xml_file in glob.glob(path + '/*.xml'):tree = ET.parse(xml_file)root = tree.getroot()for member in root.findall('object'):value = (root.find('filename').text,int(root.find('size')[0].text),int(root.find('size')[1].text),member[0].text,int(member[4][0].text),int(member[4][1].text),int(member[4][2].text),int(member[4][3].text))xml_list.append(value)column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']xml_df = pd.DataFrame(xml_list, columns=column_name)return xml_dfdef main():image_path = os.path.join(os.getcwd(), 'data\\validation\\annotations')xml_df = xml_to_csv(image_path)xml_df.to_csv('data\\test_labels.csv', index=None)print('Successfully converted xml to csv.'+ image_path)main()
这里你只需修改image_path 为你自己目录的path即可使用。生成csv后接下来就是生成record文件的python代码:generate_tfrecord.py
"""
Usage:# From tensorflow/models/# Create train data:python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record# Create test data:python generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.record
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_importimport os
import io
import pandas as pd
import tensorflow as tffrom PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDictflags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('image_dir', '', 'Path to images')
FLAGS = flags.FLAGS# TO-DO replace this with label map
def class_text_to_int(row_label):if row_label == 'hololens':return 1else:Nonedef split(df, group):data = namedtuple('data', ['filename', 'object'])gb = df.groupby(group)return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path):with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:encoded_jpg = fid.read()encoded_jpg_io = io.BytesIO(encoded_jpg)image = Image.open(encoded_jpg_io)width, height = image.sizefilename = group.filename.encode('utf8')image_format = b'jpg'xmins = []xmaxs = []ymins = []ymaxs = []classes_text = []classes = []for index, row in group.object.iterrows():xmins.append(row['xmin'] / width)xmaxs.append(row['xmax'] / width)ymins.append(row['ymin'] / height)ymaxs.append(row['ymax'] / height)classes_text.append(row['class'].encode('utf8'))classes.append(class_text_to_int(row['class']))tf_example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/filename': dataset_util.bytes_feature(filename),'image/source_id': dataset_util.bytes_feature(filename),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature(image_format),'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),}))return tf_exampledef main(_):writer = tf.python_io.TFRecordWriter(FLAGS.output_path)path = os.path.join(FLAGS.image_dir)examples = pd.read_csv(FLAGS.csv_input)grouped = split(examples, 'filename')for group in grouped:tf_example = create_tf_example(group, path)writer.write(tf_example.SerializeToString())writer.close()output_path = os.path.join(os.getcwd(), FLAGS.output_path)print('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__':tf.app.run()
只需在命令行执行
# Create train data:python generate_tfrecord.py --image_dir=data/train/images --csv_input=data/train_labels.csv --output_path=train.record# Create test data:python generate_tfrecord.py --image_dir=data/validation/images --csv_input=data/test_labels.csv --output_path=test.record
这样就生成了tensorflow能用的数据集了。结果如下
三、开始训练
1、有了上面的数据之后,接下来我们就需要下载我们上面说的你需要的一种模型,下载完解压之后是类似这样 的。
这里记住要将里面的checkpoint文件删除,这里我遇到过一个坑,这个checkpiont文件其实就是预训练模型的检查点,之前我没有删除,人家都训练完了,结果我启动训练的时候直接就是退出了,也不报错,导致我莫名奇妙的找不出原因为什么训练不了,后来才发现是这个东西的问题。
2、要开始训练,下载模型完后我们还需要修改模型中的配置文件,打开pipeline.config文件,修改其中的
这里的类别根据你之前的数据集框出的类别来定有几类就写几
这个batch_size根据你机器的内存来定,如果你内存小的话就设置小一些,一般8g内存只能设置个16左右就可以了,否则会内存溢出。
这上面图片红框中配置的是之前我们生成的测试record和训练record,配置成你自己对应的路径即可。其中的一个文件label_map.pbtxt的内容如下:这个如果有多个类别就写多个item,我这里就一个就这样写了
item {id: 1name: 'hololens'
}
3、做完上面的2部分的工作之后接下来就可以运行命令训练了。
python /data/FashionDetector/tf-models/research/object_detection/model_main.py --pipeline_config_path="E:\\pyproject\\fashiondetector\\models\\ssd_mobilenet_v1_coco_11_06_2017\\pipeline.config" --model_dir="E:\\pyproject\\fashiondetector\\models\\ssd_mobilenet_v1_coco_11_06_2017" --num_train_steps=50000 --sample_1_of_n_eval_examples=1 --alsologtostderr --verbosity=1 --stderrthreshold=debug
这个model_main.py文件就是我们官网下载的tensorflow里面对象检测的那个目录下的文件。在第一步骤安装环境的时候大家应该都已经知道了。还是给个图给大家看吧
模型训练文件图
整体目录结构
执行上述命令后正常情况下会出现一些,信息,只要不报错不退出就没问题,根据机器性能,将会有不同的等待时间,结果输出可能会很慢,一般训练的话,训练个一天一夜就差不多可以检测了。训练图类似:
会有不断的输出。而训练的结果也会不断的在你之前配置的 --model_dir中出现一个时间段的结果,如图:
四、训练结果转成能用的pb模型
等训练产生结果后,你就可以开始你自己想要的转换了,训练过程可以不停止,可以继续让其保持训练,我们来到模型训练结果的目录也就是上面的–model_dir下面,执行命令
python3 tf-models/research/object_detection/export_inference_graph.py --input_type=image_tensor --pipeline_config_path=/data/FashionDetector/models/ssd_mobilenet_v1_coco_11_06_2017/pipeline.config --trained_checkpoint_prefix=/data/FashionDetector/models/ssd_mobilenet_v1_coco_11_06_2017/model.ckpt-7904 --output_directory=/data/FashionDetector/ssd_v1_coco
这里的export_inference_graph.py也是官网的文件跟model_main.py一样都在相同目录下。运行完上面的命令后,就会在你–output_directory配置的对应的目录下生成你转换的pb模型,如图:
到这里,你的模型就可以使用了,接下来我们测试我们的训练效果。
五、测试pb模型的检测结果
1、测试pb模型的检测结果同样使用的是官网的方法根据官网的方法你可以配置模型路径修改为你自己的路径:https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
大家可以删除上面的MODEL_FILE和DOWNLOAD_BASE两行代码,这是下载模型用的,大家也可以根据自己的需要,修改这个测试代码。然后模型路径配置成自己刚刚生成的pb模型的路径即可
2、要检测的图片放置在test_images下。
上面的需要你安装jupyter ,直接使用 pip install jupyter 即可安装
启动jupyter
jupyter notebook --ip=0.0.0.0 --port 8888 --config=/root/.jupyter/jupyter_notebook_config.py --allow-root
这样就可以在浏览其中的8888端口看到你的文件了。其中的–config=/root/.jupyter/jupyter_notebook_config.py配置产生与修改,大家可以自行查资料配置,这是jupyter使用与本教程无关。
打开浏览器后,找到object_detection_tutorial.ipynb,打开,然后在页面点击执行run ALL等待一段时间后,就会在页面上出现检测结果。
六、将转的pb模型转成前端能用的json模型
如果上面的检测结果准确,就说明模型可用,那么接下来前端需要使用的话,就可在生成的pb模型下运行命令
tensorflowjs_converter --input_format=tf_saved_model --output_node_names="detection_boxes,detection_scores,detection_classes" --saved_model_tags=serve --output_json model.json ./saved_model ./web_model
接着只要把web_model下的文件给前端就可以使用了,这里我们使用的前端代码是这里的:
https://github.com/cloud-annotations/object-detection-js不过这边我们需要对这个代码做一些修改,否则无法使用,具体修改如下:
最终的实现效果:
本人文笔有限:参考的链接有
https://medium.com/coinmonks/tensorflow-object-detection-with-custom-objects-34a2710c6de5
https://towardsdatascience.com/real-time-mobile-video-object-detection-using-tensorflow-a75fa0c5859d
如何用tensorflow使用自定义数据来训练,做物体检测相关推荐
- PPv3-OCR自定义数据从训练到部署
PPv3-OCR自定义数据从训练到部署 一.配置Paddle环境 二.配置PaddleOCR 1.安装python包 2.测试环境 三 模型列表及其对应的配置文件 1. 文本检测模型 1.1 中文检测 ...
- 手把手教你用深度学习做物体检测(二):数据标注
"本篇文章将开始我们训练自己的物体检测模型之旅的第一步-- 数据标注." 上篇文章介绍了如何基于训练好的模型检测图片和视频中的物体,若你也想先感受一下物体检测,可以看看上篇 ...
- 手把手教你用深度学习做物体检测(三):模型训练
本篇文章旨在快速试验使用yolov3算法训练出自己的物体检测模型,所以会重过程而轻原理,当然,原理是非常重要的,只是原理会安排在后续文章中专门进行介绍.所以如果本文中有些地方你有原理方面的疑惑,也没关 ...
- tensorflow移植到Android端,实现物体检测自动拍照
tensorflow移植到Android端实现物体检测 一. 说明 1. tensorflow是什么: 是谷歌基于DistBelief进行研发的第二代人工智能学习系统. 2. 为什么要使用tensor ...
- python 提取最小外接矩形_放弃机器学习框架,如何用 Python 做物体检测?
每当我们听说"物体检测"时,就会想到机器学习和各种不同的框架.但实际上,我们可以在不使用机器学习或任何其他框架的情况下进行物体检测.在本文中,我将向你展示如何仅使用Python进行 ...
- 用TensorFlow训练一个物体检测器(手把手教学版)
TensorFlow内包含了一个强大的物体检测API,我们可以利用这API来训练自己的数据集实现特殊的目标检测. 作者软硬件环境配置:CPU: i7-6800k (不重要,主流的CPU均可) OS: ...
- 值得收藏!基于激光雷达数据的深度学习目标检测方法大合集(上)
作者 | 黄浴 转载自知乎专栏自动驾驶的挑战和发展 [导读]上周,我们在激光雷达,马斯克看不上,却又无可替代?>一文中对自动驾驶中广泛使用的激光雷达进行了简单的科普,今天,这篇文章将各大公司和机 ...
- 一种可训练的目标检测系统
麻省理工学院,人工智能实验室,生物与计算学习中心,美国马萨诸塞州剑桥 摘要 本文提出了一种通用的.可训练的.在无约束的.杂乱的场景中的目标检测系统.该系统的功能很大程度上来自于一种表示,该表示用一个过 ...
- Keras图像分割实战:数据整理分割、自定义数据生成器、模型训练
Keras图像分割实战:数据整理分割.自定义数据生成器.模型训练 目录 Keras图像分割实战:数据整理分割.自定义数据生成器.模型训练
最新文章
- 7.wait和waitpid
- 【面试招聘】算法岗通关宝典 | 社招一年经验,字节5轮、阿里7轮
- 局部内部类如何访问外部类方法中变量
- 文件传输-对数据进行加解密的方法!
- 支付宝年度账单被怼;英特尔CPU曝惊天漏洞;甘薇为贾跃亭喊冤 | 一周业界事
- 2019最有意思的五大 ZDI 案例之:通过调色板索引实现 Win32k.sys 本地提权漏洞(上)...
- mysql配置my.cnf文件,以及参数优化提升性能
- c# 字符串编码问题
- 看完《我的前半生》的些许感悟
- 帝国栏目导航点击显示不同样式的实现
- 格(Lattice)基础(一)
- 链栈的数据结构以及链栈的实现
- Atcoder Beginner Contest 174(ABC174) 题解
- ASEMI代理AD633JRZ原装ADI车规级AD633JRZ
- 超有趣,在idea中加入emoji图像!
- 抖音卡点视频怎么制作
- 怎么开网店新手怎么开淘宝网店
- iOS app脚手架
- 北京科技大学 Dog类定义和测试
- 合宙 ESP32C3 使用micropython 驱动配套0.96寸 TFT ST7735 屏幕显示色块和文字
热门文章
- 基于 GitLab CI 的前端工程CI/CD实践
- 量子计算机能为我们做什么,为实现量子计算,我们还需要做些什么
- mysql创建表参数_MySQL创建数据表(CREATE TABLE语句)
- js+css制作导航栏下划线跟随动画,App+H5点击效果
- 42张动图带你走进神奇的物理世界,超震撼!
- Python---元祖、循环
- 如何直观理解AUC评价指标?
- 微信在线EXCEL自动统计人数
- C# 计算指定年月的当月工作日方法
- 嵌入式linux音频播放器设计,基于嵌入式Linux下Madplay音频播放器设计论文.docx