maskrcnn_benchmark 代码详解之 roi_box_predictors.py
前言:
在对RPN预测到的边框进行进一步特征提取后,需要对边框进行预测,得到边框的类别和位置大小信息。这一操作在maskrcnn_benchmark中由roi_box_predictors.py完成,该文件实现了两种预测类:直接进行预测以及先池化再预测。其代码详解如下:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from maskrcnn_benchmark.modeling import registry
from torch import nn# todo 现将预测边框的特征进行池化,再使用边框预测结构和边框回归结构来预测边框的类别以及边框的坐标偏差值
@registry.ROI_BOX_PREDICTOR.register("FastRCNNPredictor")
class FastRCNNPredictor(nn.Module):def __init__(self, config, in_channels):super(FastRCNNPredictor, self).__init__()# 当输入层的通道为空时报错assert in_channels is not None# 输入层的通道数num_inputs = in_channels# 得到基准边框的类别数,一般都要加上一类为背景num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES# 对输入层特征先进行池化self.avgpool = nn.AdaptiveAvgPool2d(1)# 创建用于预测边框类别的网络结构:线性链接层,类别数×输入层通道数self.cls_score = nn.Linear(num_inputs, num_classes)# 当模式为方式为CLS_AGNOSTIC_BBOX_REG时,只回归2类bounding box,即前景和背景,否则为实际类别数num_bbox_reg_classes = 2 if config.MODEL.CLS_AGNOSTIC_BBOX_REG else num_classes# 创建用于预测边框回归的网络结构self.bbox_pred = nn.Linear(num_inputs, num_bbox_reg_classes * 4)# 初始化各种参数nn.init.normal_(self.cls_score.weight, mean=0, std=0.01)nn.init.constant_(self.cls_score.bias, 0)nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.001)nn.init.constant_(self.bbox_pred.bias, 0)def forward(self, x):x = self.avgpool(x)x = x.view(x.size(0), -1)cls_logit = self.cls_score(x)bbox_pred = self.bbox_pred(x)return cls_logit, bbox_pred# todo 直接使用边框预测结构和边框回归结构来预测边框的类别以及边框的坐标偏差值
@registry.ROI_BOX_PREDICTOR.register("FPNPredictor")
class FPNPredictor(nn.Module):def __init__(self, cfg, in_channels):super(FPNPredictor, self).__init__()# 得到基准边框的类别数,一般都要加上一类为背景num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES# 输入层的通道数representation_size = in_channels# 创建用于预测边框类别的网络结构:线性链接层,类别数×输入层通道数self.cls_score = nn.Linear(representation_size, num_classes)# 当模式为方式为CLS_AGNOSTIC_BBOX_REG时,只回归2类bounding box,即前景和背景,否则为实际类别数num_bbox_reg_classes = 2 if cfg.MODEL.CLS_AGNOSTIC_BBOX_REG else num_classes# 创建用于预测边框回归的网络结构self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)# 初始化各种参数nn.init.normal_(self.cls_score.weight, std=0.01)nn.init.normal_(self.bbox_pred.weight, std=0.001)for l in [self.cls_score, self.bbox_pred]:nn.init.constant_(l.bias, 0)def forward(self, x):if x.ndimension() == 4:assert list(x.shape[2:]) == [1, 1]x = x.view(x.size(0), -1)scores = self.cls_score(x)bbox_deltas = self.bbox_pred(x)return scores, bbox_deltas# todo 实例化边框预测的类
def make_roi_box_predictor(cfg, in_channels):# 使用参数中指定的边框预测类来对边框进行预测func = registry.ROI_BOX_PREDICTOR[cfg.MODEL.ROI_BOX_HEAD.PREDICTOR]return func(cfg, in_channels)
maskrcnn_benchmark 代码详解之 roi_box_predictors.py相关推荐
- maskrcnn_benchmark 代码详解之 roi_box_feature_extractors.py
前言: 在经过RPN层之后,网络会生成多个预测边框(proposal), 这时候需要对这些边框进行RoI池化,使之成为尺度一致的特征.接下来就需要对这些特征进行进一步的特征提取,这就需要用到roi_b ...
- maskrcnn_benchmark 代码详解之 poolers.py
前言: 在目标检测的深度网络中最后一个步骤就是RoI层,其中RoI Pooling会实现将RPN提取的各种形状的边框进行池化,从而形成统一尺度的特征层,这一工程中将涉及到ROIAlign操作.Pool ...
- maskrcnn_benchmark 代码详解(更新中...)
前言: maskrcnn_benchmark是faceboock公司编写的一套用于目标检索的框架,该框架集成了目前用到的大部分使用深度卷积网络来进行目标检测的模型,其中包括Fast RCNN, Fas ...
- maskrcnn-benchmar 代码详解之 fpn.py
前言 FPN网络主要应用于多层特征提取,使用多尺度的特征层来进行目标检测,可以利用不同的特征层对于不同大小特征的敏感度不同,将他们充分利用起来,以更有利于目标检测,在maskrcnn benchmar ...
- maskrcnn_benchmark 代码详解之 boxlist_ops.py
前言: 与Bounding Box有关的操作有很多,例如对边框列表进行非极大线性抑制.去除过小的边框.计算边框之间的Iou以及对两个边框列表进行合并等操作.在maskrcnn_benchmark中,这 ...
- yolov3代码详解(七)
Pytorch | yolov3代码详解七 test.py test.py from __future__ import divisionfrom models import * from utils ...
- yolov5的detect.py代码详解
目标检测系列之yolov5的detect.py代码详解 前言 哈喽呀!今天又是小白挑战读代码啊!所写的是目标检测系列之yolov5的detect.py代码详解.yolov5代码对应的是官网v6.1版本 ...
- yolov5-5.0版本代码详解----augmentations.py的augment_hsv函数
yolov5-5.0版本代码详解----augmentations.py的augment_hsv函数 1.用途 图片的hsv色域增强模块 2.调用位置 在datasets.py的LoadImagesA ...
- 【Image captioning】Show, Attend, and Tell 从零到掌握之三--train.py代码详解
[Image captioning]Show, Attend, and Tell 从零到掌握之三–train.py代码详解 作者:安静到无声 个人主页 作者简介:人工智能和硬件设计博士生.CSDN与阿 ...
最新文章
- 用 Flask 来写个轻博客 (29) — 使用 Flask-Admin 实现后台管理 SQLAlchemy
- 自定义组件 点击空白处隐藏
- Python notes
- 关于node.js和C交互的方法
- Django之models
- 80% 的 Android 应用正使用加密流量!
- 初级第四旬06— 回向与发愿试题
- 2018-04-08椭圆曲线测试程序
- Python 解决面试题47 不用加减乘除做加法
- swing简单的打字游戏源码
- tomcat中的日志配置
- excel导入,用反射匹配字段名
- python 字符串(二)
- linux安装razer鼠标驱动
- 猫哥教你写爬虫 008--input()函数
- button loading indicators
- voipdiscount免费拨打全球电话(无需手机注册)
- 叠加等边三角形的绘制 python_叠_叠是什么意思_叠字怎么读_叠的含义_叠字组词-新东方在线字典...
- php notice undefined variable,PHP错误提示,Notice: Undefined variable
- PyCharm 社区版(Community)能不能商用?