Spatial Transformer Networks(STN)-论文笔记

  • 论文: Spatial Transformer Networks
  • 作者:Max Jaderberg Karen Simonyan Andrew Zisserman Koray Kavukcuoglu
  • code1:https://github.com/oarriaga/STN.keras
  • code2:https://github.com/kevinzakka/spatial-transformer-network

1. 问题提出

  1. CNN在图像分类中取得了显著的成效,主要是得益于 CNN 的深层结构具有 :平移不变性、缩小不变性\color{red}平移不变性、缩小不变性平移不变性、缩小不变性;还对缺失的空间不变性(spatiallyinvariance)\color{red}空间不变性(spatially\;invariance)空间不变性(spatiallyinvariance)做了相应的实验。

    • 平移不变性平移不变性平移不变性主要是由于 Pooling 层 和 步长不为1的卷积层 的存在带来的。实际上主要是池化层的作用:

      • 层越多,越深,池化核或卷积核越大,空间不变性也越强;
      • 但是随之而来的问题是局部信息丢失,所以这些层越多准确率肯定是下降的,所以主流的CNN分类网络一般都很深,但是池化核都比较小,比如2×2。
    • 缩小不变性缩小不变性缩小不变性主要是通过降采样来实现的。降采样比例要根据数据集调整,找到合适的降采样比例,才能保证准确率的情况下,有较强的空间不变性。
      • 比如ResNet,GoogLeNet,VGG,FCN,这些网络的总降采样比例一般是 16或32,基本没有见过 64倍,128倍或者更高倍数的降采样(会损失局部信息降低准确率),也很少见到 2倍或者4倍的降采样比例(空间不变性太弱,泛化能力不好)。
  2. 空间不变性(spatiallyinvariance)\color{red}空间不变性(spatially\;invariance)空间不变性(spatiallyinvariance)这些不变性的本质就是图像处理的经典手段:空间变换,又服从于同一方法:坐标矩阵的仿射变换。因此DeepMind设计了SpatialTransformerNetworks\color{red}Spatial\;Transformer\;NetworksSpatialTransformerNetworks(简称STN),目的就是显式地赋予网络对于以上各项变换(transformation)的不变性(invariance) .

2. 图像处理技巧

2.1 仿射变化

主要是要处理(2×3)(2\times 3)(2×3)的变换矩阵:
Tθ=[θ11θ12θ13θ21θ22θ23](2.1)\mathcal{T}_{\theta} = \begin{bmatrix} \theta _{11} & \theta _{12} & \theta _{13} \\ \theta _{21} & \theta _{22} & \theta _{23} \end{bmatrix}\tag{2.1}Tθ​=[θ11​θ21​​θ12​θ22​​θ13​θ23​​](2.1)

  • 平移:
    [10θ1301θ23][xy1]=[x+θ13y+θ23](2.2)\left[\begin{array}{ccc} 1 & 0 & \theta_{13} \\0 & 1 & \theta_{23} \end{array}\right]\left[\begin{array}{l}x \\y \\1 \end{array}\right]=\left[\begin{array}{l} x+\theta_{13} \\y+\theta_{23} \end{array}\right]\tag{2.2}[10​01​θ13​θ23​​]⎣⎡​xy1​⎦⎤​=[x+θ13​y+θ23​​](2.2)

  • 缩放:
    [θ11000θ220][xy1]=[θ11xθ22y](2.3)\left[\begin{array}{ccc} \theta_{11} & 0 & 0 \\0 & \theta_{22} & 0 \end{array}\right]\left[\begin{array}{l}x \\y \\1 \end{array}\right]=\left[\begin{array}{l} \theta_{11} x \\\theta_{22} y\end{array}\right]\tag{2.3}[θ11​0​0θ22​​00​]⎣⎡​xy1​⎦⎤​=[θ11​xθ22​y​](2.3)

  • 旋转:
    对于旋转操作,设绕原点顺时针旋转 α\alphaα 度,坐标仿射矩阵为:
    [cos⁡(α)sin⁡(α)0−sin⁡(α)cos⁡(α)0][xy1]=[cos⁡(α)x+sin⁡(α)y−sin⁡(α)x+cos⁡(α)y](2.4)\left[\begin{array}{ccc} \cos (\alpha) & \sin (\alpha) & 0 \\ -\sin (\alpha) & \cos (\alpha) & 0 \end{array}\right]\left[\begin{array}{l}x \\y \\1 \end{array}\right]=\left[\begin{array}{c}\cos (\alpha) x+\sin (\alpha) y \\-\sin (\alpha) x+\cos (\alpha) y \end{array}\right]\tag{2.4}[cos(α)−sin(α)​sin(α)cos(α)​00​]⎣⎡​xy1​⎦⎤​=[cos(α)x+sin(α)y−sin(α)x+cos(α)y​](2.4)

    由于图像的坐标不是中心坐标系,通常需要做Normalization,把坐标调整到[−1,1][-1,1][−1,1]。这样,就绕图像中心旋转了。

2.2 逆向坐标映射

假设fixed image 的坐标点是[xtar,ytar][x^{tar}, y^{tar}][xtar,ytar],source iamge 的坐标点是[xsour,ysour][x^{sour}, y^{sour}][xsour,ysour],则一般的坐标映射可以表示为:
[θ11θ12θ13θ21θ22θ23][xsourysour1]=[xtarytar1](2.5)\begin{bmatrix} \theta_{11} & \theta_{12} & \theta_{13} \\ \theta _{21} & \theta _{22} & \theta _{23} \end{bmatrix}\begin{bmatrix} x^{sour} \\ y^{sour} \\ 1 \end{bmatrix}=\begin{bmatrix} x^{tar} \\ y^{tar} \\1 \end{bmatrix}\tag{2.5}[θ11​θ21​​θ12​θ22​​θ13​θ23​​]⎣⎡​xsourysour1​⎦⎤​=⎣⎡​xtarytar1​⎦⎤​(2.5)

逆向坐标映射表示为(θ′\theta'θ′ and θ\thetaθ are different):
[θ11′θ12′θ13′θ21′θ22′θ23′][xtarytar1]=[xsourysour1](2.6)\begin{bmatrix} \theta'_{11} & \theta'_{12} & \theta'_{13} \\ \theta' _{21} & \theta' _{22} & \theta' _{23} \end{bmatrix}\begin{bmatrix} x^{tar} \\ y^{tar} \\ 1 \end{bmatrix}=\begin{bmatrix} x^{sour} \\ y^{sour} \\1 \end{bmatrix}\tag{2.6}[θ11′​θ21′​​θ12′​θ22′​​θ13′​θ23′​​]⎣⎡​xtarytar1​⎦⎤​=⎣⎡​xsourysour1​⎦⎤​(2.6)
STN采用逆向映射,因为:target image 是固定的,正向的插值过程,都是引用像素坐标是浮点数,相对来说很难插值;对应逆向变换,得到的Source坐标是浮点数,用Source像素插值更加便捷

2.3 双线性插值

  • 一个[1,10]图像放大10倍问题,我们需要将10个像素,扩展到为100的数轴上,整个图像应该有100个像素。
    但其中90个对应Source图的坐标是非整数的,是不存在的,如果我们用黑色(RGB(0,0,0))填充,此时图像是惨不忍睹的。所以需要对缺漏的像素进行插值,利用图像数据的局部性近似原理,取邻近像素做平均生成。

  • 双线性插值是一个兼有质量与速度的方法:

  • 插值一般表达式:
    Vic=∑nH∑mWUnmck(xis−m;Φx)k(yis−n;Φy)∀i∈[1…H′W′]∀c∈[1…C](2.7)V_{i}^{c}=\sum_{n}^{H} \sum_{m}^{W} U_{n m}^{c} k\left(x_{i}^{s}-m ; \Phi_{x}\right) k\left(y_{i}^{s}-n ; \Phi_{y}\right) \forall i \in\left[1 \ldots H^{\prime} W^{\prime}\right] \forall c \in[1 \ldots C]\tag{2.7}Vic​=n∑H​m∑W​Unmc​k(xis​−m;Φx​)k(yis​−n;Φy​)∀i∈[1…H′W′]∀c∈[1…C](2.7)

    • UnmcU_{n m}^{c}Unmc​ 是输入feature map上第 ccc 个通道上坐标为 (n,m)(n, m)(n,m) 的像素值;
    • VicV_{i}^{c}Vic​ 是输出 feature map上第 ccc 个通道上坐标为 (xit,yit)\left(x_{i}^{t}, y_{i}^{t}\right)(xit​,yit​) 的像素值;
    • k()k()k() 表示插值核函数;
    • Φx,Φy\Phi x, \Phi yΦx,Φy 代表 x\mathrm{x}x 和 y\mathrm{y}y 方向的揷值核函数的参数;
    • H,WH, WH,W 输入UUU的尺寸;
    • H′,W′H^{\prime}, W^{\prime}H′,W′ 输出VVV的尺寸;
  • 双线性插值的公式:
    Vic=∑nH∑mWUnmcmax⁡(0,1−∣xis−m∣)max⁡(0,1−∣yis−n∣)(2.8)V_{i}^{c}=\sum_{n}^{H} \sum_{m}^{W} U_{n m}^{c} \max \left(0,1-\left|x_{i}^{s}-m\right|\right) \max \left(0,1-\left|y_{i}^{s}-n\right|\right)\tag{2.8}Vic​=n∑H​m∑W​Unmc​max(0,1−∣xis​−m∣)max(0,1−∣yis​−n∣)(2.8)
    这个插值核函数做的是利用UUU中离 当前源坐标 (xis,yis)\left(x_{i}^{s}, y_{i}^{s}\right)(xis​,yis​) (小数坐标) 最近的 4个整数坐标 (n,m)(n, m)(n,m) 处的像素值做双线性插值然后拷贝到VVV中的 (xit,yit)\left(x_{i}^{t}, y_{i}^{t}\right)(xit​,yit​) 坐 标处。


3. 整体框架

3.1 整体描述

Spatial Transformer Networks的结构,主要的部分—共有三个,它们的功能和名称如下:

  • Localisationnet\color{blue}Localisation\;netLocalisationnet(参数预测):
    是自己定义的网络,它输入UUU,输出变化参数θ\thetaθ,这个参数用来映射UUU和VVV的坐标关系(公式(2.1))。
  • Gridgenerator\color{green}Grid\;generatorGridgenerator(坐标映射):
    根据VVV中的坐标点和变化参数θ\thetaθ,计算出UUU中的坐标点(公式(2.6))。

    • 这里是因为VVV的大小是先定义好的,当然可以得到VVV的所有坐标点,而填充VVV中每个坐标点的像素值的时候,要从UUU中去取,所以根据VVV中每个坐标点和变化参数θ\thetaθ进行运算,得到一个坐标。
    • 在sampler中就是根据这个坐标去UUU中找到像素值,这样子来填充VVV。
  • Sampler\color{gray}SamplerSampler(像素的采集):

    根据Grid generator得到的一系列坐标和原图UUU(因为像素值要从UUU中取)来填充,因为计算出来的坐标可能为小数,要用另外的方法来填充,比如双线性插值。

3.2 基本结构与前向传播

  • DeepMind为了描述这个空间变换层,首先添加了坐标网格计算的概念,即:

    • 对应输入源特征图像素的坐标网格——Sampling Grid,保存着(xSource,ySource)(x^{Source},y^{Source})(xSource,ySource)
    • 对应输出源特征图像素的坐标网格——Regluar Grid ,保存着(xTarget,yTarget)(x^{Target},y^{Target})(xTarget,yTarget)
  1. Localisationnet\color{blue}Localisation\;netLocalisationnet(参数预测):对应着初始化的6个参数。
  2. Gridgenerator\color{green}Grid\;generatorGridgenerator(坐标映射):对应着图中的①②。
    Tθ(Gi)[xtarytar1]=[θ11′θ12′θ13′θ21′θ22′θ23′][xtarytar1]=[xsourysour1],wherei=1,2,3,4..,H∗W(3.1)\mathcal{T}_{\theta}(G_i)\begin{bmatrix} x^{tar} \\ y^{tar} \\ 1 \end{bmatrix} = \begin{bmatrix} \theta'_{11} & \theta'_{12} & \theta'_{13} \\ \theta' _{21} & \theta' _{22} & \theta' _{23} \end{bmatrix}\begin{bmatrix} x^{tar} \\ y^{tar} \\ 1 \end{bmatrix}=\begin{bmatrix} x^{sour} \\ y^{sour} \\1 \end{bmatrix}, where\;i=1,2,3,4..,H∗W\tag{3.1}Tθ​(Gi​)⎣⎡​xtarytar1​⎦⎤​=[θ11′​θ21′​​θ12′​θ22′​​θ13′​θ23′​​]⎣⎡​xtarytar1​⎦⎤​=⎣⎡​xsourysour1​⎦⎤​,wherei=1,2,3,4..,H∗W(3.1)
  3. Sampler\color{gray}SamplerSampler(像素的采集):对应着图中的③④。

3.3 梯度流动与反向传播

添加空间变换层之后,梯度流动变得有趣,如图:

  1. 后流(①):
    ErrorGradientError\;GradientErrorGradient →……→∂Next∂Vic\rightarrow \ldots \ldots \rightarrow \frac{\partial N e x t}{\partial V_{i}^{c}}→……→∂Vic​∂Next​
    这是Back Propagation从后层继承的动力源泉,没有它,你就不可能完成Back Propagation。
  2. 里流(②):
    {∂Vic∂xiS→∂xiS∂θ∂Vic∂yiS→∂yiS∂θ(3.3)\left\{\begin{aligned} \frac{\partial V_{i}^{c}}{\partial x_{i}^{S}} \rightarrow \frac{\partial x_{i}^{S}}{\partial \theta} \\ \frac{\partial V_{i}^{c}}{\partial y_{i}^{S}} \rightarrow \frac{\partial y_{i}^{S}}{\partial \theta} \end{aligned}\right.\tag{3.3}⎩⎪⎪⎪⎨⎪⎪⎪⎧​∂xiS​∂Vic​​→∂θ∂xiS​​∂yiS​∂Vic​​→∂θ∂yiS​​​(3.3)
  • 个人对这股流的最好描述就是: 一江春水流进了小黑屋。
  • 是的,你没有看错,这股流根本就没有流到网络开头,而是在定位网络处就断流了。 由此来看,定位网络就好像是在主网络旁侧偷建的小黑屋,是一个违章湕筑。
  • 所以也无怪乎作者说,定位网络直接変成了一个回归模型,因为更新完参数,流就断了,独立于主网络。
  1. 前流(③):
    ∂Vic∂Unmi→∂Unmi∂Previous (3.4)\frac{\partial V_{i}^{c}}{\partial U_{n m}^{i}} \rightarrow \frac{\partial U_{n m}^{i}}{\partial \text { Previous }}\tag{3.4}∂Unmi​∂Vic​​→∂ Previous ∂Unmi​​(3.4)
    这是Back Propagation传宗接代的根本保障,没有它,Back Propagation就断子绝孙了。

3.4 局部梯度

论文中多次出现[局部梯度] (Sub-Gradient) 的概念。采样核函数,是不连续的,不能如下直接求导:
g=∂Vic∂θ(3.5)g=\frac{\partial V_{i}^{c}}{\partial \theta}\tag{3.5} g=∂θ∂Vic​​(3.5)
而应该是分两步,先对 xiS、xiSx_{i}^{S} 、 x_{i}^{S}xiS​、xiS​ 求局部梯度: ∂Vic∂xic、∂Vic∂yic\frac{\partial V_{i}^{c}}{\partial x_{i}^{c}} 、 \frac{\partial V_{i}^{c}}{\partial y_{i}^{c}}∂xic​∂Vic​​、∂yic​∂Vic​​ ,后有:
{g=∂Vic∂xiS⋅∂xiS∂θg=∂Vic∂yiS⋅∂yiS∂θ(3.6)\left\{\begin{aligned} g=\frac{\partial V_{i}^{c}}{\partial x_{i}^{S}} \cdot \frac{\partial x_{i}^{S}}{\partial \theta} \\ g=\frac{\partial V_{i}^{c}}{\partial y_{i}^{S}} \cdot \frac{\partial y_{i}^{S}}{\partial \theta} \end{aligned}\right.\tag{3.6} ⎩⎪⎪⎪⎨⎪⎪⎪⎧​g=∂xiS​∂Vic​​⋅∂θ∂xiS​​g=∂yiS​∂Vic​​⋅∂θ∂yiS​​​(3.6)
有趣的是,对于Theano这种目动求导的 Tools,局部梯度可以直接被忽视。
因为Theano的Tensor机制,会聪明地讨论并且解离非连续函数,追踪每一个可导子式,即便你用了作者们的优雅的采样函数, Tensor.grad函数也能精确只对许出的4个点求导,所以在Theano里讨论非连续函数和局部梯度,是会贻笑大方的。


4. 实验

4.1 Distorted MNIST

这个试验的数据集 是 MNIST,不过与原版的MNIST 不同,这个数据集对图片上的数字做了各种形变操作,比如平移,扭曲,放缩,旋转等。

  • 不同形变操作的简写表示:

    • 旋转:rotation ( R),
    • 旋转+缩放+平移:rotation, scale and translation (RTS),
    • 投影变换:projective transformation ( P),
    • 弹性变形:elastic warping (E) – note that elastic warping is destructive and can not be inverted in some cases.
  • 文章将 Spatial Transformer 模块嵌入到 两种主流的分类网络,FCN和CNN中(ST-FCN 和 ST-CNN )。Spatial Transformer 模块嵌入位置在图片输入层与后续分类层之间。
  • 试验也测试了不同的变换函数对结果的影响:
    • 仿射变换:affine transformation (Aff),
    • 投影变换:projective transformation (Proj),
    • 薄板样条变换:16-point thin plate spline transformation (TPS)

其中CNN的模型与 LeNet是一样的,包含两个池化层。为了公平,所有的网络变种都只包含 3 个可学习参数的层,总体网络参数基本一致,训练策略也相同。

  • 左侧:不同的形变策略以及不同的 Spatial Transformer网络变种与 baseline的对比;
  • 右侧:一些CNN分错,但是ST-CNN分对的样本
    - (a ):输入
    - (b ):Spatial Transformer层 的 源坐标(Tθ(G) )可视化结果
    - (c ):Spatial Transformer层输出
  • 很明显:ST-CNN优于CNN, ST-FCN优于FCN,说明Spatial Transformer确实增加了 空间不变性
  • FCN中由于没有 池化层,所以FCN的空间不变性不如CNN,所以FCN效果不如CNN
  • ST-FCN效果可以达到CNN程度,说明Spatial Transformer确实增加了 空间不变性
  • ST-CNN效果优于ST-FCN,说明 池化层 确实对 增加 空间不变性很重要
  • 在 Spatial Transformer 中使用 plate spline transformation (TPS) 变换效果是最好的
  • Spatial Transformer 可以将歪的数字扭正
  • Spatial Transformer 在输入图片上确定的attention区域很明显利于后续分类层分类,可以更加有效地减少分类损失

4.2 Street View House Numbers

Street View House Numbers是一个真实的 街景门牌号 数据集,共200k张图片,每张图片包含1-5个数字 ,数字都有形变。

  • baseline character sequence CNN model :11层,5个softmax层输出对应位置的预测序列
  • STCNN Single :在输入层添加一个Spatial Transformer
  • ST-CNN Multi :前四层,每一层都添加一个Spatial Transformer 见下面 tabel 2 右侧
  • localisation networks 子网络:两层32维的全连接层
  • 使用仿射变换和双线性插值

结果:

参考

  1. arleyzhang:基础DL模型-STN-Spatial Transformer Networks-论文笔记
  2. Spatial Transformer Networks笔记
  3. 详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了
  4. Spatial Transformer Networks
  5. 论文笔记:空间变换网络(Spatial Transformer Networks)

【论文笔记-5】Spatial Transformer Networks(STN)相关推荐

  1. 论文阅读:Spatial Transformer Networks

    文章目录 1 概述 2 模型说明 2.1 Localisation Network 2.2 Parameterised Sampling Grid 3 模型效果 参考资料 1 概述 CNN的机理使得C ...

  2. Deformable ConvNets--Part2: Spatial Transfomer Networks(STN)

    转自:https://blog.csdn.net/u011974639/article/details/79681455 Deformable ConvNet简介 关于Deformable Convo ...

  3. 【论文学习】STN —— Spatial Transformer Networks

    Paper:Spatial Transformer Networks 这是Google旗下 DeepMind 大作,最近学习人脸识别,这篇paper提出的STN网络可以代替align的操作,端到端的训 ...

  4. 详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了

    Spatial Transformer Networks https://blog.jiangzhenyu.xyz/2018/10/06/Spatial-Transformer-Networks/ 2 ...

  5. Spatial Transformer Networks(STN)详解

    目录 1.STN的作用 1.1 灵感来源 1.2 什么是STN? 2.STN网络架构![在这里插入图片描述](https://img-blog.csdnimg.cn/20190908104416274 ...

  6. Spatial Transformer Networks 论文解读

    paper title:Spatial Transformer Networks paper link: https://arxiv.org/pdf/1506.02025.pdf oral or de ...

  7. Spatial Transformer Networks(STN)

    详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_多元思考力-CSDN博客_stn

  8. 注意力机制——Spatial Transformer Networks(STN)

    Spatial Transformer Networks(STN)是一种空间注意力模型,可以通过学习对输入数据进行空间变换,从而增强网络的对图像变形.旋转等几何变换的鲁棒性.STN 可以在端到端的训练 ...

  9. Paper:《Spatial Transformer Networks空间变换网络》的翻译与解读

    Paper:<Spatial Transformer Networks空间变换网络>的翻译与解读 导读:该论文提出了空间变换网络的概念.主要贡献是提出了空间变换单元(Spatial Tra ...

  10. 行为识别论文笔记|TSN|Temporal Segment Networks: Towards Good Practices for Deep Action Recognition

    行为识别论文笔记|TSN|Temporal Segment Networks: Towards Good Practices for Deep Action Recognition Temporal ...

最新文章

  1. java播放器使用教程_java 实现音乐播放器的简单实例
  2. oracle共享时监听,Oracle监听---共享连接参数配置介绍
  3. 解决Extjs中Combobox显示值和真实值赋值问题
  4. 电脑入门基础教程_ARM入门最好的文章------转载一位资身工程师的入门心得
  5. 激光雷达和相机联合标定 之 开源代码和软件汇总 (2004-2021)
  6. JavaScript对象类型Object
  7. Javaspring 7-13课 Spring Bean
  8. html背景设置为彩色,CSS3 彩色网格背景
  9. 微信支付宝个人免签约即时到帐接口开发附demo
  10. 2023年华南理工大学运筹学与控制论上岸前辈备考经验
  11. android 7 zip压缩文件,7-zip怎么把大文件压缩到最小
  12. 我的世界热力膨胀JAVA_我的世界TE4教程热力膨胀能源炉的合成与使用数据
  13. 小熊派移植 TencentOS-tiny+M26/EC20+MQTT 对接腾讯云平台IoThub
  14. ubuntu系统如何连接到服务器,远程ubuntu系统怎么连接到服务器
  15. UCI计算机工程必修专业课,UCI大学尖端专业学科盘点
  16. 人工智能发展的三个热潮
  17. Python回归预测建模实战-随机梯度下降法预测房价(附源码和实现效果)
  18. 【线代】特征值、惯性指数、标准型、规范型的关系?等价、相似与合同?
  19. 数学图形(1.8) 圆外旋轮线
  20. AllWinner T113 wifi tools交叉编译

热门文章

  1. 洛谷 P4578 [FJOI2018] Upc6605 福建OI2018 所罗门王的宝藏
  2. 前台离岗提示语_酒店客房温馨提示怎么写 酒店前台温馨提示语
  3. 管理员三权分立是什么意思?
  4. 一台电脑绿色安装多个版本google Chorme方案
  5. 2019备考[嵌入式系统设计师]之基础知识
  6. idea Translation插件翻译失败。TKK: 更新 TKK 失败,请检查网络连接
  7. python爬虫之使用urllib模块实现有道翻译功能
  8. Groovy(二)groovy基础
  9. K8S CRD 资源对象删除不掉
  10. 专访丨华为云GaussDB苏光牛:发挥生态优势,培养应用型DBA