MNIST手写数字识别教程

本文仅仅放出该教程的代码
具体教程请看 Pytorch入门——手把手教你MNIST手写数字识别

import torch
import torchvision
from tqdm import tqdm
import matplotlib#By: Elwin https://editor.csdn.net/md?not_checkout=1&articleId=112980305class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.model = torch.nn.Sequential(#The size of the picture is 28x28torch.nn.Conv2d(in_channels = 1,out_channels = 16,kernel_size = 3,stride = 1,padding = 1),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size = 2,stride = 2),#The size of the picture is 14x14torch.nn.Conv2d(in_channels = 16,out_channels = 32,kernel_size = 3,stride = 1,padding = 1),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size = 2,stride = 2),#The size of the picture is 7x7torch.nn.Conv2d(in_channels = 32,out_channels = 64,kernel_size = 3,stride = 1,padding = 1),torch.nn.ReLU(),torch.nn.Flatten(),torch.nn.Linear(in_features = 7 * 7 * 64,out_features = 128),torch.nn.ReLU(),torch.nn.Linear(in_features = 128,out_features = 10),torch.nn.Softmax(dim=1))def forward(self,input):output = self.model(input)return outputdevice = "cuda:0" if torch.cuda.is_available() else "cpu"
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean = [0.5],std = [0.5])])BATCH_SIZE = 256
EPOCHS = 10
trainData = torchvision.datasets.MNIST('./data/',train = True,transform = transform,download = True)
testData = torchvision.datasets.MNIST('./data/',train = False,transform = transform)trainDataLoader = torch.utils.data.DataLoader(dataset = trainData,batch_size = BATCH_SIZE,shuffle = True)
testDataLoader = torch.utils.data.DataLoader(dataset = testData,batch_size = BATCH_SIZE)
net = Net()
print(net.to(device))lossF = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())history = {'Test Loss':[],'Test Accuracy':[]}
for epoch in range(1,EPOCHS + 1):processBar = tqdm(trainDataLoader,unit = 'step')net.train(True)for step,(trainImgs,labels) in enumerate(processBar):trainImgs = trainImgs.to(device)labels = labels.to(device)net.zero_grad()outputs = net(trainImgs)loss = lossF(outputs,labels)predictions = torch.argmax(outputs, dim = 1)accuracy = torch.sum(predictions == labels)/labels.shape[0]loss.backward()optimizer.step()processBar.set_description("[%d/%d] Loss: %.4f, Acc: %.4f" % (epoch,EPOCHS,loss.item(),accuracy.item()))if step == len(processBar)-1:correct,totalLoss = 0,0net.train(False)with torch.no_grad():for testImgs,labels in testDataLoader:testImgs = testImgs.to(device)labels = labels.to(device)outputs = net(testImgs)loss = lossF(outputs,labels)predictions = torch.argmax(outputs,dim = 1)totalLoss += losscorrect += torch.sum(predictions == labels)testAccuracy = correct/(BATCH_SIZE * len(testDataLoader))testLoss = totalLoss/len(testDataLoader)history['Test Loss'].append(testLoss.item())history['Test Accuracy'].append(testAccuracy.item())processBar.set_description("[%d/%d] Loss: %.4f, Acc: %.4f, Test Loss: %.4f, Test Acc: %.4f" % (epoch,EPOCHS,loss.item(),accuracy.item(),testLoss.item(),testAccuracy.item()))processBar.close()matplotlib.pyplot.plot(history['Test Loss'],label = 'Test Loss')
matplotlib.pyplot.legend(loc='best')
matplotlib.pyplot.grid(True)
matplotlib.pyplot.xlabel('Epoch')
matplotlib.pyplot.ylabel('Loss')
matplotlib.pyplot.show()matplotlib.pyplot.plot(history['Test Accuracy'],color = 'red',label = 'Test Accuracy')
matplotlib.pyplot.legend(loc='best')
matplotlib.pyplot.grid(True)
matplotlib.pyplot.xlabel('Epoch')
matplotlib.pyplot.ylabel('Accuracy')
matplotlib.pyplot.show()torch.save(net,'./model.pth')

Pytorch入门——MNIST手写数字识别代码相关推荐

  1. Pytorch实现mnist手写数字识别

    2020/6/29 Hey,突然想起来之前做的一个入门实验,用pytorch实现mnist手写数字识别.可以在这个基础上增加网络层数,或是尝试用不同的数据集,去实现不一样的功能. Mnist数据集如图 ...

  2. 用PyTorch实现MNIST手写数字识别(非常详细)

    ​​​​​Keras版本: Keras入门级MNIST手写数字识别超级详细教程 2022/4/17 更新修复下代码.完善优化下文章结构,文末提供一个完整版代码. 可以在这里下载源码文件(免积分): 用 ...

  3. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

  4. 基于Pytorch的MNIST手写数字识别实现(含代码+讲解)

    说明:本人也是一个萌新,也在学习中,有代码里也有不完善的地方.如果有错误/讲解不清的地方请多多指出 本文代码链接: GitHub - Michael-OvO/mnist: mnist_trained_ ...

  5. 卷积神经网络mnist手写数字识别代码_搭建经典LeNet5 CNN卷积神经网络对Mnist手写数字数据识别实例与注释讲解,准确率达到97%...

    LeNet-5卷积神经网络是最经典的卷积网络之一,这篇文章就在LeNet-5的基础上加入了一些tensorflow的有趣函数,对LeNet-5做了改动,也是对一些tf函数的实例化笔记吧. 环境 Pyc ...

  6. Pytorch+CNN+MNIST手写数字识别实战

    文章目录 1.MNIST 2.数据预处理 2.1相关包 2.2数据载入和预处理 3.网络结构 4.优化器.损失函数.网络训练以及可视化分析 4.1定义优化器 4.2网络训练 4.3可视化分析 5.测试 ...

  7. PyTorch入门一:卷积神经网络实现MNIST手写数字识别

    先给出几个入门PyTorch的好的资料: PyTorch官方教程(中文版):http://pytorch123.com <动手学深度学习>PyTorch版:https://github.c ...

  8. 深度学习入门项目:PyTorch实现MINST手写数字识别

    完整代码下载[github地址]:https://github.com/lmn-ning/MNIST_PyTorch.git 目录 一.MNIST数据集介绍及下载地址 二.代码结构 三.代码 data ...

  9. 使用PYTORCH复现ALEXNET实现MNIST手写数字识别

    网络介绍: Alexnet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在当年夺下了不少比赛的冠军,下面是Alexnet的网络结构: 网络结构较为简单,共有五个卷积层和三个全连接层,原 ...

最新文章

  1. html frameset
  2. 一个Ext2+SWFUpload做的图片上传对话框
  3. 写了个Linux包过滤防火墙
  4. 面试字节跳动,我被怼了……
  5. mysql MHA高可用架构安装
  6. 计算机组成原理实验报告西华大学,计算机组成原理实验报告算术逻辑运算单元实验...
  7. mysql优化参数设置_MySQL服务优化参数设置参考
  8. Shell中的特殊字符
  9. html获取节点属性,JS操作属性节点(非常详细)
  10. JAVA day08 接口(interface),多态,instanceof
  11. 统一返回码,返回结果实体类
  12. 大数据分析,在中国找个身高1米7年入20万的老公,到底有多难?
  13. 多传感器融合算法,基于Lidar,Radar,Camera算法
  14. PS2022 安装教程
  15. SpringCloud入门实例
  16. 蓝牙笔记《蓝牙技术基础》
  17. vivado+vscode
  18. PTA-整除光棍(C语言)
  19. 动态系统建模与分析_伯德图
  20. JavaScript对象与JSON格式的转换

热门文章

  1. uniapp开发h5下载pdf
  2. 飞熊观察:ChatGPT不是取代元宇宙,而是丰富元宇宙内容
  3. vuepress(五)部署到github.io
  4. visual studio 加入zen-codding
  5. 利用正方教务漏洞抓取正方教务学生照片
  6. 大数据在开发的过程中,主要会遇到哪些难点?
  7. 汇顶gt系列触摸屏幕方向调试
  8. windows备份环境变量(导入导出)
  9. win10禁用自动更新(批处理方法)巨简单,又能解决已禁用Windows Update还是更新的问题
  10. 织梦cms在线生成html,织梦CMS标签生成器