Tensorflow的物体检测API是非常强大的工具,它可以使得没有机器学习背景的人都可以快速部署强大的图像识别,物体检测模型。但是,考虑到使用的指导文档不够丰富,具体如何使用成了很多人的门槛。

本篇包含以下几个部分的内容:

  • 选择模型
  • 适应当前数据集
  • 创建并标注自己的数据集
  • 修改模型配置文件
  • 训练模型
  • 存储模型
  • 部署模型

如何安装

git clone https://github.com/tensorflow/models.git

拿到模型库,这里提供了很多很多好用的机器学习模型,但是我们现在只用其中的models/research/object_detection模型。

使用模型

object_detection文件夹下,jupyter notebook,使用object_detection_tutorial.ipynb,这个notebook的内容就是一步一步带我们使用预训练模型进行物体检测。

这里把代码全部贴出来,供参考:

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfilefrom distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_opsif StrictVersion(tf.__version__) < StrictVersion('1.9.0'):raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')
# 环境设置
%matplotlib inline # 这在jupyter内使用# 导入物体检测需要的包,值得注意的是,这里我置换了刚刚下的protos文件夹,因为最新下的不全
from utils import label_map_util
from utils import visualization_utils as vis_util# 模型准备
# What model to download.
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')# 下载模型
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():file_name = os.path.basename(file.name)if 'frozen_inference_graph.pb' in file_name:tar_file.extract(file, os.getcwd())# 将frozen的模型加载到内存
detection_graph = tf.Graph()
with detection_graph.as_default():od_graph_def = tf.GraphDef()with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:serialized_graph = fid.read()od_graph_def.ParseFromString(serialized_graph)tf.import_graph_def(od_graph_def, name='')# 加载标记字典,比如预测结果输出是5,我们可以知道对应的是airplane
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)# 辅助代码
def load_image_into_numpy_array(image):(im_width, im_height) = image.sizereturn np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
# 检测
# 为了简单只用两个图片,如果想修改,只需要将图片放到test_images文件夹即可
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
PATH_TO_TEST_IMAGES_DIR = 'test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)def run_inference_for_single_image(image, graph):with graph.as_default():with tf.Session() as sess:# Get handles to input and output tensorsops = tf.get_default_graph().get_operations()all_tensor_names = {output.name for op in ops for output in op.outputs}tensor_dict = {}for key in ['num_detections', 'detection_boxes', 'detection_scores','detection_classes', 'detection_masks']:tensor_name = key + ':0'if tensor_name in all_tensor_names:tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)if 'detection_masks' in tensor_dict:# The following processing is only for single imagedetection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])# Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(detection_masks, detection_boxes, image.shape[0], image.shape[1])detection_masks_reframed = tf.cast(tf.greater(detection_masks_reframed, 0.5), tf.uint8)# Follow the convention by adding back the batch dimensiontensor_dict['detection_masks'] = tf.expand_dims(detection_masks_reframed, 0)image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')# Run inferenceoutput_dict = sess.run(tensor_dict,feed_dict={image_tensor: np.expand_dims(image, 0)})# all outputs are float32 numpy arrays, so convert types as appropriateoutput_dict['num_detections'] = int(output_dict['num_detections'][0])output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.uint8)output_dict['detection_boxes'] = output_dict['detection_boxes'][0]output_dict['detection_scores'] = output_dict['detection_scores'][0]if 'detection_masks' in output_dict:output_dict['detection_masks'] = output_dict['detection_masks'][0]return output_dict
for image_path in TEST_IMAGE_PATHS:image = Image.open(image_path)# the array based representation of the image will be used later in order to prepare the# result image with boxes and labels on it.image_np = load_image_into_numpy_array(image)# Expand dimensions since the model expects images to have shape: [1, None, None, 3]image_np_expanded = np.expand_dims(image_np, axis=0)# Actual detection.output_dict = run_inference_for_single_image(image_np, detection_graph)# Visualization of the results of a detection.vis_util.visualize_boxes_and_labels_on_image_array(image_np,output_dict['detection_boxes'],output_dict['detection_classes'],output_dict['detection_scores'],category_index,instance_masks=output_dict.get('detection_masks'),use_normalized_coordinates=True,line_thickness=8)plt.figure(figsize=IMAGE_SIZE)plt.imshow(image_np)

其中,设定自己想要检测的图片,只需要修改:

PATH_TO_TEST_IMAGES_DIR = 'test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]

将图片放置在test_images文件夹下,并按照命名格式,重新定义读取图片名字的方式。

总体看,需要修改的地方只有两处:

  • 指定模型
  • 指定检测图片路径

其中,模型仓库地址是:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

选择模型

想对模型有更深入的了解,可以去这里阅读相关材料。

object_detection文件夹下的g3doc文件夹内,有一些好用的.md文档,值得仔细阅读一番。

END.

参考:

https://medium.com/@WuStangDan/step-by-step-tensorflow-object-detection-api-tutorial-part-1-selecting-a-model-a02b6aabe39e

https://ai.googleblog.com/2017/06/supercharge-your-computer-vision-models.html

https://github.com/tensorflow/models/tree/master/research/object_detection

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

【CV】如何使用Tensorflow提供的Object Detection API --1--使用预训练模型相关推荐

  1. 【CV】如何使用Tensorflow提供的Object Detection API --2--数据转换为TFRecord格式

    本篇主要讲的是如何将现存的数据变成Tensorflow记录格式,然后我们就可以用这些数据来进行微调模型,以解决我们关心的问题了. 什么是TFRecord格式 一般使用TF读取数据有四种方式: 预先把所 ...

  2. 【CV】如何使用Tensorflow提供的Object Detection API--3--手工标注数据

    前面两篇看完,我们已经知道如何选用预训练模型以及将现存的其他数据集变成TFRecord格式的数据了. 但是如果需要用你自己的数据集,该怎么办呢? 本篇主要讲如何创建自己的数据集,并用object_de ...

  3. 【CV】如何使用Tensorflow提供的Object Detection API--4--开始训练模型

    至此已经学习了如何选择预训练模型,将数据集转为TFRecord格式.模型和数据都准备好了,是时候开启训练了. 这些在COCO数据集上的模型都是针对90类进行识别的,如果自己的任务没有这么多类,或者类不 ...

  4. TensorFlow Object Detection API Custom Object Hangs On

    TensorFlow Object Detection API Hangs On - Training and Evaluating using Custom Object Detector *The ...

  5. 【物体检测快速入门系列 | 01 】基于Tensorflow2.x Object Detection API构建自定义物体检测器

    这是机器未来的第1篇文章 原文首发地址:https://blog.csdn.net/RobotFutures/article/details/124745966 CSDN话题挑战赛第1期 活动详情地址 ...

  6. Tensorflow object detection API 搭建物体识别模型

    ----------------------------------------------------先把别人博客教程跑通-------------------------------------- ...

  7. 谷歌开放的TensorFlow Object Detection API 效果如何?对业界有什么影响

    ? 谷歌开放了一个 Object Detection API: Supercharge your C 写个简单的科普帖吧. 熟悉TensorFlow的人都知道,tf在Github上的主页是:tenso ...

  8. tensorflow精进之路(二十五)——Object Detection API目标检测(下)(VOC数据集训练自己的模型进行目标检测)

    1.概述 上一讲,我们使用了别人根据COCO数据集训练好的模型来做目标检测,这一讲,我们就来训练自己的模型. 2.下载数据集 为了方便学习,我们先使用别人整理好的数据集来训练---VOC 2012数据 ...

  9. tensorflow精进之路(二十四)——Object Detection API目标检测(中)(COCO数据集训练的模型—ssd_mobilenet_v1_coco模型)

    1.概述 上一讲简单的讲了目标检测的原理以及Tensorflow Object Detection API的安装,这一节继续讲Tensorflow Object Detection API怎么用. 2 ...

最新文章

  1. [云炬创业学笔记]第一章创业是什么测试7
  2. html2canvas源码修改,html2canvas把div保存高清图的方法代码
  3. gridview 在已有数据的基础上添加数据_基于Python的数据分析-1.语法基础(上)
  4. 2018年 第09届 蓝桥杯 Java B组 决赛真题详解及小结
  5. 中list如何清空_如何根据索引删除 list 中的元素
  6. 多年软件测试大牛分享成长经历,一个好的软件测试工程师应该做到这些!
  7. scala的list源码解密
  8. java程序的入口点_Java程序的入口点
  9. MATLAB高斯平顶化,一种高斯光束变换为平顶光束整形透镜的粒子群设计方法与流程...
  10. BMP/JPG/PNG/GIF/有损压缩和无损压缩【转载整理】
  11. Db2性能问题:临时表空间太大,导致连不上数据库
  12. 微信公众平台开发技术文档
  13. python期货基本面分析_Python量化炒期货入门与实战技巧
  14. 给定两个字符串 s 和 t,它们只包含小写字母。
  15. 四款软件,提高团队工作效率
  16. mysql主从复制(一):一主多从
  17. 2021 MCU WiFi竞争新格局,国产MCU WiFi芯片盘点,附录2020/2021 MCU WiFi排行
  18. 心形符号c语言程序,c语言心形代码及图形
  19. 数据分析,怎么做才能有前瞻性
  20. Python-爬虫(生产者、消费者模型爬取王者荣耀壁纸)

热门文章

  1. asp网上书店的代码_使用Helm将ASP.NET Core应用程序部署到Kubernetes容器集群
  2. 大数据工程师简历_大数据工程师简历范本02
  3. 麦克纳姆轮运动原理怎么安装_家用中央空调水系统原理是什么?怎么样安装比较好呢?...
  4. python解码和编码的区别_python基础小知识,is和==的区别,编码和解码
  5. numactl mysql_CentOS学习笔记 - 10. 开发机mysql安装
  6. windows10计算机用户密码,忘记Windows 10系统密码?教你重置
  7. c语言栈的实现以及操作_C++语言实现顺序栈
  8. Python-Matplotlib可视化(8)——图形的输出与保存
  9. linux web故障,网络故障处理与优化 linux服务器配置及故障排除 项目9 配置与管理web服务器.docx...
  10. 腐蚀rust电脑分辨率调多少_腐蚀Rust画面怎么设置 _游侠网