文章目录

  • CNN存在的问题
  • Spatial Transformer
  • 方法
    • Localisation Network
    • Parameterised Sampling Grid
    • Differentiable Image Sampling
    • 图片解释
  • 例程
  • 附录
    • 几种常用的线性变换
    • 双线性插值

CNN存在的问题

CNN定义了非常强大的分类模型,但是仍然受到缺乏在计算和参数效率上对输入数据空间不变性能力的限制。即,当输入图像因随机平移、缩放、旋转、混乱而失真时,CNN模型的分类准确率将会下降。

Spatial Transformer

它是对CNN的改进, 增加了一个Spatial Transformer 模块, 可以对网络内的数据进行空间操作(spatial manipulation). 这个模块可以插入到现有的CNN模块中, 使得网络能够主动的空间变换feature maps, 通过训练确定特定输入对应的空间变换
使用空间变换器的结果是模型能够学习到了对平移、缩放、旋转和更多通用的warping的不变性,得到最先进的性能.

它在这几个方面可以受益:

  1. 图像分类
  2. co-localisation(共同定位?), 给定一个包含相同但未知的类的不同实例的图像, 它可以被用于localise, … 不太理解这个地方2333
  3. spatial attention: spatial transformer可以用于需要注意力机制的任务

方法

The spatial transformer被分成3个部分, 第一个是localisation network, 它把feature map作为输入, 通过一系列隐层, 输出一些应该被用于spatial transformation的参数,
在第二部分 grid generator中, 这些被预测的参数被用于创造sampling grid, 这是一组点, 输入的map应该被这些点采样成transformed output
最后feature map和 sampling grid 作为sampler的输入, 产生在grid points从输入采样的输出map

总结来说:
它完成的是一个将输入特征图进行一定的变换的过程,而具体如何变换,是通过在训练过程中学习来的,更通俗地将,该模块在训练阶段学习如何对输入数据进行变换更有益于模型的分类,然后在测试阶段应用已经训练好的网络对输入数据进行执行相应的变换,从而提高模型的识别率。

Localisation Network

U∈RH∗W∗CU\text∈ R^{H*W*C}U∈RH∗W∗C: 输入特征图
θ\thetaθ: 被用在feature map上的 transformation TθT_{\theta}Tθ​ 的参数, θ=floc(U)\theta \text= f_{loc}(U)θ=floc​(U), thetathetatheta的大小依赖于transformation的类型, 比如对于二维仿射变换是6维度
对于仿射变换的相关知识参照附录

localisation network function floc()f_{loc}()floc​() 可以是任何形式, fc 或者CNN都行, 但是最后应该有个regression layer来产生θ\thetaθ

Parameterised Sampling Grid

该层利用Localisation 层输出的变换参数 θ\thetaθ,将输入的特征图进行变换

例如输出特征图上某一位置(xit,yit)(x^t_i, y^t_i)(xit​,yit​)根据变换参数 θθθ映射到输入特征图上某一位置(xis,yis)(x^s_i,y^s_i)(xis​,yis​),具体如下:

这里使用高度和宽度的归一化坐标


Differentiable Image Sampling

为了对输入feature map进行变换, 采样器需使用采样点 Tθ(G)T_{\theta}(G)Tθ​(G) 的集合与输入特征图U一起来生成采样的输出特征图, 输出公式如下:

Φx,Φy\Phi_x, \Phi_yΦx​,Φy​ 是一个通用的采样内核k()的参数,它定义了图像的插值(例如,双线性, 整数)。

UnmCU_{nm}^CUnmC​is the value at location (n;m) in channel c of the input
VicV_i^cVic​ is the output value for pixel i at location (xit;yit)(x^t_i; y^t_i )(xit​;yit​) in channel c

请注意,每个输入通道的采样是相同的,因此每个通道都以相同的方式进行转换(这保留了通道之间的空间一致性)

文章指出, 任何可以定义梯度的采样器都可以使用,比如:

整数采样核

双线性sampling kernel

对应的导数为:

图片解释


例程

在pytorch框架中, F.affine_grid 与 F.grid_sample(torch.nn.functional as F)联合使用来对图像进行变形。

F.affine_grid 根据形变参数产生sampling grid,F.grid_sample根据sampling grid对图像进行变形。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)# Spatial transformer localization-networkself.localization = nn.Sequential(nn.Conv2d(1, 8, kernel_size=7),nn.MaxPool2d(2, stride=2),nn.ReLU(True),nn.Conv2d(8, 10, kernel_size=5),nn.MaxPool2d(2, stride=2),nn.ReLU(True))# Regressor for the 3 * 2 affine matrixself.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3, 32),nn.ReLU(True),nn.Linear(32, 3 * 2))# Initialize the weights/bias with identity transformationself.fc_loc[2].weight.data.zero_()self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))# Spatial transformer network forward functiondef stn(self, x):xs = self.localization(x)xs = xs.view(-1, 10 * 3 * 3)theta = self.fc_loc(xs)theta = theta.view(-1, 2, 3)grid = F.affine_grid(theta, x.size())x = F.grid_sample(x, grid)return xdef forward(self, x):# transform the inputx = self.stn(x)# Perform the usual forward passx = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)model = Net().to(device)



附录

此处参见仿射变换

几种常用的线性变换


  • 这里可以倒着理解, 比如从B点逆时针旋转, 就好算很多

双线性插值

在对图像进行仿射变换时,会出现一个问题,当原图像中某一点的坐标映射到变换后图像时,坐标可能会出现小数(如下图所示),而我们知道,图像上某一像素点的位置坐标只能是整数,那该怎么办?这时候双线性插值就起作用了。

双线性插值的基本思想是通过某一点周围四个点的灰度值来估计出该点的灰度值


在实现时我们通常将变换后图像上所有的位置映射到原图像计算(这样做比正向计算方便得多),即依次遍历变换后图像上所有的像素点,根据仿射变换矩阵计算出映射到原图像上的坐标(可能出现小数),然后用双线性插值,根据该点周围4个位置的值加权平均得到该点值。过程可用如下公式表示:


把R1, R2代入, 得:


因为 Q11,Q12,Q21,Q22Q_{11},Q_{12},Q_{21},Q_{22}Q11​,Q12​,Q21​,Q22​ 是相邻的四个点,所以 y2−y1=1,x2−x1=1y_2−y_1=1, x_2−x_1=1y2​−y1​=1,x2​−x1​=1,则上式可化简为:

论文阅读: Spatial transformer networks相关推荐

  1. Google DeepMind的新论文: Spatial Transformer Networks

    @金连文: 在CNN中引入Spatial Transformation 模块,自动学习变换参数,例如仿射变换的6个参数,从而能进行达到Rotation/translation invariant等识别 ...

  2. 论文阅读:Spatial Transformer Networks

    文章目录 1 概述 2 模型说明 2.1 Localisation Network 2.2 Parameterised Sampling Grid 3 模型效果 参考资料 1 概述 CNN的机理使得C ...

  3. CalibNet:Geometrically Supervised Extrinsic Calibration using 3D Spatial Transformer Networks阅读理解

    CalibNet:Geometrically Supervised Extrinsic Calibration using 3D Spatial Transformer Networks 无目标标定的 ...

  4. Spatial Transformer Networks 论文解读

    paper title:Spatial Transformer Networks paper link: https://arxiv.org/pdf/1506.02025.pdf oral or de ...

  5. 【论文学习】STN —— Spatial Transformer Networks

    Paper:Spatial Transformer Networks 这是Google旗下 DeepMind 大作,最近学习人脸识别,这篇paper提出的STN网络可以代替align的操作,端到端的训 ...

  6. Spatial Transformer Networks(STN)详解

    目录 1.STN的作用 1.1 灵感来源 1.2 什么是STN? 2.STN网络架构![在这里插入图片描述](https://img-blog.csdnimg.cn/20190908104416274 ...

  7. 详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了

    Spatial Transformer Networks https://blog.jiangzhenyu.xyz/2018/10/06/Spatial-Transformer-Networks/ 2 ...

  8. 论文阅读:Spectral Networks and Deep Locally Connected Networks on Graphs

    论文阅读:Spectral Networks and Deep Locally Connected Networks on Graphs 目录 Abstract 1 Introduction 1.1 ...

  9. Paper:《Spatial Transformer Networks空间变换网络》的翻译与解读

    Paper:<Spatial Transformer Networks空间变换网络>的翻译与解读 导读:该论文提出了空间变换网络的概念.主要贡献是提出了空间变换单元(Spatial Tra ...

最新文章

  1. NBT:超高速细菌基因组检索技术
  2. 委托学习总结(一)浅谈对C#委托理解
  3. 【python】闭包
  4. rmi远程代码执行漏洞_WebSphere 远程代码执行漏洞浅析(CVE20204450)
  5. 一文看懂哈夫曼树与哈夫曼编码
  6. 算法相关----最大公约数算法
  7. cad多个窗口并排显示_高版本CAD如何显示阵列窗口?
  8. 关于apache的重启
  9. 【学习记录贴】#3——校园二维和三维电子地图制作
  10. Android App拥有system权限
  11. Google搜索 - 世界各国Google网址大全
  12. navicat激活已过期
  13. ITOP4412开发板学习前的准备2 -- 安装ADB驱动
  14. 幼儿园清明节活动设计方案
  15. 使用计算机组成原理全加器设计,杭电计算机组成原理全加器设计实验1
  16. php 判断百度蜘蛛抓取,百度蜘蛛抓取不存在目录 对应的解决方法
  17. 推荐10篇2021年服装设计相关毕业论文文献
  18. 操作系统中磁盘调度算法详解
  19. svn update 出现skipped '.' 或skipped '目录名称'
  20. python爬取知乎回答并进行舆情分析:爬取数据部分

热门文章

  1. matlab 多项式相减,matlab多项式计算与数据处理
  2. 敲黑板,定积分也有换元和分部积分法!
  3. HTML5、CSS3进阶——渐变背景
  4. WinEdt 7.0生成的PDF文件,用Sumatra PDF打开后,一直提示更新Sumatra PDF版本!
  5. 树莓派Linux开机使用root自动运行python的pyqt文件
  6. 一流的匠人,必有一流的心性:工作是人生最尊贵、最重要、最有价值的行为
  7. C++ STL set容器
  8. 极值点、驻点、鞍点、拐点
  9. xp无法访问2012r2域计算机列表,新安装Windows 2012域控无法没有自动创建Sysvol netlogon共享--钉子-Exchange MVP...
  10. linux下进程管理的原理,Linux进程管理:supervisor和nohup原理及使用