构建自己的模型之前,推荐先跑一下Tensorflow object detection API的demo

JustDoIT:目标检测Tensorflow object detection API​zhuanlan.zhihu.com


比较喜欢杰伦和奕迅,那就来构建检测他们的模型吧

1.准备训练数据和测试数据

D:python3models-masterresearchobject_detection新建一个名为images的文件夹

再images文件下创建两个文件夹,一个名为train,另一个名为test,文件结构如下图

train里面有杰伦的55张图片和奕迅的30张图片(其实本来打算都弄100张图片的,无奈宿舍网速不好)

test里面有杰伦的10张图片和奕迅的10张图片

图片命名格式均是image+数字,图片类型是jpg格式,如图

对每一张图片做标签,生成包含该图片标签以及位置信息的xml文件,推荐一款小软件LabelImg,方便快捷做标签

tzutalin/labelImg​github.com

点击所圈处然后下载最新的版本,然后解压,解压完如图

打开该应用程序,界面如图

点击左上角所圈处的Open Dir然后选择对应目录并点击右下角所圈处的选择文件夹,我这里选择的是train文件(test文件夹也要执行和train文件夹一样的操作,即生成xml和record文件),效果如图

然后点击左边的Create RectBox按钮,然后圈出杰伦,会跳出对话框,输入标签,我这里输的是ZJL,如果是陈奕迅就输入CYX,效果如图

点击ok然后再点左边的save按钮即可生成对应该图片的xml文件,效果如图

最后的效果如图

然后把所有的xml集合成csv文件,需要用到Python代码来实现,代码如下,把如下代码复制粘贴到一个python文件里

'''

运行上述代码,生成如下图所示的csv文件

因为Tensorflow object detection API的输入数据格式是TFRcords Format格式的,所以我们要把csv文件转化成record文件,先把上面生成的train.csv和test.csv复制粘贴到D:python3models-masterresearchobject_detectiondata,如图

然后需要用到Python代码来实现csv到record的转换,代码如下,把如下代码复制粘贴到一个D:python3models-masterresearchobject_detection 下的名为generate_TFR.py文件里

"""
Usage:# From tensorflow/models/# Create train data:
python generate_TFR.py --csv_input=data/train.csv  --output_path=data/train.record# Create test data:
python generate_TFR.py --csv_input=data/test.csv  --output_path=data/test.record需要修改三处os.chdir('D:python3models-masterresearchobject_detection')path = os.path.join(os.getcwd(), 'images/train')def class_text_to_int(row_label): #对应的标签返回一个整数,后面会有文件用到if row_label == 'ZJL':return 1elif row_label == 'CYX':return 2else:None
"""import os
import io
import pandas as pd
import tensorflow as tffrom PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDictos.chdir('D:python3models-masterresearchobject_detection')flags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS# TO-DO replace this with label map
def class_text_to_int(row_label):if row_label == 'ZJL':return 1elif row_label == 'CYX':return 2else:Nonedef split(df, group):data = namedtuple('data', ['filename', 'object'])gb = df.groupby(group)return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path):with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:encoded_jpg = fid.read()encoded_jpg_io = io.BytesIO(encoded_jpg)image = Image.open(encoded_jpg_io)width, height = image.sizefilename = group.filename.encode('utf8')image_format = b'jpg'xmins = []xmaxs = []ymins = []ymaxs = []classes_text = []classes = []for index, row in group.object.iterrows():xmins.append(row['xmin'] / width)xmaxs.append(row['xmax'] / width)ymins.append(row['ymin'] / height)ymaxs.append(row['ymax'] / height)classes_text.append(row['class'].encode('utf8'))classes.append(class_text_to_int(row['class']))tf_example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/filename': dataset_util.bytes_feature(filename),'image/source_id': dataset_util.bytes_feature(filename),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature(image_format),'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),}))return tf_exampledef main(_):writer = tf.python_io.TFRecordWriter(FLAGS.output_path)path = os.path.join(os.getcwd(), 'images/test') #20180418做了修改examples = pd.read_csv(FLAGS.csv_input)grouped = split(examples, 'filename')for group in grouped:tf_example = create_tf_example(group, path)writer.write(tf_example.SerializeToString())writer.close()output_path = os.path.join(os.getcwd(), FLAGS.output_path)print('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__':tf.app.run()

然后再“开始-Anaconda3-Anaconda Prompt”调出命令行,改变工作目录至 models-masterresearchobject_detection,输入下面命令行

转换train.csv对应的是
python generate_TFR.py --csv_input=data/train.csv --output_path=data/train.record
转换test.csv对应的是
python generate_TFR.py --csv_input=data/test.csv --output_path=data/test.record

出现下图即为转换成功

到此,数据的准备工作已经完成


2.配置文件和模型

进入 Object Detection github寻找目标模型

tensorflow/models​github.com

我这里选择的是ssd_mobilenet_v1_coco.config ,点击打开并复制里面的代码到新建的名为ssd_mobilenet_v1_coco.config的文件里,并在D:python3models-masterresearchobject_detection目录下新建一个名为training的文件夹,并把ssd_mobilenet_v1_coco.config放到train文件夹中,如下图

用文本编辑器打开ssd_mobilenet_v1_coco.config文件,如下所示

# SSD with Mobilenet v1 configuration for MSCOCO Dataset.
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
'''
需修改5处
1、train_input_reader: {tf_record_input_reader {input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record"}label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}
这的input_path是训练数据的路径,改为对应的路径,这里是input_path:data/train.record
这的label_map_path是label路径,这里是label_map_path:data/ZJL_CYX.pbtxt
2、eval_input_reader: {tf_record_input_reader {input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record"}label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"shuffle: falsenum_readers: 1
}
这的input_path是测试数据的路径,改为对应的路径,这里是input_path:data/test.record
这的label_map_path是label路径,这里是label_map_path:data/ZJL_CYX.pbtxt
3、ssd {num_classes: 90box_coder {faster_rcnn_box_coder {y_scale: 10.0x_scale: 10.0height_scale: 5.0width_scale: 5.0}}
num_classes是标签类别数,这里只有杰伦和奕迅,所以 num_classes: 2
4、train_config: {batch_size: 24optimizer {rms_prop_optimizer: {learning_rate: {exponential_decay_learning_rate {initial_learning_rate: 0.004decay_steps: 800720decay_factor: 0.95}}momentum_optimizer_value: 0.9decay: 0.9epsilon: 1.0}
}
batch_size是每次迭代的数据数,我这里设为1
5、fine_tune_checkpoint: "ssd_mobilenet_v1_coco_11_06_2017/model.ckpt"from_detection_checkpoint: true
这两行注释掉或者删除掉,否则会运行很慢
'''
model {ssd {num_classes: 90box_coder {faster_rcnn_box_coder {y_scale: 10.0x_scale: 10.0height_scale: 5.0width_scale: 5.0}}matcher {argmax_matcher {matched_threshold: 0.5unmatched_threshold: 0.5ignore_thresholds: falsenegatives_lower_than_unmatched: trueforce_match_for_each_row: true}}similarity_calculator {iou_similarity {}}anchor_generator {ssd_anchor_generator {num_layers: 6min_scale: 0.2max_scale: 0.95aspect_ratios: 1.0aspect_ratios: 2.0aspect_ratios: 0.5aspect_ratios: 3.0aspect_ratios: 0.3333}}image_resizer {fixed_shape_resizer {height: 300width: 300}}box_predictor {convolutional_box_predictor {min_depth: 0max_depth: 0num_layers_before_predictor: 0use_dropout: falsedropout_keep_probability: 0.8kernel_size: 1box_code_size: 4apply_sigmoid_to_scores: falseconv_hyperparams {activation: RELU_6,regularizer {l2_regularizer {weight: 0.00004}}initializer {truncated_normal_initializer {stddev: 0.03mean: 0.0}}batch_norm {train: true,scale: true,center: true,decay: 0.9997,epsilon: 0.001,}}}}feature_extractor {type: 'ssd_mobilenet_v1'min_depth: 16depth_multiplier: 1.0conv_hyperparams {activation: RELU_6,regularizer {l2_regularizer {weight: 0.00004}}initializer {truncated_normal_initializer {stddev: 0.03mean: 0.0}}batch_norm {train: true,scale: true,center: true,decay: 0.9997,epsilon: 0.001,}}}loss {classification_loss {weighted_sigmoid {}}localization_loss {weighted_smooth_l1 {}}hard_example_miner {num_hard_examples: 3000iou_threshold: 0.99loss_type: CLASSIFICATIONmax_negatives_per_positive: 3min_negatives_per_image: 0}classification_weight: 1.0localization_weight: 1.0}normalize_loss_by_num_matches: truepost_processing {batch_non_max_suppression {score_threshold: 1e-8iou_threshold: 0.6max_detections_per_class: 100max_total_detections: 100}score_converter: SIGMOID}}
}train_config: {batch_size: 24optimizer {rms_prop_optimizer: {learning_rate: {exponential_decay_learning_rate {initial_learning_rate: 0.004decay_steps: 800720decay_factor: 0.95}}momentum_optimizer_value: 0.9decay: 0.9epsilon: 1.0}}fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"from_detection_checkpoint: true# Note: The below line limits the training process to 200K steps, which we# empirically found to be sufficient enough to train the pets dataset. This# effectively bypasses the learning rate schedule (the learning rate will# never decay). Remove the below line to train indefinitely.num_steps: 200000data_augmentation_options {random_horizontal_flip {}}data_augmentation_options {ssd_random_crop {}}
}train_input_reader: {tf_record_input_reader {input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record"}label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}eval_config: {num_examples: 8000# Note: The below line limits the evaluation process to 10 evaluations.# Remove the below line to evaluate indefinitely.max_evals: 10
}eval_input_reader: {tf_record_input_reader {input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record"}label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"shuffle: falsenum_readers: 1
}

上面代码注释的第1,2处的data/ZJL_CYX.pbtxt文件需要自己新建,可以复制一个文件然后把文件名改了即可,如图

打开该文件,修改文件内容为

item {
name: "ZJL"
id: 1
display_name: "ZJL"
}
item {
name: "CYX"
id: 2
display_name: "CYX"
}

配置到此完成,开始训练


3.开始训练模型

“开始-Anaconda3-Anaconda Prompt”调出命令行,改变工作目录至 models-masterresearchobject_detection,输入下面命令行

python train.py --logtostderr --train_dir=training/ --pipeline_config_path=training/ssd_mobilenet_v1_coco.config

如果没有报错的话那就慢慢等待结果

如果有报以下的错 TypeError: `pred` must be a Tensor, or a Python bool, or 1 or 0. Found instead: None,那么需要把下图所圈文件的109行 is_training=None 改为 is_training=True

如果有报以下的错Tensorflow object detection API之InvalidArgumentError: image_size must contain 3 elements[4],请参考下面链接

JustDoIT:Tensorflow object detection API之InvalidArgumentError: image_size must contain 3 elements[4]​zhuanlan.zhihu.com

可以通过可视化的页面看优化的的情况

通过“开始-Anaconda3-Anaconda Prompt”调出命令行,改变工作目录至 models-masterresearchobject_detection 执行下面的命令

tensorboard --logdir=training

出现下图

复制上图所圈处的地址到火狐浏览器打开,会出现下图的界面

可以看到每迭代一次的情况

..................................................................大概经过两个多小时的等待(可能会出现训练中断或者卡顿,那应该是显存不足,所以重新输入上述命令接着训练,是的,是接着上次中断的地方开始训练),迭代到了8000次。

我们可以先来测试一下目前的模型效果如何,关闭命令行。在 D:python3models-masterresearchobject_detection 文件夹下找到 export_inference_graph.py 文件,要运行这个文件,还需要传入config以及checkpoint的相关参数。

“开始-Anaconda3-Anaconda Prompt”调出命令行,改变工作目录至 models-masterresearchobject_detection 执行下面的命令

python export_inference_graph.py --input_type image_tensor --pipeline_config_path training/ssd_mobilenet_v1_coco.config --trained_checkpoint_prefix training/model.ckpt-31012 --output_directory ZJL_CYX_inference_graph

这里的--output_directory 是输出模型的文件夹名称

运行上述命令后会在object_detection文件夹下生成ZJL_CYX_inference_graph文件夹,内容如下图

到此为止,我们的模型已经构建完成了,接下来是开始测试效果了


4.测试模型效果

对以下的代码做一点修改即可

'''

把以上代码复制到新建的python文件中,我这里命名为ZJLCYX_test.py 将其保存到D:python3models-masterresearchobject_detection 文件夹下

然后再D:python3models-masterresearchobject_detectiontest_images文件夹下放测试的图,如下图

最后打开spider运行ZJLCYX_test.py 文件

运行结果如下图

至此,整个过程就完成了

ssd目标检测训练自己的数据_目标检测Tensorflow object detection API之训练自己的数据集...相关推荐

  1. 建立自己的数据集 并用Tensorflow object detection API进行训练

    ps: 欢迎大家光临我的博客 建立数据集 标注工具: ubuntu 图像标注工具labelImg sudo apt-get install pyqt5-dev-tools sudo pip3 inst ...

  2. 使用tensorflow object detection API 训练自己的目标检测模型 (三)

    在上一篇博客"使用tensorflow object detection API 训练自己的目标检测模型 (二)"中介绍了如何使用LabelImg标记数据集,生成.xml文件,经过 ...

  3. 转 TensorFlow Object Detection API 多GPU 卡平行计算,加速模型训练速度教学

    本篇记录如何使用多张GPU 显示卡,加速TensorFlow Object Detection API 模型训练的过程. 虽然TensorFlow Object Detection API 已经有支援 ...

  4. TensorFlow Object Detection API 多GPU 卡平行计算,加速模型训练速度教学

    本篇记录如何使用多张GPU 显示卡,加速TensorFlow Object Detection API 模型训练的过程. 虽然TensorFlow Object Detection API 已经有支援 ...

  5. Tensorflow object detection API训练自己的目标检测模型 详细配置教程 (一)

    Tensorflow object detection API 简单介绍Tensorflow object detection API: 这个API是基于tensorflow构造的开源框架,易于构建. ...

  6. 使用tensorflow object detection API 训练自己的目标检测模型 (二)labelImg的安装配置过程

    上一篇博客介绍了goggle的tensorflow object detection API 的配置和使用, 这次介绍一下如何用这个API训练一个私人定制的目标检测模型. 第一步:准备自己的数据集.比 ...

  7. Tensorflow object detection API 搭建自己的目标检测模型并迁移到Android上

    参考链接:https://blog.csdn.net/dy_guox/article/details/79111949 之前参考上述一系列博客在Windows10下面成功运行了TensorFlow A ...

  8. 关于使用tensorflow object detection API训练自己的模型-补充部分(代码,数据标注工具,训练数据,测试数据)

    之前分享过关于tensorflow object detection API训练自己的模型的几篇博客,后面有人陆续碰到一些问题,问到了我解决方法.所以在这里补充点大家可能用到的东西.声明一下,本人专业 ...

  9. TensorFlow Object Detection API 技术手册(5)——制作自己的目标检测数据集

    TensorFlow Object Detection API 技术手册(5)--制作自己的目标检测数据集 (一)收集图片 (二)安装图像打标工具labelImg (三)将XML文件转化为CSV文件 ...

最新文章

  1. 【S操作】一个简单粗暴易用的远程调试方案——OTA http update
  2. 使用DataSet对象添加记录
  3. Java黑皮书课后题第10章:*10.17(平方数)找出大于Long.MAX_VALUE的前10个平方数。平方数是指形式为n^2的数
  4. TextPaint绘制文字
  5. Maven:导入Oracle的jar包时出现错误
  6. Oracle修改表空间大小
  7. Nodejs-增删改查-案列方法
  8. 软考倒计时27天:信息系统集成专业技术知识
  9. Win11玩永劫无间闪退怎么办?Win11玩永劫无间闪退的解决方法
  10. python入门经典100题-零基础学习Python开发练习100题实例(1)
  11. 从XmlDocument到XDocument的转换
  12. java对象调用方法,java 对象调用
  13. Eclipse环境搭建-scala
  14. 河南省邓州市计算机学校,2019年邓州市职业技术学校招生简章及招生专业
  15. 一种用于压力传感器的温度控制系统设计
  16. android 360度环拍,Android 4.2系统360度全景图拍摄试玩
  17. BERYL和COMPIZ FUSION的安装与使用
  18. 胡玮炜离职,摩拜成美团大包袱,王兴后悔了吗?
  19. 记录红米K20pro至尊版刷机安装httpcanary抓包全过程
  20. 个总开源License授权

热门文章

  1. 华为云企业级Redis评测第一期:稳定性与扩容表现
  2. 直播丨数据安全:Oracle多场景下比特币勒索的揭密与恢复实战
  3. 两万字深度介绍分布式系统原理,一篇通透
  4. 云图说|一张图带你了解华为云分布式数据库中间件
  5. Developer 转型记:一个开发平台的“魔力”
  6. 移动端开发语言的未来的猜想#华为云·寻找黑马程序员#
  7. 从0到1入门:7天玩转IoT物联网实战营丨IoT喊你加入学习之旅!
  8. 云+AI+5G时代,华为云已准备好多元化云服务架构
  9. android ndk 在project中加入引入dll,在Android-Studio中导入“预建库”(NDK支持)
  10. SpringBoot与Redis缓存