关于SSD的源代码详细讲解,请参考文章:https://blog.csdn.net/c20081052/article/details/80391627  代码详解

本文是实战系列的第四篇,逼自己抽空写篇博客,把之前运行的程序po出来,供需要的人参考。

下载 SSD-Tensorflow-master 解压找到里面notebooks文件夹,本文主要针对这个文件夹下提供的事例做讲解;

主要涉及的文件有ssd_notebook.ipynbvisualization.py

通过cmd切换到这个目录下,然后用jupyter notebook打开ssd_notebook.ipynb 运行这个文件,run 每个cell,你会得到这个源代码提供的事例检测结果。这个只是针对图片做检测。接下来将做视频流的目标检测。操作和运行结果如下。

结果如下;

下面要做的是对ssd_notebook.ipynb 做些更改。我们将其保存成.py文件,然后重命名个,名字叫:ssd_notebook_camera.py ; 还有原来事例中图片bbox上方没有显示目标的类别名称,接下来还要对visualization.py 做些更改,我们copy一份它,重命名叫做visualization_camera.py 吧。

*****

说明下图片中SSD_Tensorflow_master这个文件夹其实就是你下载的SSD-Tensorflow-master这个文件解压得到的,我把它copy了一份并做了重命名(看文件夹的 ‘ _ ’ 不同)放在notebook文件夹下了,因为visualization_camera.py中会引用master里的一些函数。为了图方便,就这么操作了,其实就是文件包含路径的问题。

*****

下面是更改后的两个文件:

ssd_notebook_camera.py代码如下:

# coding: utf-8import os
import math
import randomimport numpy as np
import tensorflow as tf
import cv2slim = tf.contrib.slim#get_ipython().magic('matplotlib inline')
import matplotlib.pyplot as plt
import matplotlib.image as mpimgimport sys
sys.path.append('../')from nets import ssd_vgg_300, ssd_common, np_methods
from preprocessing import ssd_vgg_preprocessing
from notebooks import visualization_camera    #visualization# TensorFlow session: grow memory when needed. TF, DO NOT USE ALL MY GPU MEMORY!!!
gpu_options = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
isess = tf.InteractiveSession(config=config)# ## SSD 300 Model
#
# The SSD 300 network takes 300x300 image inputs. In order to feed any image, the latter is resize to this input shape (i.e.`Resize.WARP_RESIZE`). Note that even though it may change the ratio width / height, the SSD model performs well on resized images (and it is the default behaviour in the original Caffe implementation).
#
# SSD anchors correspond to the default bounding boxes encoded in the network. The SSD net output provides offset on the coordinates and dimensions of these anchors.# Input placeholder.
net_shape = (300, 300)
data_format = 'NHWC'
img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
# Evaluation pre-processing: resize to SSD net shape.
image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
image_4d = tf.expand_dims(image_pre, 0)# Define the SSD model.
reuse = True if 'ssd_net' in locals() else None
ssd_net = ssd_vgg_300.SSDNet()
with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)# Restore SSD model.
ckpt_filename = '../checkpoints/ssd_300_vgg.ckpt'   #可更改为自己的模型路径
# ckpt_filename = '../checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt'
isess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, ckpt_filename)# SSD default anchor boxes.
ssd_anchors = ssd_net.anchors(net_shape)# ## Post-processing pipeline
#
# The SSD outputs need to be post-processed to provide proper detections. Namely, we follow these common steps:
#
# * Select boxes above a classification threshold;
# * Clip boxes to the image shape;
# * Apply the Non-Maximum-Selection algorithm: fuse together boxes whose Jaccard score > threshold;
# * If necessary, resize bounding boxes to original image shape.# Main image processing routine.
def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)):# Run SSD network.rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],feed_dict={img_input: img})# Get classes and bboxes from the net outputs.rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(rpredictions, rlocalisations, ssd_anchors,select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)# Resize bboxes to original image shape. Note: useless for Resize.WARP!rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)return rclasses, rscores, rbboxes# # Test on some demo image and visualize output.
# path = '../demo/'
# image_names = sorted(os.listdir(path))# img = mpimg.imread(path + image_names[-5])
# rclasses, rscores, rbboxes =  process_image(img)# # visualization.bboxes_draw_on_img(img, rclasses, rscores, rbboxes, visualization.colors_plasma)
# visualization.plt_bboxes(img, rclasses, rscores, rbboxes)##### following are added for camera demo####
cap = cv2.VideoCapture(r'D:\person.avi')
fps = cap.get(cv2.CAP_PROP_FPS)
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
fourcc = cap.get(cv2.CAP_PROP_FOURCC)
#fourcc = cv2.CAP_PROP_FOURCC(*'CVID')
print('fps=%d,size=%r,fourcc=%r'%(fps,size,fourcc))
delay=30/int(fps)while(cap.isOpened()):ret,frame = cap.read()if ret==True:
#          image = Image.open(image_path)
#          gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)image = frame# the array based representation of the image will be used later in order to prepare the# result image with boxes and labels on it.image_np = image
#          image_np = load_image_into_numpy_array(image)# Expand dimensions since the model expects images to have shape: [1, None, None, 3]image_np_expanded = np.expand_dims(image_np, axis=0)# Actual detection.rclasses, rscores, rbboxes =  process_image(image_np)# Visualization of the results of a detection.visualization_camera.bboxes_draw_on_img(image_np, rclasses, rscores, rbboxes)
#          plt.figure(figsize=IMAGE_SIZE)
#          plt.imshow(image_np)cv2.imshow('frame',image_np)cv2.waitKey(np.uint(delay))print('Ongoing...')  else:break
cap.release()
cv2.destroyAllWindows()

其中

cap = cv2.VideoCapture(r'D:\person.avi') 是你读取视频的文件目录,自行更改。

以下是visualization_camera.py内容:

# Copyright 2017 Paul Balanca. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import cv2
import randomimport matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.cm as mpcm#added 20180516#####
def num2class(n):import SSD_Tensorflow_master.datasets.pascalvoc_2007 as pasx=pas.pascalvoc_common.VOC_LABELS.items()for name,item in x:if n in item:#print(name)return name
#adden end ########## =========================================================================== #
# Some colormaps.
# =========================================================================== #
def colors_subselect(colors, num_classes=21):dt = len(colors) // num_classessub_colors = []for i in range(num_classes):color = colors[i*dt]if isinstance(color[0], float):sub_colors.append([int(c * 255) for c in color])else:sub_colors.append([c for c in color])return sub_colorscolors_plasma = colors_subselect(mpcm.plasma.colors, num_classes=21)
colors_tableau = [(255, 255, 255), (31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),(44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),(148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),(227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),(188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]# =========================================================================== #
# OpenCV drawing.
# =========================================================================== #
def draw_lines(img, lines, color=[255, 0, 0], thickness=2):"""Draw a collection of lines on an image."""for line in lines:for x1, y1, x2, y2 in line:cv2.line(img, (x1, y1), (x2, y2), color, thickness)def draw_rectangle(img, p1, p2, color=[255, 0, 0], thickness=2):cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)def draw_bbox(img, bbox, shape, label, color=[255, 0, 0], thickness=2):p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)p1 = (p1[0]+15, p1[1])cv2.putText(img, str(label), p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.5, color, 1)def bboxes_draw_on_img(img, classes, scores, bboxes, colors=dict(), thickness=2):shape = img.shape####add 20180516######colors=dict()####add #############for i in range(bboxes.shape[0]):bbox = bboxes[i]if classes[i] not in colors:colors[classes[i]] = (random.random(), random.random(), random.random())p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))cv2.rectangle(img, p1[::-1], p2[::-1], colors[classes[i]], thickness)s = '%s/%.3f' % (num2class(classes[i]), scores[i])p1 = (p1[0]-5, p1[1])cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.4, colors[classes[i]], 1)  # =========================================================================== #
# Matplotlib show...
# =========================================================================== #
def plt_bboxes(img, classes, scores, bboxes, figsize=(10,10), linewidth=1.5):"""Visualize bounding boxes. Largely inspired by SSD-MXNET!"""fig = plt.figure(figsize=figsize)plt.imshow(img)height = img.shape[0]width = img.shape[1]colors = dict()for i in range(classes.shape[0]):cls_id = int(classes[i])if cls_id >= 0:score = scores[i]if cls_id not in colors:colors[cls_id] = (random.random(), random.random(), random.random())ymin = int(bboxes[i, 0] * height)xmin = int(bboxes[i, 1] * width)ymax = int(bboxes[i, 2] * height)xmax = int(bboxes[i, 3] * width)rect = plt.Rectangle((xmin, ymin), xmax - xmin,ymax - ymin, fill=False,edgecolor=colors[cls_id],linewidth=linewidth)plt.gca().add_patch(rect)##class_name = str(cls_id) #commented 20180516#### added 20180516#####class_name = num2class(cls_id)#### added end #########plt.gca().text(xmin, ymin - 2,'{:s} | {:.3f}'.format(class_name, score),bbox=dict(facecolor=colors[cls_id], alpha=0.5),fontsize=12, color='white')plt.show()

OK 了,运行上面那个ssd_notebook_camera.py文件,以下是视频检测结果(带目标类别名称):视频流的检测效果没有那么好,可能是训练模型用的是它自带推荐的,可自行训练试试效果。(我这还是CPU跑的……)

【深度学习实战04】——SSD tensorflow图像和视频的目标检测相关推荐

  1. Keras深度学习实战(38)——图像字幕生成

    Keras深度学习实战(38)--图像字幕生成 0. 前言 1. 模型与数据集分析 1.1 数据集分析 1.2 模型分析 2. 实现图像字幕生成模型 2.1 数据集加载与预处理 2.2 模型构建与训练 ...

  2. 深度学习实战22(进阶版)-AI漫画视频生成模型,做自己的漫画视频

    大家好,我是微学AI,今天给大家带来深度学习实战22(进阶版)-AI漫画视频生成模型. 回顾之前给大家介绍了<深度学习实战8-生活照片转化漫画照片应用>,今天我借助这篇文章的原理做一个AI ...

  3. 深度学习在计算机视觉领域(图像,视频,3D点云,深度图等)应用全览

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨黄浴@知乎 来源丨https://zhuanlan.zhihu.com/p/55747295 编 ...

  4. python发音机器人_Python深度学习实战:基于TensorFlow和Keras的聊天机器人以及人脸、物体和语音识别...

    序 第1章 TensorFlow基础 1 1.1 张量 2 1.2 计算图与会话 2 1.3 常量.占位符与变量 4 1.4 占位符 6 1.5 创建张量 8 1.5.1 固定张量 9 1.5.2 序 ...

  5. 【深度学习】一种关注于重要样本的目标检测方法!

    作者:宋志龙,浙江工业大学,Datawhale成员 在目标检测中训练模型时,样本间往往有差异性,不能被简单地同等对待.这次介绍的论文提出了一种重要样本的关注机制,在训练过程中帮助模型分辨哪些是重要的样 ...

  6. 【Tensorflow】深度学习实战04——Tensorflow实现VGGNet

    [fishing-pan:https://blog.csdn.net/u013921430转载请注明出处] 前言 现在已经到了Tensorflow实现卷积神经网络的第四讲了,既然是学习.实践,我一直坚 ...

  7. 深度学习实战23(进阶版)-语义分割实战,实现人物图像抠图的效果(计算机视觉)

    大家好,我是微学AI,今天给大家带来深度学习实战23(进阶版)-语义分割实战,实现人物图像抠图的效果.语义分割是计算机视觉中的一项重要任务,其目标是将图像中的每个像素都分配一个语义类别标签.与传统的目 ...

  8. 深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类

    文章目录 一.前期工作 1. 设置GPU 2. 导入预处理词库类 二.导入预处理词库类 三.参数设定 四.创建模型 五.训练模型函数 六.测试模型函数 七.训练模型与预测 今天给大家带来一个简单的中文 ...

  9. 深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了

    大家好,我是微学AI,今天给大家带来手写OCR识别的项目.手写的文稿在日常生活中较为常见,比如笔记.会议记录,合同签名.手写书信等,手写体的文字到处都有,所以针对手写体识别也是有较大的需求.目前手写体 ...

最新文章

  1. @2021 高考生,用 Python 分析专业“钱景”
  2. Spring容器初始化和bean创建过程
  3. 请举例说明@Qualifier 注解?
  4. 2020年国家电网计算机类考纲,终于发布!详解2020届国家电网考试大纲,带你读懂考纲变化!...
  5. 1015 德才论 (25分)
  6. python批量裁剪图片_用Python写了一个图片格式批量处理工具
  7. Spring(5)---松耦合实例
  8. linux 网卡无效 设置_请教,关于更改linux网卡配置文件后重启IP不生效的问题~
  9. 论文记载:FRAP:Learning Phase Competition for Traffic Signal Control
  10. SPSS学习笔记【二】-回归分析
  11. qpython 3h下载_QPython 3Hv3.0.0 Android
  12. AMD处理器的发展历程
  13. 农户在集市上卖西瓜,他总共有1020个西瓜,第一天卖掉一半多两个,第二天卖掉剩下的一半多两个, 问照此规律实下去,该农户几天能将所有的西瓜卖完。C语言
  14. Kali对网站进行DDOS攻击
  15. The Bequeath Protocol Adapter [ID 16653.1]
  16. html合并单元格和其中的数据,巧妙提取合并单元格及对应单元格数据
  17. 孙正义从阿里巴巴董事会辞职,原因是什么?
  18. java事件处理入门
  19. 2018年12月8日国际项目经理PMP培训考试报名中
  20. 服务器系统都有哪些?

热门文章

  1. grafana设置主页面板
  2. IE8允许ActiveX控件设置
  3. 计算机格式化后数据恢复的基础,电脑格式化了怎么恢复?——格式化数据恢复教程...
  4. 在Raspberry PI上搭建LMS服务器/Squeeze lite 播放器
  5. 后端面试之系统设计-短网址(Short URL)服务怎么设计?
  6. 2021/9/4王者荣耀服务器崩掉
  7. javafx-webview中加载的网页有弹窗不显示问题
  8. 极客时间《软件工程之美》学完感
  9. android开机字库加载过程,小米手机字库维修更换和EMMC字库编程烧写方法教程
  10. [小说连载]张小庆,在路上(6)- 真心话和大冒险