空间变换网络(Spatial Transformer Network)

  • 空间变换网络(Spatial Transformer Network)

    • 空间变换器(Spatial Transformers)
    • pytorch 源码
    • Reference

本文的参考文献为:《Spatial Transformer Networks》

卷积神经网络定义了一个异常强大的模型类,但在计算和参数有效的方式下仍然受限于对输入数据的空间不变性。在此引入了一个新的可学模块,空间变换网络,它显式地允许在网络中对数据进行空间变换操作。这个可微的模块可以插入到现有的卷积架构中,使神经网络能够主动地在空间上转换特征映射,在特征映射本身上有条件,而不需要对优化过程进行额外的训练监督或修改。我们展示了空间变形的使用结果,在模型中学习了平移、缩放、旋转和更一般的扭曲,结果在几个基准上得到了很好的效果。


空间变换器(Spatial Transformers)

一个空间变换器的运作机制可以分为三个部分,如下图所示:1) 本地网络(Localisation Network);2)网格生成器( Grid Genator);3)采样器(Sampler)。

本地网络是一个用来回归变换参数θθ\theta的网络,它的输入时特征图像,然后经过一系列的隐藏网络层(全连接或者卷积网,再加一个回归层)输出空间变换参数。θθ\theta的形式可以多样,如需实现2D仿射变换,θθ\theta 就是一个6维(2x3)向量的输出。θθ\theta 的尺寸大小依赖于变换的类型。

θ=floc(U)θ=floc(U)

\theta = f_{loc}(U)

网格生成器(Grid Generator)是依据预测的变换参数来构建一个采样网格,它是一组输入图像中的点经过采样变换后得到的输出。网格生成器其实得到的是一种映射关系TθTθ\mathcal T_\theta。假设特征图像UUU 每个像素的坐标为(xsi,ysi)(xis,yis)(x^s_i,y^s_i), VVV 的每个像素坐标为(xit,yit)" role="presentation" style="position: relative;">(xti,yti)(xit,yit)(x^t_i,y^t_i), 空间变换函数 TθTθ\mathcal T_\theta 为二维仿射变换函数,那么 (xsi,ysi)(xis,yis)(x^s_i,y^s_i)和(xti,yti)(xit,yit)(x^t_i,y^t_i)的对应关系可以写为:

(xsiysi)=Tθ(Gi)=Aθ⎛⎝⎜xtiyti1⎞⎠⎟=[θ11θ21θ12θ22θ13θ23]⎛⎝⎜xtiyti1⎞⎠⎟(xisyis)=Tθ(Gi)=Aθ(xityit1)=[θ11θ12θ13θ21θ22θ23](xityit1)

\begin{pmatrix} x_i^s \\ y_i^s \end{pmatrix} =\mathcal T_\theta(G_i)=\mathbf{A}_\theta \begin{pmatrix} x_i^t \\ y_i^t\\ 1\\ \end{pmatrix}= \begin{bmatrix} \theta_{11}&\theta_{12}&\theta_{13}\\ \theta_{21}&\theta_{22}&\theta_{23}\\ \end{bmatrix} \begin{pmatrix} x_i^t \\ y_i^t\\ 1\\ \end{pmatrix}
采样器利用采样网格和输入的特征图同时作为输入产生输出,得到了特征图经过变换之后的结果。

Vci=∑nH∑mWUcnmmax(0,1−|xsi−m|)max(0,1−|ysi−n|)Vic=∑nH∑mWUnmcmax(0,1−|xis−m|)max(0,1−|yis−n|)

V_i^c=\sum_n^{H}{\sum_m^{W}{ U^c_{nm} \max(0, 1-|x_i^s-m|)\max(0, 1-|y_i^s-n|)}}

至此,整个前向传播就完成了。与以往的网络稍微不同的就是STN中有一个采样(插值)的过程,这个采样需要依靠一个特定的网格作为引导。但是细想,我们常用的池化也是一种采样(插值)方式,只不过使用的网格有点特殊而已。

既然存在网络,需要训练,那么就必须得考虑损失的反向传播了。对于自己定义的sampler,这里的反向传播公式需要推导。
其中,输出对采样器的求导公式为:

∂Vci∂Ucnm=∑nH∑mWmax(0,1−|xsi−m|)max(0,1−|ysi−n|)∂Vci∂xsi=∑nH∑mWUcnmmax(0,1−|xsi−m|⎧⎩⎨⎪⎪0,1,−1if |m−xsi|≥1if m≥xsiif m<xsi∂Vci∂ysi=∑nH∑mWUcnmmax(0,1−|ysi−n|)⎧⎩⎨⎪⎪0,1,−1if |n−ysi|≥1if n≥ysiif n<ysi∂Vic∂Unmc=∑nH∑mWmax(0,1−|xis−m|)max(0,1−|yis−n|)∂Vic∂xis=∑nH∑mWUnmcmax(0,1−|xis−m|{0,if |m−xis|≥11,if m≥xis−1if m<xis∂Vic∂yis=∑nH∑mWUnmcmax(0,1−|yis−n|){0,if |n−yis|≥11,if n≥yis−1if n<yis

\frac{\partial{V_i^c}}{\partial{U^c_{nm}}}=\sum_n^{H}{\sum_m^{W}{ \max(0, 1-|x_i^s-m|)\max(0, 1-|y_i^s-n|)}} \\ \frac{\partial{V_i^c}}{\partial{x^s_i}}=\sum_n^{H}{\sum_m^{W}{ U^c_{nm} \max(0, 1-|x_i^s-m|}} \begin{cases} 0, & \text{if $|m-x_i^s|\ge1$} \\ 1, & \text{if $m\ge{x_i^s}$}\\ -1 & \text{if $m\lt{x_i^s}$}\\ \end{cases}\\ \frac{\partial{V_i^c}}{\partial{y^s_i}}=\sum_n^{H}{\sum_m^{W}{ U^c_{nm} \max(0, 1-|y_i^s-n|)}} \begin{cases} 0, & \text{if $|n-y_i^s|\ge1$} \\ 1, & \text{if $n\ge{y_i^s}$}\\ -1 & \text{if $n\lt{y_i^s}$}\\ \end{cases}
输出对grid generator的求导公式需要依据使用的变换公式自行确定,但大体公式如下计算:

∂Vci∂θ=⎛⎝⎜∂Vci∂xsi⋅∂xsi∂θ∂Vci∂ysi⋅∂ysi∂θ⎞⎠⎟∂Vic∂θ=(∂Vic∂xis⋅∂xis∂θ∂Vic∂yis⋅∂yis∂θ)

\frac{\partial{V_i^c}}{\partial{\theta}}= \begin{pmatrix} \frac{\partial{V_i^c}}{\partial{x^s_i}} \cdot \frac{\partial{x^s_i}}{\partial{\theta}} \\ \frac{\partial{V_i^c}}{\partial{y^s_i}} \cdot \frac{\partial{y^s_i}}{\partial{\theta}} \end{pmatrix}
将以上部分组合在一起就能构成STN网络了。

pytorch 源码

# -*- coding: utf-8 -*-
"""
Spatial Transformer Networks Tutorial
=====================================
**Author**: `Ghassen HAMROUNI <https://github.com/GHamrouni>`_.. figure:: /_static/img/stn/FSeq.pngIn this tutorial, you will learn how to augment your network using
a visual attention mechanism called spatial transformer
networks. You can read more about the spatial transformer
networks in the `DeepMind paper <https://arxiv.org/abs/1506.02025>`__Spatial transformer networks are a generalization of differentiable
attention to any spatial transformation. Spatial transformer networks
(STN for short) allow a neural network to learn how to perform spatial
transformations on the input image in order to enhance the geometric
invariance of the model.
For example, it can crop a region of interest, scale and correct
the orientation of an image. It can be a useful mechanism because CNNs
are not invariant to rotation and scale and more general affine
transformations.One of the best things about STN is the ability to simply plug it into
any existing CNN with very little modification.
"""
# License: BSD
# Author: Ghassen Hamrounifrom __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as npplt.ion()   # interactive mode######################################################################
# Loading the data
# ----------------
#
# In this post we experiment with the classic MNIST dataset. Using a
# standard convolutional network augmented with a spatial transformer
# network.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Training dataset
train_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=64, shuffle=True, num_workers=4)######################################################################
# Depicting spatial transformer networks
# --------------------------------------
#
# Spatial transformer networks boils down to three main components :
#
# -  The localization network is a regular CNN which regresses the
#    transformation parameters. The transformation is never learned
#    explicitly from this dataset, instead the network learns automatically
#    the spatial transformations that enhances the global accuracy.
# -  The grid generator generates a grid of coordinates in the input
#    image corresponding to each pixel from the output image.
# -  The sampler uses the parameters of the transformation and applies
#    it to the input image.
#
# .. figure:: /_static/img/stn/stn-arch.png
#
# .. Note::
#    We need the latest version of PyTorch that contains
#    affine_grid and grid_sample modules.
#class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)# Spatial transformer localization-networkself.localization = nn.Sequential(nn.Conv2d(1, 8, kernel_size=7),nn.MaxPool2d(2, stride=2),nn.ReLU(True),nn.Conv2d(8, 10, kernel_size=5),nn.MaxPool2d(2, stride=2),nn.ReLU(True))# Regressor for the 3 * 2 affine matrixself.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3, 32),nn.ReLU(True),nn.Linear(32, 3 * 2))# Initialize the weights/bias with identity transformationself.fc_loc[2].weight.data.zero_()self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))# Spatial transformer network forward functiondef stn(self, x):xs = self.localization(x)xs = xs.view(-1, 10 * 3 * 3)theta = self.fc_loc(xs)theta = theta.view(-1, 2, 3)grid = F.affine_grid(theta, x.size())x = F.grid_sample(x, grid)return xdef forward(self, x):# transform the inputx = self.stn(x)# Perform the usual forward passx = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)model = Net().to(device)######################################################################
# Training the model
# ------------------
#
# Now, let's use the SGD algorithm to train the model. The network is
# learning the classification task in a supervised way. In the same time
# the model is learning STN automatically in an end-to-end fashion.optimizer = optim.SGD(model.parameters(), lr=0.01)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % 500 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
#
# A simple test procedure to measure STN the performances on MNIST.
#def test():with torch.no_grad():model.eval()test_loss = 0correct = 0for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)# sum up batch losstest_loss += F.nll_loss(output, target, size_average=False).item()# get the index of the max log-probabilitypred = output.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))######################################################################
# Visualizing the STN results
# ---------------------------
#
# Now, we will inspect the results of our learned visual attention
# mechanism.
#
# We define a small helper function in order to visualize the
# transformations while training.def convert_image_np(inp):"""Convert a Tensor to numpy image."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)return inp# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.def visualize_stn():with torch.no_grad():# Get a batch of training datadata = next(iter(test_loader))[0].to(device)input_tensor = data.cpu()transformed_input_tensor = model.stn(data).cpu()in_grid = convert_image_np(torchvision.utils.make_grid(input_tensor))out_grid = convert_image_np(torchvision.utils.make_grid(transformed_input_tensor))# Plot the results side-by-sidef, axarr = plt.subplots(1, 2)axarr[0].imshow(in_grid)axarr[0].set_title('Dataset Images')axarr[1].imshow(out_grid)axarr[1].set_title('Transformed Images')for epoch in range(1, 20 + 1):train(epoch)test()# Visualize the STN transformation on some input batch
visualize_stn()plt.ioff()
plt.show()

Reference

[1] 【论文笔记】Spatial Transformer Networks
[2] Spatial Transformer Networks Tutorial

STN:空间变换网络(Spatial Transformer Network)相关推荐

  1. 空间变换网络--spatial transform network

    版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/u011961856/article/details/77920970 CNN分类时,通常需要考虑输入 ...

  2. 【配准】空间变换网络Spatial Transformer Networks原理分析

    说明:此文只分析STN层的原理. STN由三个部分组成: 1.定位网络 生成空间变换参数,如二维仿射变换参数:a11,a12,a21,a22,t1,t2. 可以是任意的回归网络,如卷积网络,全连接网络 ...

  3. 空间转换网络(Spatial Transformer Networks)

    空间转换网络(Spatial Transformer Networks) 普通的CNN能够显示的学习平移不变性,以及隐式的学习旋转不变性,但attention model 告诉我们,与其让网络隐式的学 ...

  4. PyTorch 进阶学习(二)————STN:空间变换网络(Spatial Transformer Network)

    文章目录 数据集的加载 空间变换网络的介绍 定义网络 训练和测试模型 可视化 STN 结果 官方文档地址: https://pytorch.org/tutorials/intermediate/spa ...

  5. 空间映射网络--Spatial Transformer Networks

    Spatial Transformer Networks 主要对目标在特征空间做不变性归一化 解决 角度.尺度等变形引入的影响 Code: https://github.com/skaae/trans ...

  6. Paper:《Spatial Transformer Networks空间变换网络》的翻译与解读

    Paper:<Spatial Transformer Networks空间变换网络>的翻译与解读 导读:该论文提出了空间变换网络的概念.主要贡献是提出了空间变换单元(Spatial Tra ...

  7. 空间变换网络(Spatial Transform Networks,STN)

    空间变换网络(Spatial Transform Networks,STN)   该网络不需要关键点的标定,能够根据分类或者其它任务自适应地将数据进行空间变换和对齐(包括平移.缩放.旋转以及其它几何变 ...

  8. 理解Spatial Transformer Network

    其它机器学习.深度学习算法的全面系统讲解可以阅读<机器学习-原理.算法与应用>,清华大学出版社,雷明著,由SIGAI公众号作者倾力打造. 书的购买链接 书的勘误,优化,源代码资源 概述 随 ...

  9. pytorch空间变换网络

    pytorch空间变换网络 本文将学习如何使用称为空间变换器网络的视觉注意机制来扩充网络.可以在DeepMind paper 有关空间变换器网络的内容. 空间变换器网络是对任何空间变换的差异化关注的概 ...

  10. 空间变换网络简单介绍

    点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Manjunath Bhat 编译:ronghuaiyang 导读 对空间 ...

最新文章

  1. MATLAB_8-边缘检测_demo.m[课堂齿轮作业]其他的在文章末尾
  2. XML DOM – 导航节点概述
  3. css 实现页面加载中等待效果
  4. 记录:SqlParamater要点小结
  5. ELK技术栈—Kibana
  6. Excel催化剂回顾2019年产出(文章合集),展望2020年
  7. jzoj3319-[BOI2013]雪地踪迹【bfs】
  8. leetcode1029. 两地调度(贪心算法)
  9. 外部中断实验 编写程序学习外部中断的电平触发方式。无中断时发光让发光二极管从左到右依次点亮,有外部中断请求时,4位数码管从0000开始加1显示(加到9999后复位为0000),同时蜂鸣器报警。
  10. 原型以及原型链的学习随笔
  11. 一个正经的前端学习 开源 仓库(每日更新)-698道知识点
  12. 数学分析教程(科大)——2.8笔记+习题
  13. 舞蹈课(dancingLessons)
  14. 高通modem命名及对应芯片
  15. Spring Boot配置MongoDB多数据源
  16. Flowchart流程图和 Mermaid流程图的对比
  17. 某商场TD-LTE室内覆盖规划
  18. JavaScript,css时间计时器
  19. 随着信息产业的飞速发展,项目管理对于应用开发为主的软件企业是一个行之有效的管理办法,在软件开发中项目...
  20. matlab信号的能量谱代码,信号的频谱 频谱密度 功率谱密度 能量谱密度

热门文章

  1. excel迷你图 vba_如何在Excel 2010中使用迷你图
  2. FEEDSKY获得风险投资
  3. AM信号的调制与解调
  4. 通过console线登录交换机
  5. UE4搭建场景与特效文档—地形、水体、植被、雨雾效果
  6. 优秀的 Verilog/FPGA开源项目介绍(二十九)- 开源网站
  7. word-wrap和word-break
  8. 软考中级嵌入式系统设计师备考攻略
  9. Ultra Fast Deep Lane Detection with Hybrid Anchor Driven Ordinal Classification论文解读
  10. 六、python实现日语单词索引:查询单词对应的课