1. 环境配置

使用protobuf来配置模型和训练参数,所以API正常使用必须先编译protobuf库,这里可以下载直接编译好的pb库(https://github.com/google/protobuf/releases ),解压压缩包后,把protoc加入到环境变量中:

$ cd tensorflow/models

$ protoc object_detection/protos/*.proto --python_out=. #注意: *在这里有时会报错,找不到文件,可以手动添加文件命名

(我是把protoc加到环境变量中,遇到找不到*.proto文件的报错,后来把protoc.exe放到models/object_detection目录下,重新执行才可以)

然后将models和slim(tf高级框架)加入python环境变量:

PYTHONPATH=$PYTHONPATH:/your/path/to/tensorflow/models:/your/path/to/tensorflow/models/slim

2.数据准备

自己制作了限速标志的数据,总共250张,训练200,测试50.有了数据以后我们需要给他们打标签。我们需要手动在每一张图中框出限速标志的位置。一个比较好的打标工具是LabelImg,标签:sign,生成VOC格式的数据。制作VOC格式数据文件夹形式:my_images

__VOCdevkit

__VOC2012

__Annotations(文件名:2007_0000)#xml格式的标签

__JPEGImages(文件名:2007_0000)#jpg图像

__ImageSets

__Main

__train

__val

import os
import randompt="/tensorflow/model/research/object_detection/my_images/VOCdevkit/VOC2012/JPEGImages"
image_name=os.listdir(pt)
for temp in image_name:if temp.endswith(".jpg"):print (temp.replace('.jpg',''))

以上代码可以生成train和val列表

将VOC数据转化成tf.recoord数据 :参考 dataset_tools/create_pascal_tf_record.py 根据自己路径文件格式做出适当修改,我的如下;

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================r"""Convert raw PASCAL dataset to TFRecord for object_detection.Example usage:python object_detection/dataset_tools/create_pascal_tf_record.py \--data_dir=/home/user/VOCdevkit \--year=VOC2012 \--output_path=/home/user/pascal.record
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport hashlib
import io
import logging
import osfrom lxml import etree
import PIL.Image
import tensorflow as tffrom object_detection.utils import dataset_util
from object_detection.utils import label_map_utilflags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or ''merged set.')
flags.DEFINE_string('annotations_dir', 'Annotations','(Relative) path to annotations directory.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt','Path to label map proto') #此处修改自己标签
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore ''difficult instances')
FLAGS = flags.FLAGSSETS = ['train', 'val', 'trainval', 'test']
YEARS = ['VOC2007', 'VOC2012', 'merged']def dict_to_tf_example(data,dataset_directory,label_map_dict,ignore_difficult_instances=False,image_subdirectory='JPEGImages'):"""Convert XML derived dict to tf.Example proto.Notice that this function normalizes the bounding box coordinates providedby the raw data.Args:data: dict holding PASCAL XML fields for a single image (obtained byrunning dataset_util.recursive_parse_xml_to_dict)dataset_directory: Path to root directory holding PASCAL datasetlabel_map_dict: A map from string label names to integers ids.ignore_difficult_instances: Whether to skip difficult instances in thedataset  (default: False).image_subdirectory: String specifying subdirectory within thePASCAL dataset directory holding the actual image data.Returns:example: The converted tf.Example.Raises:ValueError: if the image pointed to by data['filename'] is not a valid JPEG"""img_path = os.path.join(image_subdirectory,data['filename'])#image_subdirectory=JPEGImages  此处修改了year1='VOC2012'#此处修改full_path = os.path.join(dataset_directory,year1, img_path)#此处修改with tf.gfile.GFile(full_path, 'rb') as fid:encoded_jpg = fid.read()encoded_jpg_io = io.BytesIO(encoded_jpg)image = PIL.Image.open(encoded_jpg_io)if image.format != 'JPEG':raise ValueError('Image format not JPEG')key = hashlib.sha256(encoded_jpg).hexdigest()width = int(data['size']['width'])height = int(data['size']['height'])xmin = []ymin = []xmax = []ymax = []classes = []classes_text = []truncated = []poses = []difficult_obj = []if 'object' in data:for obj in data['object']:difficult = bool(int(obj['difficult']))if ignore_difficult_instances and difficult:continuedifficult_obj.append(int(difficult))xmin.append(float(obj['bndbox']['xmin']) / width)ymin.append(float(obj['bndbox']['ymin']) / height)xmax.append(float(obj['bndbox']['xmax']) / width)ymax.append(float(obj['bndbox']['ymax']) / height)classes_text.append(obj['name'].encode('utf8'))classes.append(label_map_dict[obj['name']])truncated.append(int(obj['truncated']))poses.append(obj['pose'].encode('utf8'))example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/filename': dataset_util.bytes_feature(data['filename'].encode('utf8')),'image/source_id': dataset_util.bytes_feature(data['filename'].encode('utf8')),'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),'image/object/truncated': dataset_util.int64_list_feature(truncated),'image/object/view': dataset_util.bytes_list_feature(poses),}))return exampledef main(_):if FLAGS.set not in SETS:raise ValueError('set must be in : {}'.format(SETS))if FLAGS.year not in YEARS:raise ValueError('year must be in : {}'.format(YEARS))data_dir = FLAGS.data_diryears = ['VOC2007', 'VOC2012']if FLAGS.year != 'merged':years = [FLAGS.year]writer = tf.python_io.TFRecordWriter(FLAGS.output_path)label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)for year in years:logging.info('Reading from PASCAL %s dataset.', year)examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',FLAGS.set + '.txt')annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)examples_list = dataset_util.read_examples_list(examples_path)for idx, example in enumerate(examples_list):if idx % 100 == 0:logging.info('On image %d of %d', idx, len(examples_list))path = os.path.join(annotations_dir, example + '.xml')#此处修改with tf.gfile.GFile(path, 'r') as fid:xml_str = fid.read()xml = etree.fromstring(xml_str)data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,FLAGS.ignore_difficult_instances)writer.write(tf_example.SerializeToString())writer.close()if __name__ == '__main__':tf.app.run()

修改data/pascal_label_map.pbtxt 改成自己的标签数据,我只有一类,从id:1开始:

item {
  id: 1
  name: 'sign'
}

生成训练数据:pascal_train.record

python dataset_tools/create_pascal_tf_record.py --data_dir=my_images/VOCdevkit/ --year=VOC2012 --output_path=my_images/VOCdevkit/pascal_train.record --set=train

生成测试数据:pascal_val.record

python dataset_tools/create_pascal_tf_record.py --data_dir=my_images/VOCdevkit/ --year=VOC2012 --output_path=my_images/VOCdevkit/pascal_val.record --set=val

3,下载模型:ssd_mobilenet_v1_coco

下载链接https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

下载ssd_mobilenet_v1_coco完后,将其解压到my_images文件夹下,将model.ckpt  3个文件放在VOC2012下面

接下来呢,新建配置文件,samples/configs/文件夹下有一些示例文件,我们就模仿它们配置,参考faster_rcnn_inception_resnet_v2_atrous_coco.config文件,将其复制在VOC2012下面

文件名:ssd_mobilenet_v1.config

model {
ssd {
num_classes: 1
box_coder {
faster_rcnn_box_coder {
y_scale: 10.0
x_scale: 10.0
height_scale: 5.0
width_scale: 5.0
}
}
matcher {
argmax_matcher {
matched_threshold: 0.5
unmatched_threshold: 0.5
ignore_thresholds: false
negatives_lower_than_unmatched: true
force_match_for_each_row: true
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
num_layers: 6
min_scale: 0.2
max_scale: 0.95
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
aspect_ratios: 3.0
aspect_ratios: 0.3333
}
}
image_resizer {
fixed_shape_resizer {
height: 300
width: 300
}
}
box_predictor {
convolutional_box_predictor {
min_depth: 0
max_depth: 0
num_layers_before_predictor: 0
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 1
box_code_size: 4
apply_sigmoid_to_scores: false
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
}
feature_extractor {
type: 'ssd_mobilenet_v1'
min_depth: 16
depth_multiplier: 1.0
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
loss {
classification_loss {
weighted_sigmoid {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
hard_example_miner {
num_hard_examples: 3000
iou_threshold: 0.99
loss_type: CLASSIFICATION
max_negatives_per_positive: 3
min_negatives_per_image: 0
}
classification_weight: 1.0
localization_weight: 1.0
}
normalize_loss_by_num_matches: true
post_processing {
batch_non_max_suppression {
score_threshold: 1e-8
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 100
}
score_converter: SIGMOID
}
}
}train_config: {
batch_size: 24
optimizer {
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 800720
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
}
fine_tune_checkpoint: "my_images/VOCdevkit/VOC2012/model.ckpt"
from_detection_checkpoint: true
# Note: The below line limits the training process to 200K steps, which we
# empirically found to be sufficient enough to train the pets dataset. This
# effectively bypasses the learning rate schedule (the learning rate will
# never decay). Remove the below line to train indefinitely.
num_steps: 20000
data_augmentation_options {
random_horizontal_flip {
}
}
data_augmentation_options {
ssd_random_crop {
}
}
}train_input_reader: {
tf_record_input_reader {
input_path: "my_images/pascal_train.record"
}
label_map_path: "data/pascal_label_map.pbtxt"
}eval_config: {
num_examples: 50
# Note: The below line limits the evaluation process to 10 evaluations.
# Remove the below line to evaluate indefinitely.
max_evals: 10
}eval_input_reader: {
tf_record_input_reader {
input_path: "my_images/pascal_val.record"
}
label_map_path: "data/pascal_label_map.pbtxt"
shuffle: false
num_readers: 1
}

4,训练模型

在object_detection下执行:创建检查点文件在my_images/train

python legacy/train.py --train_dir=my_images/train/ --pipeline_config_path=my_images/VOCdevkit/VOC2012/ssd_mobilenet_v1.config

my_images目录下新建一个eval目录,用于保存eval的文件。另开终端,执行如下命令

/tensorflow/models/research/object_detection$ python legacy/eval.py 
    --logtostderr \
    --pipeline_config_path=my_images/VOCdevkit/VOC2012/ssd_mobilenet_v1_raccoon.config \
    --checkpoint_dir=my_images/train \
    --eval_dir=my_images/eval

特别提醒:建议使用legacy/train.py  用object_detection/model_main.py本人没成功过。可能还会遇到GPU内存方面的错误,建议指定GPU训练。训练了20k steps.效果测试精确度在94%,97%左右,数据太少每次测试有变化。但是效果还是不错的,贴几张效果图如下:

下一篇:tensorflow  训练完模型的导出和测试模型

tensorflow 物体检测(检测限速标志)相关推荐

  1. Google发布新的TensorFlow物体检测API

    \\ Google发布TensorFlow物体检测API,帮助开发人员和研究人员识别图片中的物体.Google专注于提高API的易用性和性能,新的模型于6月16号发布,在基准测试中表现出良好的性能,并 ...

  2. 使用SSD网络模型进行Tensorflow物体检测(V1.1摄像头检测)

    使用SSD网络模型进行Tensorflow物体检测?(V1.1摄像头检测) 文章目录 使用SSD网络模型进行Tensorflow物体检测?(V1.1摄像头检测) 1.模型的加载和utils库环境的配置 ...

  3. 使用SSD网络模型进行Tensorflow物体检测(V1.2视频检测)

    使用SSD网络模型进行Tensorflow物体检测?(V1.2视频检测) 文章目录 使用SSD网络模型进行Tensorflow物体检测?(V1.2视频检测) 1.模型的加载和utils库环境的配置? ...

  4. vuforia的物体识别能识别大物体吗_衢州sensopart 物体识别检测视觉-灵测信息

    首页 > 新闻列表 > 浏览文章 发布时间:2020-11-03 12:07:00 浏览量: 6 导读:灵测信息为您提供衢州sensopart 物体识别检测视觉的相关知识与详情: 本文的关 ...

  5. YOLOv3物体/目标检测之实战篇(Windows系统、Python3、TensorFlow2版本)

    前言 基于YOLO进行物体检测.对象识别,在搭建好开发环境后,先和大家进行实践应用中,体验YOLOv3物体/目标检测效果和魅力:同时逐步了解YOLOv3的不足和优化思路. 开发环境参数 系统:Wind ...

  6. 人车目标检测、交通标志识别…云测数据喊你参加第九届CCF大数据与计算智能大赛啦

    2021年大数据与AI领域年度盛事--第九届CCF大数据与计算智能大赛已全面开赛! 云测数据携手OneFlow一流科技,发布基于自动驾驶场景下的"人车目标检测.交通标志分类识别.交通灯识别. ...

  7. 线扫相机——机器视觉中无限制物体的检测

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 在机器视觉中,在检测连续物体或者滚动物体时,线扫相机是最佳的解决方 ...

  8. 利用SIFT和RANSAC算法(openCV框架)实现物体的检测与定位,并求出变换矩阵(findFundamentalMat和findHomography的比较)

    本文目标是通过使用SIFT和RANSAC算法,完成特征点的正确匹配,并求出变换矩阵,通过变换矩阵计算出要识别物体的边界(文章中有部分源码,整个工程我也上传了,请点击这里). SIFT算法是目前公认的效 ...

  9. 【自动驾驶】鸽了很久的小物体目标检测代码【小物体目标检测】

    鸽了很久的小物体目标检测代码 https://github.com/LT1st/SmallObstacleDetection/tree/main/code#readme Pytorch: Small ...

最新文章

  1. win10怎么设置开机启动项目_苹果mac开机启动项怎么设置
  2. 关于Virtual-Hosting的理解
  3. latex中怎样使公式居中_LaTeX_多行公式对齐居中的同时选择性的加编号
  4. accept和select的区别
  5. mybatis那些事~
  6. ruby gem 安装mysql2_Ruby gem mysql2安装错误
  7. 【Calcite】Calcite入门
  8. 一个系统管理员的自白
  9. 蚂蚁上市,身边又多了一堆千万富翁!
  10. arcpy 实现列举目录下的要素类与描述矢量数据要素类
  11. 561. 数组拆分 I
  12. 最简单的WIN7内核PE系统的U盘安装方法+WIN7密码破解
  13. 什么是静态测试、动态测试、黑盒测试、白盒测试、α测试 β测试?
  14. [历朝通俗演义-蔡东藩-前汉]第011回 降真龙光韬泗水 斩大蛇夜走丰乡
  15. python中pixels函数_Python的PIL库中getpixel方法的使用
  16. 如何关闭WPS烦人的广告推送
  17. 「Ubuntu」ubuntu18.04键盘输入卡顿、延迟输入
  18. 【从0到1搭建LoRa物联网】15、LoRa连接到The Things Network
  19. 设计模式之 Composite(组合)通俗理解
  20. [2016 版] 常见操作性能对比

热门文章

  1. 迎来全民网络支付新时代
  2. 1034 Head of a Gang (30 分)
  3. 【深圳线下】FMI人工智能和大数据线下技术沙龙第870期
  4. 第3周作业 #高级编程技术
  5. Twincat导出Scope数据(机器人控制),并采用origin绘图
  6. C++ 实现点与圆的位置关系
  7. redis防止表单重复提交
  8. 2019电赛--无人机题目OpenMV总结
  9. 【文字版】津津有味:感染新冠后没食欲,那咱也得吃饭啊!
  10. LaTex(0):LaTex工具TeXStudio、TeXLive安装与基本使用