导读

pytorch为了方便实现STN,里面封装了affine_gridgrid_sample两个高级API。对STN不太了解的同学可以参考这篇详细解读Spatial Transformer Networks(STN)

其实STN的作用是想让CNN具备平移、旋转、缩放、剪切不变性,虽然说CNN中的Pooling可以让网络具备一点平移不变性,但这毕竟是隐性的,如果能让网络直接具备这样的能力岂不是更好。

如果对图像处理有了解的同学也许听过仿射变换这个名词,我们只需要通过变换矩阵θ\thetaθ(由6个参数组成)就能实现上面的这些功能,如果对仿射变换不了解的同学可以参考我的这篇一文搞懂仿射变换

STN也是因为受到这个启发而诞生的,那么我们如何将这种能力嵌入到CNN中呢?这便是STN需要解决的问题

STN简介

上面引用的文章中已经详细介绍了STN网络,我这里总结概括一下

  • Localisation net

Localisation net模块通过CNN提取图像的特征来预测变换矩阵θ\thetaθ

  • Grid generator

Grid generator模块就是利用Localisation net模块回归出来的θ\thetaθ参数来对图片中的位置进行变换,输入图片到输出图片之间的变换,需要特别注意的是这里指的是图片像素所对应的位置

例如:如果此时θ\thetaθ参数功能是实现图片的平移变换(向右平移1,),输入图片上的坐标(1,1),那对应输出图片上的坐标的(2,1),也就是说输入图片上(1,1)对应的像素值等于输出图片上(2,1)对应的像素值。在变换的时候必然会遇到当输入图片的位置变换到输出图片上是如果位置出现小数怎么办?

  • Sampler

Sampler就是用来解决Grid generator模块变换出现小数位置的问题的。针对这种情况,STN采用的是双线性插值(Bilinear Interpolation),下面我们来介绍一下这个算法

上图中(x,y)(x,y)(x,y)是变换后输出图像上的位置,带下标的坐标位置表示的是与(x,y)(x,y)(x,y)在输入图像对应的四个相邻的坐标。上面的坐标满足下面的关系
x1−x0=1y1−y0=1x_1-x_0 = 1\\ y1-y_0 = 1 x1​−x0​=1y1−y0​=1
根据双线性插值的原则距离相邻点近的坐标占的比重越大,所以(x,y)(x,y)(x,y)对应的像素值为,我们用f(x,y)f(x,y)f(x,y)表示点(x,y)(x,y)(x,y)所对应的像素值
f(x,y)=(x1−x)(y1−y)f(x0,y0)+(x−x0)(y1−y)f(x1,y0)=+(x−x0)(y−y0)f(x1,y1)+(x1−x)(y−y0)f(x0,y1)\begin{aligned} f(x,y) &= (x_1-x)(y1-y)f(x_0,y_0)+(x-x_0)(y_1-y)f(x_1,y_0)\\ &=+(x-x_0)(y-y_0)f(x_1,y_1)+(x_1-x)(y-y_0)f(x_0,y_1) \end{aligned} f(x,y)​=(x1​−x)(y1−y)f(x0​,y0​)+(x−x0​)(y1​−y)f(x1​,y0​)=+(x−x0​)(y−y0​)f(x1​,y1​)+(x1​−x)(y−y0​)f(x0​,y1​)​

STN层的实现

  • pytorch的实现

通过pytorchaffine_gridgrid_sample可以很容易实现STN的后两个模块

from torchvision import transforms
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt#读取图片
img = Image.open("img/test.jpg")
#将图片转换为torch tensor
img_tensor = transforms.ToTensor()(img)#定义平移变换矩阵
#0.1表示将图片向左平移图片宽的百分比
#0.2表示将图片向上平移图片高的百分比
theta = torch.tensor([[1,0,0.1],[0,1,0.2]],dtype=torch.float)
#根据变换矩阵来计算变换后图片的对应位置
grid = F.affine_grid(theta.unsqueeze(0),img_tensor.unsqueeze(0).size(),align_corners=True)
#默认使用双向性插值,可以通过mode参数设置
output = F.grid_sample(img_tensor.unsqueeze(0),grid,align_corners=True)plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.array(img))
plt.title("original image")plt.subplot(1,2,2)
plt.imshow(output[0].numpy().transpose(1,2,0))
plt.title("stn transform image")plt.show()

  • numpy的实现

我们通过numpy来实现STN的后两个模块,来帮助大家更好的理解STN

class Grid_sample(object):def affine_grid(self,theta,img_size):if len(img_size) != 2:assert("img_size size must is 2")num_batch = np.shape(theta)[0]img_w,img_h = img_size#将图片位置归一化到(-1,1)x = np.linspace(-1.0,1.0,img_w)y = np.linspace(-1.0,1.0,img_h)#组合x和y获取到图片的位置坐标x_t,y_t = np.meshgrid(x,y)x_t_flat = np.reshape(x_t,[-1])y_t_flat = np.reshape(y_t,[-1])#创建一个图片的位置数组ones = np.ones_like(x_t_flat)sampling_grid = np.stack([x_t_flat,y_t_flat,ones])sampling_grid = np.expand_dims(sampling_grid,axis=0)sampling_grid = np.tile(sampling_grid,np.stack([num_batch,1,1]))#计算变换后的图片位置batch_grids = np.matmul(theta,sampling_grid)batch_grids = np.reshape(batch_grids,[num_batch,2,img_h,img_w])return batch_gridsdef bilinear_sampler(self,img,batch_grids):if (batch_grids.shape) != 4:assert("batch_grids shape is must equal 4")#获取变换后图片位置的x和y轴的坐标位置x = batch_grids[:, 0, :, :]y = batch_grids[:, 1, :, :]img_w,img_h = img.shape[:2]max_x = img_w - 1max_y = img_h - 1#将变换后的坐标位置固定到(0,w/h-1)x = 0.5 * ((x+1.0)*(max_x-1))y = 0.5 * ((y+1.0)*(max_y-1))#将坐标位置取整,便于从输入图片中获取位置对应的像素值x0 = np.floor(x).astype(np.int)x1 = x0 + 1y0 = np.floor(y).astype(np.int)y1 = y0 + 1#防止坐标越界x0 = np.clip(x0,0,max_x)x1 = np.clip(x1,0,max_x)y0 = np.clip(y0,0,max_y)y1 = np.clip(y1,0,max_y)#根据坐标位置,取像素值Ia = img[y0,x0,:]Ib = img[y1,x0,:]Ic = img[y0,x1,:]Id = img[y1,x1,:]wa = np.expand_dims((x1-x)*(y1-y),axis=3)wb = np.expand_dims((x1-x)*(y-y0),axis=3)wc = np.expand_dims((x-x0)*(y1-y),axis=3)wd = np.expand_dims((x-x0)*(y-y0),axis=3)#利用双线性插值计算变换后的像素值out = wa*Ia + wb*Ib + wc*Ic + wd*Idreturn outgrid_sampler = Grid_sample()
img = np.array(Image.open("img/test.jpg"))
img_h,img_w = img.shape[:2]
theta = np.array([[[1, 0, 0.1], [0, 1, 0.2]]],dtype=np.float)
theta = np.expand_dims(theta,axis=0)batch_grids = grid_sampler.affine_grid(theta,(img_w,img_h))
out = grid_sampler.bilinear_sampler(img,batch_grids)plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(np.array(img))
plt.title("original image")plt.subplot(1, 2, 2)
plt.imshow(out[0].astype(np.uint8))
plt.title("stn transform image")plt.show()


下一篇文章我们介绍如何将STN模块插入到CNN中

通俗易懂的Spatial Transformer Networks(STN)(一)相关推荐

  1. Deformable ConvNets--Part2: Spatial Transfomer Networks(STN)

    转自:https://blog.csdn.net/u011974639/article/details/79681455 Deformable ConvNet简介 关于Deformable Convo ...

  2. 详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了

    Spatial Transformer Networks https://blog.jiangzhenyu.xyz/2018/10/06/Spatial-Transformer-Networks/ 2 ...

  3. Spatial Transformer Networks(STN)详解

    目录 1.STN的作用 1.1 灵感来源 1.2 什么是STN? 2.STN网络架构![在这里插入图片描述](https://img-blog.csdnimg.cn/20190908104416274 ...

  4. Spatial Transformer Networks(STN)

    详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_多元思考力-CSDN博客_stn

  5. 【论文学习】STN —— Spatial Transformer Networks

    Paper:Spatial Transformer Networks 这是Google旗下 DeepMind 大作,最近学习人脸识别,这篇paper提出的STN网络可以代替align的操作,端到端的训 ...

  6. 注意力机制——Spatial Transformer Networks(STN)

    Spatial Transformer Networks(STN)是一种空间注意力模型,可以通过学习对输入数据进行空间变换,从而增强网络的对图像变形.旋转等几何变换的鲁棒性.STN 可以在端到端的训练 ...

  7. 论文阅读:Spatial Transformer Networks

    文章目录 1 概述 2 模型说明 2.1 Localisation Network 2.2 Parameterised Sampling Grid 3 模型效果 参考资料 1 概述 CNN的机理使得C ...

  8. 空间转换网络(Spatial Transformer Networks)

    空间转换网络(Spatial Transformer Networks) 普通的CNN能够显示的学习平移不变性,以及隐式的学习旋转不变性,但attention model 告诉我们,与其让网络隐式的学 ...

  9. Spatial Transformer Networks 论文解读

    paper title:Spatial Transformer Networks paper link: https://arxiv.org/pdf/1506.02025.pdf oral or de ...

  10. Paper:《Spatial Transformer Networks空间变换网络》的翻译与解读

    Paper:<Spatial Transformer Networks空间变换网络>的翻译与解读 导读:该论文提出了空间变换网络的概念.主要贡献是提出了空间变换单元(Spatial Tra ...

最新文章

  1. matlab外部接口简介
  2. 用 GStreamer 简化 Linux 多媒体开发
  3. TB级微服务海量日志监控平台
  4. 实现一个网易云音乐的 BottomSheetDialog
  5. SQLSERVER导入导出文本文件
  6. Java:数列排序 给定一个长度为n的数列,将这个数列按从小到大的顺序排列。1<=n<=200
  7. leetcode 427. Construct Quad Tree | 427. 建立四叉树(分治法)
  8. python 异步 生产者 消费者_python 线程通信 生产者与消费者
  9. 手机按键中控运行思路的个人理解
  10. c 通过jni调用java_使用c通过jni调用java
  11. 操作系统之计算机系统概述:7、操作系统的体系结构
  12. wait和waitpid
  13. 企业级多用户发卡平台源码
  14. 【笔记】马克思主义哲学(二)-- 唯物论
  15. 装饰器模式Decorate
  16. Win10开机取消微软登录密码
  17. 海思hi3516dv300音频调节总结
  18. 虚拟主机可以运行java_下面哪种类型的文件可以在Java虚拟机中运行( ).
  19. 2023智能家电、智能家居解决方案与技术论坛——CAEE
  20. Android Studio 1.3RC版 build加速

热门文章

  1. 美国10大计算机软件,美国计算机软件工程专业研究生排名
  2. 什么是计算机的超级用户账号,administrator是什么意思
  3. 三十、动名词短语 2
  4. 我的世界java版幻翼_见到幻翼的方式是熬夜?这几个被忽略了
  5. 常见随机变量的数学期望和方差
  6. Illumination Normalization Based on Weber’s Law With Application to Face Recognition
  7. HTML5+CSS大作业——学生个人博客(5页) 大学生个人博客网页作品 网页设计作业模板 学生网页制作源代码下载
  8. java 图片与base64相互转化
  9. 《当时只道是寻常》——安意如——品纳兰容若《饮水词》
  10. phpmyadmin mysql配置_phpmyadmin配置方式