FasterRCNN代码解读
之前的文章简要介绍了Faster-RCNN等物体检测的算法,本文将从代码角度详细分析介绍Faster-RCNN的实现。本文使用的代码参考了chenyuntc的实现,代码的位置看这里。需要注意的是,本文使用的框架是Pytorch。
数据载入
数据载入部分的代码主要见./data/dataset.py
中的类Dataset
与TestDataset
。
数据载入部分的逻辑如下:
- 从VOC数据集中获得
img, bbox, label
- 将
img, bbox
进行放缩(放缩的目的是让图片处于合适的大小,这样预先指定锚框才有意义) - 将
img
进行标准化正则处理 - 如果是训练阶段,将
img
翻转以增加训练数据
网络结构
FasterRCNN的网络结构如下图所示:
FasterRCNN结构的代码主要见./model.faster_rcnn.py
,其结构包含三大部分:
- 预训练的CNN模型
decom_vgg16
- rpn网络
RegionProposalNetwork
- roi及以上网络
VGG16RoIHead
下面,将以放缩后大小为[1, 3, 600, 800]
的图片为例针对每个部分分别介绍。图像类别共计21类(包含背景)。
预训练的CNN模型
该部分代码见./model/vgg16.py
。
输入:图片,大小[1, 3, 600, 800]
输出:特征图features
,大小[1, 512, 37, 50]
其逻辑如下:
- 载入预先训练好的CNN模型VGG16。
- 将模型拆分为两部分
extractor
,classifier
。其中,extractor
的参数固定。 - 图片通过
extractor
可以得到特征图features
。根据extractor
中池化参数可知图像通过extractor
缩小了16倍。
rpn网络
该部分代码见./model/rpn.py
。
输入:特征图features
,大小[1, 512, 37, 50]
输出:
rpn_locs
:rpn对位置的修正,大小[1, 16650, 4]
rpn_scores
:rpn判断区域前景背景,大小[1, 16650, 2]
rois
:rpn筛选出的roi的位置,大小[300, 4]
roi_indices
:rpn筛选出的roi对应的图片索引,大小[300]
anchor
:原图像的锚点,大小[16650, 4]
其中,16650是放缩后的图像所产生的所有锚点(37*50*9),每个锚点都对应了一个rp。通过 rpn_scores
以及nms可以得到筛选后的大小为300的roi。
其逻辑如下:
- 对特征图
features
以基准长度为16、选择合适的ratios
和scales
取基准锚点anchor_base
。(选择长度为16的原因是图片大小为600*800左右,基准长度16对应的原图区域是256*256,考虑放缩后的大小有128*128,512*512比较合适) - 根据
anchor_base
在原图上获得anchors
。 - 对特征图
features
采用卷积得到rpn_locs
和rpn_scores
- 根据
anchors
和rpn_locs
获得修正后的rp
- 对
rp
进一步修正获得rois
和roi_indices
,修正包括超出边界的部分截断、移除太小的、nms。
roi及以上网络
该部分代码见./model/roi_module.py
。
输入:
features
:特征图,大小[1, 512, 37, 50]
rois
:rpn筛选出的roi的位置,大小[300, 4]
roi_indices
:rpn筛选出的roi对应的图片索引,大小[300]
输出:
roi_cls_locs
:roi
位置的修正,大小[300, 84]
roi_scores
:roi
各类的分数,大小[300, 21]
其逻辑如下:
- 通过
RoIPooling2D
将大小不同的roi
变成大小一致,得到pooling后的特征,大小为[300, 512, 7, 7]
- 接入预训练的CNN模型引入的
classifier
- 分别接入全连接得到
roi_cls_locs
、roi_scores
训练
训练部分的代码主要见./trainer/trainer.py
中的FasterRCNNTrainer
中的train_step
函数。
训练部分的核心是loss如何求取。
loss求取前网络的步骤如下:
- 预训练CNN特征提取:输入
img
到extractor
获得features
- rpn网络得到roi:输入
features
到rpn
获得rpn_locs
,rpn_scores
,rois
,roi_indices
,anchor
- 抽样roi:输入
rois
,bbox
,label
到ProposalTargetCreator
获得sample_roi
,gt_roi_loc
,gt_roi_label
。该步骤的含义是得到正负例比例和位置合适的roi
。 - head网络得到roi的位置修正与分数:输入
features
,sample_roi
,sample_roi_index
得到roi_cls_loc
,roi_score
各个loss求取的方式如下:
rpn_loc_loss
:已知rpn_loc
,需要先根据anchor
和bbox
得到真实的gt_rpn_loc
和gt_rpn_label
。该处loss的计算只考虑前景,所以根据rpn_loc
,gt_rpn_loc
,gt_rpn_label
计算L1-LOSS即可。rpn_cls_loss
:根据rpn_score
和gt_rpn_label
计算二分类的交叉熵即可。roi_loc_loss
:已知roi_loc
,在sample roi的过程中已获得gt_roi_loc
,gt_roi_label
。根据roi_loc
,gt_roi_loc
,gt_roi_label
计算L1-LOSS即可。roi_cls_loss
:根据roi_score
和gt_roi_label
计算多分类的交叉熵即可。
整体的loss为以上各loss相加求和。
测试
训练部分的代码主要见./model/faster_rcnn.py
中的FasterRCNNTrainer
中的predict
函数。
其步骤如下:
- 图片预处理
- 预训练CNN特征提取:输入
img
到extractor
获得features
- rpn网络得到roi:输入
features
到rpn
获得rpn_locs
,rpn_scores
,rois
,roi_indices
,anchor
- head网络得到roi的位置修正与分数:输入
features
,rois
,roi_indices
得到roi_cls_loc
,roi_score
- 得到图片预测的bbox:输入
roi_cls_loc
、roi_score
、rois
,采用nms等方法得到预测的bbox
。
FasterRCNN代码解读相关推荐
- nsga2代码解读python_代码资料
faster RCNN TensorFlow版本: 龙鹏:[技术综述]万字长文详解Faster RCNN源代码(一) buptscdc:tensorflow 版faster rcnn代码理解(1) l ...
- 200行代码解读TDEngine背后的定时器
作者 | beyondma来源 | CSDN博客 导读:最近几周,本文作者几篇有关陶建辉老师最新的创业项目-TdEngine代码解读文章出人意料地引起了巨大的反响,原以为C语言已经是昨日黄花,不过从读 ...
- 装逼一步到位!GauGAN代码解读来了
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...
- Unet论文解读代码解读
论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 网络 架构: a.U-net建立在FCN的网络架构上,作者修改并扩大了这个网络框架,使其能够使用很少 ...
- Lossless Codec---APE代码解读系列(二)
APE file 一些概念 APE代码解读系列(一) APE代码解读系列(三) 1. 先要了解APE compression level APE主要有5level, 分别是: CompressionL ...
- RT-Thread 学习笔记(五)—— RTGUI代码解读
---恢复内容开始--- RT-Thread 版本:2.1.0 RTGUI相关代码解读,仅为自己学习记录,若有错误之处,请告知maoxudong0813@163.com,不胜感激! GUI流程: ma ...
- vins 解读_代码解读 | VINS 视觉前端
AI 人工智能 代码解读 | VINS 视觉前端 本文作者是计算机视觉life公众号成员蔡量力,由于格式问题部分内容显示可能有问题,更好的阅读体验,请查看原文链接:代码解读 | VINS 视觉前端 v ...
- BERT:代码解读、实体关系抽取实战
目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...
- shfflenetv2代码解读
shufflenetv2代码解读 目录 shufflenetv2代码解读 概述 shufflenetv2网络结构图 shufflenetv2架构参数 shufflenetv2代码细节分析 概述 shu ...
最新文章
- 替换ubuntu 源
- 世博展示生态化住宅 物联网助推智能化家居
- boost::log::sinks用法的测试程序
- js基本数据类型和复杂数据类型的区别
- python sort、sorted 高级用法
- SVN插件版本过低1.6的已经不兼容现在新版的eclipse 了用 1.8X的吧
- Chart.js学习
- 微信又出隐藏“技能”,一夜之间朋友圈刷屏了
- nginx 安装配置指南
- caffe实践程序1——mnist任务总结
- IO-01. 表格输出(5)
- Julia :HDF5数据文件读写与更新
- 微信小程序实现图片翻转效果
- 全连接神经网络的二分类问题
- 【协程】MyCoroutine轻量级协程框架代码详细剖解
- 笔记:Linux系统调用在文件中的分布情况
- 计算机组装与维修专用周报告,《计算机组装与维护专用周》实习报告.doc
- 如何合理选择 PLC
- 计算机主机恢复上电检查,电脑故障维修判断指导总结
- SQLi lab: Equivalent to information schema on Oracle
热门文章
- Cefsharp生成的项目在自己电脑上能打开,其他电脑上不行,提示找不到指定文件cefsharp.core.dll
- winform程序打包EXE三种方式
- 基于点线特征避免单目视觉SLAM的退化
- 深度学习在机器人视觉中的局限与优势(综述)
- SA-SSD:阿里达摩院最新3D检测力作(CVPR2020)
- Nat. Methods | 学习微生物与代谢产物之间相互作用的神经网络
- Bioinformatics|基于知识图谱嵌入的药物靶标发现
- 在R中子集化数据框的5种方法
- AGAT|GTF/GFF文件处理工具
- JIPB:白洋组综述根系微生物组群落构建及其对植物适应性的贡献