之前的文章简要介绍了Faster-RCNN等物体检测的算法,本文将从代码角度详细分析介绍Faster-RCNN的实现。本文使用的代码参考了chenyuntc的实现,代码的位置看这里。需要注意的是,本文使用的框架是Pytorch。

数据载入

数据载入部分的代码主要见./data/dataset.py中的类DatasetTestDataset

数据载入部分的逻辑如下:

  1. 从VOC数据集中获得img, bbox, label
  2. img, bbox进行放缩(放缩的目的是让图片处于合适的大小,这样预先指定锚框才有意义)
  3. img进行标准化正则处理
  4. 如果是训练阶段,将img翻转以增加训练数据

网络结构

FasterRCNN的网络结构如下图所示:

FasterRCNN结构的代码主要见./model.faster_rcnn.py,其结构包含三大部分:

  1. 预训练的CNN模型 decom_vgg16
  2. rpn网络RegionProposalNetwork
  3. roi及以上网络VGG16RoIHead

下面,将以放缩后大小为[1, 3, 600, 800]的图片为例针对每个部分分别介绍。图像类别共计21类(包含背景)。

预训练的CNN模型

该部分代码见./model/vgg16.py

输入:图片,大小[1, 3, 600, 800]
输出:特征图features,大小[1, 512, 37, 50]


其逻辑如下:

  1. 载入预先训练好的CNN模型VGG16。
  2. 将模型拆分为两部分extractor, classifier。其中,extractor的参数固定。
  3. 图片通过extractor可以得到特征图features。根据extractor中池化参数可知图像通过extractor缩小了16倍。

rpn网络

该部分代码见./model/rpn.py

输入:特征图features,大小[1, 512, 37, 50]
输出:

  • rpn_locs:rpn对位置的修正,大小[1, 16650, 4]
  • rpn_scores :rpn判断区域前景背景,大小[1, 16650, 2]
  • rois:rpn筛选出的roi的位置,大小[300, 4]
  • roi_indices:rpn筛选出的roi对应的图片索引,大小[300]
  • anchor:原图像的锚点,大小[16650, 4]

其中,16650是放缩后的图像所产生的所有锚点(37*50*9),每个锚点都对应了一个rp。通过 rpn_scores以及nms可以得到筛选后的大小为300的roi。


其逻辑如下:

  1. 对特征图features以基准长度为16、选择合适的ratiosscales取基准锚点anchor_base。(选择长度为16的原因是图片大小为600*800左右,基准长度16对应的原图区域是256*256,考虑放缩后的大小有128*128,512*512比较合适)
  2. 根据anchor_base在原图上获得anchors
  3. 对特征图features采用卷积得到rpn_locsrpn_scores
  4. 根据anchorsrpn_locs获得修正后的rp
  5. rp进一步修正获得roisroi_indices,修正包括超出边界的部分截断、移除太小的、nms。

roi及以上网络

该部分代码见./model/roi_module.py

输入:

  • features:特征图,大小[1, 512, 37, 50]
  • rois:rpn筛选出的roi的位置,大小[300, 4]
  • roi_indices:rpn筛选出的roi对应的图片索引,大小[300]

输出:

  • roi_cls_locsroi位置的修正,大小[300, 84]
  • roi_scoresroi各类的分数,大小[300, 21]

其逻辑如下:

  1. 通过RoIPooling2D将大小不同的roi变成大小一致,得到pooling后的特征,大小为[300, 512, 7, 7]
  2. 接入预训练的CNN模型引入的classifier
  3. 分别接入全连接得到roi_cls_locsroi_scores

训练

训练部分的代码主要见./trainer/trainer.py中的FasterRCNNTrainer中的train_step函数。

训练部分的核心是loss如何求取。

loss求取前网络的步骤如下:

  1. 预训练CNN特征提取:输入imgextractor获得features
  2. rpn网络得到roi:输入featuresrpn获得rpn_locs, rpn_scores, rois, roi_indices, anchor
  3. 抽样roi:输入roisbboxlabelProposalTargetCreator获得sample_roi, gt_roi_loc, gt_roi_label。该步骤的含义是得到正负例比例和位置合适的roi
  4. head网络得到roi的位置修正与分数:输入features,sample_roi,sample_roi_index得到roi_cls_loc, roi_score

各个loss求取的方式如下:

  1. rpn_loc_loss:已知rpn_loc,需要先根据anchorbbox得到真实的gt_rpn_locgt_rpn_label。该处loss的计算只考虑前景,所以根据rpn_loc,gt_rpn_loc,gt_rpn_label计算L1-LOSS即可。
  2. rpn_cls_loss:根据rpn_scoregt_rpn_label计算二分类的交叉熵即可。
  3. roi_loc_loss:已知roi_loc,在sample roi的过程中已获得gt_roi_loc, gt_roi_label。根据roi_loc,gt_roi_loc,gt_roi_label计算L1-LOSS即可。
  4. roi_cls_loss:根据roi_scoregt_roi_label计算多分类的交叉熵即可。

整体的loss为以上各loss相加求和。

测试

训练部分的代码主要见./model/faster_rcnn.py中的FasterRCNNTrainer中的predict函数。

其步骤如下:

  1. 图片预处理
  2. 预训练CNN特征提取:输入imgextractor获得features
  3. rpn网络得到roi:输入featuresrpn获得rpn_locs, rpn_scores, rois, roi_indices, anchor
  4. head网络得到roi的位置修正与分数:输入features,rois,roi_indices得到roi_cls_loc, roi_score
  5. 得到图片预测的bbox:输入roi_cls_locroi_scorerois,采用nms等方法得到预测的bbox

FasterRCNN代码解读相关推荐

  1. nsga2代码解读python_代码资料

    faster RCNN TensorFlow版本: 龙鹏:[技术综述]万字长文详解Faster RCNN源代码(一) buptscdc:tensorflow 版faster rcnn代码理解(1) l ...

  2. 200行代码解读TDEngine背后的定时器

    作者 | beyondma来源 | CSDN博客 导读:最近几周,本文作者几篇有关陶建辉老师最新的创业项目-TdEngine代码解读文章出人意料地引起了巨大的反响,原以为C语言已经是昨日黄花,不过从读 ...

  3. 装逼一步到位!GauGAN代码解读来了

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...

  4. Unet论文解读代码解读

    论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 网络 架构: a.U-net建立在FCN的网络架构上,作者修改并扩大了这个网络框架,使其能够使用很少 ...

  5. Lossless Codec---APE代码解读系列(二)

    APE file 一些概念 APE代码解读系列(一) APE代码解读系列(三) 1. 先要了解APE compression level APE主要有5level, 分别是: CompressionL ...

  6. RT-Thread 学习笔记(五)—— RTGUI代码解读

    ---恢复内容开始--- RT-Thread 版本:2.1.0 RTGUI相关代码解读,仅为自己学习记录,若有错误之处,请告知maoxudong0813@163.com,不胜感激! GUI流程: ma ...

  7. vins 解读_代码解读 | VINS 视觉前端

    AI 人工智能 代码解读 | VINS 视觉前端 本文作者是计算机视觉life公众号成员蔡量力,由于格式问题部分内容显示可能有问题,更好的阅读体验,请查看原文链接:代码解读 | VINS 视觉前端 v ...

  8. BERT:代码解读、实体关系抽取实战

    目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...

  9. shfflenetv2代码解读

    shufflenetv2代码解读 目录 shufflenetv2代码解读 概述 shufflenetv2网络结构图 shufflenetv2架构参数 shufflenetv2代码细节分析 概述 shu ...

最新文章

  1. 替换ubuntu 源
  2. 世博展示生态化住宅 物联网助推智能化家居
  3. boost::log::sinks用法的测试程序
  4. js基本数据类型和复杂数据类型的区别
  5. python sort、sorted 高级用法
  6. SVN插件版本过低1.6的已经不兼容现在新版的eclipse 了用 1.8X的吧
  7. Chart.js学习
  8. 微信又出隐藏“技能”,一夜之间朋友圈刷屏了
  9. nginx 安装配置指南
  10. caffe实践程序1——mnist任务总结
  11. IO-01. 表格输出(5)
  12. Julia :HDF5数据文件读写与更新
  13. 微信小程序实现图片翻转效果
  14. 全连接神经网络的二分类问题
  15. 【协程】MyCoroutine轻量级协程框架代码详细剖解
  16. 笔记:Linux系统调用在文件中的分布情况
  17. 计算机组装与维修专用周报告,《计算机组装与维护专用周》实习报告.doc
  18. 如何合理选择 PLC
  19. 计算机主机恢复上电检查,电脑故障维修判断指导总结
  20. SQLi lab: Equivalent to information schema on Oracle

热门文章

  1. Cefsharp生成的项目在自己电脑上能打开,其他电脑上不行,提示找不到指定文件cefsharp.core.dll
  2. winform程序打包EXE三种方式
  3. 基于点线特征避免单目视觉SLAM的退化
  4. 深度学习在机器人视觉中的局限与优势(综述)
  5. SA-SSD:阿里达摩院最新3D检测力作(CVPR2020)
  6. Nat. Methods | 学习微生物与代谢产物之间相互作用的神经网络
  7. Bioinformatics|基于知识图谱嵌入的药物靶标发现
  8. 在R中子集化数据框的5种方法
  9. AGAT|GTF/GFF文件处理工具
  10. JIPB:白洋组综述根系微生物组群落构建及其对植物适应性的贡献