前言:

  在经过RPN层之后,网络会生成多个预测边框(proposal), 这时候需要对这些边框进行RoI池化,使之成为尺度一致的特征。接下来就需要对这些特征进行进一步的特征提取,这就需要用到roi_box_feature_extractors.py。roi_box_feature_extractors.py定义了三种不同的特种提取方式:ResNet卷基层方式、MIL全连接方式以及使用多个卷基层组+全连接方式。其代码详解为:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn
from torch.nn import functional as Ffrom maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.backbone import resnet
from maskrcnn_benchmark.modeling.poolers import Pooler
from maskrcnn_benchmark.modeling.make_layers import group_norm
from maskrcnn_benchmark.modeling.make_layers import make_fc# 使用ResNet50的Conv5层来提取roi特征
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor")
class ResNet50Conv5ROIFeatureExtractor(nn.Module):def __init__(self, config, in_channels):super(ResNet50Conv5ROIFeatureExtractor, self).__init__()# resolution为roi pooling之后特征图的大小,一般为7resolution = config.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION# 获得原始图到特征图的比例函数,比如原始图到Res50的stage2是1/4scales = config.MODEL.ROI_BOX_HEAD.POOLER_SCALES# sampling_ratio即采样率,指的是锚点大小与池化之后特征图的大小比例。一般情况下不指定sampling_ratio = config.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO# 初始化池化类,内含ROIAlign函数pooler = Pooler(output_size=(resolution, resolution),scales=scales,sampling_ratio=sampling_ratio,)# 获得Stage5的模板,用以构造Resnet的Stage5stage = resnet.StageSpec(index=4, block_count=3, return_features=False)# 构造Resnet的Stage5层网络结构head = resnet.ResNetHead(block_module=config.MODEL.RESNETS.TRANS_FUNC,stages=(stage,),num_groups=config.MODEL.RESNETS.NUM_GROUPS,width_per_group=config.MODEL.RESNETS.WIDTH_PER_GROUP,stride_in_1x1=config.MODEL.RESNETS.STRIDE_IN_1X1,stride_init=None,res2_out_channels=config.MODEL.RESNETS.RES2_OUT_CHANNELS,dilation=config.MODEL.RESNETS.RES5_DILATION)# 将参数复制给私有变量self.pooler = poolerself.head = headself.out_channels = head.out_channelsdef forward(self, x, proposals):x = self.pooler(x, proposals)x = self.head(x)return x# todo 采用MLP的全连接网络结构来提取ROI特征
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPN2MLPFeatureExtractor")
class FPN2MLPFeatureExtractor(nn.Module):"""Heads for FPN for classification用于分类的FPN层模型"""def __init__(self, cfg, in_channels):super(FPN2MLPFeatureExtractor, self).__init__()# resolution为roi pooling之后特征图的大小,一般为7resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION# 获得原始图到特征图的比例函数,比如原始图到Res50的stage2是1/4scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES# sampling_ratio即采样率,指的是锚点大小与池化之后特征图的大小比例。一般情况下不指定sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO# 初始化池化类,内含ROIAlign函数pooler = Pooler(output_size=(resolution, resolution),scales=scales,sampling_ratio=sampling_ratio,)# 输入层大小为把每一个元素拉成一个向量,为全连接层input_size = in_channels * resolution ** 2# MLP的全连接输出层通道数representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM# 标明是否使用GNuse_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN# 将参数复制给私有变量self.pooler = poolerself.fc6 = make_fc(input_size, representation_size, use_gn)self.fc7 = make_fc(representation_size, representation_size, use_gn)self.out_channels = representation_sizedef forward(self, x, proposals):x = self.pooler(x, proposals)x = x.view(x.size(0), -1)x = F.relu(self.fc6(x))x = F.relu(self.fc7(x))return x# 由多个堆叠的卷基层来对RoI Pooling后的特征进行进一步特征加工
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPNXconv1fcFeatureExtractor")
class FPNXconv1fcFeatureExtractor(nn.Module):"""Heads for FPN for classification"""def __init__(self, cfg, in_channels):super(FPNXconv1fcFeatureExtractor, self).__init__()# resolution为roi pooling之后特征图的大小,一般为7resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION# 获得原始图到特征图的比例函数,比如原始图到Res50的stage2是1/4scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES# sampling_ratio即采样率,指的是锚点大小与池化之后特征图的大小比例。一般情况下不指定sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO# 初始化池化类,内含ROIAlign函数pooler = Pooler(output_size=(resolution, resolution),scales=scales,sampling_ratio=sampling_ratio,)self.pooler = pooler# 标明是否使用GNuse_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN# 申明本卷基层的输出通道数conv_head_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM# 指定卷积层的个数num_stacked_convs = cfg.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS# 是否采用空洞卷积dilation = cfg.MODEL.ROI_BOX_HEAD.DILATION# 初始化多个卷积层的模型xconvs = []# 循环添加多个卷基层for ix in range(num_stacked_convs):xconvs.append(nn.Conv2d(in_channels,conv_head_dim,kernel_size=3,stride=1,padding=dilation,dilation=dilation,bias=False if use_gn else True))in_channels = conv_head_dimif use_gn:# 添加GNxconvs.append(group_norm(in_channels))# 每一个卷基层后添加一个激活层xconvs.append(nn.ReLU(inplace=True))# 将这个卷基层组加入到模型当中self.add_module("xconvs", nn.Sequential(*xconvs))# 初始化模型参数for modules in [self.xconvs,]:for l in modules.modules():if isinstance(l, nn.Conv2d):torch.nn.init.normal_(l.weight, std=0.01)if not use_gn:torch.nn.init.constant_(l.bias, 0)# 添加全连接层input_size = conv_head_dim * resolution ** 2representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIMself.fc6 = make_fc(input_size, representation_size, use_gn=False)self.out_channels = representation_sizedef forward(self, x, proposals):x = self.pooler(x, proposals)x = self.xconvs(x)x = x.view(x.size(0), -1)x = F.relu(self.fc6(x))return x# todo 实例化roi边框特征提取方式的类
def make_roi_box_feature_extractor(cfg, in_channels):# 用参数里指定的特征提取方式来实例化相应的类或者函数func = registry.ROI_BOX_FEATURE_EXTRACTORS[cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR]return func(cfg, in_channels)

maskrcnn_benchmark 代码详解之 roi_box_feature_extractors.py相关推荐

  1. maskrcnn_benchmark 代码详解之 roi_box_predictors.py

    前言: 在对RPN预测到的边框进行进一步特征提取后,需要对边框进行预测,得到边框的类别和位置大小信息.这一操作在maskrcnn_benchmark中由roi_box_predictors.py完成, ...

  2. maskrcnn_benchmark 代码详解之 poolers.py

    前言: 在目标检测的深度网络中最后一个步骤就是RoI层,其中RoI Pooling会实现将RPN提取的各种形状的边框进行池化,从而形成统一尺度的特征层,这一工程中将涉及到ROIAlign操作.Pool ...

  3. maskrcnn_benchmark 代码详解(更新中...)

    前言: maskrcnn_benchmark是faceboock公司编写的一套用于目标检索的框架,该框架集成了目前用到的大部分使用深度卷积网络来进行目标检测的模型,其中包括Fast RCNN, Fas ...

  4. maskrcnn-benchmar 代码详解之 fpn.py

    前言 FPN网络主要应用于多层特征提取,使用多尺度的特征层来进行目标检测,可以利用不同的特征层对于不同大小特征的敏感度不同,将他们充分利用起来,以更有利于目标检测,在maskrcnn benchmar ...

  5. maskrcnn_benchmark 代码详解之 boxlist_ops.py

    前言: 与Bounding Box有关的操作有很多,例如对边框列表进行非极大线性抑制.去除过小的边框.计算边框之间的Iou以及对两个边框列表进行合并等操作.在maskrcnn_benchmark中,这 ...

  6. yolov3代码详解(七)

    Pytorch | yolov3代码详解七 test.py test.py from __future__ import divisionfrom models import * from utils ...

  7. yolov5的detect.py代码详解

    目标检测系列之yolov5的detect.py代码详解 前言 哈喽呀!今天又是小白挑战读代码啊!所写的是目标检测系列之yolov5的detect.py代码详解.yolov5代码对应的是官网v6.1版本 ...

  8. yolov5-5.0版本代码详解----augmentations.py的augment_hsv函数

    yolov5-5.0版本代码详解----augmentations.py的augment_hsv函数 1.用途 图片的hsv色域增强模块 2.调用位置 在datasets.py的LoadImagesA ...

  9. 【Image captioning】Show, Attend, and Tell 从零到掌握之三--train.py代码详解

    [Image captioning]Show, Attend, and Tell 从零到掌握之三–train.py代码详解 作者:安静到无声 个人主页 作者简介:人工智能和硬件设计博士生.CSDN与阿 ...

最新文章

  1. excel最常用的八个函数_Excel中最常用的快捷键
  2. easyDarwin--开源流媒体实现
  3. 信息服务器怎么填写,如何设定服务器信息
  4. JAVA Drp项目实战—— Unable to compile class for JSP 一波三折
  5. iteritems()与items()
  6. java终结器_Java的终结器仍然存在
  7. batchnorm and relu_日本AND荷重传感器
  8. 怎么输出一个二维数组_LeetCode54与59,一个口诀教会你旋转二维数组
  9. mulitp request
  10. 你有遇到过最没良心的人吗?
  11. 第十二章UML与Rational Rose 软件
  12. 披着“云”衣裳的狗——搜狗输入法“云”版本尝鲜记
  13. Qt绘 —— QPixmap 的使用
  14. [论文笔记] Fusion++: VolumetricObject-LevelSLAM
  15. 特征工程和数据预处理常用工具和方法
  16. 今天你代言了吗?WPS版“陈欧体”引热议
  17. ARCGIS中如何实现点集之间的两两连线
  18. 关于服务器ftp服务器设置基本步骤及注意要点
  19. HBuilder开发词典app(一)--基本页面布局
  20. 查看电脑开机记录和时间

热门文章

  1. 【python 手机号码归属地】手机号码归属地获取
  2. python 使用listdir 遍历目录
  3. 第 17 节 字段、属性、索引器、常量
  4. 《网络是怎样连接的》一书读后感
  5. 树莓派下载,卸载软件
  6. go每日新闻(2022-06-14)——一文告诉你Go 1.19都有哪些新特性
  7. ArcGIS教程:ArcGIS符号库制作
  8. PS技术之如何写弧形文字
  9. 支付宝批量转账系统解析
  10. 青软集团联合桂林理工大学共建的大数据产业学院成功揭牌