说明:本人也是一个萌新,也在学习中,有代码里也有不完善的地方。如果有错误/讲解不清的地方请多多指出

本文代码链接:

GitHub - Michael-OvO/mnist: mnist_trained_model with torch

明确任务目标:

使用pytorch作为框架使用mnist数据集训练一个手写数字的识别

换句话说:输入为

输出: 0

比较简单直观

1. 环境搭建

需要安装Pytorch, 具体过程因系统而异,这里也就不多赘述了

具体教程可以参考这个视频 (这个系列的P1是环境配置)

PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili【已完结!!!已完结!!!2021年5月31日已完结】本系列教程,将带你用全新的思路,快速入门PyTorch。独创的学习思路,仅此一家。个人公众号:我是土堆各种资料,请自取。代码:https://github.com/xiaotudui/PyTorch-Tutorial蚂蚁蜜蜂/练手数据集:链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码https://www.bilibili.com/video/BV1hE411t7RN?share_source=copy_web

2. 基本导入

import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import time
import matplotlib.pyplot as plt
import random
from numpy import argmax

不多解释,导入各种需要的包

3. 基本参数定义

#Basic Params-----------------------------
epoch = 1
learning_rate = 0.01
batch_size_train = 64
batch_size_test = 1000
gpu = torch.cuda.is_available()
momentum = 0.5

epoch是整体进行几批训练

learning rate 为学习率

随后是每批训练数据大小和测试数据大小

gpu是一个布尔值,方便没有显卡的同学可以不用cuda加速,但是有显卡的同学可以优先使用cuda

momentum 是动量,避免找不到局部最优解的尴尬情况

这些都是比较基本的网络参数

4. 数据加载

使用Dataloader加载数据,如果是第一次运行将会从网上下载数据

如果下载一直不行的话也可以从官方直接下载并放入./data目录即可

​​​​​​MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

(有4个包都需要下载)

#Load Data-------------------------------
train_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True,transform=torchvision.transforms.Compose([                  torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_train, shuffle=True)test_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True,transform=torchvision.transforms.Compose([  torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_test, shuffle=True)train_data_size = len(train_loader)
test_data_size = len(test_loader)

5. 网络定义

接下来是重中之重

网络的定义

这边的网络结构参考了这张图:

有了结构图,代码就很好写了, 直接对着图敲出来就好了

非常建议使用sequential直接写网络结构,会方便很多

#Define Model----------------------------class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(in_features=3136, out_features=128),nn.Linear(in_features=128, out_features=10),)def forward(self, x):return self.model(x)if gpu:net = Net().cuda()
else:net = Net()

6.损失函数和优化器

交叉熵和SGD(随机梯度下降)

另外为了方便记录训练情况可以使用TensorBoard的Summary Writer

#Define Loss and Optimizer----------------if gpu: loss_fn = nn.CrossEntropyLoss().cuda()
else:loss_fn = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)#Define Tensorboard-------------------writer = SummaryWriter(log_dir='logs/{}'.format(time.strftime('%Y%m%d-%H%M%S')))

7. 开始训练

#Train---------------------------------total_train_step = 0def train(epoch):global total_train_steptotal_train_step = 0for data in train_loader:imgs,targets = dataif gpu:imgs,targets = imgs.cuda(),targets.cuda()optimizer.zero_grad()outputs = net(imgs)loss = loss_fn(outputs,targets)loss.backward()optimizer.step()if total_train_step % 200 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, total_train_step, train_data_size,100. * total_train_step / train_data_size, loss.item()))writer.add_scalar('loss', loss.item(), total_train_step)total_train_step += 1#Test---------------------------------def test():correct = 0total = 0with torch.no_grad():for data in test_loader:imgs,targets = dataif gpu:imgs,targets = imgs.cuda(),targets.cuda()outputs = net(imgs)_,predicted = torch.max(outputs.data,1)total += targets.size(0)correct += (predicted == targets).sum().item()print('Test Accuracy: {}/{} ({:.0f}%)'.format(correct, total, 100.*correct/total))return correct/total#Run----------------------------------for i in range(1,epoch+1):print("-----------------Epoch: {}-----------------".format(i))train(i)test()writer.add_scalar('test_accuracy', test(), total_train_step)#save modeltorch.save(net,'model/mnist_model.pth')print('Saved model')writer.close()

注意这里必须要先在同一文件夹下创建一个叫做model的文件夹!!!不然模型目录将找不到地方保存!!!会报错!!!

Train函数作为训练,Test函数作为测试

注意每次训练需要梯度清零

模型测试时要写with torch.no_grad()

运行的过程如果有GPU加速会很快,运行结果应该如下

正确率也还算是可以,一个epoch就能跑到98,如果不满意或者想调epoch次数可以在basic params区域直接进行修改

8. 模型验证和结果展示

小细节很多

首先是抽取样本的时候需要考虑转cuda的问题

其次如果直接将样本扔到网络里dimension不对,需要reshape

需要对结果进行argmax处理,因为结果是一个向量(有10个features,分别代表每个数字的概率),argmax会找到最大概率并输出模型的预测结果

使用matplotlib画图

#Evaluate---------------------------------model = torch.load("./model/mnist_model.pth")
model.eval()
print(model)fig = plt.figure(figsize=(20,20))
for i in range(20):#随机抽取20个样本index = random.randint(0,test_data_size)data = test_loader.dataset[index]#注意Cuda问题if gpu:img = data[0].cuda()else:img = data[0]#维度不对必须要reshapeimg = torch.reshape(img,(1,1,28,28))with torch.no_grad():output = model(img)#plot the image and the predicted numberfig.add_subplot(4,5,i+1)#一定要取Argmax!!!plt.title(argmax(output.data.cpu().numpy()))plt.imshow(data[0].numpy().squeeze(),cmap='gray')
plt.show()

运行结果如下:

效果还是很不错的

至此我们就完成了一整个模型训练,保存,导入,验证的基本流程。

完整代码

import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import time
import matplotlib.pyplot as plt
import random
from numpy import argmax#Basic Params-----------------------------
epoch = 1
learning_rate = 0.01
batch_size_train = 64
batch_size_test = 1000
gpu = torch.cuda.is_available()
momentum = 0.5#Load Data-------------------------------
train_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True,transform=torchvision.transforms.Compose([                  torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_train, shuffle=True)test_loader = DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True,transform=torchvision.transforms.Compose([  torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_test, shuffle=True)train_data_size = len(train_loader)
test_data_size = len(test_loader)#Define Model----------------------------class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(in_features=3136, out_features=128),nn.Linear(in_features=128, out_features=10),)def forward(self, x):return self.model(x)if gpu:net = Net().cuda()
else:net = Net()#Define Loss and Optimizer----------------if gpu: loss_fn = nn.CrossEntropyLoss().cuda()
else:loss_fn = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)#Define Tensorboard-------------------writer = SummaryWriter(log_dir='logs/{}'.format(time.strftime('%Y%m%d-%H%M%S')))#Train---------------------------------total_train_step = 0def train(epoch):global total_train_steptotal_train_step = 0for data in train_loader:imgs,targets = dataif gpu:imgs,targets = imgs.cuda(),targets.cuda()optimizer.zero_grad()outputs = net(imgs)loss = loss_fn(outputs,targets)loss.backward()optimizer.step()if total_train_step % 200 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, total_train_step, train_data_size,100. * total_train_step / train_data_size, loss.item()))writer.add_scalar('loss', loss.item(), total_train_step)total_train_step += 1#Test---------------------------------def test():correct = 0total = 0with torch.no_grad():for data in test_loader:imgs,targets = dataif gpu:imgs,targets = imgs.cuda(),targets.cuda()outputs = net(imgs)_,predicted = torch.max(outputs.data,1)total += targets.size(0)correct += (predicted == targets).sum().item()print('Test Accuracy: {}/{} ({:.0f}%)'.format(correct, total, 100.*correct/total))return correct/total#Run----------------------------------for i in range(1,epoch+1):print("-----------------Epoch: {}-----------------".format(i))train(i)test()writer.add_scalar('test_accuracy', test(), total_train_step)#save modeltorch.save(net,'model/mnist_model.pth')print('Saved model')writer.close()#Evaluate---------------------------------model = torch.load("./model/mnist_model.pth")
model.eval()
print(model)fig = plt.figure(figsize=(20,20))
for i in range(20):#random numberindex = random.randint(0,test_data_size)data = test_loader.dataset[index]if gpu:img = data[0].cuda()else:img = data[0]img = torch.reshape(img,(1,1,28,28))with torch.no_grad():output = model(img)#plot the image and the predicted numberfig.add_subplot(4,5,i+1)plt.title(argmax(output.data.cpu().numpy()))plt.imshow(data[0].numpy().squeeze(),cmap='gray')
plt.show()

基于Pytorch的MNIST手写数字识别实现(含代码+讲解)相关推荐

  1. 基于K210的MNIST手写数字识别

    基于K210的MNIST手写数字识别 项目已开源链接: Github. 硬件平台 采用Maixduino开发板 在sipeed官方有售 软件平台 使用MaixPy环境进行单片机的编程 官方资源可在这里 ...

  2. Pytorch实现mnist手写数字识别

    2020/6/29 Hey,突然想起来之前做的一个入门实验,用pytorch实现mnist手写数字识别.可以在这个基础上增加网络层数,或是尝试用不同的数据集,去实现不一样的功能. Mnist数据集如图 ...

  3. 用PyTorch实现MNIST手写数字识别(非常详细)

    ​​​​​Keras版本: Keras入门级MNIST手写数字识别超级详细教程 2022/4/17 更新修复下代码.完善优化下文章结构,文末提供一个完整版代码. 可以在这里下载源码文件(免积分): 用 ...

  4. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

  5. Pytorch入门——MNIST手写数字识别代码

    MNIST手写数字识别教程 本文仅仅放出该教程的代码 具体教程请看 Pytorch入门--手把手教你MNIST手写数字识别 import torch import torchvision from t ...

  6. 1、基于Keras、Mnist手写数字识别数据集构建全连接(FC)神经网络训练模型

    文章目录 前言 一.MNIST数据集是什么? 二.构建神经网络训练模型 1.导入库 2.载入数据 3.数据处理 4.创建模型 5.编译模型 6.训练模型 7.评估模型 三.总代码 前言 提示: 1.本 ...

  7. 深度学习案例之基于 CNN 的 MNIST 手写数字识别

    一.模型结构 本文只涉及利用Tensorflow实现CNN的手写数字识别,CNN的内容请参考:卷积神经网络(CNN) MNIST数据集的格式与数据预处理代码input_data.py的讲解请参考 :T ...

  8. 深度学习入门实例——基于keras的mnist手写数字识别

    本文介绍了利用keras做mnist数据集的手写数字识别. 参考网址 http://www.cnblogs.com/lc1217/p/7132364.html mnist数据集中的图片为28*28的单 ...

  9. 【手写数字识别】基于Lenet网络实现手写数字识别附matlab代码

    1 内容介绍 当今社会,人工智能得到快速发展,而模式识 别作为人工智能的一个重要应用领域也得到了飞 速发展,它利用计算机通过计算的方法根据样本的 特征对样本进行分类,其中的光学字符识别技术受 到广大研 ...

最新文章

  1. SQL Server 中@@IDENTITY的用法
  2. update empty content to text instance - where is B mode changed to D by frame
  3. linux的自定义input,linux键值到Android键值的转换与自定义
  4. rust建的怎么拆除_罗志祥私建泳池已拆除,后续还将接受物业的监督与教育
  5. 16.What is pass in Python?
  6. 结对-结对编项目作业名称-设计文档
  7. Struts2.0 xml文件的配置(package,namespace,action)
  8. java飞机订票系统课程设计_JAVA数据结构课程设计,航空订票系统求助
  9. 简直太强,把任意图片设置为鼠标指针
  10. 快递公司面单纸张标准
  11. 黑客第二课:脱屌第一步(主要讲unix-like系统的初步知识)
  12. 移动流量转赠给好友_中国移动怎样转赠手机流量?月结流量用不完怎办
  13. 非常全面的贝叶斯网络介绍 非常多的例子说明
  14. webstorm-主题和配色
  15. 词云生成库WordCloud详解(一):概述、ImageColorGenerator类
  16. idea中ctrl+shift+f(在文件中查找)失效,全图文解决方案
  17. qq申诉网站无法接到服务器,为什么我qq申诉不成功 - 卡饭网
  18. LeetCode 344.Reverse String
  19. 各国际会议的影响因子
  20. Xcode 11的问题及 Xcode 11 beta 1和beta 2 版下载链接, 官方下载后上传到百度网盘的.

热门文章

  1. java登陆注册 mysql_Java+mysql用户注册登录功能
  2. 机器学习中常用的几何距离测量和统计距离测量方法总结
  3. java升序排列数组_java 数组升序排列
  4. 用行列式的定义方法求解n阶行列式的值(C++)
  5. 焱融全闪存储轻松构建百亿私募量化投研平台
  6. 2020 dns排名_《2020年全球DNS威胁报告》:DNS攻击平均损失高达92万美元
  7. vlookup多项匹配_Excel 怎样用VLOOKUP匹配多列数据/excle全部筛选匹配
  8. 11.8版本更新公告:灵罗娃娃 格温登场
  9. Could not open a connection to SQL Server [53]
  10. iphone8引发的AR大事件