转载自:faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py - 野孩子的专栏 - 博客频道 - CSDN.NET

http://blog.csdn.net/u010668907/article/details/51945320

faster用python版本的https://github.com/rbgirshick/py-faster-rcnn

train_faster_rcnn_alt_opt.py源码在https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/train_faster_rcnn_alt_opt.py

faster rcnn训练的开始是:faster_rcnn_alt_opt.sh。下面命令是训练的,还有它的参数说明。

1.调用最初脚本的说明

cd $FRCN_ROOT

# ./experiments/scripts/faster_rcnn_alt_opt.sh  GPU  NET  DATASET [options args to {train,test}_net.py]

# GPU_ID is the GPU you want to train on

# NET in {ZF, VGG_CNN_M_1024, VGG16} is the network arch to use

# DATASET is only pascal_voc for now

train_faster_rcnn_alt_opt.py的源码:

[python] view plaincopy print?
  1. #!/usr/bin/env python
  2. # --------------------------------------------------------
  3. # Faster R-CNN
  4. # Copyright (c) 2015 Microsoft
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # Written by Ross Girshick
  7. # --------------------------------------------------------
  8. """Train a Faster R-CNN network using alternating optimization.
  9. This tool implements the alternating optimization algorithm described in our
  10. NIPS 2015 paper ("Faster R-CNN: Towards Real-time Object Detection with Region
  11. Proposal Networks." Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun.)
  12. """
  13. import _init_paths
  14. from fast_rcnn.train import get_training_roidb, train_net
  15. from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
  16. from datasets.factory import get_imdb
  17. from rpn.generate import imdb_proposals
  18. import argparse
  19. import pprint
  20. import numpy as np
  21. import sys, os
  22. import multiprocessing as mp
  23. import cPickle
  24. import shutil
  25. def parse_args():
  26. """
  27. Parse input arguments
  28. """
  29. parser = argparse.ArgumentParser(description='Train a Faster R-CNN network')
  30. parser.add_argument('--gpu', dest='gpu_id',
  31. help='GPU device id to use [0]',
  32. default=0, type=int)
  33. parser.add_argument('--net_name', dest='net_name',
  34. help='network name (e.g., "ZF")',
  35. default=None, type=str)
  36. parser.add_argument('--weights', dest='pretrained_model',
  37. help='initialize with pretrained model weights',
  38. default=None, type=str)
  39. parser.add_argument('--cfg', dest='cfg_file',
  40. help='optional config file',
  41. default=None, type=str)
  42. parser.add_argument('--imdb', dest='imdb_name',
  43. help='dataset to train on',
  44. default='voc_2007_trainval', type=str)
  45. parser.add_argument('--set', dest='set_cfgs',
  46. help='set config keys', default=None,
  47. nargs=argparse.REMAINDER)
  48. if len(sys.argv) == 1:
  49. parser.print_help()
  50. sys.exit(1)
  51. args = parser.parse_args()
  52. return args
  53. def get_roidb(imdb_name, rpn_file=None):
  54. imdb = get_imdb(imdb_name)
  55. print 'Loaded dataset `{:s}` for training'.format(imdb.name)
  56. imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
  57. print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
  58. if rpn_file is not None:
  59. imdb.config['rpn_file'] = rpn_file
  60. roidb = get_training_roidb(imdb)
  61. return roidb, imdb
  62. def get_solvers(net_name):
  63. # Faster R-CNN Alternating Optimization
  64. n = 'faster_rcnn_alt_opt'
  65. # Solver for each training stage
  66. solvers = [[net_name, n, 'stage1_rpn_solver60k80k.pt'],
  67. [net_name, n, 'stage1_fast_rcnn_solver30k40k.pt'],
  68. [net_name, n, 'stage2_rpn_solver60k80k.pt'],
  69. [net_name, n, 'stage2_fast_rcnn_solver30k40k.pt']]
  70. solvers = [os.path.join(cfg.MODELS_DIR, *s) for s in solvers]
  71. # Iterations for each training stage
  72. max_iters = [80000, 40000, 80000, 40000]
  73. # max_iters = [100, 100, 100, 100]
  74. # Test prototxt for the RPN
  75. rpn_test_prototxt = os.path.join(
  76. cfg.MODELS_DIR, net_name, n, 'rpn_test.pt')
  77. return solvers, max_iters, rpn_test_prototxt
  78. # ------------------------------------------------------------------------------
  79. # Pycaffe doesn't reliably free GPU memory when instantiated nets are discarded
  80. # (e.g. "del net" in Python code). To work around this issue, each training
  81. # stage is executed in a separate process using multiprocessing.Process.
  82. # ------------------------------------------------------------------------------
  83. def _init_caffe(cfg):
  84. """Initialize pycaffe in a training process.
  85. """
  86. import caffe
  87. # fix the random seeds (numpy and caffe) for reproducibility
  88. np.random.seed(cfg.RNG_SEED)
  89. caffe.set_random_seed(cfg.RNG_SEED)
  90. # set up caffe
  91. caffe.set_mode_gpu()
  92. caffe.set_device(cfg.GPU_ID)
  93. def train_rpn(queue=None, imdb_name=None, init_model=None, solver=None,
  94. max_iters=None, cfg=None):
  95. """Train a Region Proposal Network in a separate training process.
  96. """
  97. # Not using any proposals, just ground-truth boxes
  98. cfg.TRAIN.HAS_RPN = True
  99. cfg.TRAIN.BBOX_REG = False  # applies only to Fast R-CNN bbox regression
  100. cfg.TRAIN.PROPOSAL_METHOD = 'gt'
  101. cfg.TRAIN.IMS_PER_BATCH = 1
  102. print 'Init model: {}'.format(init_model)
  103. print('Using config:')
  104. pprint.pprint(cfg)
  105. import caffe
  106. _init_caffe(cfg)
  107. roidb, imdb = get_roidb(imdb_name)
  108. print 'roidb len: {}'.format(len(roidb))
  109. output_dir = get_output_dir(imdb)
  110. print 'Output will be saved to `{:s}`'.format(output_dir)
  111. model_paths = train_net(solver, roidb, output_dir,
  112. pretrained_model=init_model,
  113. max_iters=max_iters)
  114. # Cleanup all but the final model
  115. for i in model_paths[:-1]:
  116. os.remove(i)
  117. rpn_model_path = model_paths[-1]
  118. # Send final model path through the multiprocessing queue
  119. queue.put({'model_path': rpn_model_path})
  120. def rpn_generate(queue=None, imdb_name=None, rpn_model_path=None, cfg=None,
  121. rpn_test_prototxt=None):
  122. """Use a trained RPN to generate proposals.
  123. """
  124. cfg.TEST.RPN_PRE_NMS_TOP_N = -1     # no pre NMS filtering
  125. cfg.TEST.RPN_POST_NMS_TOP_N = 2000  # limit top boxes after NMS
  126. print 'RPN model: {}'.format(rpn_model_path)
  127. print('Using config:')
  128. pprint.pprint(cfg)
  129. import caffe
  130. _init_caffe(cfg)
  131. # NOTE: the matlab implementation computes proposals on flipped images, too.
  132. # We compute them on the image once and then flip the already computed
  133. # proposals. This might cause a minor loss in mAP (less proposal jittering).
  134. imdb = get_imdb(imdb_name)
  135. print 'Loaded dataset `{:s}` for proposal generation'.format(imdb.name)
  136. # Load RPN and configure output directory
  137. rpn_net = caffe.Net(rpn_test_prototxt, rpn_model_path, caffe.TEST)
  138. output_dir = get_output_dir(imdb)
  139. print 'Output will be saved to `{:s}`'.format(output_dir)
  140. # Generate proposals on the imdb
  141. rpn_proposals = imdb_proposals(rpn_net, imdb)
  142. # Write proposals to disk and send the proposal file path through the
  143. # multiprocessing queue
  144. rpn_net_name = os.path.splitext(os.path.basename(rpn_model_path))[0]
  145. rpn_proposals_path = os.path.join(
  146. output_dir, rpn_net_name + '_proposals.pkl')
  147. with open(rpn_proposals_path, 'wb') as f:
  148. cPickle.dump(rpn_proposals, f, cPickle.HIGHEST_PROTOCOL)
  149. print 'Wrote RPN proposals to {}'.format(rpn_proposals_path)
  150. queue.put({'proposal_path': rpn_proposals_path})
  151. def train_fast_rcnn(queue=None, imdb_name=None, init_model=None, solver=None,
  152. max_iters=None, cfg=None, rpn_file=None):
  153. """Train a Fast R-CNN using proposals generated by an RPN.
  154. """
  155. cfg.TRAIN.HAS_RPN = False           # not generating prosals on-the-fly
  156. cfg.TRAIN.PROPOSAL_METHOD = 'rpn'   # use pre-computed RPN proposals instead
  157. cfg.TRAIN.IMS_PER_BATCH = 2
  158. print 'Init model: {}'.format(init_model)
  159. print 'RPN proposals: {}'.format(rpn_file)
  160. print('Using config:')
  161. pprint.pprint(cfg)
  162. import caffe
  163. _init_caffe(cfg)
  164. roidb, imdb = get_roidb(imdb_name, rpn_file=rpn_file)
  165. output_dir = get_output_dir(imdb)
  166. print 'Output will be saved to `{:s}`'.format(output_dir)
  167. # Train Fast R-CNN
  168. model_paths = train_net(solver, roidb, output_dir,
  169. pretrained_model=init_model,
  170. max_iters=max_iters)
  171. # Cleanup all but the final model
  172. for i in model_paths[:-1]:
  173. os.remove(i)
  174. fast_rcnn_model_path = model_paths[-1]
  175. # Send Fast R-CNN model path over the multiprocessing queue
  176. queue.put({'model_path': fast_rcnn_model_path})
  177. if __name__ == '__main__':
  178. args = parse_args()
  179. print('Called with args:')
  180. print(args)
  181. if args.cfg_file is not None:
  182. cfg_from_file(args.cfg_file)
  183. if args.set_cfgs is not None:
  184. cfg_from_list(args.set_cfgs)
  185. cfg.GPU_ID = args.gpu_id
  186. # --------------------------------------------------------------------------
  187. # Pycaffe doesn't reliably free GPU memory when instantiated nets are
  188. # discarded (e.g. "del net" in Python code). To work around this issue, each
  189. # training stage is executed in a separate process using
  190. # multiprocessing.Process.
  191. # --------------------------------------------------------------------------
  192. # queue for communicated results between processes
  193. mp_queue = mp.Queue()
  194. # solves, iters, etc. for each training stage
  195. solvers, max_iters, rpn_test_prototxt = get_solvers(args.net_name)
  196. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  197. print 'Stage 1 RPN, init from ImageNet model'
  198. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  199. cfg.TRAIN.SNAPSHOT_INFIX = 'stage1'
  200. mp_kwargs = dict(
  201. queue=mp_queue,
  202. imdb_name=args.imdb_name,
  203. init_model=args.pretrained_model,
  204. solver=solvers[0],
  205. max_iters=max_iters[0],
  206. cfg=cfg)
  207. p = mp.Process(target=train_rpn, kwargs=mp_kwargs)
  208. p.start()
  209. rpn_stage1_out = mp_queue.get()
  210. p.join()
  211. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  212. print 'Stage 1 RPN, generate proposals'
  213. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  214. mp_kwargs = dict(
  215. queue=mp_queue,
  216. imdb_name=args.imdb_name,
  217. rpn_model_path=str(rpn_stage1_out['model_path']),
  218. cfg=cfg,
  219. rpn_test_prototxt=rpn_test_prototxt)
  220. p = mp.Process(target=rpn_generate, kwargs=mp_kwargs)
  221. p.start()
  222. rpn_stage1_out['proposal_path'] = mp_queue.get()['proposal_path']
  223. p.join()
  224. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  225. print 'Stage 1 Fast R-CNN using RPN proposals, init from ImageNet model'
  226. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  227. cfg.TRAIN.SNAPSHOT_INFIX = 'stage1'
  228. mp_kwargs = dict(
  229. queue=mp_queue,
  230. imdb_name=args.imdb_name,
  231. init_model=args.pretrained_model,
  232. solver=solvers[1],
  233. max_iters=max_iters[1],
  234. cfg=cfg,
  235. rpn_file=rpn_stage1_out['proposal_path'])
  236. p = mp.Process(target=train_fast_rcnn, kwargs=mp_kwargs)
  237. p.start()
  238. fast_rcnn_stage1_out = mp_queue.get()
  239. p.join()
  240. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  241. print 'Stage 2 RPN, init from stage 1 Fast R-CNN model'
  242. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  243. cfg.TRAIN.SNAPSHOT_INFIX = 'stage2'
  244. mp_kwargs = dict(
  245. queue=mp_queue,
  246. imdb_name=args.imdb_name,
  247. init_model=str(fast_rcnn_stage1_out['model_path']),
  248. solver=solvers[2],
  249. max_iters=max_iters[2],
  250. cfg=cfg)
  251. p = mp.Process(target=train_rpn, kwargs=mp_kwargs)
  252. p.start()
  253. rpn_stage2_out = mp_queue.get()
  254. p.join()
  255. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  256. print 'Stage 2 RPN, generate proposals'
  257. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  258. mp_kwargs = dict(
  259. queue=mp_queue,
  260. imdb_name=args.imdb_name,
  261. rpn_model_path=str(rpn_stage2_out['model_path']),
  262. cfg=cfg,
  263. rpn_test_prototxt=rpn_test_prototxt)
  264. p = mp.Process(target=rpn_generate, kwargs=mp_kwargs)
  265. p.start()
  266. rpn_stage2_out['proposal_path'] = mp_queue.get()['proposal_path']
  267. p.join()
  268. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  269. print 'Stage 2 Fast R-CNN, init from stage 2 RPN R-CNN model'
  270. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  271. cfg.TRAIN.SNAPSHOT_INFIX = 'stage2'
  272. mp_kwargs = dict(
  273. queue=mp_queue,
  274. imdb_name=args.imdb_name,
  275. init_model=str(rpn_stage2_out['model_path']),
  276. solver=solvers[3],
  277. max_iters=max_iters[3],
  278. cfg=cfg,
  279. rpn_file=rpn_stage2_out['proposal_path'])
  280. p = mp.Process(target=train_fast_rcnn, kwargs=mp_kwargs)
  281. p.start()
  282. fast_rcnn_stage2_out = mp_queue.get()
  283. p.join()
  284. # Create final model (just a copy of the last stage)
  285. final_path = os.path.join(
  286. os.path.dirname(fast_rcnn_stage2_out['model_path']),
  287. args.net_name + '_faster_rcnn_final.caffemodel')
  288. print 'cp {} -> {}'.format(
  289. fast_rcnn_stage2_out['model_path'], final_path)
  290. shutil.copy(fast_rcnn_stage2_out['model_path'], final_path)
  291. print 'Final model: {}'.format(final_path)

2. train_faster_rcnn_alt_opt.py的部分参数说明

net_name:      {ZF, VGG_CNN_M_1024, VGG16}

pretrained_model:      data/imagenet_models/${net_name}.v2.caffemodel

cfg_file:     experiments/cfgs/faster_rcnn_alt_opt.yml

imdb_name:     "voc_2007_trainval" or "voc_2007_test"

cfg.TRAIN.HAS_RPN = True表示用xml提供的propoal

cfg是配置文件,它的默认值放在上面的cfg_file里,其他还可以自己写配置文件之后与默认配置文件融合。

2.1 net_name是用get_solvers()找到网络。还要用到cfg的参数MODELS_DIR,

例子是join(MODELS_DIR, net_name, 'faster_rcnn_alt_opt', 'stage1_rpn_solver60k80k.pt')

2.2 imdb_name在factory中被拆成‘2007’(year)和‘trainval’/‘test’(split)到类pascal_voc中产生相应的imdb

2.3 整个step的大致流程:

(ImageNet model)->stage1_rpn_train->rpn_test

|(proposal_path)

(ImageNetmodel)->stage1_fast_rcnn_train-> stage2_rpn_train-> rpn_test-> stage2_fast_rcnn_train

2.4 数据imdb和roidb

roidb原本是imdb的一个属性,但imdb其实是为了计算roidb存在的,他所有的其他属性和方法都是为了计算roidb

faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py相关推荐

  1. faster rcnn源码解读(四)之数据类型imdb.py和pascal_voc.py(主要是imdb和roidb数据类型的解说)

    转载自:faster rcnn源码解读(四)之数据类型imdb.py和pascal_voc.py(主要是imdb和roidb数据类型的解说) - 野孩子的专栏 - 博客频道 - CSDN.NET ht ...

  2. faster rcnn源码解读(六)之minibatch

    转载自:faster rcnn源码解读(六)之minibatch - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u010668907/article/ ...

  3. faster rcnn源码解读(五)之layer(网络里的input-data)

    转载自:faster rcnn源码解读(五)之layer(网络里的input-data) - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u010668 ...

  4. faster rcnn源码解读总结

    转载自:faster rcnn源码解读总结 - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u010668907/article/details/519 ...

  5. 还不懂目标检测嘛?一起来看看Faster R-CNN源码解读

  6. 【Faster R-CNN论文精度系列】从Faster R-CNN源码中,我们“学习”到了什么?

    [Faster R-CNN论文精度系列] (如下为建议阅读顺序) 1[Faster R-CNN论文精度系列]从Faster R-CNN源码中,我们"学习"到了什么? 2[Faste ...

  7. Mask RCNN源码解读

    Mask RCNN源码解读 前言 数据集 数据载入 模型搭建 模型输入 模型输出 resnet101 RPN网络 ProposalLayer DetectionTargetLayer fpn_clas ...

  8. faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data)

    转载自:faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data) - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.n ...

  9. Faster R-CNN源码中RPN的解析(自用)

    参考博客(一定要看前面两个) 一文看懂Faster R-CNN 详细的Faster R-CNN源码解析之RPN源码解析 关于RPN一些我的想法 rpn的中心思想就是在了anchors了,如何产生anc ...

最新文章

  1. jndi mysql数据库_数据库连接池技术中dbcp、c3p0、jndi
  2. C语言开发笔记(八)static
  3. asp.net 根据当前时间计算是否股票、期货、黄金交易日期
  4. SQL Server中常用全局变量介绍
  5. android内存占用分析,Android App性能评测分析-内存篇
  6. 关于推荐系统中的冷启动问题探讨(Approaching the Cold Start Problem in Recommender Systems)...
  7. 工具使用——印象笔记(5)
  8. linux安装google浏览器
  9. 古罗马帝国莱茵河-多瑙河防线之谜
  10. 设计模式回顾——模板模式(C++)
  11. Linux实战(20):Docker部署EKL入门环境记录文档
  12. 极限、可导、可微、连续之间的关系
  13. 【开发工具下载汇总】
  14. 视频怎么水平翻转画面并做锐化处理?
  15. 【Python】 标准差计算(std)
  16. 正则表达式对密码限定格式:必须包含英文,数字,字符且密码长度大于8位
  17. 当代世界顶级彩铅牛人的画作,每一副仿佛照片一样逼真!
  18. 前程无忧签订合并协议:交易对价降至61美元/股,相对降幅约30%
  19. 经济不确定环境下,制造业的数字化转型之道
  20. 十天 教你从创意到上线APP

热门文章

  1. 剖析Elasticsearch集群系列第三篇 近实时搜索、深层分页问题和搜索相关性权衡之道...
  2. php和java的memcached使用的兼容性问题解决过程
  3. quartz源码分析之深刻理解job,sheduler,calendar,trigger及listener之间的关系
  4. 金融风控实战——生肖属性单变量分析
  5. 解析金融反欺诈技术的应用与实践
  6. insightface和facenet效果+性能比较
  7. python pandas dataframe 行列选择,切片操作 原创 2017年02月15日 21:43:18 标签: python 30760 python pandas dataframe
  8. 微软创立全新人工智能实验室,与DeepMind、OpenAI同台竞技
  9. 简单有趣的 NLP 教程:手把手教你用 PyTorch 辨别自然语言(附代码)
  10. 乐视姓孙还是姓贾?反正我不知道