Spatial Transformer Networks(STN)-代码实现

  • pytorch为了方便实现STN,里面封装了affine_gridgrid_sample两个高级API。

  • STN的基本步骤是:

  • Localisationnet\color{blue}Localisation\;netLocalisationnet(参数预测):
    Localisation net模块通过CNNCNNCNN提取图像的特征来预测变换矩阵θ\thetaθ

  • Gridgenerator\color{green}Grid\;generatorGridgenerator(坐标映射):
    Grid generator模块就是利用Localisation net模块回归出来的θ\thetaθ参数来对图片中的位置进行变换,输入图片到输出图片之间的变换,需要特别注意的是这里指的是图片像素所对应的位置。

  • Sampler\color{gray}SamplerSampler(像素的采集):
    Sampler就是用来解决Grid generator模块变换出现小数位置的问题的。针对这种情况,STN采用的是双线性插值(Bilinear Interpolation),下面我们来介绍一下这个算法

1. STN层的实现

from torchvision import transforms
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt#读取图片
img = Image.open("img/test.jpg")
#将图片转换为torch tensor
img_tensor = transforms.ToTensor()(img)#定义平移变换矩阵
#0.1表示将图片向左平移图片宽的百分比
#0.2表示将图片向上平移图片高的百分比
theta = torch.tensor([[1,0,0.1],[0,1,0.2]],dtype=torch.float)
#根据变换矩阵来计算变换后图片的对应位置
grid = F.affine_grid(theta.unsqueeze(0),img_tensor.unsqueeze(0).size(),align_corners=True)
#默认使用双向性插值,可以通过mode参数设置
output = F.grid_sample(img_tensor.unsqueeze(0),grid,align_corners=True)plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.array(img))
plt.title("original image")plt.subplot(1,2,2)
plt.imshow(output[0].numpy().transpose(1,2,0))
plt.title("stn transform image")plt.show()

2. STN+CNN

当输入图片通过STN模块之后获得变换后的图片,然后我们再将变换后的图片输入到CNNCNNCNN网络中,通过损失函数计算losslossloss,然后计算梯度更新θ\thetaθ参数,最终STN模块会学习到如何矫正图片。

2.1 参数设置

config.py

import argparsedef parse_args():parse = argparse.ArgumentParser("config stn args")parse.add_argument("--lr",default=0.01,type=float,help="learning rate")parse.add_argument("--epoch_nums",default=20,type=int,help="iterated epochs")parse.add_argument("--use_stn",default=True,type=bool,help="whether to use STN module")parse.add_argument("--batch_size",default=64,type=int,help="batch size")parse.add_argument("--use_eval",default=True,type=bool,help="whether to evaluate")parse.add_argument("--use_visual",default=True,type=bool,help="visual STN transform image")parse.add_argument("--use_gpu",default=True,type=bool,help="whether to use GPU")parse.add_argument("--show_net_construct",default=False,type=bool,help="print net construct info")return parse.parse_args()

2.2 加载数据

DataLoader.py

import torch
from torchvision import datasets,transforms
import numpy as npdef get_dataloader(batch_size):# 加载数据集# 如果GPU可用就用GPU,否则用CPUdevice = torch.device("cuda" if torch.cuda.is_available()else "cpu")# 加载训练集train_dataloader = torch.utils.data.DataLoader(datasets.MNIST(root="D:\PyCharm\PyCharm_Project\STN", train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=batch_size, shuffle=True)# 加载测试集test_dataloader = torch.utils.data.DataLoader(datasets.MNIST(root="D:\PyCharm\PyCharm_Project\STN", train=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=batch_size, shuffle=True)return train_dataloader,test_dataloaderdef tensor_to_array(img_tensor):img_array = img_tensor.numpy().transpose((1,2,0))mean = np.array([0.485,0.456,0.406])std = np.array([0.229,0.224,0.225])img_array = std * img_array + meanimg = np.clip(img_array,0,1)return img

2.3 定义网络

Net.py

import torch
import torch.nn as nn
import torch.nn.functional as Fclass STN_Net(nn.Module):def __init__(self,use_stn=True):super(STN_Net, self).__init__()self.conv1 = nn.Conv2d(1,10,kernel_size=5)self.conv2 = nn.Conv2d(10,20,kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320,50)self.fc2 = nn.Linear(50,10)#用来判断是否使用STNself._use_stn = use_stn#localisation net#从输入图像中提取特征#输入图片的shape为(-1,1,28,28)self.localization = nn.Sequential(#卷积输出shape为(-1,8,22,22)nn.Conv2d(1,8,kernel_size=7),#最大池化输出shape为(-1,1,11,11)nn.MaxPool2d(2,stride=2),nn.ReLU(True),#卷积输出shape为(-1,10,7,7)nn.Conv2d(8,10,kernel_size=5),#最大池化层输出shape为(-1,10,3,3)nn.MaxPool2d(2,stride=2),nn.ReLU(True))#利用全连接层回归\theta参数self.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3,32),nn.ReLU(True),nn.Linear(32,2*3))self.fc_loc[2].weight.data.zero_()self.fc_loc[2].bias.data.copy_(torch.tensor([1,0,0,0,1,0],dtype=torch.float))def stn(self,x):#提取输入图像中的特征xs = self.localization(x)xs = xs.view(-1,10*3*3)#回归theta参数theta = self.fc_loc(xs)theta = theta.view(-1,2,3)#利用theta参数计算变换后图片的位置grid = F.affine_grid(theta,x.size())#根据输入图片计算变换后图片位置填充的像素值x = F.grid_sample(x,grid)return xdef forward(self,x):#使用STN模块if self._use_stn:x = self.stn(x)#利用STN矫正过的图片来进行图片的分类#经过conv1卷积输出的shape为(-1,10,24,24)#经过max pool的输出shape为(-1,10,12,12)x = F.relu(F.max_pool2d(self.conv1(x),2))#经过conv2卷积输出的shape为(-1,20,8,8)#经过max pool的输出shape为(-1,20,4,4)x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))x = x.view(-1,320)x = F.relu(self.fc1(x))x = F.dropout(x,training=self.training)x = self.fc2(x)return F.log_softmax(x,dim=1)

2.4 训练模型

train.py

import torch,torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimdef train(net,epoch_nums,lr,train_dataloader,per_batch,device):#使用训练模式net.train()#选择梯度下降优化算法optimizer = optim.SGD(net.parameters(),lr=lr)#训练模型for epoch in range(epoch_nums):for batch_idx,(data,label) in enumerate(train_dataloader):data,label = data.to(device),label.to(device)optimizer.zero_grad()pred = net(data)loss = F.nll_loss(pred,label)loss.backward()optimizer.step()if batch_idx % per_batch == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_dataloader.dataset),100. * batch_idx / len(train_dataloader), loss.item()))

2.5 评估模型

evaluate.py

import torch
import torch.nn.functional as Fdef evaluate(net,test_dataloader,device):with torch.no_grad():#使用评估模式net.eval()eval_loss = 0eval_acc = 0for data,label in test_dataloader:data,label = data.to(device),label.to(device)pred = net(data)eval_loss += F.nll_loss(pred,label,size_average=False).item()pred_label = pred.max(1,keepdim=True)[1]eval_acc += pred_label.eq(label.view_as(pred_label)).sum().item()eval_loss /= len(test_dataloader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(eval_loss, eval_acc, len(test_dataloader.dataset),100. * eval_acc / len(test_dataloader.dataset)))

2.6 可视化

Visualize.py

import torch,torchvision
import matplotlib.pyplot as plt
from DataLoader import tensor_to_arraydef visualize_stn(net,dataloader,device):with torch.no_grad():data = next(iter(dataloader))[0].to(device)input_tensor = data.cpu()t_input_tensor = net.stn(data).cpu()in_grid = tensor_to_array(torchvision.utils.make_grid(input_tensor))out_grid = tensor_to_array(torchvision.utils.make_grid(t_input_tensor))f,axarr = plt.subplots(1,2)axarr[0].imshow(in_grid)axarr[0].set_title("input images")axarr[1].imshow(out_grid)axarr[1].set_title("stn transformed images")plt.show()

2.7 主函数

MAIN.py

import torch
from Net import STN_Net
from Visualize import visualize_stn
from train import train
from config import parse_args
from DataLoader import get_dataloader
from evaluate import evaluateif __name__ == "__main__":args = parse_args()if args.use_gpu and torch.cuda.is_available():device = "cuda"else:device = "cpu"#加载数据集train_loader,test_loader = get_dataloader(args.batch_size)#创建网络net = STN_Net(args.use_stn).to(device)#训练模型train(net,args.epoch_nums,args.lr,train_loader,args.batch_size,device)if args.use_eval:#评估模型evaluate(net,test_loader,device)if args.use_visual:#可视化展示效果visualize_stn(net,test_loader,device)
Test set: Average loss: 0.0423, Accuracy: 9868/10000 (99%)

参考

  1. 通俗易懂的Spatial Transformer Networks(STN)(一)
  2. 通俗易懂的Spatial Transformer Networks(STN)(二)
  3. SPATIAL TRANSFORMER NETWORKS TUTORIAL

Spatial Transformer Networks(STN)-代码实现相关推荐

  1. Deformable ConvNets--Part2: Spatial Transfomer Networks(STN)

    转自:https://blog.csdn.net/u011974639/article/details/79681455 Deformable ConvNet简介 关于Deformable Convo ...

  2. 详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了

    Spatial Transformer Networks https://blog.jiangzhenyu.xyz/2018/10/06/Spatial-Transformer-Networks/ 2 ...

  3. Spatial Transformer Networks(STN)详解

    目录 1.STN的作用 1.1 灵感来源 1.2 什么是STN? 2.STN网络架构![在这里插入图片描述](https://img-blog.csdnimg.cn/20190908104416274 ...

  4. 注意力机制——Spatial Transformer Networks(STN)

    Spatial Transformer Networks(STN)是一种空间注意力模型,可以通过学习对输入数据进行空间变换,从而增强网络的对图像变形.旋转等几何变换的鲁棒性.STN 可以在端到端的训练 ...

  5. Spatial Transformer Networks(STN)

    详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_多元思考力-CSDN博客_stn

  6. 【论文学习】STN —— Spatial Transformer Networks

    Paper:Spatial Transformer Networks 这是Google旗下 DeepMind 大作,最近学习人脸识别,这篇paper提出的STN网络可以代替align的操作,端到端的训 ...

  7. 论文阅读:Spatial Transformer Networks

    文章目录 1 概述 2 模型说明 2.1 Localisation Network 2.2 Parameterised Sampling Grid 3 模型效果 参考资料 1 概述 CNN的机理使得C ...

  8. 空间转换网络(Spatial Transformer Networks)

    空间转换网络(Spatial Transformer Networks) 普通的CNN能够显示的学习平移不变性,以及隐式的学习旋转不变性,但attention model 告诉我们,与其让网络隐式的学 ...

  9. Spatial Transformer Networks 论文解读

    paper title:Spatial Transformer Networks paper link: https://arxiv.org/pdf/1506.02025.pdf oral or de ...

  10. Paper:《Spatial Transformer Networks空间变换网络》的翻译与解读

    Paper:<Spatial Transformer Networks空间变换网络>的翻译与解读 导读:该论文提出了空间变换网络的概念.主要贡献是提出了空间变换单元(Spatial Tra ...

最新文章

  1. 从城市大脑到世界数字大脑 构建人类协同发展的超级智能平台
  2. Nginx负载均衡、ssl原理,生成ssl密钥对,配置Nginxssl
  3. Android根据URL下载文件保存到SD卡
  4. python3.6安装ipython_centos6.5下安装python3.6、pip、ipython
  5. Date和Calendar
  6. Deep Image Prior:深度卷积网络先天就理解自然图像
  7. 排土场、弃渣场和尾矿库的区别
  8. SharePlex数据库复制解决方案
  9. Valine-1.4.4新版本尝鲜+个性制定(表情包、qq头像、UI样式)
  10. python 抽签程序_抽签程序源码
  11. 阿里云首席安全科学家吴翰清的思考:弹性安全网络,构建下一代安全的互联网...
  12. java物联网第三天 智慧农业物联网
  13. ictclas4j java_ictclas4j 分词工具包 安装流程
  14. 20135323符运锦---信息安全系统设计基础第一周学习总结
  15. web前端面试中常问的问题
  16. 华为GaussDB数据库10个知识点,第7个你知道么?
  17. 音视频播放疑难杂症分析和解决思路
  18. HHDBCS及HHDESK的资源加密功能
  19. MERGE操作学习总结
  20. JVM(java堆)

热门文章

  1. hosts文件在哪里,怎样配置修改hosts文件
  2. 《SPSS从入门到精通》笔记(1-4章)
  3. Home Assistant -ESP Home接入homekit篇
  4. Java——TCP/IP超详细总结
  5. 多元相关与回归分析及R使用
  6. nominal和ordinal 数据处理中四种基本数据类型
  7. 用html做祝福语朋友,朋友关怀祝福语(把祝福语化作露带给你)
  8. Babylongjs-纹理
  9. javaWeb文档预览之office转pdf(附详细代码)
  10. 苹果与深圳唯冠的未了之战