转载自:http://blog.csdn.net/u010668907/article/details/51439503

faster rcnn用Python版本https://github.com/rbgirshick/py-faster-rcnn

以demo.py中默认网络VGG16.

原本demo.py地址https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/demo.py

图有点多,贴一个图的本分结果出来:

上图是原图,下面第一张是网络中命名为“conv1_1”的结果图;第二张是命名为“rpn_cls_prob_reshape”的结果图;第三张是“rpnoutput”的结果图

看一下我修改后的代码:

[python] view plaincopyprint?
  1. #!/usr/bin/env python
  2. # --------------------------------------------------------
  3. # Faster R-CNN
  4. # Copyright (c) 2015 Microsoft
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # Written by Ross Girshick
  7. # --------------------------------------------------------
  8. """
  9. Demo script showing detections in sample images.
  10. See README.md for installation instructions before running.
  11. """
  12. import _init_paths
  13. from fast_rcnn.config import cfg
  14. from fast_rcnn.test import im_detect
  15. from fast_rcnn.nms_wrapper import nms
  16. from utils.timer import Timer
  17. import matplotlib.pyplot as plt
  18. import numpy as np
  19. import scipy.io as sio
  20. import caffe, os, sys, cv2
  21. import argparse
  22. import math
  23. CLASSES = ('__background__',
  24. 'aeroplane', 'bicycle', 'bird', 'boat',
  25. 'bottle', 'bus', 'car', 'cat', 'chair',
  26. 'cow', 'diningtable', 'dog', 'horse',
  27. 'motorbike', 'person', 'pottedplant',
  28. 'sheep', 'sofa', 'train', 'tvmonitor')
  29. NETS = {'vgg16': ('VGG16',
  30. 'VGG16_faster_rcnn_final.caffemodel'),
  31. 'zf': ('ZF',
  32. 'ZF_faster_rcnn_final.caffemodel')}
  33. def vis_detections(im, class_name, dets, thresh=0.5):
  34. """Draw detected bounding boxes."""
  35. inds = np.where(dets[:, -1] >= thresh)[0]
  36. if len(inds) == 0:
  37. return
  38. im = im[:, :, (2, 1, 0)]
  39. fig, ax = plt.subplots(figsize=(12, 12))
  40. ax.imshow(im, aspect='equal')
  41. for i in inds:
  42. bbox = dets[i, :4]
  43. score = dets[i, -1]
  44. ax.add_patch(
  45. plt.Rectangle((bbox[0], bbox[1]),
  46. bbox[2] - bbox[0],
  47. bbox[3] - bbox[1], fill=False,
  48. edgecolor='red', linewidth=3.5)
  49. )
  50. ax.text(bbox[0], bbox[1] - 2,
  51. '{:s} {:.3f}'.format(class_name, score),
  52. bbox=dict(facecolor='blue', alpha=0.5),
  53. fontsize=14, color='white')
  54. ax.set_title(('{} detections with '
  55. 'p({} | box) >= {:.1f}').format(class_name, class_name,
  56. thresh),
  57. fontsize=14)
  58. plt.axis('off')
  59. plt.tight_layout()
  60. #plt.draw()
  61. def save_feature_picture(data, name, image_name=None, padsize = 1, padval = 1):
  62. data = data[0]
  63. #print "data.shape1: ", data.shape
  64. n = int(np.ceil(np.sqrt(data.shape[0])))
  65. padding = ((0, n ** 2 - data.shape[0]), (0, 0), (0, padsize)) + ((0, 0),) * (data.ndim - 3)
  66. #print "padding: ", padding
  67. data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))
  68. #print "data.shape2: ", data.shape
  69. data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
  70. #print "data.shape3: ", data.shape, n
  71. data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
  72. #print "data.shape4: ", data.shape
  73. plt.figure()
  74. plt.imshow(data,cmap='gray')
  75. plt.axis('off')
  76. #plt.show()
  77. if image_name == None:
  78. img_path = './data/feature_picture/'
  79. else:
  80. img_path = './data/feature_picture/' + image_name + "/"
  81. check_file(img_path)
  82. plt.savefig(img_path + name + ".jpg", dpi = 400, bbox_inches = "tight")
  83. def check_file(path):
  84. if not os.path.exists(path):
  85. os.mkdir(path)
  86. def demo(net, image_name):
  87. """Detect object classes in an image using pre-computed object proposals."""
  88. # Load the demo image
  89. im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
  90. im = cv2.imread(im_file)
  91. # Detect all object classes and regress object bounds
  92. timer = Timer()
  93. timer.tic()
  94. scores, boxes = im_detect(net, im)
  95. for k, v in net.blobs.items():
  96. if k.find("conv")>-1 or k.find("pool")>-1 or k.find("rpn")>-1:
  97. save_feature_picture(v.data, k.replace("/", ""), image_name)#net.blobs["conv1_1"].data, "conv1_1")
  98. timer.toc()
  99. print ('Detection took {:.3f}s for '
  100. '{:d} object proposals').format(timer.total_time, boxes.shape[0])
  101. # Visualize detections for each class
  102. CONF_THRESH = 0.8
  103. NMS_THRESH = 0.3
  104. for cls_ind, cls in enumerate(CLASSES[1:]):
  105. cls_ind += 1 # because we skipped background
  106. cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
  107. cls_scores = scores[:, cls_ind]
  108. dets = np.hstack((cls_boxes,
  109. cls_scores[:, np.newaxis])).astype(np.float32)
  110. keep = nms(dets, NMS_THRESH)
  111. dets = dets[keep, :]
  112. vis_detections(im, cls, dets, thresh=CONF_THRESH)
  113. def parse_args():
  114. """Parse input arguments."""
  115. parser = argparse.ArgumentParser(description='Faster R-CNN demo')
  116. parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
  117. default=0, type=int)
  118. parser.add_argument('--cpu', dest='cpu_mode',
  119. help='Use CPU mode (overrides --gpu)',
  120. action='store_true')
  121. parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
  122. choices=NETS.keys(), default='vgg16')
  123. args = parser.parse_args()
  124. return args
  125. def print_param(net):
  126. for k, v in net.blobs.items():
  127. print (k, v.data.shape)
  128. print ""
  129. for k, v in net.params.items():
  130. print (k, v[0].data.shape)
  131. if __name__ == '__main__':
  132. cfg.TEST.HAS_RPN = True  # Use RPN for proposals
  133. args = parse_args()
  134. prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],
  135. 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')
  136. #print "prototxt: ", prototxt
  137. caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',
  138. NETS[args.demo_net][1])
  139. if not os.path.isfile(caffemodel):
  140. raise IOError(('{:s} not found.\nDid you run ./data/script/'
  141. 'fetch_faster_rcnn_models.sh?').format(caffemodel))
  142. if args.cpu_mode:
  143. caffe.set_mode_cpu()
  144. else:
  145. caffe.set_mode_gpu()
  146. caffe.set_device(args.gpu_id)
  147. cfg.GPU_ID = args.gpu_id
  148. net = caffe.Net(prototxt, caffemodel, caffe.TEST)
  149. #print_param(net)
  150. print '\n\nLoaded network {:s}'.format(caffemodel)
  151. # Warmup on a dummy image
  152. im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
  153. for i in xrange(2):
  154. _, _= im_detect(net, im)
  155. im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
  156. '001763.jpg', '004545.jpg']
  157. for im_name in im_names:
  158. print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
  159. print 'Demo for data/demo/{}'.format(im_name)
  160. demo(net, im_name)
  161. #plt.show()
#!/usr/bin/env python# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------"""
Demo script showing detections in sample images.See README.md for installation instructions before running.
"""import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse
import mathCLASSES = ('__background__','aeroplane', 'bicycle', 'bird', 'boat','bottle', 'bus', 'car', 'cat', 'chair','cow', 'diningtable', 'dog', 'horse','motorbike', 'person', 'pottedplant','sheep', 'sofa', 'train', 'tvmonitor')NETS = {'vgg16': ('VGG16','VGG16_faster_rcnn_final.caffemodel'),'zf': ('ZF','ZF_faster_rcnn_final.caffemodel')}def vis_detections(im, class_name, dets, thresh=0.5):"""Draw detected bounding boxes."""inds = np.where(dets[:, -1] >= thresh)[0]if len(inds) == 0:returnim = im[:, :, (2, 1, 0)]fig, ax = plt.subplots(figsize=(12, 12))ax.imshow(im, aspect='equal')for i in inds:bbox = dets[i, :4]score = dets[i, -1]ax.add_patch(plt.Rectangle((bbox[0], bbox[1]),bbox[2] - bbox[0],bbox[3] - bbox[1], fill=False,edgecolor='red', linewidth=3.5))ax.text(bbox[0], bbox[1] - 2,'{:s} {:.3f}'.format(class_name, score),bbox=dict(facecolor='blue', alpha=0.5),fontsize=14, color='white')ax.set_title(('{} detections with ''p({} | box) >= {:.1f}').format(class_name, class_name,thresh),fontsize=14)plt.axis('off')plt.tight_layout()#plt.draw()
def save_feature_picture(data, name, image_name=None, padsize = 1, padval = 1):data = data[0]#print "data.shape1: ", data.shapen = int(np.ceil(np.sqrt(data.shape[0])))padding = ((0, n ** 2 - data.shape[0]), (0, 0), (0, padsize)) + ((0, 0),) * (data.ndim - 3)#print "padding: ", paddingdata = np.pad(data, padding, mode='constant', constant_values=(padval, padval))#print "data.shape2: ", data.shapedata = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))#print "data.shape3: ", data.shape, ndata = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])#print "data.shape4: ", data.shapeplt.figure()plt.imshow(data,cmap='gray')plt.axis('off')#plt.show()if image_name == None:img_path = './data/feature_picture/' else:img_path = './data/feature_picture/' + image_name + "/"check_file(img_path)plt.savefig(img_path + name + ".jpg", dpi = 400, bbox_inches = "tight")
def check_file(path):if not os.path.exists(path):os.mkdir(path)
def demo(net, image_name):"""Detect object classes in an image using pre-computed object proposals."""# Load the demo imageim_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)im = cv2.imread(im_file)# Detect all object classes and regress object boundstimer = Timer()timer.tic()scores, boxes = im_detect(net, im)for k, v in net.blobs.items():if k.find("conv")>-1 or k.find("pool")>-1 or k.find("rpn")>-1:save_feature_picture(v.data, k.replace("/", ""), image_name)#net.blobs["conv1_1"].data, "conv1_1") timer.toc()print ('Detection took {:.3f}s for ''{:d} object proposals').format(timer.total_time, boxes.shape[0])# Visualize detections for each classCONF_THRESH = 0.8NMS_THRESH = 0.3for cls_ind, cls in enumerate(CLASSES[1:]):cls_ind += 1 # because we skipped backgroundcls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]cls_scores = scores[:, cls_ind]dets = np.hstack((cls_boxes,cls_scores[:, np.newaxis])).astype(np.float32)keep = nms(dets, NMS_THRESH)dets = dets[keep, :]vis_detections(im, cls, dets, thresh=CONF_THRESH)def parse_args():"""Parse input arguments."""parser = argparse.ArgumentParser(description='Faster R-CNN demo')parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',default=0, type=int)parser.add_argument('--cpu', dest='cpu_mode',help='Use CPU mode (overrides --gpu)',action='store_true')parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',choices=NETS.keys(), default='vgg16')args = parser.parse_args()return argsdef print_param(net):for k, v in net.blobs.items():print (k, v.data.shape)print ""for k, v in net.params.items():print (k, v[0].data.shape)  if __name__ == '__main__':cfg.TEST.HAS_RPN = True  # Use RPN for proposalsargs = parse_args()prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')#print "prototxt: ", prototxtcaffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',NETS[args.demo_net][1])if not os.path.isfile(caffemodel):raise IOError(('{:s} not found.\nDid you run ./data/script/''fetch_faster_rcnn_models.sh?').format(caffemodel))if args.cpu_mode:caffe.set_mode_cpu()else:caffe.set_mode_gpu()caffe.set_device(args.gpu_id)cfg.GPU_ID = args.gpu_idnet = caffe.Net(prototxt, caffemodel, caffe.TEST)#print_param(net)print '\n\nLoaded network {:s}'.format(caffemodel)# Warmup on a dummy imageim = 128 * np.ones((300, 500, 3), dtype=np.uint8)for i in xrange(2):_, _= im_detect(net, im)im_names = ['000456.jpg', '000542.jpg', '001150.jpg','001763.jpg', '004545.jpg']for im_name in im_names:print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'print 'Demo for data/demo/{}'.format(im_name)demo(net, im_name)#plt.show()

1.在data下手动创建“feature_picture”文件夹就可以替换原来的demo使用了。

2.上面代码主要添加方法是:save_feature_picture,它会对网络测试的某些阶段的数据处理然后保存。

3.某些阶段是因为:if k.find("conv")>-1 or k.find("pool")>-1 or k.find("rpn")>-1这行代码(110行),保证网络层name有这三个词的才会被保存,因为其他层无法用图片

保存,如全连接(参数已经是二维的了)等层。

4.放开174行print_param(net)的注释,就可以看到网络参数的输出。

5.执行的最终结果 是在data/feature_picture产生以图片名字为文件夹名字的文件夹,文件夹下有以网络每层name为名字的图片。

6.另外部分网络的层name中有非法字符不能作为图片名字,我在代码的111行只是把‘字符/’剔除掉了,所以建议网络名字不要又其他字符。

图片下载和代码下载方式:

[plain] view plaincopyprint?
  1. git clone https://github.com/meihuakaile/faster-rcnn.git

faster rcnn可视化(修改demo.py保存网络中间结果)相关推荐

  1. Caffe版Faster R-CNN可视化——网络模型,图像特征,Loss图,PR曲线

    可视化网络模型   Caffe目前有两种常用的可视化模型方式: 使用Netscope在线可视化 Caffe代码包内置的draw_net.py文件可以可视化网络模型 Netscope能可视化神经网络体系 ...

  2. 深度学习之windows python faster rcnn 配置及demo运行

    写这篇文章主要是针对深度学习零基础的新手,因为我也是新手,在配置环境这一块花了我很大的心血,网上的资料很多都只是说配置,然后直接运行就完了,可是对于我这样的新手在配置的过程中会遇见各种各样的问题,所以 ...

  3. python pr曲线_Py-Faster R-CNN可视化——网络模型,图像特征,Loss图,PR曲线

    可视化网络模型 使用Netscope在线可视化 Netscope Netscope能可视化神经网络体系结构(或技术上说,Netscope能可视化任何有向无环图).目前Netscope能可视化Caffe ...

  4. faster rcnn接口_Faster R-CNN教程

    Faster R-CNN教程 最后更新日期:2016年4月29日 本教程主要基于python版本的faster R-CNN,因为python layer的使用,这个版本会比matlab的版本速度慢10 ...

  5. [计算机视觉][神经网络与深度学习]Faster R-CNN配置及其训练教程

    Faster R-CNN教程 Faster R-CNN教程 最后更新日期:2016年4月29日 本教程主要基于python版本的faster R-CNN,因为python layer的使用,这个版本会 ...

  6. Faster R-CNN教程

    转载自:Faster R-CNN教程 - CarryPotMan的博客 - 博客频道 - CSDN.NET http://blog.csdn.net/u012891472/article/detail ...

  7. Faster R-CNN WINDOWS CPU环境搭建(详细版)

    操作系统: bigtop@bigtop-SdcOS-Hypervisor:~/py-faster-rcnn/tools$ cat /etc/issue Ubuntu 14.04.2 LTS \n \l ...

  8. 一文读懂Faster RCNN

    来源:信息网络工程研究中心本文约7500字,建议阅读10+分钟 本文从四个切入点为你介绍Faster R-CNN网络. 经过R-CNN和Fast RCNN的积淀,Ross B. Girshick在20 ...

  9. Faster R-CNN论文及源码解读

    R-CNN是目标检测领域中十分经典的方法,相比于传统的手工特征,R-CNN将卷积神经网络引入,用于提取深度特征,后接一个分类器判决搜索区域是否包含目标及其置信度,取得了较为准确的检测结果.Fast R ...

最新文章

  1. 【Redis学习笔记】2018-07-11 Redis指令学习5
  2. Matlab计算多项式的值(数值)
  3. Matlab对话框总结
  4. Tomcat 全攻略
  5. 控制台修改应用端口_应用架构六边型架构:三个原则和一个实现示例
  6. 好用的格式化SQL工具SQL Prompt
  7. 嵌入式系统开发笔记94:使用FlyMcu连接STM32开发板
  8. 黑马十次方2.0项目
  9. 给大家推荐一个免费职业评测
  10. 彩蛋-管理员root@‘locahost‘ 密码丢失,处理方案。
  11. [雪浪小镇启动仪式]阿里王坚:没有制造业的互联网没有未来?
  12. 如何安装和搭建wordpress个人网站(超详细+零基础)
  13. 关于剪枝对象的分类(weights剪枝、神经元剪枝、filters剪枝、layers剪枝、channel剪枝、对channel分组剪枝、Stripe剪枝)
  14. 【直达本质讲运放】运放的“第一原理”式定量分析法
  15. GOF设计模式之适配器模式的理解
  16. 13 个最好用的免费服务器和网络监控工具,不看吃亏!
  17. cocos2d带冷却的菜单按钮封装
  18. GIS二次开发(C#+AE)
  19. 解决R6025-pure virtual function call弹窗
  20. 有可以模拟钢琴弹奏乐曲的手机软件吗?

热门文章

  1. 互联网主要安全威胁解读及应对方案大讨论 | 高可用架构系列
  2. 基于Redis实现分布式应用限流--转
  3. 2015!我来了,你在哪里?今年第一篇
  4. 也谈BIO | NIO | AIO (Java版--转)
  5. 【Python】Pandas基础:结构化数据处理
  6. 【采用】百度大规模知识图谱构建及智能应用
  7. python对XML的解析
  8. 元宇宙:Facebook正式改名为Meta,要砸600亿做这件事
  9. 京东DNN Lab首席科学家:用深度学习搞定80%的客服工作
  10. 机器学习入门系列二(关键词:多变量(非)线性回归,批处理,特征缩放,正规方程