import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置train = True, # 载入训练集transform=transforms.ToTensor(), # 把数据变成tensor类型download = True # 下载)
# 测试集
test_data = datasets.MNIST(root="./",train = False,transform=transforms.ToTensor(),download = True)
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):inputs,labels = dataprint(inputs.shape)print(labels.shape)break
# 定义网络结构
class LSTM(nn.Module):def __init__(self):super(LSTM,self).__init__()# 初始化self.lstm = torch.nn.LSTM(input_size = 28, # 表示输入特征的大小hidden_size = 64, # 表示lstm模块的数量num_layers = 1, # 表示lstm隐藏层的层数batch_first = True # lstm默认格式input(seq_len,batch,feature)等于True表示input和output变成(batch,seq_len,feature))self.out = torch.nn.Linear(in_features=64,out_features=10)self.softmax = torch.nn.Softmax(dim=1)def forward(self,x):# (batch,seq_len,feature)x = x.view(-1,28,28)# output:(batch,seq_len,hidden_size)包含每个序列的输出结果# 虽然lstm的batch_first为True,但是h_n,c_n的第0个维度还是num_layers# h_n :[num_layers,batch,hidden_size]只包含最后一个序列的输出结果# c_n:[num_layers,batch,hidden_size]只包含最后一个序列的输出结果output,(h_n,c_n) = self.lstm(x)output_in_last_timestep = h_n[-1,:,:]x = self.out(output_in_last_timestep)x = self.softmax(x)return x
# 定义模型
model = LSTM()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 随机梯度下降
# 定义模型训练和测试的方法
def train():# 模型的训练状态model.train()for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 交叉熵代价函数out(batch,C:类别的数量),labels(batch)loss = mse_loss(out,labels)# 梯度清零optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():# 模型的测试状态model.eval()correct = 0 # 测试集准确率for i,data in enumerate(test_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted==labels).sum()print("Test acc:{0}".format(correct.item()/len(test_data)))correct = 0for i,data in enumerate(train_loader): # 训练集准确率# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted==labels).sum()print("Train acc:{0}".format(correct.item()/len(train_data)))
# 训练
for epoch in range(10):print("epoch:",epoch)train()test()

PyTorch基础-使用LSTM神经网络实现手写数据集识别-08相关推荐

  1. 我的Go+语言初体验——Go+语言构建神经网络实战手写数字识别

    "我的Go+语言初体验" | 征文活动进行中- 我的Go+语言初体验--Go+语言构建神经网络实战手写数字识别 0. 前言 1. 神经网络相关概念 2. 构建神经网络实战手写数字识 ...

  2. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

  3. 神经网络实现手写数字识别(MNIST)

    一.缘起 原本想沿着 传统递归算法实现迷宫游戏 --> 遗传算法实现迷宫游戏 --> 神经网络实现迷宫游戏的思路,在本篇当中也写如何使用神经网络实现迷宫的,但是研究了一下, 感觉有些麻烦不 ...

  4. 基于matlab BP神经网络的手写数字识别

    摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入.灰度化以及二值化等处理,通过神 ...

  5. 基于BP神经网络的手写数字识别

    基于BP神经网络的手写数字识别 摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入 ...

  6. BP神经网络实现手写数字识别Python实现,带GUI手写画板

    BP神经网络实现手写数字识别 BP神经网络模型 用tkinter编写用于手写输入的画板 程序运行的效果截图 在B站看了一个机器学习基础的视频( 链接)后,发现到资料里面有一个用BP神经网络对手写数字进 ...

  7. 读书笔记-深度学习入门之pytorch-第四章(含卷积神经网络实现手写数字识别)(详解)

    1.卷积神经网络在图片识别上的应用 (1)局部性:对一张照片而言,需要检测图片中的局部特征来决定图片的类别 (2)相同性:可以用同样的模式去检测不同照片的相同特征,只不过这些特征处于图片中不同的位置, ...

  8. MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试

    文章目录 MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试 一.题目要求 二.完整的目录结构说明 三.Mnist数据集及数据格式转换 四.BP神经网络相关知识 4.1 ...

  9. 深度学习笔记:07神经网络之手写数字识别的经典实现

    神经网络之手写数字识别的经典实现 上一节完成了简单神经网络代码的实现,下面我们将进行最终的实现:输入一张手写图片后,网络输出该图片对应的数字.由于网络需要用0-9一共十个数字中挑选出一个,所以我们的网 ...

最新文章

  1. C# CAD对象 构造时应把它的父对象也加进它的属性里
  2. [教程] [承風雅傳HSU]用ES4封裝Win7---ES4 Win7封裝教程(未完待續)
  3. EOS (3)系统特点
  4. ngrok服务器搭建
  5. PWN-PRACTICE-BUUCTF-10
  6. python面向对象实现简易银行管理员页面系统
  7. Pandas 文本数据方法 get( )
  8. 仓库装箱管理装箱发货,装箱扫描,装箱条码扫描系统成品装箱系统
  9. wgs84坐标系拾取工具_百度坐标(BD09)、国测局坐标(火星坐标,GCJ02)、和WGS84坐标系之间的转换的工具...
  10. windows下安装yarn
  11. 使用原生js 监听video 当前播放时间和是否点击了播放或者暂停按钮
  12. vue图片裁剪功能的实现
  13. 中国互联网惊呆老外?微信大数据揭露“无现金”真相
  14. health HEALTH_WARN;352 pgs degraded;352 pgs stuck unclean;352 pgs undersized;recovery 20/40 objects
  15. 黑色温敏性PNIPAM-AuNPs/CHOL-AuNPs纳米金粒修饰聚合物的制备过程
  16. 从四个维度谈谈如何做好团队管理
  17. QQ群设置里的“不提示消息只显示数目”与“接收不提示消息”的区别
  18. 全球与中国农用软管卷盘市场现状及未来发展趋势
  19. C语言计算程序运行时间简单实例
  20. 微信公众号服务器接收不到粉丝留言消息

热门文章

  1. 吉林高考成绩查询2021年几号公布,2021年吉林高考成绩查询时间及查分方式
  2. java list按照某个字段排序_java相关:List对象去重和按照某个字段排序的实现方法...
  3. 华为 泰山 服务器 操作系统安装,华为Taishan服务器安装CentOS7操作系统
  4. 人脸识别的python实现代码_手把手教你用1行代码实现人脸识别 --Python Face_recognition...
  5. 传智学员信息登记表html代码_IT兄弟连 HTML5教程 HTML5文字版面和编辑标签 使用HTML表格...
  6. 大连海事大学计算机调剂,大连海事大学2017年考研调剂信息
  7. mysql innodb 并行_关于MySQL8.0 InnoDB并行执行的详解
  8. 微信小程序 html modal,微信小程序参考微信设计规范做的modal模态框
  9. vue 子级拿值_vue 父组件通过$refs获取子组件的值和方法详解
  10. foreach用法_25个你不得不知道的数组reduce高级用法