Faster-rcnn 源码学习(二)

本节主要介绍分步训练中利用训练好的RPN网络生成proposal。

产品proposal的网络结构如下:

第二步利用训练好的rpn网络产生proposal,代码如下:

## 第二步,主要是利用第一步训练好的RPN网络来生成proposalprint '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'print 'Stage 1 RPN, generate proposals'print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'mp_kwargs = dict(queue=mp_queue,imdb_name=args.imdb_name,rpn_model_path=str(rpn_stage1_out['model_path']),cfg=cfg,rpn_test_prototxt=rpn_test_prototxt)p = mp.Process(target=rpn_generate, kwargs=mp_kwargs) #rpn_generate()产生proposalp.start() #开始生成proposalrpn_stage1_out['proposal_path'] = mp_queue.get()['proposal_path']p.join()
def rpn_generate(queue=None, imdb_name=None, rpn_model_path=None, cfg=None,rpn_test_prototxt=None):"""Use a trained RPN to generate proposals."""cfg.TEST.RPN_PRE_NMS_TOP_N = -1     # no pre NMS filtering,不使用NMScfg.TEST.RPN_POST_NMS_TOP_N = 2000  # limit top boxes after NMS 使用nms后产生2000个bboxprint 'RPN model: {}'.format(rpn_model_path)print('Using config:')pprint.pprint(cfg)import caffe_init_caffe(cfg)# NOTE: the matlab implementation computes proposals on flipped images, too.# We compute them on the image once and then flip the already computed# proposals. This might cause a minor loss in mAP (less proposal jittering).imdb = get_imdb(imdb_name)print 'Loaded dataset `{:s}` for proposal generation'.format(imdb.name)# Load RPN and configure output directoryrpn_net = caffe.Net(rpn_test_prototxt, rpn_model_path, caffe.TEST) #加载RPN网络output_dir = get_output_dir(imdb) #显示输出目录print 'Output will be saved to `{:s}`'.format(output_dir)# Generate proposals on the imdbrpn_proposals = imdb_proposals(rpn_net, imdb) #使用imdb_proposals()在所有图片上产生proposals# Write proposals to disk and send the proposal file path through the# multiprocessing queuerpn_net_name = os.path.splitext(os.path.basename(rpn_model_path))[0]rpn_proposals_path = os.path.join(output_dir, rpn_net_name + '_proposals.pkl') #将proposal的文件路径放入多线程中with open(rpn_proposals_path, 'wb') as f:  #以二进制写模式打开cPickle.dump(rpn_proposals, f, cPickle.HIGHEST_PROTOCOL) #将之前生成的proposal序列化并存储到之前设置好的路径中print 'Wrote RPN proposals to {}'.format(rpn_proposals_path)queue.put({'proposal_path': rpn_proposals_path})#proposal序列化后存入队列,供其他进程使用

​ 首先设置了预 NMS,还有就是经过 NMS 后产生 2000 个 proposals,然后初始化 caffe,再用 get_imdb() 函数得到 imdb 数据,方法和前面一样,再用 caffe.NET()加载 RPN 网络,再使用 imdb_proposals() 得到 proposal,那我们就进入这个函数:

def imdb_proposals(net, imdb):"""Generate RPN proposals on all images in an imdb.在所有图片上生成proposal"""_t = Timer() #产生一个时钟对象imdb_boxes = [[] for _ in xrange(imdb.num_images)]for i in xrange(imdb.num_images):im = cv2.imread(imdb.image_path_at(i)) #读取图片_t.tic()imdb_boxes[i], scores = im_proposals(net, im) #在单张图像上获得proplsals,包括boxes和scores,注意RPN中的cls只做一个二分类任务_t.toc()print 'im_proposals: {:d}/{:d} {:.3f}s' \.format(i + 1, imdb.num_images, _t.average_time)if 0:dets = np.hstack((imdb_boxes[i], scores)) # from IPython import embed; embed()_vis_proposals(im, dets[:3, :], thresh=0.9)plt.show()return imdb_boxes

​ 该函数的作用就是在所有的图片上生成 proposal,不过作者又嵌套了一个 im_proposals() 函数,即在一张图片上产生 proposals,进入 im_proposals() 函数中:

def im_proposals(net, im):"""Generate RPN proposals on a single image."""blobs = {}blobs['data'], blobs['im_info'] = _get_image_blob(im) #将图片转换为blob的格式net.blobs['data'].reshape(*(blobs['data'].shape)) #将网络中的blob对应的结构相应的进行修改net.blobs['im_info'].reshape(*(blobs['im_info'].shape))blobs_out = net.forward(data=blobs['data'].astype(np.float32, copy=False),im_info=blobs['im_info'].astype(np.float32, copy=False)) #进行一次前向传播scale = blobs['im_info'][0, 2] #获得缩放比例boxes = blobs_out['rois'][:, 1:].copy() / scale #这个产生的boxes是对应与原图的尺寸scores = blobs_out['scores'].copy()return boxes, scores # 获取boxes和scores

首先用_get_image_blob() 函数将图片数据转换为 caffe 的 blob 格式,进入该函数:

def _get_image_blob(im):"""Converts an image into a network input.将输入的RGB图像转化为网络的输入格式Arguments:im (ndarray): a color image in BGR orderReturns:blob (ndarray): a data blob holding an image pyramidim_scale_factors (list): list of image scales (relative to im) usedin the image pyramid"""im_orig = im.astype(np.float32, copy=True)  #实现变量类型转换,im_org是像素矩阵im_orig -= cfg.PIXEL_MEANS #减去像素均值im_shape = im_orig.shape #获得图像的像素尺寸im_size_min = np.min(im_shape[0:2])im_size_max = np.max(im_shape[0:2])processed_ims = []assert len(cfg.TEST.SCALES) == 1 #确保测试时只有一种图像尺寸target_size = cfg.TEST.SCALES[0] #图片的最短边im_scale = float(target_size) / float(im_size_min) #变换后的最短边除以原始图像的最短边得到缩放比例# Prevent the biggest axis from being more than MAX_SIZEif np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE: # 如果scale处理后的图像最大边大于要求的最大值MAX_SIZE,则修改放大比例im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,interpolation=cv2.INTER_LINEAR)  #利用双线性插值进行图像缩放,按比例im_info = np.hstack((im.shape[:2], im_scale))[np.newaxis, :] #使用np.newaxis得到in_info,格式为【M,N,im_scale】processed_ims.append(im) #将调整好大小后的图片添加到processed_ims中# Create a blob to hold the input imagesblob = im_list_to_blob(processed_ims) #blob格式为(batch elem,channel,height,width)return blob, im_info

​ 上述将缩放后的图像调用im_list_to_blob函数转换成blob格式的图像,进入该函数

def im_list_to_blob(ims):"""Convert a list of images into a network input.Assumes images are already prepared (means subtracted, BGR order, ...)."""max_shape = np.array([im.shape for im in ims]).max(axis=0)num_images = len(ims)blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),dtype=np.float32)for i in xrange(num_images):im = ims[i]blob[i, 0:im.shape[0], 0:im.shape[1], :] = im# Move channels (axis 3) to axis 1# Axis order will become: (batch elem, channel, height, width)channel_swap = (0, 3, 1, 2)blob = blob.transpose(channel_swap)return blob

最终得到的 blob 格式为 (batch elem , channel , height , width),im_info 格式为[M,N,im_scale],其中 im_scale 是缩放比例,原始图片输入 faster rcnn 中进行训练时都需要先缩放成统一的规格;再回到 im_proposals() 函数中,使用 net.forward()函数进行一次前向传播,获得blobs_out为计算得到的proposal,数据格式如下:

之后再回到 imdb_proposals()函数中,最后返回得到的 imdb_boxes, 即我们从 RPN 上产生的 proposals。再回到 rpn_generate()函数中,变量rpn_proposals值如下图,测试中用到了5张图,因此是一个长为5的list,每个list的结构为(2000,4)即每张图有2000个预测框。

接着就是将生成的 proposals 保存并传输到多线程中去供下一步训练使用,这个函数使命就暂时完成了;

Faster-rcnn 源码学习(二)相关推荐

  1. 【Faster R-CNN论文精度系列】从Faster R-CNN源码中,我们“学习”到了什么?

    [Faster R-CNN论文精度系列] (如下为建议阅读顺序) 1[Faster R-CNN论文精度系列]从Faster R-CNN源码中,我们"学习"到了什么? 2[Faste ...

  2. faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data)

    转载自:faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data) - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.n ...

  3. Faster R-CNN源码中RPN的解析(自用)

    参考博客(一定要看前面两个) 一文看懂Faster R-CNN 详细的Faster R-CNN源码解析之RPN源码解析 关于RPN一些我的想法 rpn的中心思想就是在了anchors了,如何产生anc ...

  4. faster rcnn源码解读(六)之minibatch

    转载自:faster rcnn源码解读(六)之minibatch - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u010668907/article/ ...

  5. faster rcnn源码解读(五)之layer(网络里的input-data)

    转载自:faster rcnn源码解读(五)之layer(网络里的input-data) - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u010668 ...

  6. faster rcnn源码解读(四)之数据类型imdb.py和pascal_voc.py(主要是imdb和roidb数据类型的解说)

    转载自:faster rcnn源码解读(四)之数据类型imdb.py和pascal_voc.py(主要是imdb和roidb数据类型的解说) - 野孩子的专栏 - 博客频道 - CSDN.NET ht ...

  7. faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py

    转载自:faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u ...

  8. faster rcnn源码解读总结

    转载自:faster rcnn源码解读总结 - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u010668907/article/details/519 ...

  9. 人工智能学习07--pytorch18--目标检测:Faster RCNN源码解析(pytorch)

    参考博客: https://blog.csdn.net/weixin_46676835/article/details/130175898 VOC2012 1.代码的使用 查看pytorch中的fas ...

  10. 详细的Faster R-CNN源码解析之RPN源码解析

    在阔别了将近三个月之后,笔者又准备更新博客了.对于前两个多月的未及时更新,笔者在此向大家表示歉意,请大家原谅. 本次博客的更新是关于Faster R-CNN的源码.首先说一下笔者为什么要更新Faste ...

最新文章

  1. IOS自定义表格UITableViewCell
  2. 深入解析Spring架构与设计原理-AOP
  3. Spark中常用的算法
  4. Scala闭包特性的一个测试
  5. POJ1236Network of Schools——强连通分量缩点建图
  6. 股票自动交易使用协议
  7. Go单测测试 — 数据库 CRUD 的 Mock 测试
  8. 为什么c语言一用windows.h就报错_C代码里面加一行网址依然可以运行,并不会报错,为何...
  9. 进阶04 4 Collection集合类+Iterator迭代器+增强for+泛型
  10. 设计原则 里氏替换原则
  11. 常用 ASCII 码整理
  12. opencv查看版本路径
  13. 中兴V889DRoot后可删和不可删
  14. 世界上第一次网络瘫痪 | 历史上的今天
  15. python列表del_python删除列表元素的三种方法(remove,pop,del)
  16. 谷歌浏览器清除dns缓存
  17. hashmap是单向链表吗_LRU(Least Recent Used) java 实现为这么采用HashMap+双向链表
  18. RK3308 WIFI驱动调试
  19. 骞云科技携手上海电力、兴业证券,双案例入选2022年CMP优秀案例
  20. android gps locationCb 数据

热门文章

  1. [Minitab]如何製作柏拉圖(Pareto chart)?
  2. UVC Extension Unit 相关资料整理
  3. android双卡切换流量代码,双卡双待手机流量怎么切换 方法有哪些【图文】
  4. 解决Android studio运行代码手机出现xxx keeps stopping
  5. python怎么爬虎牙_【python】虎牙直播爬虫项目
  6. 扯ruan蛋的房价,恶心死了我
  7. 零基础入门,资深吃货带你搞懂大数据
  8. 梅特勒-托利多 TCS-35 电子台秤
  9. 计算机内存占用过高,如果内存使用率过高怎么办? Win10计算机内存占用率高的原因和解决方案...
  10. CMMI3-CMMI5评估认证需要遵循七大原则