stn在mnist上的实现

个人博客 - https://cxy-sky.github.io/

代码参考来源:PyTorch框架实战系列(3)——空间变换器网络STN_Daniel Yuz的博客-CSDN博客

理论:Pytorch中的仿射变换(affine_grid)_liangbaqiang的博客-CSDN博客

详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_黄小猿的博客-CSDN博客_stn

​ 图片显示用的是matplotlib,自己没下opencv.

CNN

import torch
from torch import nn, optimclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.cnn = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=4),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3),)self.linear = nn.Sequential(nn.Dropout2d(0.5),nn.Linear(512, 10))def forward(self, x):x = self.cnn(x)x = x.view(x.size()[0], -1)# print(x.size())x = self.linear(x)return xif __name__ == '__main__':model = CNN()x = torch.rand(1, 1, 28, 28)print(model)y = model(x)print(y)

STN

import torch
from torch import nnclass STN(nn.Module):def __init__(self):super(STN, self).__init__()self.location_cov = nn.Sequential(nn.Conv2d(1, 8, kernel_size=7),nn.ReLU(),nn.MaxPool2d(2, stride=2),nn.Conv2d(8, 10, kernel_size=5),nn.ReLU(),nn.MaxPool2d(2, stride=2),)self.localization_linear = nn.Sequential(nn.Linear(in_features=10 * 3 * 3, out_features=32),nn.ReLU(),nn.Linear(in_features=32, out_features=2 * 3))self.localization_linear[2].weight.data.zero_()self.localization_linear[2].bias.data.copy_(torch.tensor([1, 0, 0,0, 1, 0], dtype=torch.float))self.cnn = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=4),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3),)self.linear = nn.Sequential(nn.Dropout2d(0.5),nn.Linear(512, 10))def stn(self, x):x2 = self.location_cov(x)x2 = x2.view(x2.size()[0], -1)x2 = self.localization_linear(x2)theta = x2.view(x2.size()[0], 2, 3)grid = nn.functional.affine_grid(theta, x.size(), align_corners=True)x = nn.functional.grid_sample(x, grid, align_corners=True)return xdef forward(self, x):x = self.stn(x)x = self.cnn(x)x = x.view(x.size()[0], -1)x = self.linear(x)return xif __name__ == '__main__':x = torch.rand(1, 1, 28, 28)model = STN()print(model)print(model(x))

train

import numpy as np
import torch
from torchvision import transforms
import torch.utils.data
import matplotlib.pyplot as plt
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from PIL import Image
from torch import nn, optimfrom stn.CNN import CNN
from stn.STN import STNdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 数据处理
transform = transforms.Compose([transforms.RandomRotation(45),transforms.ToTensor(),transforms.Normalize((0.5), (0.5))
]
)train_data = torchvision.datasets.MNIST('../data/mnist',download=True,train=True,transform=transform)test_data = torchvision.datasets.MNIST('../data/mnist',download=True,train=False,transform=transform, )train_loader = torch.utils.data.DataLoader(train_data,batch_size=64,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,batch_size=64,shuffle=True)data_iter = iter(train_loader)
imgs = torchvision.utils.make_grid(next(data_iter)[0], 8)
imgs = imgs.numpy().transpose(1, 2, 0)
imgs = imgs * 0.5 + 0.5
plt.imshow(imgs)
plt.show()# model = CNN()
model = STN()
model = model.to(device)
loss_fun = nn.CrossEntropyLoss().to(device)
opt_fun = optim.Adam(params=model.parameters(), lr=0.001)loss = 0
train_acc_count = []
test_acc_count = []
train_loss = []
test_loss = []def train(epoch):for i in range(epoch):for index, data in enumerate(train_loader):imgs = data[0].to(device)labels = data[1].to(device)outputs = model(imgs).to(device)loss = loss_fun(outputs, labels)loss.backward()opt_fun.step()opt_fun.zero_grad()if index % 100 == 0:print("第{}轮,第{}次,loss为:{}".format(i + 1, index, loss.item()))train_loss.append(loss.item())def test():test_count = 0.for imgs, labels in test_loader:with torch.no_grad():outputs = model(imgs.to(device)).to(device)test_acc_count = (torch.max(outputs, dim=1)[1] == labels.to(device)).sum().item()test_count = labels.size()[0]print("测试集准确率{}".format(test_acc_count / test_count))if __name__ == '__main__':# 设置随机数种子np.random.seed(1)torch.manual_seed(1)torch.cuda.manual_seed_all(1)# 保证每次结果一样torch.backends.cudnn.deterministic = Truetrain(10)test()sava_path = '../model/mnistStn.pth'torch.save(model.state_dict(), sava_path)plt.plot(train_loss)plt.show()

showImage

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import torch
import matplotlib.pyplot as pltfrom stn.STN import STNtransform = transforms.Compose([transforms.RandomRotation(45),transforms.ToTensor(),transforms.Normalize((0.5), (0.5))
]
)train_data = torchvision.datasets.MNIST('../data/mnist',download=True,train=True,transform=transform)train_loader = torch.utils.data.DataLoader(train_data,batch_size=64,shuffle=True)data_iter = iter(train_loader)
imgs, labels = next(data_iter)
pre = torchvision.utils.make_grid(imgs, 8)
pre = pre.numpy().transpose(1, 2, 0)
pre = pre * 0.5 + 0.5
plt.subplot(2, 1, 1)
plt.imshow(pre)
plt.title('pre')model = STN()
model.load_state_dict(torch.load('../model/mnistStn.pth'))
now = model.stn(imgs).detach()
now = torchvision.utils.make_grid(now, 8)
now = now.numpy().transpose(1, 2, 0)
now = now * 0.5 + 0.5
plt.subplot(2, 1, 2)
plt.imshow(now)
plt.title('now')plt.show()

train,epoch=10

​ 展示transom后的图片,还是感觉很神奇

stn在mnist上的实现相关推荐

  1. pytorch学习笔记(2):在MNIST上实现一个CNN

    参考文档:https://mp.weixin.qq.com/s/1TtPWYqVkj2Gaa-3QrEG1A 这篇文章是在一个大家经常见到的数据集 MNIST 上实现一个简单的 CNN.我们会基于上一 ...

  2. Keras在mnist上的CNN实践,并且自定义loss函数曲线图

    使用keras实现CNN,直接上代码: from keras.datasets import mnist from keras.models import Sequential from keras. ...

  3. 采用SVM实现实现MNIST手写体分类,数据下载链接在http://yann.lecun.com/exdb/mnist/上。上传源码和实现结果,语言不限。

    基于OpenCV的MNIST手写体分类 简介 实验要求 实验环境 OpenCV的配置 总体概览 在python中绘制 开始上手OpenCV 查看完整内容 简介 MNIST 数据集来自美国国家标准与技术 ...

  4. pytorch空间变换网络

    pytorch空间变换网络 本文将学习如何使用称为空间变换器网络的视觉注意机制来扩充网络.可以在DeepMind paper 有关空间变换器网络的内容. 空间变换器网络是对任何空间变换的差异化关注的概 ...

  5. PyTorch 系列教程之空间变换器网络

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 在本教程中,您将学习如何使用称为空间变换器网络的视觉注意机制来扩充 ...

  6. PyTorch官方教程中文版:Pytorch之图像篇

    微调基于 torchvision 0.3的目标检测模型 """ 为数据集编写类 """ import os import numpy as ...

  7. 比用Pytorch框架快200倍!0.76秒后,笔记本上的CNN就搞定了MNIST | 开源

    博雯 发自 凹非寺 量子位 报道 | 公众号 QbitAI 在MNIST上进行训练,可以说是计算机视觉里的"Hello World"任务了. 而如果使用PyTorch的标准代码训练 ...

  8. 笔记本上的CNN搞定了MNIST

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:量子位 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 在M ...

  9. Deformable ConvNets--Part2: Spatial Transfomer Networks(STN)

    转自:https://blog.csdn.net/u011974639/article/details/79681455 Deformable ConvNet简介 关于Deformable Convo ...

  10. 长篇自动驾驶技术综述论文(上)

    长篇自动驾驶技术综述论文(上) A Survey of Autonomous Driving: Common Practices and Emerging Technologies Ekim Yurt ...

最新文章

  1. 支付宝移动支付文档url
  2. C++继承与派生(原理归纳)
  3. c语言动态申请函数,C语言用malloc函数申请二维动态数组
  4. 安全警报:Oracle 2018一月号安全补丁修复由来已久安全漏洞
  5. ElementUI 添加修改提示成功后 如何刷新表格数据展示
  6. C++设计模式之适配器模式
  7. python 打印表格边框_python设置表格边框的具体方法
  8. LeetCode 231. 2的幂
  9. 华为不做黑寡妇,开源编译器,与友商共建安卓性能
  10. android中设置lmargin简书,超详细React Native实现微信好友/朋友圈分享功能-Android/iOS双平台通用...
  11. java list 取几个字段组装成map_java.util.concurrent 并发包诸类概览
  12. web视图引擎框架对比
  13. 常用C/C++预处理指令详解
  14. Java基础篇:为Box类添加一个方法
  15. 初识Redis educoder
  16. 联想服务器系统如何备份软件,联想笔记本如何使用系统自带备份/还原功能进行备份与还原系统...
  17. MacBook雷电3接口失灵不可用
  18. iOS 如何获取手机型号、系统版本、电池电量
  19. java 登录失效时间_详谈Java设置session超时(失效)的时间
  20. 美国贝勒大学计算机科学专业怎么样,贝勒大学专业排名一览(含历年专业排名信息,USNEWS美国大学排名版)...

热门文章

  1. quilt 工具增加 patch 方法
  2. [Steam]成就游戏销量乐观
  3. js获取月的第一天、最后一天
  4. PhD Debate-11 预告 | 回顾与展望神经网络的后门攻击与防御
  5. PS:将一个图片变成圆形
  6. Jenkins插件安装和系统配置
  7. 震惊!让90%的程序员一看就会的入门级AI项目!
  8. 股票买卖问题-含手续费
  9. 互联网公益陷入信任危机,智慧公益能否力挽狂澜?
  10. 微服务Http健康检查