实验环境

  • win10 + anaconda + jupyter notebook
  • Pytorch1.1.0
  • Python3.7
  • gpu环境(可选)

MNIST数据集介绍

MNIST 包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“Hello Word”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。

导入包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
torch.__version__

定义超参数

BATCH_SIZE=512
EPOCHS=20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

数据集

我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False

train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=BATCH_SIZE, shuffle=True)

定义网络

该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。

class ConvNet(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24) self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)self.fc1 = nn.Linear(20*10*10,500)self.fc2 = nn.Linear(500,10)def forward(self,x):in_size = x.size(0)out = self.conv1(x)out = F.relu(out)out = F.max_pool2d(out, 2, 2)  out = self.conv2(out)out = F.relu(out)out = out.view(in_size,-1)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)out = F.log_softmax(out,dim=1)return out

实例化网络

model = ConvNet().to(DEVICE) # 将网络移动到gpu上
optimizer = optim.Adam(model.parameters()) # 使用Adam优化器

定义训练函数

def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if(batch_idx+1)%30 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))

定义测试函数

def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

开始训练

for epoch in range(1, EPOCHS + 1):train(model, DEVICE, train_loader, optimizer, epoch)test(model, DEVICE, test_loader)

实验结果

Train Epoch: 1 [14848/60000 (25%)]   Loss: 0.375058
Train Epoch: 1 [30208/60000 (50%)]  Loss: 0.255248
Train Epoch: 1 [45568/60000 (75%)]  Loss: 0.128060Test set: Average loss: 0.0992, Accuracy: 9690/10000 (97%)Train Epoch: 2 [14848/60000 (25%)]  Loss: 0.093066
Train Epoch: 2 [30208/60000 (50%)]  Loss: 0.087888
Train Epoch: 2 [45568/60000 (75%)]  Loss: 0.068078Test set: Average loss: 0.0599, Accuracy: 9816/10000 (98%)Train Epoch: 3 [14848/60000 (25%)]  Loss: 0.043926
Train Epoch: 3 [30208/60000 (50%)]  Loss: 0.037321
Train Epoch: 3 [45568/60000 (75%)]  Loss: 0.068404Test set: Average loss: 0.0416, Accuracy: 9859/10000 (99%)Train Epoch: 4 [14848/60000 (25%)]  Loss: 0.031654
Train Epoch: 4 [30208/60000 (50%)]  Loss: 0.041341
Train Epoch: 4 [45568/60000 (75%)]  Loss: 0.036493Test set: Average loss: 0.0361, Accuracy: 9873/10000 (99%)Train Epoch: 5 [14848/60000 (25%)]  Loss: 0.027688
Train Epoch: 5 [30208/60000 (50%)]  Loss: 0.019488
Train Epoch: 5 [45568/60000 (75%)]  Loss: 0.018023Test set: Average loss: 0.0344, Accuracy: 9875/10000 (99%)Train Epoch: 6 [14848/60000 (25%)]  Loss: 0.024212
Train Epoch: 6 [30208/60000 (50%)]  Loss: 0.018689
Train Epoch: 6 [45568/60000 (75%)]  Loss: 0.040412Test set: Average loss: 0.0350, Accuracy: 9879/10000 (99%)Train Epoch: 7 [14848/60000 (25%)]  Loss: 0.030426
Train Epoch: 7 [30208/60000 (50%)]  Loss: 0.026939
Train Epoch: 7 [45568/60000 (75%)]  Loss: 0.010722Test set: Average loss: 0.0287, Accuracy: 9892/10000 (99%)Train Epoch: 8 [14848/60000 (25%)]  Loss: 0.021109
Train Epoch: 8 [30208/60000 (50%)]  Loss: 0.034845
Train Epoch: 8 [45568/60000 (75%)]  Loss: 0.011223Test set: Average loss: 0.0299, Accuracy: 9904/10000 (99%)Train Epoch: 9 [14848/60000 (25%)]  Loss: 0.011391
Train Epoch: 9 [30208/60000 (50%)]  Loss: 0.008091
Train Epoch: 9 [45568/60000 (75%)]  Loss: 0.039870Test set: Average loss: 0.0341, Accuracy: 9890/10000 (99%)Train Epoch: 10 [14848/60000 (25%)] Loss: 0.026813
Train Epoch: 10 [30208/60000 (50%)] Loss: 0.011159
Train Epoch: 10 [45568/60000 (75%)] Loss: 0.024884Test set: Average loss: 0.0286, Accuracy: 9901/10000 (99%)Train Epoch: 11 [14848/60000 (25%)] Loss: 0.006420
Train Epoch: 11 [30208/60000 (50%)] Loss: 0.003641
Train Epoch: 11 [45568/60000 (75%)] Loss: 0.003402Test set: Average loss: 0.0377, Accuracy: 9894/10000 (99%)Train Epoch: 12 [14848/60000 (25%)] Loss: 0.006866
Train Epoch: 12 [30208/60000 (50%)] Loss: 0.012617
Train Epoch: 12 [45568/60000 (75%)] Loss: 0.008548Test set: Average loss: 0.0311, Accuracy: 9908/10000 (99%)Train Epoch: 13 [14848/60000 (25%)] Loss: 0.010539
Train Epoch: 13 [30208/60000 (50%)] Loss: 0.002952
Train Epoch: 13 [45568/60000 (75%)] Loss: 0.002313Test set: Average loss: 0.0293, Accuracy: 9905/10000 (99%)Train Epoch: 14 [14848/60000 (25%)] Loss: 0.002100
Train Epoch: 14 [30208/60000 (50%)] Loss: 0.000779
Train Epoch: 14 [45568/60000 (75%)] Loss: 0.005952Test set: Average loss: 0.0335, Accuracy: 9897/10000 (99%)Train Epoch: 15 [14848/60000 (25%)] Loss: 0.006053
Train Epoch: 15 [30208/60000 (50%)] Loss: 0.002559
Train Epoch: 15 [45568/60000 (75%)] Loss: 0.002555Test set: Average loss: 0.0357, Accuracy: 9894/10000 (99%)Train Epoch: 16 [14848/60000 (25%)] Loss: 0.000895
Train Epoch: 16 [30208/60000 (50%)] Loss: 0.004923
Train Epoch: 16 [45568/60000 (75%)] Loss: 0.002339Test set: Average loss: 0.0400, Accuracy: 9893/10000 (99%)Train Epoch: 17 [14848/60000 (25%)] Loss: 0.004136
Train Epoch: 17 [30208/60000 (50%)] Loss: 0.000927
Train Epoch: 17 [45568/60000 (75%)] Loss: 0.002084Test set: Average loss: 0.0353, Accuracy: 9895/10000 (99%)Train Epoch: 18 [14848/60000 (25%)] Loss: 0.004508
Train Epoch: 18 [30208/60000 (50%)] Loss: 0.001272
Train Epoch: 18 [45568/60000 (75%)] Loss: 0.000543Test set: Average loss: 0.0380, Accuracy: 9894/10000 (99%)Train Epoch: 19 [14848/60000 (25%)] Loss: 0.001699
Train Epoch: 19 [30208/60000 (50%)] Loss: 0.000661
Train Epoch: 19 [45568/60000 (75%)] Loss: 0.000275Test set: Average loss: 0.0339, Accuracy: 9905/10000 (99%)Train Epoch: 20 [14848/60000 (25%)] Loss: 0.000441
Train Epoch: 20 [30208/60000 (50%)] Loss: 0.000695
Train Epoch: 20 [45568/60000 (75%)] Loss: 0.000467Test set: Average loss: 0.0396, Accuracy: 9894/10000 (99%)

总结

一个实际项目的工作流程:找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。

用PyTorch实现MNIST手写体识别相关推荐

  1. python模拟手写笔迹_pytorch实现MNIST手写体识别

    本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下 实验环境 pytorch 1.4 Windows 10 python 3.7 cuda 10.1(我笔记 ...

  2. TensorRT(3)-C++ API使用:mnist手写体识别

    本节将介绍如何使用tensorRT C++ API 进行网络模型创建. 1 使用C++ API 进行 tensorRT 模型创建 还是通过 tensorRT官方给的一个例程来学习. 还是mnist手写 ...

  3. TensorRT(2)-基本使用:mnist手写体识别

    结合 tensorRT官方给出的一个例程,介绍tensorRT的使用. 这个例程是mnist手写体识别.例程位于目录: /usr/src/tensorrt/samples/sampleMNIST 文件 ...

  4. R︱Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)

    本文转载自经管之家论坛, R语言中的Softmax Regression建模 (MNIST 手写体识别和文档多分类应用) R中的softmaxreg包,发自2016-09-09,链接:https:// ...

  5. 【人工智能项目】MNIST手写体识别实验及分析

    [人工智能项目]MNIST数据集实验报告 这是之前接的小作业,现在分享出来,给大家以学习!!! [人工智能项目]MNIST手写体识别实验及分析 1.实验内容简述 1.1 实验环境 本实验采用的软硬件实 ...

  6. 2021年人工神经网络第四次作业 - 第二题MNIST手写体识别

    简 介: ※MNIST数据集合是深度学习基础训练数据集合.改数据集合可以使用稠密前馈神经网络训练,也可以使用CNN.本文采用了单隐层BP网络和LeNet网络对于MNIST数据集合进行测试.实验结果标明 ...

  7. python神经网络案例——CNN卷积神经网络实现mnist手写体识别

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python教程全解 CNN卷积神经网络的理论教程参考 ...

  8. python神经网络案例——FC全连接神经网络实现mnist手写体识别

    全栈工程师开发手册 (作者:栾鹏) python教程全解 FC全连接神经网络的理论教程参考 http://blog.csdn.net/luanpeng825485697/article/details ...

  9. mnist手写体识别中用到的TensorFlow API总结

    声明:本文通过CNN实现mnist例子总结了TensorFlow 1.12的相关API.代码来源于<Learning TensorFlow>这本书,API查阅了TensorFlow官网AP ...

最新文章

  1. IDEA IntelliJ/ DataGrip 修改自动补全快捷键
  2. UVA1411 Ants(带权二分图的最大完美匹配、zkw费用流)
  3. UA OPTI570 量子力学34 Harmonic Perturbation简介
  4. AB1601低功耗注意事项
  5. Java实现复制文件
  6. closewait一直不释放_夏至太阳庆典 || 巨蟹座日蚀新月大释放蜡烛魔法仪式(寻找新的幸运儿)...
  7. 使用.NET Core创建Windows服务(二) - 使用Topshelf方式
  8. 单进程服务器(python版)
  9. 基于Swoole和beanstalkd实现多进程处理消息队列。
  10. 极简静态 Web 服务器
  11. Js-Alert弹出框几种样式(一级)
  12. vnc全屏界面怎么设置小_如何设置 才能 修改 vnc 有效窗口大小
  13. PowerDesigner中主键外键唯一键设置
  14. vscode html可视化,在VSCode中可视化数据结构
  15. 二极管计算机原理,二极管的原理
  16. 阿里面试经历与感受谈
  17. 3D人体姿态估计论文汇总(CVPR/ECCV/ACCV/AAAI)
  18. python实现对输入日期计算日期为当年第几天
  19. SIM卡类型之间的差异-选择哪种SIM卡
  20. 代尔夫特理工大学计算机科学排名,2019-2020代尔夫特理工大学世界排名多少【QS最新第50名】...

热门文章

  1. 脉冲波形的产生和整形
  2. jeesite实战(二)——jeesite工具生成基本的页面
  3. MSN 8.1 优化
  4. malloc函数java_malloc函数具体解释
  5. 微信小程序中通过两点经纬度计算距离
  6. arcgisengine 线转面方法
  7. ADS设计日志(一):阻抗变换器详讲
  8. ZZULIOJ(C语言)1115数组最小值
  9. confluence 编辑器这次没有加载_推荐一款编辑器然后是R语言的相关性分析
  10. 麦田音乐节,只等你来