Faster-rcnn 源码学习(二)
Faster-rcnn 源码学习(二)
本节主要介绍分步训练中利用训练好的RPN网络生成proposal。
产品proposal的网络结构如下:
第二步利用训练好的rpn网络产生proposal,代码如下:
## 第二步,主要是利用第一步训练好的RPN网络来生成proposalprint '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'print 'Stage 1 RPN, generate proposals'print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'mp_kwargs = dict(queue=mp_queue,imdb_name=args.imdb_name,rpn_model_path=str(rpn_stage1_out['model_path']),cfg=cfg,rpn_test_prototxt=rpn_test_prototxt)p = mp.Process(target=rpn_generate, kwargs=mp_kwargs) #rpn_generate()产生proposalp.start() #开始生成proposalrpn_stage1_out['proposal_path'] = mp_queue.get()['proposal_path']p.join()
def rpn_generate(queue=None, imdb_name=None, rpn_model_path=None, cfg=None,rpn_test_prototxt=None):"""Use a trained RPN to generate proposals."""cfg.TEST.RPN_PRE_NMS_TOP_N = -1 # no pre NMS filtering,不使用NMScfg.TEST.RPN_POST_NMS_TOP_N = 2000 # limit top boxes after NMS 使用nms后产生2000个bboxprint 'RPN model: {}'.format(rpn_model_path)print('Using config:')pprint.pprint(cfg)import caffe_init_caffe(cfg)# NOTE: the matlab implementation computes proposals on flipped images, too.# We compute them on the image once and then flip the already computed# proposals. This might cause a minor loss in mAP (less proposal jittering).imdb = get_imdb(imdb_name)print 'Loaded dataset `{:s}` for proposal generation'.format(imdb.name)# Load RPN and configure output directoryrpn_net = caffe.Net(rpn_test_prototxt, rpn_model_path, caffe.TEST) #加载RPN网络output_dir = get_output_dir(imdb) #显示输出目录print 'Output will be saved to `{:s}`'.format(output_dir)# Generate proposals on the imdbrpn_proposals = imdb_proposals(rpn_net, imdb) #使用imdb_proposals()在所有图片上产生proposals# Write proposals to disk and send the proposal file path through the# multiprocessing queuerpn_net_name = os.path.splitext(os.path.basename(rpn_model_path))[0]rpn_proposals_path = os.path.join(output_dir, rpn_net_name + '_proposals.pkl') #将proposal的文件路径放入多线程中with open(rpn_proposals_path, 'wb') as f: #以二进制写模式打开cPickle.dump(rpn_proposals, f, cPickle.HIGHEST_PROTOCOL) #将之前生成的proposal序列化并存储到之前设置好的路径中print 'Wrote RPN proposals to {}'.format(rpn_proposals_path)queue.put({'proposal_path': rpn_proposals_path})#proposal序列化后存入队列,供其他进程使用
首先设置了预 NMS,还有就是经过 NMS 后产生 2000 个 proposals,然后初始化 caffe,再用 get_imdb() 函数得到 imdb 数据,方法和前面一样,再用 caffe.NET()加载 RPN 网络,再使用 imdb_proposals() 得到 proposal,那我们就进入这个函数:
def imdb_proposals(net, imdb):"""Generate RPN proposals on all images in an imdb.在所有图片上生成proposal"""_t = Timer() #产生一个时钟对象imdb_boxes = [[] for _ in xrange(imdb.num_images)]for i in xrange(imdb.num_images):im = cv2.imread(imdb.image_path_at(i)) #读取图片_t.tic()imdb_boxes[i], scores = im_proposals(net, im) #在单张图像上获得proplsals,包括boxes和scores,注意RPN中的cls只做一个二分类任务_t.toc()print 'im_proposals: {:d}/{:d} {:.3f}s' \.format(i + 1, imdb.num_images, _t.average_time)if 0:dets = np.hstack((imdb_boxes[i], scores)) # from IPython import embed; embed()_vis_proposals(im, dets[:3, :], thresh=0.9)plt.show()return imdb_boxes
该函数的作用就是在所有的图片上生成 proposal,不过作者又嵌套了一个 im_proposals() 函数,即在一张图片上产生 proposals,进入 im_proposals() 函数中:
def im_proposals(net, im):"""Generate RPN proposals on a single image."""blobs = {}blobs['data'], blobs['im_info'] = _get_image_blob(im) #将图片转换为blob的格式net.blobs['data'].reshape(*(blobs['data'].shape)) #将网络中的blob对应的结构相应的进行修改net.blobs['im_info'].reshape(*(blobs['im_info'].shape))blobs_out = net.forward(data=blobs['data'].astype(np.float32, copy=False),im_info=blobs['im_info'].astype(np.float32, copy=False)) #进行一次前向传播scale = blobs['im_info'][0, 2] #获得缩放比例boxes = blobs_out['rois'][:, 1:].copy() / scale #这个产生的boxes是对应与原图的尺寸scores = blobs_out['scores'].copy()return boxes, scores # 获取boxes和scores
首先用_get_image_blob() 函数将图片数据转换为 caffe 的 blob 格式,进入该函数:
def _get_image_blob(im):"""Converts an image into a network input.将输入的RGB图像转化为网络的输入格式Arguments:im (ndarray): a color image in BGR orderReturns:blob (ndarray): a data blob holding an image pyramidim_scale_factors (list): list of image scales (relative to im) usedin the image pyramid"""im_orig = im.astype(np.float32, copy=True) #实现变量类型转换,im_org是像素矩阵im_orig -= cfg.PIXEL_MEANS #减去像素均值im_shape = im_orig.shape #获得图像的像素尺寸im_size_min = np.min(im_shape[0:2])im_size_max = np.max(im_shape[0:2])processed_ims = []assert len(cfg.TEST.SCALES) == 1 #确保测试时只有一种图像尺寸target_size = cfg.TEST.SCALES[0] #图片的最短边im_scale = float(target_size) / float(im_size_min) #变换后的最短边除以原始图像的最短边得到缩放比例# Prevent the biggest axis from being more than MAX_SIZEif np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE: # 如果scale处理后的图像最大边大于要求的最大值MAX_SIZE,则修改放大比例im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,interpolation=cv2.INTER_LINEAR) #利用双线性插值进行图像缩放,按比例im_info = np.hstack((im.shape[:2], im_scale))[np.newaxis, :] #使用np.newaxis得到in_info,格式为【M,N,im_scale】processed_ims.append(im) #将调整好大小后的图片添加到processed_ims中# Create a blob to hold the input imagesblob = im_list_to_blob(processed_ims) #blob格式为(batch elem,channel,height,width)return blob, im_info
上述将缩放后的图像调用im_list_to_blob函数转换成blob格式的图像,进入该函数
def im_list_to_blob(ims):"""Convert a list of images into a network input.Assumes images are already prepared (means subtracted, BGR order, ...)."""max_shape = np.array([im.shape for im in ims]).max(axis=0)num_images = len(ims)blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),dtype=np.float32)for i in xrange(num_images):im = ims[i]blob[i, 0:im.shape[0], 0:im.shape[1], :] = im# Move channels (axis 3) to axis 1# Axis order will become: (batch elem, channel, height, width)channel_swap = (0, 3, 1, 2)blob = blob.transpose(channel_swap)return blob
最终得到的 blob 格式为 (batch elem , channel , height , width),im_info 格式为[M,N,im_scale],其中 im_scale 是缩放比例,原始图片输入 faster rcnn 中进行训练时都需要先缩放成统一的规格;再回到 im_proposals() 函数中,使用 net.forward()函数进行一次前向传播,获得blobs_out为计算得到的proposal,数据格式如下:
之后再回到 imdb_proposals()函数中,最后返回得到的 imdb_boxes, 即我们从 RPN 上产生的 proposals。再回到 rpn_generate()函数中,变量rpn_proposals值如下图,测试中用到了5张图,因此是一个长为5的list,每个list的结构为(2000,4)即每张图有2000个预测框。
接着就是将生成的 proposals 保存并传输到多线程中去供下一步训练使用,这个函数使命就暂时完成了;
Faster-rcnn 源码学习(二)相关推荐
- 【Faster R-CNN论文精度系列】从Faster R-CNN源码中,我们“学习”到了什么?
[Faster R-CNN论文精度系列] (如下为建议阅读顺序) 1[Faster R-CNN论文精度系列]从Faster R-CNN源码中,我们"学习"到了什么? 2[Faste ...
- faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data)
转载自:faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data) - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.n ...
- Faster R-CNN源码中RPN的解析(自用)
参考博客(一定要看前面两个) 一文看懂Faster R-CNN 详细的Faster R-CNN源码解析之RPN源码解析 关于RPN一些我的想法 rpn的中心思想就是在了anchors了,如何产生anc ...
- faster rcnn源码解读(六)之minibatch
转载自:faster rcnn源码解读(六)之minibatch - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u010668907/article/ ...
- faster rcnn源码解读(五)之layer(网络里的input-data)
转载自:faster rcnn源码解读(五)之layer(网络里的input-data) - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u010668 ...
- faster rcnn源码解读(四)之数据类型imdb.py和pascal_voc.py(主要是imdb和roidb数据类型的解说)
转载自:faster rcnn源码解读(四)之数据类型imdb.py和pascal_voc.py(主要是imdb和roidb数据类型的解说) - 野孩子的专栏 - 博客频道 - CSDN.NET ht ...
- faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py
转载自:faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u ...
- faster rcnn源码解读总结
转载自:faster rcnn源码解读总结 - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u010668907/article/details/519 ...
- 人工智能学习07--pytorch18--目标检测:Faster RCNN源码解析(pytorch)
参考博客: https://blog.csdn.net/weixin_46676835/article/details/130175898 VOC2012 1.代码的使用 查看pytorch中的fas ...
- 详细的Faster R-CNN源码解析之RPN源码解析
在阔别了将近三个月之后,笔者又准备更新博客了.对于前两个多月的未及时更新,笔者在此向大家表示歉意,请大家原谅. 本次博客的更新是关于Faster R-CNN的源码.首先说一下笔者为什么要更新Faste ...
最新文章
- IOS自定义表格UITableViewCell
- 深入解析Spring架构与设计原理-AOP
- Spark中常用的算法
- Scala闭包特性的一个测试
- POJ1236Network of Schools——强连通分量缩点建图
- 股票自动交易使用协议
- Go单测测试 — 数据库 CRUD 的 Mock 测试
- 为什么c语言一用windows.h就报错_C代码里面加一行网址依然可以运行,并不会报错,为何...
- 进阶04 4 Collection集合类+Iterator迭代器+增强for+泛型
- 设计原则 里氏替换原则
- 常用 ASCII 码整理
- opencv查看版本路径
- 中兴V889DRoot后可删和不可删
- 世界上第一次网络瘫痪 | 历史上的今天
- python列表del_python删除列表元素的三种方法(remove,pop,del)
- 谷歌浏览器清除dns缓存
- hashmap是单向链表吗_LRU(Least Recent Used) java 实现为这么采用HashMap+双向链表
- RK3308 WIFI驱动调试
- 骞云科技携手上海电力、兴业证券,双案例入选2022年CMP优秀案例
- android gps locationCb 数据
热门文章
- [Minitab]如何製作柏拉圖(Pareto chart)?
- UVC Extension Unit 相关资料整理
- android双卡切换流量代码,双卡双待手机流量怎么切换 方法有哪些【图文】
- 解决Android studio运行代码手机出现xxx keeps stopping
- python怎么爬虎牙_【python】虎牙直播爬虫项目
- 扯ruan蛋的房价,恶心死了我
- 零基础入门,资深吃货带你搞懂大数据
- 梅特勒-托利多 TCS-35 电子台秤
- 计算机内存占用过高,如果内存使用率过高怎么办? Win10计算机内存占用率高的原因和解决方案...
- CMMI3-CMMI5评估认证需要遵循七大原则