faster rcnn中RPN网络源码分析(pytorch)
最近刚入坑检测,初步看了RGB大佬的faster rcnn文章,再看看源码
本次分析的源码是陈云大佬pytorch版本的GITHUB地址
上一张输入输出图
一、forward
主文件./model/region_proposal_network.py
- rpn_scores & rpn_locs
input : feature maps
output : rpn_scores 、 rpn_locs
(1)、feature maps过 n_anchor * 2个卷积核得到每个anchor的前景背景的分类得分rpn_locs
(2)、feature maps过 n_anchor * 4个卷积核得到每个anchor的中心点坐标和宽高的尺度变换比值rpn_locs
#rpn中初始化定义的Layer
self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
self.score = nn.Conv2d(mid_channels, n_anchor * 2, 1, 1, 0) #n_anchor * 2 ,作为每个anchor的前景背景的分类得分,二分类所以*2
self.loc = nn.Conv2d(mid_channels, n_anchor * 4, 1, 1, 0) #n_anchor * 4,为每个anchor的中心坐标偏移比例和宽、高的各自的尺寸比
#rpn前传网络中连接
h = F.relu(self.conv1(x)) #x是extractor提取的feature maps
rpn_locs = self.loc(h)
rpn_scores = self.score(h)
- anchor
# 生成anchor_base,即feature map第一个点对应的anchors
#使用的函数(path:./model/utils/bbox_tool.py)
def generate_anchor_base(base_size=16, ratios=[0.5, 1, 2],anchor_scales=[8, 16, 32])
#以(8,8)为中心,长宽比分别为[0.5, 1, 2],面积分别为16*16 *[8, 16, 32],共9个anchor
input : feature map size、feat_stride、anchor_base
output : anchor
# height, width为feature map 尺寸
# feat_stride为image 和 feature map 尺寸比
# 将anchor_base经过平移变换和等比变换,得到对应于image的anchors
# 使用的函数(path:./model/region_proposal_network.py)
def _enumerate_shifted_anchor(anchor_base, feat_stride, height, width)
下面3、4主体都在ProposalCreator类(path:./model/utils/creator_tool.py)的def __call__
中实现
- 未经过滤的roi
input : rpn_locs、anchor
output : 未经过滤的roi
通过rpn_locs(dy,dx,dh,dw)对anchor做坐标变换,得到未经过滤的roi。变换公式如下
#x,y,w,h为anchor的中心点坐标(x,y)及宽w、高h
#x',y',w',h'为未经过滤的roi的中心点坐标(x',y')及宽w'、高h'
dx = (x' - x) / w
dy = (y' - y) / h
dw = ln(w' / w) #使用ln是为了是w'/w为一个正数
dh = ln(h' / h)
#使用的函数(path:./model/utils/bbox_tool.py)
#src_bbox : anchor
#loc : rpn_locs
def loc2bbox(src_bbox, loc)
- roi
input : 未经过滤的roi、rpn_scores
output : roi
#1、尺寸过滤(只保留尺寸大于阈值的roi)
min_size = self.min_size * scale
hs = roi[:, 2] - roi[:, 0]
ws = roi[:, 3] - roi[:, 1]
keep = np.where((hs >= min_size) & (ws >= min_size))[0]
roi = roi[keep, :]
score = score[keep]#2、前期前景得分过滤(只保留得分最高的n_pre_nms 个roi)
order = score.ravel().argsort()[::-1]
if n_pre_nms > 0:order = order[:n_pre_nms]
roi = roi[order, :]#3、非极大值抑制
keep = non_maximum_suppression(cp.ascontiguousarray(cp.asarray(roi)),thresh=self.nms_thresh)#4、后期前景得分过滤(只保留得分最高的n_post_nms 个roi)
if n_post_nms > 0:keep = keep[:n_post_nms]
roi = roi[keep]
至此前传基本就结束了,之后就是将roi送入roi pooling…
二、backward
主文件./trainer.py
- gt_rpn_label
input : bbox、anchor
output : gt_rpn_label
# 主要就是计算bbox个各个anchor的iou值
# 使用的函数(path:./model/utils/creator_tool.py)
# bbox :目标真实位置
def _create_label(self, inside_index, anchor, bbox)...# iou小于阈值为背景label[max_ious < self.neg_iou_thresh] = 0...# iou大于阈值为前景label[max_ious >= self.pos_iou_thresh] = 1...# 前景个数大于n_pos,就将n_pos后的忽略,即label=-1if len(pos_index) > n_pos:disable_index = np.random.choice(pos_index, size=(len(pos_index) - n_pos), replace=False)label[disable_index] = -1...# 背景个数大于n_neg,就将n_neg后的忽略,即label=-1if len(neg_index) > n_neg:disable_index = np.random.choice(neg_index, size=(len(neg_index) - n_neg), replace=False)label[disable_index] = -1
gt_rpn_loc
input : bbox、anchor
output : gt_rpn_loc
# 使用的函数(path:./model/utils/bbox_tool.py)
# src_bbox : anchor
# dst_bbox: bbox
# 利用每个anchor分别与bbox计算(公式见一中的3),得到每个anchor的中心点坐标和宽高的尺度变换比值的GT
def bbox2loc(src_bbox, dst_bbox)
- rpn_cls_loss
input : gt_rpn_label、rpn_scores
output : rpn_cls_loss
# 交叉熵
rpn_cls_loss = F.cross_entropy(rpn_score, gt_rpn_label.cuda(), ignore_index=-1)
- rpn_loc_loss
input : gt_rpn_loc、rpn_locs、gt_rpn_label
output : rpn_loc_loss
#使用的函数(path:./trainer.py)
def _fast_rcnn_loc_loss(pred_loc, gt_loc, gt_label, sigma)
至此反向传播需要的两项loss就计算结束了,之后就是和roi pooling的两项loss 求和,再反向传播
faster rcnn中RPN网络源码分析(pytorch)相关推荐
- 【朝花夕拾】Android自定义View篇之(六)Android事件分发机制(中)从源码分析事件分发机制...
前言 转载请注明,转自[https://www.cnblogs.com/andy-songwei/p/11039252.html]谢谢! 在上一篇文章[[朝花夕拾]Android自定义View篇之(五 ...
- 【朝花夕拾】Android自定义View篇之(六)Android事件分发机制(中)从源码分析事件分发逻辑及经常遇到的一些“诡异”现象
前言 转载请注明,转自[https://www.cnblogs.com/andy-songwei/p/11039252.html]谢谢! 在上一篇文章[[朝花夕拾]Android自定义View篇之(五 ...
- zipline中benchmarks.py源码分析
zipline中benchmarks源码分析 1 benchmark 基准数据 2 get_benchmark_returns_from_file 从文件中获取基准数据 3 BenchmarkSpec ...
- Anchor和RPN的浅薄理解(三)-mmdetection中Anchor生成源码分析
在 MMDetection 中,RPN 网络使用 AnchorGenerator 类生成 Anchor,在config文件中 AnchorGenerator 的默认设置如下: anchor_gener ...
- faster rcnn中rpn的anchor,sliding windows,proposals的理解
一直对faster rcnn里的rpn以及下图中的上面的那部分的区别不太理解,今天看到了知乎里面的回答,感觉有点明白了,特此记录 作者:马塔 链接:https://www.zhihu.com/ques ...
- 力引导算法深入理解及其在d3.js中实现的源码分析
中学时最喜欢的学科是物理,大学误打误撞读了计算机.最近在做图计算的相关工作,图的可视化中有一个非常重要的算法:"力引导算法",这个算法的原理居然就是最简单的粒子间的作用力,真是没想 ...
- 深度学习目标检测系列:faster RCNN实现|附python源码
目标检测一直是计算机视觉中比较热门的研究领域,有一些常用且成熟的算法得到业内公认水平,比如RCNN系列算法.SSD以及YOLO等.如果你是从事这一行业的话,你会使用哪种算法进行目标检测任务呢?在我寻求 ...
- SF中DispSync.cpp源码分析
源码位置位于: /frameworks/native/services/surfaceflinger/DispSync.cpp 先来看下构造方法: 关键是初始化了DispSyncThread线程变量, ...
- Java 8 中 GZIPInputStream 类源码分析
这是<水煮 JDK 源码>系列 的第4篇文章,计划撰写100篇关于JDK源码相关的文章 GZIPInputStream 类位于 java.util.zip 包下,继承于 InflaterI ...
最新文章
- 阿里云移动测试平台MQC移动测试沙龙第3期【北京站】
- python从零开始系列连载_技术 | Python从零开始系列连载(一)
- java 死循环排查_java应用死循环排查方法或查找程序消耗资源的线程方法(面试)...
- pgd 游戏教程 基地
- go 变量大写_28. 一文了解Go语言中编码规范
- 设计模式之二-Proxy模式
- js获取baseurl
- 数据结构C语言严蔚敏版(第二版)超详细笔记附带课后习题
- 游戏王抽卡模拟器(概率计算器)
- 《麦肯锡方法》第9章 头脑风暴-思维导图
- [Ubuntu 18.04][CPU]MindSpore V1.0源码安装初体验(直播结束)
- 计算机与现代社会英语作文,高一英语作文,科技以下是题目:众所周知,科技在现代社会和生活中扮演着越来越重要的角色,但科技同时也是一把双刃剑,在它璀璨...
- Null check operator used on a null value
- MindManager2022安装使用教程
- 佳片有约|《第六感生死缘》:生如夏花,死若秋叶的爱恋
- 【智能制造】周宏仁:智能制造的三个支点;全球制造业新趋势
- ssh免密登录服务器
- 单片机-结构体函数指针高级使用方法
- 未来计算图鉴:十年后的计算长什么模样?
- 全国企业信息网站地址
热门文章
- 机器学习之泰坦尼克号实战
- 李宏毅机器学习(23)
- 随手记——Linux中C语言调用shell指令的三种方式
- Word常用快捷指令
- 用Java编辑员工信息_编写一个函数来显示基于Oracle中特定部门的员工信息?
- 服务器执行到这里就停住不动了Initializing Spring root WebApplicationContext
- unity3D的面试题
- 【密码算法 之六】CCM 浅析
- 《Web前端设计与开发-HTML+CSS+JavaScript+HTML 5+jQuery》-漫步时尚广场代码-1
- 小程序onPageScroll上滑显示,下滑隐藏