前面的博文我们讲了LSTM的原理与分析,这一篇我们用pytorch类LSTM做测试

完整测试代码如下,用于进行MNIST数据集测试,主要学习LSTM类的输入输出维度。

这里定义的LSTM模型是用了三层深度模型,双向的,输出层增加了线性转换。

完整代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms# step 1:===========================================定义LSTM结构
# 双向网络,三层网络,带上最后输出的先线性层
class Rnn(nn.Module):def __init__(self, input_dim, hidden_dim, n_layer, n_classes):super(Rnn, self).__init__()# 这里把 batch_size 放在第一维度# 使用双向循环LSTMself.lstm = nn.LSTM(input_dim, hidden_dim, n_layer, batch_first=True, bidirectional=True)# 这个是网络最后的线性层self.classifier = nn.Linear(hidden_dim, n_classes)# 默认输入数据格式:# input(seq_len, batch_size, input_size)# h0(num_layers * num_directions, batch_size, hidden_size)# c0(num_layers * num_directions, batch_size, hidden_size)# 默认输出数据格式:# output(seq_len, batch_size, hidden_size * num_directions)# hn(num_layers * num_directions, batch_size, hidden_size)# cn(num_layers * num_directions, batch_size, hidden_size)# batch_first=True 在此条件下,batch_size是处在第一个维度的。def forward(self, input): # input [128, 28, 28]out, (h_n, c_n) = self.lstm(input)# x = out[:, -1, :] # 此时可以从out中获得最终输出的状态hx = h_n[-1, :, :]x = self.classifier(x)return x# 实例化网络对象, 输入数据的维度是28维度,隐藏层维度是10,3层网络,10个线性分类输出
lstmNet = Rnn(28, 10, 3, 10)# step 2:===========================================加载MNIST数据,并形成批量数据
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),
])trainset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)testset = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)# step 3:===========================================定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(lstmNet.parameters(), lr=0.1, momentum=0.9)# step 3:===========================================定义损训练过程和测试过程
# Training
def train(epoch):print('\nEpoch: %d' % epoch)lstmNet.train()train_loss = 0correct = 0total = 0# inputs = [128, 1, 28, 28], targets = [128]for batch_idx, (inputs, targets) in enumerate(trainloader):optimizer.zero_grad()outputs = lstmNet(torch.squeeze(inputs, 1))loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()print(batch_idx, 'Loss: %.3f | Acc: %.3f%% (%d/%d)'% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))def test(epoch):global best_acclstmNet.eval()test_loss = 0correct = 0total = 0with torch.no_grad():# inputs = [128, 1, 28, 28], targets = [128]for batch_idx, (inputs, targets) in enumerate(testloader):outputs = lstmNet(torch.squeeze(inputs, 1))loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))for epoch in range(100):train(epoch)test(epoch)

测试结果还是比较好的,甚至高达99.5%(节选自输出打印):
89 Loss: 0.015 | Acc: 99.540% (11467/11520)
90 Loss: 0.015 | Acc: 99.536% (11594/11648)
91 Loss: 0.015 | Acc: 99.533% (11721/11776)
92 Loss: 0.015 | Acc: 99.538% (11849/11904)
93 Loss: 0.015 | Acc: 99.535% (11976/12032)

Pytorch《LSTM模型》相关推荐

  1. Pytorch《GAN模型生成MNIST数字》

    这里的代码都是,参考网上其他的博文学习的,今天是我第一次学习GAN,心情难免有些激动,想着赶快跑一个生成MNIST数字图像的来瞅瞅效果,看看GAN的神奇. 参考博文是如下三个: https://www ...

  2. GAN网络生成手写体数字图片

    Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的. 目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接 ...

  3. GAN 模型生成山水画,骗过半数观察者,普林斯顿大学本科生出品

    作者 | 高卫华 出品 | AI科技大本营 近年来,基于生成对抗网络GAN模型,图像生成领域实现了许多有趣的应用,尤其是在绘画创作方面. 英伟达曾在2019年提出一款名叫GauGAN的神经网络作图工具 ...

  4. 深度学习《GAN模型学习》

    前言:今天我们来一起学习下GAN神经网络,上一篇博文我先用pytorch运行了几个网上的代码例子,用于生成MNIST图片,现在我才反过来写GAN的学习,这次反了过来,效果也是很显而易见的,起码有个直观 ...

  5. GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

    有关条件GAN(cgan)的相关原理,可以参考: GAN系列之CGAN原理简介以及pytorch项目代码实现 其他类型的GAN原理介绍以及应用,可以查看我的GANs专栏 一.数据集介绍,加载数据 依旧 ...

  6. 搭建简单GAN生成MNIST手写体

    Keras搭建GAN生成MNIST手写体 GAN简介 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前 ...

  7. 【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧

    1 散度在无监督学习中的应用 在神经网络的损失计算中,最大化和最小化两个数据分布间散度的方法,已经成为无监督模型中有效的训练方法之一. 在无监督模型训练中,不但可以使用K散度JS散度,而且可以使用其他 ...

  8. Pytorch 使用DCGAN生成动漫人物头像 入门级实战教程

    有关DCGAN实战的小例子之前已经更新过一篇,感兴趣的朋友可以点击查看 Pytorch 使用DCGAN生成MNIST手写数字 入门级教程 有关DCGAN的相关原理:DCGAN论文解读-----DCGA ...

  9. GAN掉人脸识别系统?GAN模型「女扮男装」

    文章来源 新智元 编辑:LRS [新智元导读]人脸识别技术最近又有新的破解方式!一位斯坦福的学生使用GAN模型生成了几张自己的图片,轻松攻破两个约会软件,最离谱的是「女扮男装」都识别不出来. 真的有人 ...

  10. pytorch学习之GAN生成MNIST手写数字

    0.简单介绍: 学深度学习的人必然知道,最基本的GAN模型由一个生成器 G 和判别器 D 组成.生成器用于生成假样本,判别器用于判断样本是真实的还是假的. 在整个训练过程中,生成器努力地让生成的图像更 ...

最新文章

  1. java中可用于定义成员常量_13秋北航《Java语言与面向对象程序设计》在线作业三辅导 …...
  2. 通过Java字节码发现有趣的内幕之String篇(上)(转)
  3. 使用FragmentTabHost和ViewPager实现仿微信主界面侧滑
  4. spyder matlab,将pycharm配置为matlab或者spyder的用法说明
  5. hdu1010 Tempter of the Bone
  6. 《linux操作系统》第06章在线测试,Linux系统管理一测试题-附答案.doc
  7. php 类 和 函数,PHP函数和类
  8. 阿里13篇论文入选数据库顶会!PolarDB技术被认为引领数据库发展方向
  9. MFC通过窗口标题获得窗口句柄
  10. C语言课程设计——工资管理系统
  11. docker安装nessus
  12. Unity3D官方案例--太空射击游戏总结
  13. 图像超分之——寻找两张图差异的区域
  14. 嵌入式数据库和数据库服务器的区别
  15. Android仿微信朋友圈九宫格图片展示自定义控件,支持缩放动画~
  16. 深度学习中滑动平均模型的作用、计算方法及tensorflow代码示例
  17. 江南大学计算机技术复试科目,江南大学计算机专硕考哪些科目
  18. 安卓手机备忘录怎么添加录音
  19. HTML 显示系统时间
  20. IBM造海水电池,“搅局”锂电池产业?

热门文章

  1. 【原创】线上环境 SYN flooding 问题排查
  2. 对弈类游戏的人工智能(3)--博弈树优化
  3. flex中DataGrid里使用itemRenderer后数据无法绑定到数据源的问题
  4. DEDE获得顶级栏目名称
  5. 【原】android获取设备基本信息
  6. 读书笔记2013第5本:《拖延心理学》
  7. VNC客户端连接MacOS时一闪而过的解决办法
  8. 面试:Websocket
  9. 你不知道的Chrome调试技巧
  10. Docker的今生前世,关于Docker的一些见解