pytorch rnn 实现手写字体识别

  • 构建 RNN 代码
  • 加载数据
  • 使用RNN 训练 和测试数据

构建 RNN 代码


import  torch
import   torch.nn  as  nn
from  torch.autograd  import  Variableimport  torch.utils.data  as  Dataimport   torchvisionimport   matplotlib.pyplot  as  plttorch.manual_seed(1)#batch size
BATCH_SIZE=50
#学习率
LR= 0.001
DOWNLOAD=False
#是否训练
TRAIN =Falseclass   RNN(nn.Module):def __init__(self):super(RNN,self).__init__()'''input_size:输入特征的数目hidden_size:隐层的特征数目num_layers:这个是模型集成的LSTM的个数 记住这里是模型中有多少个LSTM摞起来 一般默认就1个#batch_first: 输入数据的size为[batch_size, time_step, input_size]还是[time_step, batch_size, input_size]'''self.rnn= nn.LSTM(input_size=28,hidden_size=64,num_layers=3,batch_first=True #batch_first: 输入数据的size为[batch_size, time_step, input_size]还是[time_step, batch_size, input_size])self.out = nn.Linear(64,10)self.optimizer = torch.optim.Adam(self.parameters(),lr=LR)self.lossFunc= nn.CrossEntropyLoss()def forward(self,x):#x [ batch,28,28]r_out ,(h_n,h_c)= self.rnn(x,None)#r_out [50,28,64]   h_n=[1,50,64]  h_c =[1,50,64]#r_out  表示 每一次输入  28 个像素  输入了  50* 28 次#h_n 表示    每 28*28 为一次 记录  隐藏层 为 64 所以为  50,64  每28*28为一个记录 参数print(r_out.size(), h_n.size(),h_c.size())r_out = self.out(r_out[:,-1,:])return  r_outdef  lossFunction(self,predict ,batchY):loss = self.lossFunc(predict,batchY)self.optimizer.zero_grad()loss.backward()print("loss==",loss.data)self.optimizer.step()

加载数据

tranData =  torchvision.datasets.MNIST(root="d:/mnist/",train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD)testData = torchvision.datasets.MNIST(root="d:/mnist/",train=False
)trainLoader =  Data.DataLoader(dataset=tranData,batch_size=BATCH_SIZE,shuffle=True)# 为了节约时间, 我们测试时只测试前2000个
# shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_x = Variable(torch.unsqueeze(testData.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
test_y = testData.test_labels[:2000]

使用RNN 训练 和测试数据


#构造RNN
myRNN= RNN()#训练数据
if(TRAIN):for  epoch in  range(3):for  step  ,(x,y) in enumerate(trainLoader):trainX = Variable(x.view(-1,28,28))print("trainX==",trainX.size())tranY = Variable(y)predict= myRNN(trainX)print("predict==",predict)myRNN.lossFunction(predict,tranY)torch.save(myRNN.state_dict(), "d:/mnist/rnn.pkl")
else:myRNN.load_state_dict(torch.load("d:/mnist/rnn.pkl"))#测试数据
testOut = myRNN(test_x[:20].view(-1,28,28))print("testOut==",testOut.size())
#预测值
testPredict = torch.max(testOut,1)[1]print("testPredict==", testPredict.size())
print(testPredict,test_y[:20])

pytorch rnn 实现手写字体识别相关推荐

  1. pytorch CNN手写字体识别

    ## """CNN手写字体识别"""import torch import torch.nn as nn from torch.autogr ...

  2. 【PyTorch学习笔记_04】--- PyTorch(开始动手操作_案例1:手写字体识别)

    手写字体识别的流程 定义超参数(自己定义的参数) 构建transforms, 主要是对图像做变换 下载,加载数据集MNIST 构建网络模型(重要,自己定义) 定义训练方法 定义测试方法 开始训练模型, ...

  3. PyTorch手写字体识别MNIST

    手写字体识别MNIST 1.准备工作 可以看这个老师的视频进行学习,讲解的非常仔细:视频学习 2.项目代码 2.1 导入模块 # 1.加载相关库 import torch import torch.n ...

  4. 第六讲 Keras实现手写字体识别分类

    一 本节课程介绍 1.1 知识点 1.图像识别分类相关介绍: 2.Mnist手写数据集介绍: 3.标准化数据预处理: 4.实验手写字体识别 二 课程内容 2.1 图像识别分类基本介绍 计算机的图像识别 ...

  5. Android Studio编写一个手写字体识别程序

    1.activity_main.xml 的代码 <?xml version="1.0" encoding="utf-8"?> <LinearL ...

  6. 人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist)

    人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist),使用技术(Django+js+tensorflow+html+bootstrap+inspinia框架) 直接上图,项目效果 1 ...

  7. Pytorch实现mnist手写数字识别

    2020/6/29 Hey,突然想起来之前做的一个入门实验,用pytorch实现mnist手写数字识别.可以在这个基础上增加网络层数,或是尝试用不同的数据集,去实现不一样的功能. Mnist数据集如图 ...

  8. python手写字体程序_深度学习---手写字体识别程序分析(python)

    我想大部分程序员的第一个程序应该都是"hello world",在深度学习领域,这个"hello world"程序就是手写字体识别程序. 这次我们详细的分析下手 ...

  9. 实验3 手写字体识别【机器学习】

    推荐 python实现手写数字识别(小白入门) 原文MNIST Handwritten Digit Recognition in PyTorch 翻译用PyTorch实现MNIST手写数字识别(非常详 ...

最新文章

  1. LeetCode:跳跃游戏【55】
  2. WPF实现聚光灯效果
  3. 集合——对象数组(引用数据类型数组)
  4. spring框架(六)之拦截器
  5. idea 导入到码云
  6. 解决Linux下DNS配置重启失效问题
  7. 多线程有几种实现方法?同步有几种实现方法?
  8. Atitit 微服务的一些理论 目录 1. 微服务的4个设计原则和19个解决方案 1 2. 微服务应用4个设计原则 1 2.1. AKF拆分原则 2 2.2. 前后端分离 2 2.3. 无状态服务
  9. 二级计算机vfp知识,全国计算机vfp二级考试
  10. iMeta | FSCapture报告录屏和视频剪辑(视频教程)
  11. 像素值与灰度值的区别与关系
  12. html如何调用less,LESS
  13. 虚拟路由器冗余协议——VRRP
  14. Android - scheme 一个app跳转另一个app、模块开发
  15. 关于PHP签名中的容易犯错问题记录
  16. Java用户注册手机短信验证码校验功能实现
  17. Android扫描车牌,车牌拍照识别SDK
  18. 0xC000005:Access Violation和指针强制转换问题
  19. 大数据实战 --- 淘宝用户行为数据分析
  20. 面向对象程序设计实验报告

热门文章

  1. 晴园直播(全球直播)订阅源+轻站+海阔小程序
  2. ubuntu linux下的C语言开发(进程创建)
  3. Git报错error: could not lock config file C:/Program Files/Git/mingw64/etc/gitconfig: Permission denie
  4. 多edittext监听变化的优化
  5. HTML基础的回顾复习(基本标签,简单的一个登陆验证)
  6. Windows Azure Marketplace DataMarket概述
  7. C#实现生产者与消费者关系
  8. 用正则表达式来判断手机号、地址、身份证号、邮箱等格式是否正确
  9. 地上有一个m行和n列的方格。一个机器人从坐标0,0的格子开始移动,每一次只能向左,右,上,下四个方向移动一格,但是不能进入行坐标和列坐标的数位之和大于k的格子。 例如,当k为18时,机器人能够进入方格
  10. php srs api,srs 身份认证