该版本的SSD实现github路径 GitHub - balancap/SSD-Tensorflow: Single Shot MultiBox Detector in TensorFlow

所用库配置: python 3.6.0

tensorflow 1.11

Keras 2.1.5

下载完毕后checkpoints下已经有训练好的模型,可以用此模型来预测下自带的测试图片;以及对视频内物体进行定位;此可参见该博客【SSD目标检测】1:图片、视频内的物体检测与定位_zzZ_CMing的博客-CSDN博客_ssd物体识别

接下来咱们配置下自己的数据集。

1. 主目录下新建一个文件夹,用于存放原图、标注图、及参与训练和验证集的样本分布文本,这里取名为VOC2007

Annotations和JPEGImages的制作见我的之前的博客ultralytics/yolov3训练预测自己数据集的配置过程_竹叶青lvye的博客-CSDN博客_ultralytics yolov3

2. 下面要生成满足VOC2007数据集格式的ImageSets\Main里的四个txt文件

可以如上新建一个GenerateTXT.py文件

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:Icecream.Shao
# -*- coding:utf-8 -*-
# -*- author:zzZ_CMing  CSDN address:https://blog.csdn.net/zzZ_CMing
# -*- 2018/07/18; 15:19
# -*- python3.5
import os
import randomtrainval_percent = 0.7
train_percent = 0.8
xmlfilepath = 'Annotations/'
txtsavepath = 'ImageSets/Main'
total_xml = os.listdir(xmlfilepath)num = len(total_xml)
list = range(num)
tv = int(num*trainval_percent)
tr = int(tv*train_percent)
trainval = random.sample(list,tv)
train = random.sample(trainval,tr)ftrainval = open(txtsavepath+'/trainval.txt', 'w')
ftest = open(txtsavepath+'/test.txt', 'w')
ftrain = open(txtsavepath+'/train.txt', 'w')
fval = open(txtsavepath+'/val.txt', 'w')for i in list:name = total_xml[i][:-4]+'\n'if i in trainval:ftrainval.write(name)if i in train:ftrain.write(name)else:fval.write(name)else:ftest.write(name)ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
print('Well Done!!!')

3.每个框架所用的文件格式是不一样的,这里需要做转化,可以使用主目录 下的tf_convert_data.py文件

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convert a dataset to TFRecords format, which can be easily integrated into
a TensorFlow pipeline.Usage:
```shell
python tf_convert_data.py \--dataset_name=pascalvoc \--dataset_dir=/tmp/pascalvoc \--output_name=pascalvoc \--output_dir=/tmp/
```
"""
import tensorflow as tffrom datasets import pascalvoc_to_tfrecordsFLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_string('dataset_name', 'pascalvoc','The name of the dataset to convert.')
tf.app.flags.DEFINE_string('dataset_dir', '.\\VOC2007\\','Directory where the original dataset is stored.')
tf.app.flags.DEFINE_string('output_name', 'mydata_train','Basename used for TFRecords output files.')
tf.app.flags.DEFINE_string('output_dir', '.\\tfrecords\\','Output directory where to store TFRecords files.')def main(_):if not FLAGS.dataset_dir:raise ValueError('You must supply the dataset directory with --dataset_dir')print('Dataset directory:', FLAGS.dataset_dir)print('Output directory:', FLAGS.output_dir)if FLAGS.dataset_name == 'pascalvoc':pascalvoc_to_tfrecords.run(FLAGS.dataset_dir, FLAGS.output_dir, FLAGS.output_name)else:raise ValueError('Dataset [%s] was not recognized.' % FLAGS.dataset_name)if __name__ == '__main__':tf.app.run()

4. 修改datasets目录下的pascalvoc_common.py文件中的VOC_LABELS变量

5. 修改pascalvoc_to_tfrecords.py中的代码

# Copyright 2015 Paul Balanca. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts Pascal VOC data to TFRecords file format with Example protos.The raw Pascal VOC data set is expected to reside in JPEG files located in the
directory 'JPEGImages'. Similarly, bounding box annotations are supposed to be
stored in the 'Annotation directory'This TensorFlow script converts the training and evaluation data into
a sharded data set consisting of 1024 and 128 TFRecord files, respectively.Each validation TFRecord file contains ~500 records. Each training TFREcord
file contains ~1000 records. Each record within the TFRecord file is a
serialized Example proto. The Example proto contains the following fields:image/encoded: string containing JPEG encoded image in RGB colorspaceimage/height: integer, image height in pixelsimage/width: integer, image width in pixelsimage/channels: integer, specifying the number of channels, always 3image/format: string, specifying the format, always'JPEG'image/object/bbox/xmin: list of float specifying the 0+ human annotatedbounding boxesimage/object/bbox/xmax: list of float specifying the 0+ human annotatedbounding boxesimage/object/bbox/ymin: list of float specifying the 0+ human annotatedbounding boxesimage/object/bbox/ymax: list of float specifying the 0+ human annotatedbounding boxesimage/object/bbox/label: list of integer specifying the classification index.image/object/bbox/label_text: list of string descriptions.Note that the length of xmin is identical to the length of xmax, ymin and ymax
for each example.
"""
import os
import sys
import randomimport numpy as np
import tensorflow as tfimport xml.etree.ElementTree as ETfrom datasets.dataset_utils import int64_feature, float_feature, bytes_feature
from datasets.pascalvoc_common import VOC_LABELS# Original dataset organisation.
DIRECTORY_ANNOTATIONS = 'Annotations/'
DIRECTORY_IMAGES = 'JPEGImages/'# TFRecords convertion parameters.
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 5def _process_image(directory, name):"""Process a image and annotation file.Args:filename: string, path to an image file e.g., '/path/to/example.JPG'.coder: instance of ImageCoder to provide TensorFlow image coding utils.Returns:image_buffer: string, JPEG encoding of RGB image.height: integer, image height in pixels.width: integer, image width in pixels."""# Read the image file.filename = directory + DIRECTORY_IMAGES + name + '.bmp'image_data = tf.gfile.FastGFile(filename, 'rb').read()# Read the XML annotation file.filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')tree = ET.parse(filename)root = tree.getroot()# Image shape.size = root.find('size')shape = [int(size.find('height').text),int(size.find('width').text),int(size.find('depth').text)]# Find annotations.bboxes = []labels = []labels_text = []difficult = []truncated = []for obj in root.findall('object'):label = obj.find('name').textlabels.append(int(VOC_LABELS[label][0]))labels_text.append(label.encode('ascii'))if obj.find('difficult'):difficult.append(int(obj.find('difficult').text))else:difficult.append(0)if obj.find('truncated'):truncated.append(int(obj.find('truncated').text))else:truncated.append(0)bbox = obj.find('bndbox')bboxes.append((float(bbox.find('ymin').text) / shape[0],float(bbox.find('xmin').text) / shape[1],float(bbox.find('ymax').text) / shape[0],float(bbox.find('xmax').text) / shape[1]))return image_data, shape, bboxes, labels, labels_text, difficult, truncateddef _convert_to_example(image_data, labels, labels_text, bboxes, shape,difficult, truncated):"""Build an Example proto for an image example.Args:image_data: string, JPEG encoding of RGB image;labels: list of integers, identifier for the ground truth;labels_text: list of strings, human-readable labels;bboxes: list of bounding boxes; each box is a list of integers;specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belongto the same label as the image label.shape: 3 integers, image shapes in pixels.Returns:Example proto"""xmin = []ymin = []xmax = []ymax = []for b in bboxes:assert len(b) == 4# pylint: disable=expression-not-assigned[l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]# pylint: enable=expression-not-assignedimage_format = b'JPEG'example = tf.train.Example(features=tf.train.Features(feature={'image/height': int64_feature(shape[0]),'image/width': int64_feature(shape[1]),'image/channels': int64_feature(shape[2]),'image/shape': int64_feature(shape),'image/object/bbox/xmin': float_feature(xmin),'image/object/bbox/xmax': float_feature(xmax),'image/object/bbox/ymin': float_feature(ymin),'image/object/bbox/ymax': float_feature(ymax),'image/object/bbox/label': int64_feature(labels),'image/object/bbox/label_text': bytes_feature(labels_text),'image/object/bbox/difficult': int64_feature(difficult),'image/object/bbox/truncated': int64_feature(truncated),'image/format': bytes_feature(image_format),'image/encoded': bytes_feature(image_data)}))return exampledef _add_to_tfrecord(dataset_dir, name, tfrecord_writer):"""Loads data from image and annotations files and add them to a TFRecord.Args:dataset_dir: Dataset directory;name: Image name to add to the TFRecord;tfrecord_writer: The TFRecord writer to use for writing."""image_data, shape, bboxes, labels, labels_text, difficult, truncated = \_process_image(dataset_dir, name)example = _convert_to_example(image_data, labels, labels_text,bboxes, shape, difficult, truncated)tfrecord_writer.write(example.SerializeToString())def _get_output_filename(output_dir, name, idx):return '%s/%s_%03d.tfrecord' % (output_dir, name, idx)def run(dataset_dir, output_dir, name='voc_train', shuffling=False):"""Runs the conversion operation.Args:dataset_dir: The dataset directory where the dataset is stored.output_dir: Output directory."""if not tf.gfile.Exists(dataset_dir):tf.gfile.MakeDirs(dataset_dir)# Dataset filenames, and shuffling.path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)filenames = sorted(os.listdir(path))if shuffling:random.seed(RANDOM_SEED)random.shuffle(filenames)# Process dataset files.i = 0fidx = 0while i < len(filenames):# Open new TFRecord file.tf_filename = _get_output_filename(output_dir, name, fidx)with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:j = 0while i < len(filenames) and j < SAMPLES_PER_FILES:sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))sys.stdout.flush()filename = filenames[i]img_name = filename[:-4]_add_to_tfrecord(dataset_dir, img_name, tfrecord_writer)i += 1j += 1fidx += 1# Finally, write the labels file:# labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))# dataset_utils.write_label_file(labels_to_class_names, dataset_dir)print('\nFinished converting the Pascal VOC dataset!')

6.执行上面的tf_convert_data文件后,在tfrecords目录下有。tfrecord后缀的文件

7. 再修改如下几处:

datasets/pascalvoc_2007.py

nets/ssd_vgg_300.py

eval_ssd_network.py

train_ssd_network.py

点击train_ssd_network.py开始训练,训练结束后,在train_model目录下存在了所间隔时间内保存的模型

7.用训练好的模型来预测一张人行步道图片

# -*- coding:utf-8 -*-
# -*- author:zzZ_CMing  CSDN address:https://blog.csdn.net/zzZ_CMing
# -*- 2018/07/14; 15:19
# -*- python3.5
"""
address: https://blog.csdn.net/qq_35608277/article/details/78660469
本文代码来自于github中微软官方仓库
"""
import os
import cv2
import math
import random
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.cm as mpcm
import matplotlib.image as mpimg
from notebooks import visualization
from nets import ssd_vgg_300, ssd_common, np_methods
from preprocessing import ssd_vgg_preprocessing
import syssys.path.append('./SSD-Tensorflow-master/')slim = tf.contrib.slim
gpu_options = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
isess = tf.InteractiveSession(config=config)l_VOC_CLASS = ['sidewalk']net_shape = (300, 300)
img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
data_format = 'NHWC'image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(img_input, None, None, net_shape, data_format,resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)image_4d = tf.expand_dims(image_pre, 0)reuse = True if 'ssd_net' in locals() else None
ssd_net = ssd_vgg_300.SSDNet()
with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)ckpt_filename = '../train_model/model.ckpt-20000'
isess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, ckpt_filename)ssd_anchors = ssd_net.anchors(net_shape)
def colors_subselect(colors, num_classes=2):dt = len(colors) // num_classessub_colors = []for i in range(num_classes):color = colors[i * dt]if isinstance(color[0], float):sub_colors.append([int(c * 255) for c in color])else:sub_colors.append([c for c in color])return sub_colorsdef bboxes_draw_on_img(img, classes, scores, bboxes, colors, thickness=2):shape = img.shapefor i in range(bboxes.shape[0]):bbox = bboxes[i]color = colors[classes[i]]# Draw bounding box...p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)# Draw text...s = '%s/%.3f' % (l_VOC_CLASS[int(classes[i]) - 1], scores[i])p1 = (p1[0] - 5, p1[1])# cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 1.5, color, 3)colors_plasma = colors_subselect(mpcm.plasma.colors, num_classes=21)# 主流程函数
def process_image(img, case, select_threshold=0.15, nms_threshold=.1, net_shape=(300, 300)):# select_threshold:box阈值——每个像素的box分类预测数据的得分会与box阈值比较,高于一个box阈值则认为这个box成功框到了一个对象# nms_threshold:重合度阈值——同一对象的两个框的重合度高于该阈值,则运行下面去重函数# 执行SSD模型,得到4维输入变量,分类预测,坐标预测,rbbox_img参数为最大检测范围,本文固定为[0,0,1,1]即全图rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions,localisations, bbox_img], feed_dict={img_input: img})# ssd_bboxes_select()函数根据每个特征层的分类预测分数,归一化后的映射坐标,# ancohor_box的大小,通过设定一个阈值计算得到每个特征层检测到的对象以及其分类和坐标rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(rpredictions, rlocalisations, ssd_anchors,select_threshold=select_threshold,img_shape=net_shape,num_classes=21, decode=True)"""这个函数做的事情比较多,这里说的细致一些:首先是输入,输入的数据为每个特征层(一共6个,见上文)的:rpredictions: 分类预测数据,rlocalisations: 坐标预测数据,ssd_anchors: anchors_box数据其中:分类预测数据为当前特征层中每个像素的每个box的分类预测坐标预测数据为当前特征层中每个像素的每个box的坐标预测anchors_box数据为当前特征层中每个像素的每个box的修正数据函数根据坐标预测数据和anchors_box数据,计算得到每个像素的每个box的中心和长宽,这个中心坐标和长宽会根据一个算法进行些许的修正,从而得到一个更加准确的box坐标;修正的算法会在后文中详细解释,如果只是为了理解算法流程也可以不必深究这个,因为这个修正算法属于经验算法,并没有太多逻辑可循。修正完box和中心后,函数会计算每个像素的每个box的分类预测数据的得分,当这个分数高于一个阈值(这里是0.5)则认为这个box成功框到了一个对象,然后将这个box的坐标数据,所属分类和分类得分导出,从而得到:rclasses:所属分类rscores:分类得分rbboxes:坐标最后要注意的是,同一个目标可能会在不同的特征层都被检测到,并且他们的box坐标会有些许不同,这里并没有去掉重复的目标,而是在下文中专门用了一个函数来去重"""# 检测有没有超出检测边缘rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)# 去重,将重复检测到的目标去掉rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)# 将box的坐标重新映射到原图上(上文所有的坐标都进行了归一化,所以要逆操作一次)rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)if case == 1:bboxes_draw_on_img(img, rclasses, rscores, rbboxes, colors_plasma, thickness=8)return imgelse:return rclasses, rscores, rbboxes"""
# 只做目标定位,不做预测分析
case = 1
img = cv2.imread("../demo/person.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(process_image(img, case))
plt.show()
"""
# 做目标定位,同时做预测分析
case = 2
path = '../VOC2007/JPEGImages/166.bmp'
# 读取图片
img = mpimg.imread(path)
# 执行主流程函数
rclasses, rscores, rbboxes = process_image(img, case)
# visualization.bboxes_draw_on_img(img, rclasses, rscores, rbboxes, visualization.colors_plasma)
# 显示分类结果图
visualization.plt_bboxes(img, rclasses, rscores, rbboxes), rscores, rbboxes

预测结果如下:

原图

预测结果

后续有精力会对参数进行精调,获得一张好的检测效果图。

工程代码见如下链接:

链接:https://pan.baidu.com/s/1EDWix2XvzF8URTxlbNLJCA 
提取码:3kyb

balancap/SSD-Tensorflow使用及训练预测自己的数据集相关推荐

  1. ultralytics/yolov3训练预测自己数据集的配置过程

    需要使用https://github.com/ultralytics/yolov3提供的pytorch yolov3版本来训练预测自己的数据集,以检测出感兴趣目标,目前还没有看到详细的资料,这边系统记 ...

  2. 阿里巴巴开源大规模稀疏模型训练/预测引擎DeepRec

    简介:经历6年时间,在各团队的努力下,阿里巴巴集团大规模稀疏模型训练/预测引擎DeepRec正式对外开源,助力开发者提升稀疏模型训练性能和效果. 作者 | 烟秋 来源 | 阿里技术公众号 经历6年时间 ...

  3. [转]TensorFlow如何进行时序预测

    TensorFlow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数组 ...

  4. DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测

    DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测 目录 输出结果 核心代码 输出结果 数据集 tensorboard可视化 iter: 0 loss: 0.010 ...

  5. [深度学习TF2][RNN-LSTM]文本情感分析包含(数据预处理-训练-预测)

    基于LSTM的文本情感分析 0. 前言 1. 数据下载 2. 训练数据介绍 3. 用到Word2Vector介绍 wordsList.npy介绍 wordVectors.npy介绍 4 数据预处理 4 ...

  6. 用TensorFlow的Linear/DNNRegrressor预测数据

    五月两场 | NVIDIA DLI 深度学习入门课程 5月19日/5月26日一天密集式学习  快速带你入门阅读全文> 正文共2762个字,预计阅读时间8分钟. 今天要处理的问题对于一个只学了线性 ...

  7. 作为深度学习最强框架的TensorFlow如何进行时序预测!(转)

    作为深度学习最强框架的TensorFlow如何进行时序预测! BigQuant 2 个月前 摘要: 2017年深度学习框架关注度排名tensorflow以绝对的优势占领榜首,本文通过一个小例子介绍了T ...

  8. Kesci:Tensorflow 实现 LSTM——时间序列预测(超详细)

    云脑项目3 -真实业界数据的时间序列预测挑战 https://www.kesci.com/home/project/5a391c670e1fc52691fde623 这篇文章将讲解如何使用lstm进行 ...

  9. [翻译] 使用 TensorFlow 进行分布式训练

    [翻译] 使用 TensorFlow 进行分布式训练 文章目录 [翻译] 使用 TensorFlow 进行分布式训练 0x00 摘要 1. 概述 2. 策略类型 2.1 MirroredStrateg ...

  10. 完整实现利用tensorflow训练自己的图片数据集

    经过差不多一个礼拜的时间的学习,终于把完整的一个利用自己爬取的图片做训练数据集的卷积神经网络的实现(基于tensorflow) 目录 总体思路 第三部分:预处理 第四部分:网络模型 第五部分:训练 2 ...

最新文章

  1. 待在小公司好多年了,微服务还没怎么玩过。。。
  2. linux pfn,ARM Linux下的page和pfn之间转换的宏。
  3. 计划策略的配置参数(2)
  4. c char转int_C/C++ 各数据类型占用字节数
  5. 频数直方图的步骤_如何运用QC七大手法和九大步骤分析问题?
  6. win10 Security Center服务无法禁用,启动类型灰色不可改解决方法
  7. MySQL的安装与配置——详细教程
  8. iOS开发iPhone竖屏icon尺寸与启动页尺寸汇总
  9. 基于中国天气网的数据库设计与开发(python+MySQL)
  10. 企业OKR终极目标:让员工成功
  11. win7设置无线wifi连接到服务器,自动连接wifi怎么设置_如何设置无线网自动连接...
  12. Kettle PDI工具连接Mysql时报Driver class ‘org.gjt.mm.mysql.Driver‘ could not be found, make sure the ‘MySQL
  13. Delphi中多库关联查询
  14. 机器学习----PyTorch入门
  15. 来自一位搞算法的本科生的学习感想
  16. 重要开源协议的比较(BSD,Apache,GPL,LGPL,MIT) – 整理
  17. 几款常用的表单设计器解决方案
  18. 计算机三级单片机考试试题及答案,2008秋计算机三级单片机试卷及部分答案
  19. 学习资料:Chisel汇总
  20. 关于brainstorm的正确打开方式

热门文章

  1. 计算机视觉领域专家主页代码
  2. 【听课笔记】复旦大学遗传学_05染色体畸变
  3. TypeScript系列教程十一《装饰器》 -- reflect-metadata
  4. (WSI分类)WSI分类文献小综述
  5. CUDA编程.cu文件
  6. 十本经典JavaScript书籍
  7. 社区团购小程序走热,小程序商城将创造新的电商神话
  8. 基于mina框架的GPS设备与服务器之间的交互
  9. 桶装水同城预订下单送水小程序开发制作(水站桶装水配送系统)
  10. w7计算机虚拟内存设置,win7虚拟内存怎么设置最好