目录

  • 数据准备
  • 网络模型
  • 完整实现

数据准备

torch.utils.data.Datasets是PyTorch用来表示数据集的类,它是用PyTorch进行手写数字识别的关键。
下面是加载mnist数据集并对其可视化的代码

from torchvision import datasets
from torchvision.datasets import MNIST
from matplotlib import pyplot as plt
import numpy as npmnist = datasets.MNIST(root='C:/Users/WSY/Desktop/用pytorch进行手写数字识别',train=True,download=True)for i, j in enumerate(np.random.randint(0,len(mnist),(10,))):data, label = mnist[j]plt.subplot(2,5,i+1)plt.imshow(data)

这段代码中,首先实例化了Datasets对象mnist ,datasets.MNIST能够自动下载数据集保存到本地磁盘的root位置(自己设置),参数train默认为True,用于控制加载的数据集是训练集还是测试集。for循环中,使用len(mnist)调用了__len__方法,使用mnist[j]调用了__getitem__方法(在我们自己建立数据集时,需要继承Dataset,并且覆写__len__和__getitem__两个方法)。最后两行代码绘制了MNIST手写数字数据集。
运行代码,查看数据集的部分数据可视化结果

在变量浏览器中可以看到有关变量的内容,其中mnist就是实例化的数据集对象,它包含6000张图像内容,在for循环过程中,data读取的是28×28的Image类型图像,label是该图像对应的标签,也就是图像上表示的数字。

由于数据预处理是非常重要的步骤,所以PyTorch提供了torchvision.transforms用于处理数据及数据增强。在这里我们使用了torchvision.transforms.ToTensor将PIL Image或者numpy.ndarray类型的数据转换为Tensor,并且它会将数据从【0,255】映射到【0,1】。torchvision.transforms.Normalize会将数据标准化,加速模型在训练中的收敛速率。在使用中,可利用torchvision.transforms.Compose将多个transforms组合在一起,被包含的transforms会顺序执行。
数据流程处理准备完善后开始读取用于训练的数据,torch.utils.data.DataLoader提供了迭代数据、随机抽取数据、批量化数据。
下面的代码中实例化了mnist对象,定义了transforms方法trans并将其使用到数据集的实例化过程中,用于对数据集进行处理。然后定义了函数imshow,函数体的的第一行代码将数据从标准化的数据中恢复,第二行代码将Tensor类型转换为ndarray,这样才可以用matplotlib绘制出来,绘制的结果如下图所示。函数体最后一行使用transpose函数将矩阵维度从(C,W,H)转换为(W,H,C),这样才符合正常的通道顺序。

import torchvision
from torchvision import datasets,transforms
from torchvision.datasets import MNIST
from matplotlib import pyplot as plt
import numpy as np
import torchtrans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])mnist = datasets.MNIST(root='C:/Users/WSY/Desktop/用pytorch进行手写数字识别',train=True,download=True,transform=trans)def imshow(img):img = img * 0.3081 + 0.1307npimg = img.numpy()plt.imshow(np.transpose(npimg, (1,2,0)))dataloader = torch.utils.data.DataLoader(mnist, batch_size=4, shuffle=True, num_workers=0)
images, labels = next(iter(dataloader))imshow(torchvision.utils.make_grid(images))

运行代码得到的变量情况及绘制结果如下图

网络模型

下面构建用于识别手写数字的神经网络模型

import torch.nn as nnclass MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.inputlayer = nn.Sequential(nn.Linear(28*28, 256),nn.ReLU(),nn.Dropout(0.2))self.hiddenlayer = nn.Sequential(nn.Linear(256, 256),nn.ReLU(),nn.Dropout(0.2))self.outlayer = nn.Sequential(nn.Linear(256, 10))def forward(self, x):#将输入图像拉伸为一维向量x = x.view(x.size(0), -1)x = self.inputlayer(x)x = self.hiddenlayer(x)x = self.outlayer(x)return x

通过nn.Module对象看到其网络结构,如下

In [12]: print(MLP())
MLP((inputlayer): Sequential((0): Linear(in_features=784, out_features=256, bias=True)(1): ReLU()(2): Dropout(p=0.2, inplace=False))(hiddenlayer): Sequential((0): Linear(in_features=256, out_features=256, bias=True)(1): ReLU()(2): Dropout(p=0.2, inplace=False))(outlayer): Sequential((0): Linear(in_features=256, out_features=10, bias=True))
)

完整实现

准备好数据和模型后,就可以进行训练模型了。下面分别定义了数据处理和加载流程、模型、优化器、损失函数以及用准确率评估模型能力。训练过程持续10个epoch

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import optim
from tqdm import tqdm
from torchvision import datasets,transforms
import matplotlib.pylab as plt#数据处理和加载
trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = datasets.MNIST(root='C:/Users/WSY/Desktop/用pytorch进行手写数字识别',train=True,download=True,transform=trans)
mnist_val = datasets.MNIST(root='C:/Users/WSY/Desktop/用pytorch进行手写数字识别',train=False,download=True,transform=trans)trainloader = DataLoader(mnist_train, batch_size=16, shuffle=True, num_workers=0)
valloader = DataLoader(mnist_val, batch_size=16, shuffle=True, num_workers=0)#定义网络模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.inputlayer = nn.Sequential(nn.Linear(28*28, 256),nn.ReLU(),nn.Dropout(0.2))self.hiddenlayer = nn.Sequential(nn.Linear(256, 256),nn.ReLU(),nn.Dropout(0.2))self.outlayer = nn.Sequential(nn.Linear(256, 10))def forward(self, x):#将输入图像拉伸为一维向量x = x.view(x.size(0), -1)x = self.inputlayer(x)x = self.hiddenlayer(x)x = self.outlayer(x)return xmodel = MLP()#优化器
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)#损失函数
celoss = nn.CrossEntropyLoss()#计算准确率
def accuracy(pred, target):pred_label = torch.argmax(pred, 1)correct = sum(pred_label == target).to(torch.float)return correct, len(pred)acc = {'train':[], 'val':[]}
loss_all = {'train':[], 'val':[]}for epoch in tqdm(range(10)):#设置为验证模式model.eval()numer_val, denumer_val, loss_tr = 0., 0., 0.with torch.no_grad():for data, target in valloader:output = model(data)loss = celoss(output, target)loss_tr += loss.datanum, denum = accuracy(output, target)numer_val += numdenumer_val += denum#设置为训练模式model.train()numer_tr, denumer_tr, loss_val = 0., 0., 0.for data, target in trainloader:optimizer.zero_grad()output = model(data)loss = celoss(output, target)loss_val += loss.dataloss.backward()optimizer.step()num, denum = accuracy(output, target)numer_tr += numdenumer_tr += denumloss_all['train'].append(loss_tr/len(trainloader))loss_all['val'].append(loss_val/len(valloader))acc['train'].append(numer_tr/denumer_tr)acc['val'].append(numer_val/denumer_val)

运行完成,如下

设计到的变量情况如下


查看模型训练迭代过程的损失图像

plt.plot(loss_all['train'])
plt.plot(loss_all['val'])


查看训练迭代过程的准确率图

plt.plot(acc['train'])
plt.plot(acc['val'])


O V E R !

用PyTorch进行手写数字识别相关推荐

  1. 使用Pytorch实现手写数字识别

    使用Pytorch实现手写数字识别 1. 思路和流程分析 流程: 准备数据,这些需要准备DataLoader 构建模型,这里可以使用torch构造一个深层的神经网络 模型的训练 模型的保存,保存模型, ...

  2. pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

    文章目录 1. MNIST 手写数字识别 2. 聚焦数据集扩充后的模型训练 3. pytorch 手写数字识别基本实现 3.1完整代码及 MNIST 测试集测试结果 3.1.1代码 3.1.2 MNI ...

  3. pytorch实现手写数字识别_送源码!人工智能实现:识别图片中的手写数字,值得收藏...

    作者|小林同学 关注<高手杰瑞>,每天有不一样的实用小教程发布哦! 哈喽,大家好我是杰瑞.今天我给大家带来一个用机器学习的方法来实现手写数字识别的教程,就像C语言中输出的那一行" ...

  4. 使用PyTorch进行手写数字识别,在20 k参数中获得99.5%的精度。

    In this article we'll build a simple convolutional neural network in PyTorch and train it to recogni ...

  5. Pytorch CNN 手写数字识别 0-9

    使用的软件是pycharm 环境是在anaconda下创的虚拟环境pytorch 整个过程大体为,在画板手写数字,用python代码实现手写数字的批量生成,定义超参数,创建数据集包括训练集和数据集,创 ...

  6. 使用Pytorch实现手写数字识别(Mnist数据集)

    目标 知道如何使用Pytorch完成神经网络的构建 知道Pytorch中激活函数的使用方法 知道Pytorch中torchvision.transforms中常见图形处理函数的使用 知道如何训练模型和 ...

  7. pytorch实现手写数字识别_Paddle和Pytorch实现MNIST手写数字集识别对比

    一.简介 1. Paddle PaddlePaddle是百度自主研发的集深度学习核心框架.工具组件和服务平台为一体的技术领先.功能完备的开源深度学习平台,有全面的官方支持的工业级应用模型,涵盖自然语言 ...

  8. 深度学习-Pytorch:项目标准流程【构建、保存、加载神经网络模型;数据集构建器Dataset、数据加载器DataLoader(线性回归案例、手写数字识别案例)】

    1.拿到文本,分词,清晰数据(去掉停用词语): 2.建立word2index.index2word表 3.准备好预训练好的word embedding 4.做好DataSet / Dataloader ...

  9. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

最新文章

  1. c#导出包含图片的word文档
  2. [ CodeForces 865 D ] Buy Low Sell High
  3. PIVOT 和 UNPIVOT 命令的SQL Server版本
  4. vmware中linux无法动态获取dhcp解决方法
  5. C#关于MSMQ通过HTTP远程发送专有队列消息的问题
  6. 实战渗透之一个破站日一天
  7. linux学习教程(一)(安装篇)centos7没有安装ifconfig命令的解决方法
  8. springboot-web进阶(三)——统一异常处理
  9. markdown生成html不出效果,mdeditor: 简单markdown编辑器,同步预览html效果。不依赖任何插件,使用简单,原创,造轮子中。。。更新中。。。...
  10. Jenkins+Ant自动布署war
  11. python无限锁屏_定时锁屏程序,Python祝你原理猝死!
  12. 调试错误,请回到请求来源地,重新发起请求。 错误代码 insufficient-isv-permissions 错误原因: ISV权限不足,建议在开发者中心检查对应功能是否已经添加
  13. 学习笔记 Tianmao 篇 fresco 图片缓存加载框架
  14. CocosCreator之构建web版时自动使用模板文件
  15. Android ANR的trace文件基本信息解读
  16. 变焦光学系统工作原理及初始结构设计方法
  17. python的opencv库使用gpu加速_Python跳一跳:使用Cython加速opencv像素级访问
  18. 十大监控工具,值得一试
  19. 【聆思CSK6 视觉AI开发套件试用】体验头肩检测和手势识别最全教程
  20. 8万ta煤焦油加氢(8400ha)工艺设计

热门文章

  1. 极限编程-拥抱变化阅读感想(一)
  2. 牛客网:智力题+判断推理+数量关系(1)
  3. [转载]刘峰获“区块链60人”2020赋能中国区块链创新人物奖
  4. 类ku6未注册域名分享
  5. win10禁用USB恢复USB
  6. 记腾讯的暑期实习面试
  7. android 判断是白天还是晚上,然后设置地图模式
  8. 机械结构_day12
  9. 万向球头的锁紧结构图_联动锁紧球关节万向杆的制作方法
  10. 【项目分享】基于AB32和RT-Thread的墨水屏智能日历