前言:

  在目标检测的深度网络中最后一个步骤就是RoI层,其中RoI Pooling会实现将RPN提取的各种形状的边框进行池化,从而形成统一尺度的特征层,这一工程中将涉及到ROIAlign操作。Pool中的Scale是一个数组,代表原始图片变换到FPN的各个特征层需要的变换比例,比如到Stage2是1/4, 以此类推。其代码详解为:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
from torch import nnfrom maskrcnn_benchmark.layers import ROIAlignfrom .utils import catclass LevelMapper(object):"""Determine which FPN level each RoI in a set of RoIs should map to basedon the heuristic in the FPN paper.""""""LevelMapper函数的作用是获得某个特征区域将会从网络的那一层特征上进行提取,面积越大的目标区往往会在高层进行提取,小目标则在低层卷基层上进行特征提取。本函数的主要目标就是确定某个目标最好从那一层上进行提取。实现FPN论文里的公式"""def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):"""Arguments:k_min (int)k_max (int)canonical_scale (int)canonical_level (int)eps (float)"""# k_min是进行FPN的最低层网络在第几层,一般为2,表示FPN从第2层开始self.k_min = k_min# k_max是进行FPN的最高层网络在第几层,一般为5,表示FPN到第5层结束self.k_max = k_max# s0表示原始图像的边长为多大,以便确定目标是相对大还是小。这是参考imagenet预训练模型中的图片都是边长为224.如有必要,参数要调节self.s0 = canonical_scale# FPN层数self.lvl0 = canonical_level# 防止目标区域过小self.eps = epsdef __call__(self, boxlists):"""Arguments:boxlists (list[BoxList])"""# Compute level ids# 计算目标区域边长s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists]))# Eqn.(1) in FPN paper# 计算FPN论文里的公式1target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps))# 吧target_lvls缩小到正确的范围target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)return target_lvls.to(torch.int64) - self.k_minclass Pooler(nn.Module):"""Pooler for Detection with or without FPN.It currently hard-code ROIAlign in the implementation,but that can be made more generic later on.Also, the requirement of passing the scales is not strictly necessary, as theycan be inferred from the size of the feature map / size of original image,which is available thanks to the BoxList."""def __init__(self, output_size, scales, sampling_ratio):"""Arguments:output_size (list[tuple[int]] or list[int]): output size for the pooled region输出特征的大小scales (list[float]): scales for each Pooler # 获得参与FPN的最低层sampling_ratio (int): sampling ratio for ROIAlign 每个bin内高和宽方向的采样率,论文中默认的是2.即每个bin采样2*2=4"""super(Pooler, self).__init__()# 按照不同的尺度构造池化层poolers = []for scale in scales:poolers.append(ROIAlign(output_size, spatial_scale=scale, sampling_ratio=sampling_ratio))self.poolers = nn.ModuleList(poolers)self.output_size = output_size# get the levels in the feature map by leveraging the fact that the network always# downsamples by a factor of 2 at each level.# 获得参与FPN的最低层lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()# 获得参与FPN的最高层lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()self.map_levels = LevelMapper(lvl_min, lvl_max)# 转换成roi的格式def convert_to_roi_format(self, boxes):concat_boxes = cat([b.bbox for b in boxes], dim=0)device, dtype = concat_boxes.device, concat_boxes.dtypeids = cat([torch.full((len(b), 1), i, dtype=dtype, device=device)for i, b in enumerate(boxes)],dim=0,)rois = torch.cat([ids, concat_boxes], dim=1)return roisdef forward(self, x, boxes):"""Arguments:x (list[Tensor]): feature maps for each levelboxes (list[BoxList]): boxes to be used to perform the pooling operation.Returns:result (Tensor)"""# 得到提取特征的层的个数num_levels = len(self.poolers)rois = self.convert_to_roi_format(boxes)if num_levels == 1:return self.poolers[0](x[0], rois)# 得到目标特征应该映射到的最有的层levels = self.map_levels(boxes)# 获得roi个数num_rois = len(rois)# 获得通道数num_channels = x[0].shape[1]# 获得输出大小output_size = self.output_size[0]# 获得特征的数据类型和它所在的设备dtype, device = x[0].dtype, x[0].device# 初始化返回数据result = torch.zeros((num_rois, num_channels, output_size, output_size),dtype=dtype,device=device,)for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):# 获得所有应该从同一特征层提取特征的roiidx_in_level = torch.nonzero(levels == level).squeeze(1)# 或者这些roi的编号rois_per_level = rois[idx_in_level]# 将大小相似的这些目标特征送入到特定同一个特征层进行池化,得到相应的结果result[idx_in_level] = pooler(per_level_feature, rois_per_level).to(dtype)return resultdef make_pooler(cfg, head_name):# 获得输出特征图的大小resolution = cfg.MODEL[head_name].POOLER_RESOLUTION# 获得参与FPN的最低层scales = cfg.MODEL[head_name].POOLER_SCALES# 每个bin内高和宽方向的采样率,论文中默认的是2.即每个bin采样2 * 2 = 4sampling_ratio = cfg.MODEL[head_name].POOLER_SAMPLING_RATIO# 获得池化层pooler = Pooler(output_size=(resolution, resolution),scales=scales,sampling_ratio=sampling_ratio,)return pooler

maskrcnn_benchmark 代码详解之 poolers.py相关推荐

  1. maskrcnn_benchmark 代码详解之 roi_box_predictors.py

    前言: 在对RPN预测到的边框进行进一步特征提取后,需要对边框进行预测,得到边框的类别和位置大小信息.这一操作在maskrcnn_benchmark中由roi_box_predictors.py完成, ...

  2. maskrcnn_benchmark 代码详解之 roi_box_feature_extractors.py

    前言: 在经过RPN层之后,网络会生成多个预测边框(proposal), 这时候需要对这些边框进行RoI池化,使之成为尺度一致的特征.接下来就需要对这些特征进行进一步的特征提取,这就需要用到roi_b ...

  3. maskrcnn_benchmark 代码详解(更新中...)

    前言: maskrcnn_benchmark是faceboock公司编写的一套用于目标检索的框架,该框架集成了目前用到的大部分使用深度卷积网络来进行目标检测的模型,其中包括Fast RCNN, Fas ...

  4. maskrcnn-benchmar 代码详解之 fpn.py

    前言 FPN网络主要应用于多层特征提取,使用多尺度的特征层来进行目标检测,可以利用不同的特征层对于不同大小特征的敏感度不同,将他们充分利用起来,以更有利于目标检测,在maskrcnn benchmar ...

  5. maskrcnn_benchmark 代码详解之 boxlist_ops.py

    前言: 与Bounding Box有关的操作有很多,例如对边框列表进行非极大线性抑制.去除过小的边框.计算边框之间的Iou以及对两个边框列表进行合并等操作.在maskrcnn_benchmark中,这 ...

  6. yolov3代码详解(七)

    Pytorch | yolov3代码详解七 test.py test.py from __future__ import divisionfrom models import * from utils ...

  7. yolov5的detect.py代码详解

    目标检测系列之yolov5的detect.py代码详解 前言 哈喽呀!今天又是小白挑战读代码啊!所写的是目标检测系列之yolov5的detect.py代码详解.yolov5代码对应的是官网v6.1版本 ...

  8. yolov5-5.0版本代码详解----augmentations.py的augment_hsv函数

    yolov5-5.0版本代码详解----augmentations.py的augment_hsv函数 1.用途 图片的hsv色域增强模块 2.调用位置 在datasets.py的LoadImagesA ...

  9. 【Image captioning】Show, Attend, and Tell 从零到掌握之三--train.py代码详解

    [Image captioning]Show, Attend, and Tell 从零到掌握之三–train.py代码详解 作者:安静到无声 个人主页 作者简介:人工智能和硬件设计博士生.CSDN与阿 ...

最新文章

  1. JMC | 药物发现中的迁移学习
  2. Mac OS X 创新卡关三年,唯一看得出版本不同之处是「预设桌布」
  3. 二 关于s5p4418 无线wifi模块出现SDIO读写错误的解决方法
  4. SpringBoot高级-消息-@RabbitListener@EnableRabbit
  5. 顺序表查找+折半查找(二级)
  6. 1059 C语言竞赛 (ID映射编号映射字符串)
  7. python窗口动态实时显示时间_量化交易实时动态监视系统(纯Python,只需要浏览器就能用)-直接GitHub开源可下载...
  8. 序列化和反序列化(JSON、protobuf)
  9. HDU 2674 N!Again
  10. 国际h2真假u盘测试软件,u盘速度测试软件H2TEST
  11. uniapp文件体积超过 500KB报错
  12. 数据库表同义词mysql修改_SQLServer中同义词Synonym的用法
  13. 地月距离竟然如此遥远
  14. RT-Thread ENV工具 pkgs --upgrade 报错:open .config failed
  15. 超炫的3D特效相册功能android
  16. win7 查看网络计算机和设备,WIN7 网络发现已关闭 网络计算机和设备不可见
  17. web性能优化 JS/CSS CDN加速公共库
  18. 【深度学习-机器学习】分类度量指标 : 正确率、召回率、灵敏度、特异度,ROC曲线、AUC等
  19. 网络编程基础 --> 网络通信机理、报文与协议、套接字通信预备
  20. 小程序助力博物馆餐厅,用“艾”打造品牌

热门文章

  1. python 实现华安信达论坛自动登录
  2. 四级作文万能套用模板
  3. 想要删除视频中不需要的片头片尾怎么操作
  4. python模拟声音输出_声音的输入输出
  5. unity5.6.6版本_Unity 5.3的新版本,Linux的新游戏以及更多开放式游戏新闻
  6. mac securecrt程序无响应_IT人员必备工具SecureCRT介绍及一些实用小技巧
  7. 数据结构-第二期——链表(Python)
  8. Qt软键盘-发送按键事件
  9. java实现方差分析(ANOVA)
  10. 人生苦短,我用Manjaro || 愿你Manjaro半天,归来仍是Deepin