Fast RCNN 训练自己的数据集(3训练和检测)

转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/

https://github.com/YihangLou/fast-rcnn-train-another-dataset 这是我在github上修改的几个文件的链接,求星星啊,求星星啊(原谅我那么不要脸~~)

在之前两篇文章中我介绍了怎么编译Fast RCNN,和怎么修改Fast RCNN的读取数据接口,接下来我来说明一下怎么来训练网络和之后的检测过程
先给看一下极好的检测效果

https://github.com/YihangLou/fast-rcnn-train-another-dataset

1.预训练模型介绍

首先在data目录下,有两个目录就是之前在1中解压好

  • fast_rcnn_models/
  • imagenet_models/

fast_rcnn_model文件夹下面是作者用fast rcnn训练好的三个网络,分别对应着小、中、大型网络,大家可以试用一下这几个网络,看一些检测效果,他们训练都迭代了40000次,数据集都是pascal_voc的数据集。

  1. caffenet_fast_rcnn_iter_40000.caffemodel
  2. vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel
  3. vgg16_fast_rcnn_iter_40000.caffemodel

imagenet_model文件夹下面是在Imagenet上训练好的通用模型,在这里用来初始化网络的参数

  1. CaffeNet.v2.caffemodel
  2. VGG_CNN_M_1024.v2.caffemodel
  3. VGG16.v2.caffemodel

在这里我比较推荐先用中型网络训练,中型网络训练和检测的速度都比较快,效果也都比较理想,大型网络的话训练速度比较慢,我当时是5000多个标注信息,网络配置默认,中型网络训练大概两三个小时,大型网络的话用十几个小时,需要注意的是网络训练最好用GPU,CPU的话太慢了,我当时用的实验室的服务器,有16块Tesla K80,用起来真的是灰常爽!

2. 修改模型文件配置

模型文件在models下面对应的网络文件夹下,在这里我用中型网络的配置文件修改为例子
比如:我的检测目标物是car ,那么我的类别就有两个类别即 background 和 car
因此,首先打开网络的模型文件夹,打开train.prototxt
修改的地方重要有三个
分别是个地方

  1. 首先在data层把num_classes 从原来的21类 20类+背景 ,改成 2类 车+背景
  2. 接在在cls_score层把num_output 从原来的21 改成 2
  3. 在bbox_pred层把num_output 从原来的84 改成8, 为检测类别个数乘以4,比如这里是2类那就是2*4=8

OK,如果你要进一步修改网络训练中的学习速率,步长,gamma值,以及输出模型的名字,需要在同目录下的solver.prototxt中修改。
如下图:

train_net: "models/VGG_CNN_M_1024/train.prototxt"
base_lr: 0.001
lr_policy: "step"
gamma: 0.1
stepsize: 30000
display: 20
average_loss: 100
momentum: 0.9
weight_decay: 0.0005
# We disable standard caffe solver snapshotting and implement our own snapshot
# function
snapshot: 0
# We still use the snapshot prefix, though
snapshot_prefix: "vgg_cnn_m_1024_fast_rcnn"
#debug_info: true

3.启动Fast RCNN网络训练

启动训练:
./tools/train_net.py --gpu 11 --solver models/VGG_CNN_M_1024_LOUYIHANG/solver.prototxt --weights data/imagenet_models/VGG_CNN_M_1024.v2.caffemodel --imdb KakouTrain

参数讲解:

  • 这里的--是两个-,markdown写的,大家不要输错
  • train_net.py是网络的训练文件,之后的参数都是附带的输入参数
  • --gpu 代表机器上的GPU编号,如果是nvidia系列的tesla显卡,可以在终端中输入nvidia-smi来查看当前的显卡负荷,选择合适的显卡
  • --solver 代表模型的配置文件,train.prototxt的文件路径已经包含在这个文件之中
  • --weights 代表初始化的权重文件,这里用的是Imagenet上预训练好的模型,中型的网络我们选择用VGG_CNN_M_1024.v2.caffemodel
  • --imdb 这里给出的训练的数据库名字需要在factory.py的__sets中,我在文件里面有__sets['KakouTrain'],train_net.py这个文件会调用factory.py再生成kakou这个类,来读取数据

4.启动Fast RCNN网络检测

我修改了tools下面的demo.py这个文件,用来做检测,并且将检测的坐标结果输出到相应的txt文件中
可以看到原始的demo.py 是用网络测试了两张图像,并做可视化输出,有具体的检测效果,但是我是在Linux服务器的终端下,没有display device,因此部分代码要少做修改

下面是原始的demo.py:

#!/usr/bin/env python# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------"""
Demo script showing detections in sample images.See README.md for installation instructions before running.
"""import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from utils.cython_nms import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparseCLASSES = ('__background__','aeroplane', 'bicycle', 'bird', 'boat','bottle', 'bus', 'car', 'cat', 'chair','cow', 'diningtable', 'dog', 'horse','motorbike', 'person', 'pottedplant','sheep', 'sofa', 'train', 'tvmonitor')NETS = {'vgg16': ('VGG16','vgg16_fast_rcnn_iter_40000.caffemodel'),'vgg_cnn_m_1024': ('VGG_CNN_M_1024','vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel'),'caffenet': ('CaffeNet','caffenet_fast_rcnn_iter_40000.caffemodel')}def vis_detections(im, class_name, dets, thresh=0.5):"""Draw detected bounding boxes."""inds = np.where(dets[:, -1] >= thresh)[0]if len(inds) == 0:returnim = im[:, :, (2, 1, 0)]fig, ax = plt.subplots(figsize=(12, 12))ax.imshow(im, aspect='equal')for i in inds:bbox = dets[i, :4]score = dets[i, -1]ax.add_patch(plt.Rectangle((bbox[0], bbox[1]),bbox[2] - bbox[0],bbox[3] - bbox[1], fill=False,edgecolor='red', linewidth=3.5))ax.text(bbox[0], bbox[1] - 2,'{:s} {:.3f}'.format(class_name, score),bbox=dict(facecolor='blue', alpha=0.5),fontsize=14, color='white')ax.set_title(('{} detections with ''p({} | box) >= {:.1f}').format(class_name, class_name,thresh),fontsize=14)plt.axis('off')plt.tight_layout()plt.draw()def demo(net, image_name, classes):"""Detect object classes in an image using pre-computed object proposals."""# Load pre-computed Selected Search object proposalsbox_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo',image_name + '_boxes.mat')obj_proposals = sio.loadmat(box_file)['boxes']# Load the demo imageim_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo', image_name + '.jpg')im = cv2.imread(im_file)# Detect all object classes and regress object boundstimer = Timer()timer.tic()scores, boxes = im_detect(net, im, obj_proposals)timer.toc()print ('Detection took {:.3f}s for ''{:d} object proposals').format(timer.total_time, boxes.shape[0])# Visualize detections for each classCONF_THRESH = 0.8NMS_THRESH = 0.3for cls in classes:cls_ind = CLASSES.index(cls)cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]cls_scores = scores[:, cls_ind]dets = np.hstack((cls_boxes,cls_scores[:, np.newaxis])).astype(np.float32)keep = nms(dets, NMS_THRESH)dets = dets[keep, :]print 'All {} detections with p({} | box) >= {:.1f}'.format(cls, cls,CONF_THRESH)vis_detections(im, cls, dets, thresh=CONF_THRESH)def parse_args():"""Parse input arguments."""parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',default=0, type=int)parser.add_argument('--cpu', dest='cpu_mode',help='Use CPU mode (overrides --gpu)',action='store_true')parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',choices=NETS.keys(), default='vgg16')args = parser.parse_args()return argsif __name__ == '__main__':args = parse_args()prototxt = os.path.join(cfg.ROOT_DIR, 'models', NETS[args.demo_net][0],'test.prototxt')caffemodel = os.path.join(cfg.ROOT_DIR, 'data', 'fast_rcnn_models',NETS[args.demo_net][1])if not os.path.isfile(caffemodel):raise IOError(('{:s} not found.\nDid you run ./data/script/''fetch_fast_rcnn_models.sh?').format(caffemodel))if args.cpu_mode:caffe.set_mode_cpu()else:caffe.set_mode_gpu()caffe.set_device(args.gpu_id)net = caffe.Net(prototxt, caffemodel, caffe.TEST)print '\n\nLoaded network {:s}'.format(caffemodel)print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'print 'Demo for data/demo/000004.jpg'demo(net, '000004', ('car',))print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'print 'Demo for data/demo/001551.jpg'demo(net, '001551', ('sofa', 'tvmonitor'))plt.show()

复制这个demo.py 修改成CarFaceTest.py,下面是修改后的文件
修改后的文件主要是添加了outputDetectionResult和runDetection两个函数, 添加了部分注释

#!/usr/bin/env python
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------"""
Demo script showing detections in sample images.See README.md for installation instructions before running.
"""import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from utils.cython_nms import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse#CLASSES = ('__background__','aeroplane','bicycle','bird','boat',
#       'bottle','bus','car','cat','chair','cow','diningtable','dog','horse'
#       'motorbike','person','pottedplant','sheep','sofa','train','tvmonitor')CLASSES = ('__background__','car') #需要跟自己训练的数据集中的类别一致,原来是21类的voc数据集,自己的数据集就是car和backgroundNETS = {'vgg16': ('VGG16','vgg16_fast_rcnn_iter_40000.caffemodel'),'vgg_cnn_m_1024': ('VGG_CNN_M_1024','vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel'),'vgg_cnn_m_1024_louyihang': ('VGG_CNN_M_1024_LOUYIHANG','vgg_cnn_m_1024_fast_rcnn_louyihang_iter_40000.caffemodel'),'caffenet': ('CaffeNet','caffenet_fast_rcnn_iter_40000.caffemodel'),'caffenet_louyihang':('CaffeNet_LOUYIHANG','caffenet_fast_rcnn_louyihang_iter_40000.caffemodel'),'vgg16_louyihang':('VGG16_LOUYIHANG','vgg16_fast_rcnn_louyihang_iter_40000.caffemodel')}#映射到对应的模型文件def outputDetectionResult(im, class_name, dets, thresh=0.5): #打开相应的输出文件outputFile = open('CarDetectionResult.txt')inds = np.where(dets[:,-1] >= thresh)[0]if len(inds) == 0:return
def runDetection (net, basePath, testFileName,classes):#这个函数是自己后加的,取代了demo函数,给定测试数据列表ftest = open(testFileName,'r')imageFileName = basePath+'/' + ftest.readline().strip()num = 1outputFile = open('CarDetectionResult.txt','w')while imageFileName:print imageFileNameprint 'now is ', numnum +=1imageFileBaseName = os.path.basename(imageFileName)imageFileDir = os.path.dirname(imageFileName)boxFileName = imageFileDir +'/'+imageFileBaseName.replace('.jpg','_boxes.mat')print boxFileNameobj_proposals = sio.loadmat(boxFileName)['boxes']#obj_proposals[:,2] = obj_proposals[:, 2] + obj_proposals[:, 0]#这里也需要注意,OP里面的坐标数据是否为x1y1x2y2还是x1y1wh#obj_proposals[:,3] = obj_proposals[:, 3] + obj_proposals[:, 1]im = cv2.imread(imageFileName)timer = Timer()timer.tic()scores, boxes = im_detect(net, im, obj_proposals)#检测函数timer.toc()print ('Detection took {:.3f} for ''{:d} object proposals').format(timer.total_time, boxes.shape[0])CONF_THRESH = 0.8NMS_THRESH = 0.3#NMS参数用来控制非极大值抑制for cls in classes:cls_ind = CLASSES.index(cls)cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]cls_scores = scores[:, cls_ind]dets = np.hstack((cls_boxes,cls_scores[:, np.newaxis])).astype(np.float32)keep = nms(dets, NMS_THRESH)dets = dets[keep, :]print 'All {} detections with p({} | box) >= {:.1f}'.format(cls, cls,CONF_THRESH)inds = np.where(dets[:, -1] >= CONF_THRESH)[0]print 'inds.size', inds.sizeif len(inds) != 0:outputFile.write(imageFileName+' ')outputFile.write(str(inds.size)+' ')将检测的结果写出相应的文件里for i in inds:bbox = dets[i, :4]outputFile.write(str(int(bbox[0]))+' '+ str(int(bbox[1]))+' '+ str(int(bbox[2]))+' '+ str(int(bbox[3]))+' ')outputFile.write('\n')else:outputFile.write(imageFileName +' 0' '\n')temp = ftest.readline().strip()if temp:imageFileName = basePath+'/' + tempelse:break
def vis_detections(im, class_name, dets, thresh=0.5):#这个函数需要加以说明,这个函数虽然没有用,但是我的服务器上没有输出设备"""Draw detected bounding boxes."""#因此要将部分用到显示的函数给注释掉,否则运行会报错inds = np.where(dets[:, -1] >= thresh)[0]print 'inds.shape', inds.shapeprint indsprint 'inds.size', inds.sizeif len(inds) == 0:return#im = im[:, :, (2, 1, 0)]#fig, ax = plt.subplots(figsize=(12, 12))#ax.imshow(im, aspect='equal')#for i in inds:#    bbox = dets[i, :4]#    score = dets[i, -1]#    ax.add_patch(#        plt.Rectangle((bbox[0], bbox[1]),#                      bbox[2] - bbox[0],#                      bbox[3] - bbox[1], fill=False,#                      edgecolor='red', linewidth=3.5)#        )#    ax.text(bbox[0], bbox[1] - 2,#            '{:s} {:.3f}'.format(class_name, score),#            bbox=dict(facecolor='blue', alpha=0.5),#            fontsize=14, color='white')#ax.set_title(('{} detections with '#              'p({} | box) >= {:.1f}').format(class_name, class_name,#                                              thresh),#              fontsize=14)#plt.axis('off')#plt.tight_layout()#plt.draw()def demo(net, image_name, classes):#原来的demo函数,没有修改"""Detect object classes in an image using pre-computed object proposals."""# Load pre-computed Selected Search object proposals#box_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo',image_name + '_boxes.mat')basePath='/home/chenjie/DataSet/500CarTestDataSet2'box_file = os.path.join(basePath,image_name + '_boxes.mat')obj_proposals = sio.loadmat(box_file)['boxes']# Load the demo image#im_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo', image_name + '.jpg')im_file = os.path.join(basePath, image_name + '.jpg')im = cv2.imread(im_file)# Detect all object classes and regress object boundstimer = Timer()timer.tic()scores, boxes = im_detect(net, im, obj_proposals)timer.toc()print ('Detection took {:.3f}s for ''{:d} object proposals').format(timer.total_time, boxes.shape[0])# Visualize detections for each classCONF_THRESH = 0.8NMS_THRESH = 0.3for cls in classes:cls_ind = CLASSES.index(cls)cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]cls_scores = scores[:, cls_ind]dets = np.hstack((cls_boxes,cls_scores[:, np.newaxis])).astype(np.float32)keep = nms(dets, NMS_THRESH)dets = dets[keep, :]print 'All {} detections with p({} | box) >= {:.1f}'.format(cls, cls,CONF_THRESH)vis_detections(im, cls, dets, thresh=CONF_THRESH)def parse_args():"""Parse input arguments."""parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',default=0, type=int)parser.add_argument('--cpu', dest='cpu_mode',help='Use CPU mode (overrides --gpu)',action='store_true')parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',choices=NETS.keys(), default='vgg16')args = parser.parse_args()return argsif __name__ == '__main__':args = parse_args()prototxt = os.path.join(cfg.ROOT_DIR, 'models', NETS[args.demo_net][0],'test.prototxt')#caffemodel = os.path.join(cfg.ROOT_DIR, 'data', 'fast_rcnn_models',#                          NETS[args.demo_net][1])#caffemodel = '/home/chenjie/fast-rcnn/output/default/KakouTrain/vgg16_fast_rcnn_louyihang_iter_40000.caffemodel'#caffemodel = '/home/chenjie/louyihang/fast-rcnn/output/default/KakouTrain/caffenet_fast_rcnn_louyihang_iter_40000.caffemodel'caffemodel = '/home/chenjie/fast-rcnn/output/default/KakouTrain/vgg_cnn_m_1024_fast_rcnn_louyihang_iter_40000.caffemodel'#我在这里直接指定了训练好的模型文件,训练好的模型文件是在工程根目录下的,output/default/对应的数据库名字下面if not os.path.isfile(caffemodel):raise IOError(('{:s} not found.\nDid you run ./data/script/''fetch_fast_rcnn_models.sh?').format(caffemodel))if args.cpu_mode:caffe.set_mode_cpu()else:caffe.set_mode_gpu()caffe.set_device(args.gpu_id)net = caffe.Net(prototxt, caffemodel, caffe.TEST)print '\n\nLoaded network {:s}'.format(caffemodel)#demo(net, 'Target0/000001', ('car',))#输入对应的测试图像列表,需要在同级目录下摆放同名的_boxes.mat文件,它会自动的替换后缀名!#runDetection(net, '/home/chenjie/DataSet/temptest','/home/chenjie/DataSet/temptest/Imagelist.txt',('car',))runDetection(net, '/home/chenjie/DataSet/500CarTestDataSet2','/home/chenjie/DataSet/500CarTestDataSet2/Imagelist.txt',('car',))#runDetection(net, '/home/chenjie/DataSet/Kakou_Test_Scale0.25/','/home/chenjie/DataSet/Kakou_Test_Scale0.25/imagelist.txt',('car',))#runDetection(net, '/home/chenjie/DataSet/Images_Version1_Test_Boxes','/home/chenjie/DataSet/Images_Version1_Test_Boxes/ImageList_Version1_List.txt',('car',))#plt.show()

5.检测结果

训练数据集

首先给出我的训练数据集,其实我的训练数据集并不是太复杂的

测试数据集

输出检测结果到txt文件中,

测试效果

**在复杂场景下的测试效果非常好,速度也非常快,中型网络监测平均每张在K80显卡下时0.1~0.2S左右,图像的尺寸是480*640,6000张测试数据集下达到的准确率是98%!!!**

Fast RCNN 训练自己的数据集(3训练和检测)相关推荐

  1. faster rcnn在自己的数据集上训练

    本文是一个总结,参考了网上的众多资料,汇集而成,以供自己后续参考. 一般说来,训练自己的数据,有两种方法:第一种就是将自己的数据集完全改造成VOC2007的形式,然后放到py-faster-rcnn/ ...

  2. Win10下用TensorFlow训练自己的数据集来做目标检测

    本文主要来源于YouTube博主 EdjeElectronics的视频https://www.youtube.com/watch?v=Rgpfk6eYxJA (对于视频看不了的朋友,我把视频录制下来放 ...

  3. Fast R-CNN解读:单阶段,多任务完成检测

    Title: Fast R-CNN(2015) Link: paper , code Tips: R-CNN and SPPnet 的缺点以及 Fast R-CNN 的优点(1.1) Fast R-C ...

  4. 使用yolov5训练自己的数据集(苹果成熟度检测)

    先从Github上下载YOLOv5,下载好解压配置好就可以使用,地址: https://github.com/ultralytics/yolov5 1.训练数据集的准备工作 在yolov5 目录的da ...

  5. yolov4训练自己的数据集实现安全帽佩戴检测

    昨儿立下来的flag,今天还是要含泪完成的,抓紧时间赶呀!!! 本次用yolov4来实现人佩戴安全帽检测,若未佩戴安全帽则将人脸框出来,若佩戴安全帽,则将安全帽以及人脸框出来,多说无益,直接看效果吧! ...

  6. gluoncv 训练自己的数据集,进行目标检测

    跑了一晚上的模型,实在占GPU资源,这两天已经有很多小朋友说我了.我选择了其中一个参数. https://github.com/dmlc/gluon-cv/blob/master/scripts/de ...

  7. YOLOv3 Darknet安装编译与训练自己的数据集

    YOLOv3 Darknet安装编译与训练自己的数据集 文章目录: 1安装编译darknet与测试darknet 1.1 安装编译darknet 1.1.1 下载darknet安装包 1.1.2 编译 ...

  8. internetreadfile读取数据长度为0_YOLOV3的TensorFlow2.0实现,支持在自己的数据集上训练...

    GitHub链接: calmisential/YOLOv3_TensorFlow2​github.com 我主要参考了yolov3的一个keras实现版本: qqwweee/keras-yolo3​g ...

  9. YOLO训练自己的数据集的一些心得

    YOLO训练自己的数据集 YOLO-darknet训练自己的数据 [Darknet][yolo v2]训练自己数据集的一些心得----VOC格式 YOLO模型训练可视化训练过程中的中间参数 项目开源代 ...

  10. 目标检测系列(四)——Fast R-CNN译文

    文章目录 摘要 1. 引言 1.1 R-CNN和SPPnet 1.2 本文贡献点 2. Fast R-CNN的框架和训练过程 2.1 RoI pooling层 2.2 从预训练网络初始化 2.3 针对 ...

最新文章

  1. 批处理命令——goto 和 :
  2. Session 时间
  3. cf559C. Gerald and Giant Chess(容斥原理)
  4. Eclipse下配置struts很让我痛苦
  5. pthread多线程编程的学习小结
  6. MySQL中的整数类型
  7. P4317-花神的数论题【组合数学】
  8. 揭秘 TI-ONE 大型赛事支撑能力是如何练就的
  9. Zabbix实战-简易教程--拓扑图(Maps)
  10. zendstudio for mac破解版
  11. 熬了几个通宵,我写了份CUDA新手入门代码
  12. python 最优解 ma_python遗传算法求最优解
  13. android后台定时执行任务,后台执行的定时任务
  14. 领导力培训知识点汇总及感悟-1
  15. iPad——添加学校邮箱到邮件解决方案
  16. 瀚高数据库debug问题
  17. python之路金角大王_python 之路,致那些年,我们依然没搞明白的编码
  18. 经典排序算法动图图解
  19. 导览讲解 | 二维码语音讲解有什么优势?
  20. 迎接“119”消防宣传日|助力企事业单位实现消防安全,前景互联在行动

热门文章

  1. java Android版_java-Android系统版本
  2. java 读取html模板文件_Java根据html模板创建 html文件
  3. leetcode算法题--不同路径
  4. java一定时间间隔的定时任务_Java 定时任务---Timer
  5. php for循环可以变量关联数组,数组字符串变量之间的转换,数组元素的回调,for循环遍历关联数组...
  6. Windows CE 程序设计 (3rd 版)
  7. 数制之非十进制数转十进制数
  8. Redis设计与实现 第一部分
  9. 使用ifconfig取出网卡eth0的ip地址-看看你有多少方法 ?
  10. cefsharp wpf 中文输入问题解决方法