文章目录

  • RoIAlign 的用处
  • RoIAlign 计算原理
    • 双线性插值(Bilinear Interpolation)
  • pytorch中的实现
    • 简单示例
    • 在FasterRCNN中的使用示例
  • 参考链接

RoIAlign 的用处

RoIAlign 用于将任意尺寸感兴趣区域的特征图,都转换为具有固定尺寸 H×W 的小特征图。

与RoI pooling一样,其基本原理是将 h×wh×wh×w 的特征划分为 H×WH×WH×W 网格,每个格子是大小近似为 h/H×w/Wh/H×w/Wh/H×w/W 的子窗口 ,然后将每个子窗口中的值最大池化到相应的输出网格单元中。想复习RoI pooling概念的可以看这篇。

RoIAlign 其实就是更精确版本的 RoIPooling,用双线性插值取代了RoIPooling中的直接取整的操作。

下面用一个具体图例看下 RoIAlign 计算原理。

RoIAlign 计算原理

输入一个feature map,对于每个不同尺寸的proposed region,需要转换成固定大小 H×WH×WH×W的 feature map,H和W是这一层的超参数。

黑色粗框部分是一个 7×57×57×5 大小的 proposed region,首先切分成 H×WH×WH×W 个sections(这里以2x2为例):

对每个section采样四个区域,用红色×表示其中心位置:

每个section中四个红色×的值,由双线性插值计算:

对每个 section 中四个值进行 max pooling,输出结果:

就是我们所需要的固定大小输出了。

这个固定大小输出可以通过全连接的层,用于边界框回归和分类,常用于检测和分割模型中。

双线性插值(Bilinear Interpolation)

借用下图从视觉上来理解双线性插值,黑点上的双线期插值是附近四个点的加权和,权值是四个点对应的颜色矩形在总面积中的占比。比如左上角黄点 (x1,y2)(x_1,y_2)(x1​,y2​) 对应的是右下较大的黄色矩阵面积。

pytorch中的实现

RoIAlign在pytorch中的实现是torchvision.ops.RoIAlign,torchvision.ops中实现的是计算机视觉中特定的operators。

class: torchvision.ops.RoIAlign(output_size, spatial_scale, sampling_ratio)

  • output_size (int or Tuple[int, int]) – 输出大小,用 (height, width) 表示。
  • spatial_scale (float) – 将输入坐标映射到框坐标的比例因子。默认值1.0。
  • sampling_ratio (int) – 插值网格中用于计算每个合并输出bin的输出值的采样点数目。如果> 0,则恰好使用sampling_ratio x sampling_ratio网格点。如果<= 0,则使用自适应数量的网格点(计算为cell (roi_width / pooled_w),同样计算高度)。默认值1。

torchvision.ops.roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1)

  • input (Tensor[N, C, H, W]) – 输入张量
  • boxes (Tensor[K, 5] or List[Tensor[L, 4]]) – 区域包围框以 (x1,y1,x2,y2)(x1, y1, x2, y2)(x1,y1,x2,y2) 形式表示。如果输入的是单个tensor,第一列表示batch index;如果输入是一个tensor List,每个tensor对应batch中的第iii个元素的方框。

简单示例

import torch
import torchvision# 创建RoIAlign层
pooler = torchvision.ops.RoIAlign(output_size=2,sampling_ratio=2,spatial_scale=5)# 输入一个 8x8 的feature:
inputTensor = torch.rand(1,1,8,8)

inputTensor类似如下:

再创建一个box:

box =  torch.tensor([[0.0,0.375,0.875,0.625]]) output = pooler(inputTensor,[box])#shape:[1, 1, 2, 2]

输出结果:

在FasterRCNN中的使用示例

import torchvision
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# FasterRCNN needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 1280# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),aspect_ratios=((0.5, 1.0, 2.0),))# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# OrderedDict[Tensor], and in featmap_names you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size=7,sampling_ratio=2)# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone,num_classes=2,rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler)

参考链接

https://zhuanlan.zhihu.com/p/59692298
https://zhuanlan.zhihu.com/p/73138740
https://pytorch.org/docs/1.2.0/torchvision/ops.html
https://pytorch.org/docs/1.2.0/_modules/torchvision/ops/roi_align.html

图解 RoIAlign 以及在 PyTorch 中的使用(含代码示例)相关推荐

  1. java事件绑定,Java编程GUI中的事件绑定代码示例

    程序绑定的概念: 绑定指的是一个方法的调用与方法所在的类(方法主体)关联起来.对java来说,绑定分为静态绑定和动态绑定:或者叫做前期绑定和后期绑定 静态绑定: 在程序执行前方法已经被绑定,此时由编译 ...

  2. 使用PyTorch进行知识蒸馏的代码示例

    随着机器学习模型的复杂性和能力不断增加.提高大型复杂模型在小数据集性能的一种有效技术是知识蒸馏,它包括训练一个更小.更有效的模型来模仿一个更大的"教师"模型的行为. 在本文中,我们 ...

  3. 【Groovy】闭包 Closure ( 闭包调用 与 call 方法关联 | 接口中定义 call() 方法 | 类中定义 call() 方法 | 代码示例 )

    文章目录 总结 一.接口中定义 call() 方法 二.类中定义 call() 方法 三.完整代码示例 总结 在 实例对象后使用 " () " 括号符号 , 表示调用该实例对象的 ...

  4. 【Groovy】集合遍历 ( 操作符重载 | 集合中的 “ + “ 操作符重载 | 集合中的 “ - “ 操作符重载 | 代码示例 )

    文章目录 一.集合中的 " + " 操作符重载 二.集合中的 " - " 操作符重载 三.完整代码示例 一.集合中的 " + " 操作符重载 ...

  5. 【Android NDK 开发】Kotlin 语言中使用 NDK ( 创建支持 Kotlin 的 NDK 项目 | Kotlin 语言中使用 NDK 要点 | 代码示例 )

    文章目录 一.创建支持 Kotlin 的 NDK 项目 二.Kotlin 语言中使用 NDK 要点 1.加载动态库 2.声明 ndk 方法 3.Project 下的 build.gradle 配置 4 ...

  6. php 使用dataview,echarts如何优化数据视图dataView中的样式(代码示例)

    本篇文章给大家带来的内容是关于echarts如何优化数据视图dataView中的样式(代码示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助. 在使用echart过程中,toolbo ...

  7. 如何使用功能性JavaScript编写经典游戏Snake并在浏览器中播放-完整的代码示例教程

    Remember the game Snake that came pre-installed on every Nokia phone back in the 1990s? You steered ...

  8. Java客户端操作zookeeper:获取及修改节点中的数据内容代码示例

  9. python将list转换为迭代器代码_python中的迭代器附带代码示例

    迭代的概念 迭代就是执行重复的特定的任务,知道任务完成为止 相当于我们盖房子,今天添一块砖,明天加一块瓦,直到房子盖完为止.这里每天的工作就是一次迭代 (1.)可迭代对象 a.可以直接作用于for-i ...

最新文章

  1. MySQL存储引擎的介绍
  2. Linux下python升级安装步骤
  3. 献给开发者的大礼--打造CSDN论坛专用阅读器(电脑报2006年11月6日 第44期)
  4. mysql数据库维护_维护MySQL数据库表
  5. 15M安装包就能玩《原神》,带你了解云游戏背后的技术秘密
  6. 函数式编程的Java编码实践:利用惰性写出高性能且抽象的代码
  7. linux重定向输出命令
  8. 实战篇:教你建设企业销售分析系统
  9. python在哪里写代码-程序员面试被要求手写代码,你与顶级程序员的差别在哪?...
  10. drools规则引擎 java_Drools规则引擎的使用总结
  11. Dual Thrust 策略
  12. 使用python将视频按照帧转为图片
  13. U盘提示''这张磁盘有写保护''修复工具
  14. shell自动部署k8s集群:新增加的work node节点加入k8s集群
  15. 北京的互联网公司有哪些?
  16. android 小米手机播放短小音频无声音问题
  17. Excel中如何快速输入☑和☒
  18. 论文“Structure-from-Motion Revisited” 对ISFM改进的理解
  19. 泰勒级数定义及相关展开式
  20. Mac下安装whistle

热门文章

  1. web.py开发web 第一章 Hello World
  2. 解决python中TypeError: not enough arguments for format stringj
  3. CentOS 5.8 Zimbra邮件系统安装与配置
  4. 双向链表(不带头结点)
  5. JavaScript 中遍历对象的属性 1
  6. Nacos Spring 快速开始
  7. Java 11新特性
  8. C#委托和事件实现观察者模式
  9. C#LeetCode刷题-线段树
  10. C#LeetCode刷题之#896-单调数列(Monotonic Array)