介绍

上一篇博文讲到tensorflow Object Detection api 基于SSD模型对数据进行训练,然后通过C++版本的opencv进行调用,但是通过实验发现,SSD虽然快但是准确率实在太低了,所以又重新使用Faster RCNN进行重新训练~废话不多说了,开始主要内容介绍了!

训练阶段

配置:GTX1060、I7-8700k~
关于object detection api的配置使用就不多说了,详细请参考:
https://blog.csdn.net/zong596568821xp/article/details/82015126
https://blog.csdn.net/chuquanchang1051/article/details/79804965
这里是基于faster_rcnn_resnet50_coco模型,下载链接见:faster_rcnn_resnet50_coco_2018_01_28,下载之后解压,解压后的文件如下图所示:

当解压模型后,这时需要到object_detection/samples/configs/文件夹中找到对应的config文件,这里是faster_rcnn_resnet50_coco.config文件,打开config文件,在这里我修改了它的类别数目,由于这里只有一种garbage类别,所以num_classes=1。

model {
faster_rcnn {
num_classes: 1  //有几种类别就写几种
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}

除此之外,还需要配置一些文件目录,其中fine_tune_checkpoint就是下载模型中的model.ckpt,input_path和label_map_path就是训练数据和标签,其中object-detection.pbtxt是自己建立的文件,记录着自己的标签信息,因为这里只有一种类别,格式如下:

item {id: 1name: 'garbage'
}
use_moving_average: false
}
gradient_clipping_by_norm: 10.0
fine_tune_checkpoint: "model_zoo/faster_rcnn_resnet50_coco_2018_01_28/model.ckpt"
from_detection_checkpoint: true
# Note: The below line limits the training process to 200K steps, which we
# empirically found to be sufficient enough to train the pets dataset. This
# effectively bypasses the learning rate schedule (the learning rate will
# never decay). Remove the below line to train indefinitely.
num_steps: 200000
data_augmentation_options {
random_horizontal_flip {
}
}
}train_input_reader: {
tf_record_input_reader {
input_path: "train_data/garbage/train/tf_record/train.record"
}
label_map_path: "model_zoo/faster_rcnn_resnet50_coco_2018_01_28/object-detection.pbtxt"
}eval_config: {
num_examples: 116  #测试集的数据数量,需要修改下
# Note: The below line limits the evaluation process to 10 evaluations.
# Remove the below line to evaluate indefinitely.
max_evals: 10
}eval_input_reader: {
tf_record_input_reader {
input_path: "train_data/garbage/test/tf_record/test.record"
}
label_map_path: "model_zoo/faster_rcnn_resnet50_coco_2018_01_28/object-detection.pbtxt"
shuffle: false
num_readers: 1
}

修改完配置文件之后,回到object detection api的train.py文件中,train.py文件原来在legacy文件夹中,将其复制出来即可。在train.py文件中,修改以下两项内容:

#模型保存路径
flags.DEFINE_string('train_dir', default='training_model/garbage/',help='')
#修改过的配置文件.config路径
flags.DEFINE_string('pipeline_config_path', default='samples/configs/faster_rcnn_resnet50_coco.config',help='')

然后就可以训练了~

训练一晚后的效果!

模型的转换

当训练结束后,需要把训练好的模型转换为.pb文件,下图为训练生成的文件。

转换文件为export_inference_graph.py,源代码如下所示,其中转换的模型要为最好的模型:

import tensorflow as tf
from google.protobuf import text_format
from object_detection import exporter
from object_detection.protos import pipeline_pb2slim = tf.contrib.slim
flags = tf.app.flagsflags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be ''one of [`image_tensor`, `encoded_image_string_tensor`, ''`tf_example`]')
flags.DEFINE_string('input_shape', None,'If input_type is `image_tensor`, this can explicitly set ''the shape of this input tensor to a fixed size. The ''dimensions are to be provided as a comma-separated list ''of integers. A value of -1 can be used for unknown ''dimensions. If not specified, for an `image_tensor, the ''default shape will be partially specified as ''`[None, None, None, 3]`.')
flags.DEFINE_string('pipeline_config_path', 'samples/configs/faster_rcnn_resnet50_coco.config','Path to a pipeline_pb2.TrainEvalPipelineConfig config ''file.')
flags.DEFINE_string('trained_checkpoint_prefix', 'training_model/garbage/model.ckpt-103122','Path to trained checkpoint, typically of the form ''path/to/model.ckpt')
flags.DEFINE_string('output_directory', 'pb_model/faster_rcnn_resnet50_coco', 'Path to write outputs.')
flags.DEFINE_string('config_override', '','pipeline_pb2.TrainEvalPipelineConfig ''text proto to override pipeline_config_path.')
flags.DEFINE_boolean('write_inference_graph', False,'If true, writes inference graph to disk.')
tf.app.flags.mark_flag_as_required('pipeline_config_path')
tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix')
tf.app.flags.mark_flag_as_required('output_directory')
FLAGS = flags.FLAGSdef main(_):pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:text_format.Merge(f.read(), pipeline_config)text_format.Merge(FLAGS.config_override, pipeline_config)if FLAGS.input_shape:input_shape = [int(dim) if dim != '-1' else Nonefor dim in FLAGS.input_shape.split(',')]else:input_shape = Noneexporter.export_inference_graph(FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_prefix,FLAGS.output_directory, input_shape=input_shape,write_inference_graph=FLAGS.write_inference_graph)if __name__ == '__main__':tf.app.run()

转换好之后,会在对应的文件夹中生成.pb文件。

这和最初下载的模型非常相似,但是里面的结构有所不同。opencv在调用faster rcnn等物体检测模型时,还需要一个.pbtxt文件,它可以告诉函数该怎么读取模型,该转换由tf_text_graph_faster_rcnn.py来完成。

详细请参考https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API#generate-a-config-file,该文件在opencv的dnn模块中,下载链接:dnn模块,下载完之后放到object_detection文件夹下即可。

需要修改的地方有input、output、config、num_classes,如果没有num_classes,可以将其加上,也可以不加,.config文件中包含这项。


if __name__ == "__main__":parser = argparse.ArgumentParser(description='Run this script to get a text graph of ''Faster-RCNN model from TensorFlow Object Detection API. ''Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')parser.add_argument('--input', default='F:/tensorflow_object_detection/object_detection/pb_model/faster_rcnn_resnet50_coco/frozen_inference_graph.pb',help='Path to frozen TensorFlow graph.')parser.add_argument('--output', default='F:/tensorflow_object_detection/object_detection/pb_model/faster_rcnn_resnet50_coco/faster_rcnn.pbtxt',help='Path to output text graph.')parser.add_argument('--config',default='F:/tensorflow_object_detection/object_detection/samples/configs/faster_rcnn_resnet50_coco.config',help='Path to a *.config file is used for training.')# parser.add_argument('--num_classes', required=True, default=1,help='Path to a *.config file is used for training.')args = parser.parse_args()createFasterRCNNGraph(args.input, args.config, args.output)

转换完成之后会生成faster_rcnn.pbtxt文件,然后就交给C++调用吧!除了C++,还有一个opencv-python版本的测试程序:

import cv2 as cvcvNet = cv.dnn.readNetFromTensorflow('C:/Users/18301/Desktop/faster_rcnn_resnet50_coco/frozen_inference_graph.pb','C:/Users/18301/Desktop/faster_rcnn_resnet50_coco/faster_rcnn.pbtxt')img = cv.imread('C:/Users/18301/Desktop/images/image12.jpg')
rows = img.shape[0]
cols = img.shape[1]
cvNet.setInput(cv.dnn.blobFromImage(img, size=(300, 300), swapRB=True, crop=False))
cvOut = cvNet.forward()
print(cvOut)
for detection in cvOut[0,0,:,:]:score = float(detection[2])if score > 0.3:left = detection[3] * colstop = detection[4] * rowsright = detection[5] * colsbottom = detection[6] * rowscv.rectangle(img, (int(left), int(top)), (int(right), int(bottom)), (23, 230, 210), thickness=2)cv.imshow('img', img)
cv.waitKey()

Opencv调用Faster RCNN源代码

#include<opencv2\opencv.hpp>
#include<opencv2\dnn.hpp>
#include <iostream>
#include<map>
#include<string>
#include<time.h>using namespace std;
using namespace cv;//这里的宽度和高度不能太小了,否则识别率会降低,但是大的尺寸会消耗比较多的时间。
const size_t inWidth = 600;
const size_t inHeight = 600;
//const float WHRatio = inWidth / (float)inHeight;
const char* classNames[] = { "garbage" };//只有一类的话,写一类就好//这是coco数据集的类别
//const char* classNames[] = { "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
//"fire hydrant", "background", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "background", "backpack",
//"umbrella", "background", "background", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
//"bottle", "background", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
//"cake", "chair", "couch", "potted plant", "bed", "background", "dining table", "background", "background", "toilet", "background", "tv", "laptop", "mouse", "remote", "keyboard",
//"cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "background", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "background" };const float WHRatio = inWidth / (float)inHeight;int main()
{String weights = "C:/Users/18301/Desktop/faster_rcnn_resnet50_coco/frozen_inference_graph.pb";String prototxt = "C:/Users/18301/Desktop/faster_rcnn_resnet50_coco/faster_rcnn.pbtxt";dnn::Net net = cv::dnn::readNetFromTensorflow(weights, prototxt);Mat frame = cv::imread("C:/Users/18301/Desktop/images/image17.jpg");Size frame_size = frame.size();cv::Mat blob = cv::dnn::blobFromImage(frame, 1, Size(inWidth, inHeight), false, true);//这里的格式是个坑,参考其它博客真的是不行,最后在google上找了好久才发现这个问题,这是个坑!net.setInput(blob);Mat output = net.forward();Mat detectionMat(output.size[2], output.size[3], CV_32F, output.ptr<float>());float confidenceThreshold = 0.5;for (int i = 0; i < detectionMat.rows; i++){float confidence = detectionMat.at<float>(i, 2);if (confidence > confidenceThreshold){size_t objectClass = (size_t)(detectionMat.at<float>(i, 1));int xLeftBottom = static_cast<int>(detectionMat.at<float>(i, 3) * frame.cols);int yLeftBottom = static_cast<int>(detectionMat.at<float>(i, 4) * frame.rows);int xRightTop = static_cast<int>(detectionMat.at<float>(i, 5) * frame.cols);int yRightTop = static_cast<int>(detectionMat.at<float>(i, 6) * frame.rows);ostringstream ss;ss << confidence;String conf(ss.str());Rect object((int)xLeftBottom, (int)yLeftBottom,(int)(xRightTop - xLeftBottom),(int)(yRightTop - yLeftBottom));rectangle(frame, object, Scalar(0, 255, 0), 2);String label = String(classNames[objectClass]) + ": " + conf;int baseLine = 0;Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);rectangle(frame, Rect(Point(xLeftBottom, yLeftBottom - labelSize.height),Size(labelSize.width, labelSize.height + baseLine)),Scalar(0, 255, 0), CV_FILLED);putText(frame, label, Point(xLeftBottom, yLeftBottom),FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0));}}namedWindow("image", CV_WINDOW_NORMAL);imshow("image", frame);waitKey(0);return 0;
}

注意:cv::Mat blob = cv::dnn::blobFromImage(frame, 1, Size(inWidth, inHeight), false, true);这里会导致faster rcnn检测结果乱码,自己也是查了好久才发现,按这个格式来,亲测有效!

这里的宽度和高度不能太小了,否则识别率会降低,但是大的尺寸会消耗比较多的时间,自己找一个折衷吧!

const size_t inWidth = 600;
const size_t inHeight = 600;

实验结果




python训练Faster RCNNC++调用训练好的模型进行物体检测-基于opencv3.4.3(超详细)相关推荐

  1. 详解200行Python代码实现控制台版2048【总有一款坑适合你】【超详细】

    跟着实验楼学习了2048的Python实现,先丢个地址 200行Python代码实现2048 我接触Python时间不长,只了解一些基本的语法和容器,在学习的过程中遇到不少问题,这里做一个记录. cu ...

  2. python智能机器人设计与实现_从AI模型到智能机器人:基于Python与TensorFlow

    领取成功 您已领取成功! 您可以进入Android/iOS/Kindle平台的多看阅读客户端,刷新个人中心的已购列表,即可下载图书,享受精品阅读时光啦! - | 回复不要太快哦~ 回复内容不能为空哦 ...

  3. python函数定义及调用-python函数声明和调用定义及原理详解

    这篇文章主要介绍了python函数声明和调用定义及原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 函数是指代码片段,可以重复调用,比如我们前 ...

  4. python函数声明和调用定义及原理详解

    这篇文章主要介绍了python函数声明和调用定义及原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 函数是指代码片段,可以重复调用,比如我们前 ...

  5. caffe学习(六):使用python调用训练好的模型来分类(Ubuntu)

    在caffe的学习过程中,我发现我需要一个模板的程序来方便我测试训练的模型.我在上一篇博客中(caffe学习(五):cifar-10数据集训练及测试(Ubuntu) ),最后测试训练好的模型时是修改c ...

  6. Python视觉深度学习系列教程 第三卷 第14章 从头开始训练Faster R-CNN

            第三卷 第十四章 从头开始训练Faster R-CNN 本章的目的是达到以下四点: 1.在您的系统上安装和配置 TensorFlow Object Detection API. 2.在 ...

  7. pytorch基础知识+构建LeNet对Cifar10进行训练+PyTorch-OpCounter统计模型大小和参数量+模型存储与调用

    整个环境的配置请参考我另一篇博客.ubuntu安装python3.5+pycharm+anaconda+opencv+docker+nvidia-docker+tensorflow+pytorch+C ...

  8. MUSDB18-HQ音乐分轨训练集和MUSDB调用方法

    MUSDB18-HQ音乐分轨训练集和MUSDB调用方法 MUSDB18和MUSDB18-HQ简单介绍 MUSDB 组成 解析器 Musdb库API介绍 用法举例 设置musdb 遍历MUSDB18曲目 ...

  9. Tensorflow下用自己的数据集对Faster RCNN进行训练和测试(二)1

    原 Tensorflow下用自己的数据集对Faster RCNN进行训练和测试(二) 2018年08月21日 22:20:38 子季鹰才 阅读数:1811 对于Tensorflow版本的Faster ...

  10. 调用训练好的模型(tensorflow)

    使用Tensorflow框架完美保存并实现调用训练好的模型 opencv调用tf训练好的模型          主机调用,不用安装tf,不需要显卡 OpenCV的dnn模块调用TesorFlow训练的 ...

最新文章

  1. python 日志模块 logging
  2. .NET设计模式(15):结构型模式专题总结
  3. 有序数组给定始末的中位数c++
  4. jquery程序 windows移植到linux显示不了,windows程序移植linux
  5. Python多线程的两种实现方式
  6. jsr223 java_Jmeter 组件 JSR223 使用详解
  7. VTK:图片之ImageRFFT
  8. wordweb在线编辑_使用WordWeb享受按需词典和词库功能
  9. centos7配置br0_Docker CentOS7 修改网络配置与宿主机桥接
  10. 大数据时代,如何做好数字化精益生产?附26页智慧工厂解决方案
  11. 鼠标经过背景图片变换
  12. ParNew垃圾回收器总结
  13. 卸载程序_Windows 7 如何卸载或删除应用和程序,我教你
  14. python经济统计学论文_统计学论文
  15. socket.io实现简单多人聊天室
  16. 单片机消抖c语言程序,单片机中按键消抖程序
  17. C++求最大公约数和最小公倍数
  18. 2022年的1024程序员节有啥好玩的小游戏推荐?
  19. 记一个bug:ImportError: cannot import name ‘comb‘
  20. python+opencv入门-基于Harr特征的人脸检测分类器

热门文章

  1. universal link使用
  2. python selenium 打开新窗口
  3. Xcode8上传app一直显示正在处理
  4. 自学如何去学习jQuery
  5. ubuntu下安装python的gevent模块遇到的一个问题
  6. oracle 大批量数据更新
  7. Nginx系列——Windows中安装Nginx
  8. javascript classList add报错
  9. Linux下rpm安装MySQL及配置
  10. 成都文理学院计算机一级还没考过,两次查成绩不一致,合格成不合格?成都文理学院官方回应...