什么是MNIST数据集?
直接看图!!他就是一套手写数字的图片,他的规格是(1x28x28),单通道,横竖都是28个像素,下面是前20个图片。我们今天的目的就是用卷积神经网络去识别这些数字到底是多少。

话不多说,咱们边看代码边聊,相信你一定能和我一起看懂他。

import  torch
from torch.utils.data import DataLoader #我们要加载数据集的
from torchvision import transforms #数据的原始处理
from torchvision import datasets #pytorch十分贴心的为我们直接准备了这个数据集
import  torch.nn.functional as F#激活函数
import torch.optim as optimbatch_size = 64#我们拿到的图片是pillow,我们要把他转换成模型里能训练的tensor也就是张量的格式
transform = transforms.Compose([transforms.ToTensor()])#加载训练集,pytorch十分贴心的为我们直接准备了这个数据集,注意,即使你没有下载这个数据集
#在函数中输入download=True,他在运行到这里的时候发现你给的路径没有,就自动下载
train_dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)
#同样的方式加载一下测试集
test_dataset = datasets.MNIST(root='../data',  train=False, download=True, transform=transform)
test_loader = DataLoader(dataset=test_dataset, shuffle=False,  batch_size=batch_size)#接下来我们看一下模型是怎么做的
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()#定义了我们第一个要用到的卷积层,因为图片输入通道为1,第一个参数就是1#输出的通道为10,kernel_size是卷积核的大小,这里定义的是5x5的self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)#看懂了上面的定义,下面这个你肯定也能看懂self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)#再定义一个池化层self.pooling = torch.nn.MaxPool2d(2)#最后是我们做分类用的线性层self.fc = torch.nn.Linear(320, 10)#下面就是计算的过程def forward(self, x):# Flatten data from (n, 1, 28, 28) to (n, 784)batch_size = x.size(0) #这里面的0是x大小第1个参数,自动获取batch大小#输入x经过一个卷积层,之后经历一个池化层,最后用relu做激活x = F.relu(self.pooling(self.conv1(x)))#再经历上面的过程x = F.relu(self.pooling(self.conv2(x)))#为了给我们最后一个全连接的线性层用#我们要把一个二维的图片(实际上这里已经是处理过的)20x4x4张量变成一维的x = x.view(batch_size, -1) # flatten#经过线性层,确定他是0~9每一个数的概率x = self.fc(x)return xmodel = Net()#实例化模型
#把计算迁移到GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)#定义一个损失函数,来计算我们模型输出的值和标准值的差距
criterion = torch.nn.CrossEntropyLoss()
#定义一个优化器,训练模型咋训练的,就靠这个,他会反向的更改相应层的权重
optimizer = optim.SGD(model.parameters(),lr=0.1,momentum=0.5)#lr为学习率def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0):#每次取一个样本inputs, target = datainputs, target = inputs.to(device), target.to(device)#优化器清零optimizer.zero_grad()# 正向计算一下outputs = model(inputs)#计算损失loss = criterion(outputs, target)#反向求梯度loss.backward()#更新权重optimizer.step()#把损失加起来running_loss += loss.item()#每300次输出一下数据if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 2000))running_loss = 0.0def test():correct = 0total = 0with torch.no_grad():#不用算梯度for data in test_loader:inputs, target = datainputs, target = inputs.to(device), target.to(device)outputs = model(inputs)#我们取概率最大的那个数作为输出_, predicted = torch.max(outputs.data, dim=1)total += target.size(0)#计算正确率correct += (predicted == target).sum().item()print('Accuracy on test set: %d %% [%d/%d]' % (100 * correct / total, correct, total))if __name__=='__main__':for epoch in range(10):train(epoch)if epoch % 10 == 9:test()

结果:
[1, 300] loss: 0.054
[1, 600] loss: 0.018
[1, 900] loss: 0.014
[2, 300] loss: 0.011
[2, 600] loss: 0.009
[2, 900] loss: 0.009
[3, 300] loss: 0.008
[3, 600] loss: 0.007
[3, 900] loss: 0.007
[4, 300] loss: 0.006
[4, 600] loss: 0.006
[4, 900] loss: 0.006
[5, 300] loss: 0.005
[5, 600] loss: 0.005
[5, 900] loss: 0.005
[6, 300] loss: 0.005
[6, 600] loss: 0.004
[6, 900] loss: 0.005
[7, 300] loss: 0.004
[7, 600] loss: 0.004
[7, 900] loss: 0.004
[8, 300] loss: 0.003
[8, 600] loss: 0.003
[8, 900] loss: 0.004
[9, 300] loss: 0.003
[9, 600] loss: 0.003
[9, 900] loss: 0.003
[10, 300] loss: 0.003
[10, 600] loss: 0.003
[10, 900] loss: 0.003
Accuracy on test set: 99 % [9903/10000]

如果你觉得我说的那些你还是听不懂,去看看我前面发布的文章呀!!!!
如果你觉得我说的太罗嗦了,可以复制下面的代码,和上面的一样,只不过少了注释

import  torch
from torch.utils.data import DataLoader
from torchvision import transforms #数据的原始处理
from torchvision import datasets
import  torch.nn.functional as F#激活函数
import torch.optim as optimbatch_size = 64
transform = transforms.Compose([transforms.ToTensor()])train_dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../data',  train=False, download=True,  transform=transform)
test_loader = DataLoader(dataset=test_dataset, shuffle=False,  batch_size=batch_size)class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)self.pooling = torch.nn.MaxPool2d(2)self.fc = torch.nn.Linear(320, 10)def forward(self, x):# Flatten data from (n, 1, 28, 28) to (n, 784)batch_size = x.size(0)x = F.relu(self.pooling(self.conv1(x)))x = F.relu(self.pooling(self.conv2(x)))x = x.view(batch_size, -1) # flattenx = self.fc(x)return xmodel = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.1,momentum=0.5)#lr为学习率def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0):inputs, target = datainputs, target = inputs.to(device), target.to(device)optimizer.zero_grad()# forward + backward + updateoutputs = model(inputs)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 2000))running_loss = 0.0def test():correct = 0total = 0with torch.no_grad():for data in test_loader:inputs, target = datainputs, target = inputs.to(device), target.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, dim=1)total += target.size(0)correct += (predicted == target).sum().item()print('Accuracy on test set: %d %% [%d/%d]' % (100 * correct / total, correct, total))if __name__=='__main__':for epoch in range(10):train(epoch)if epoch % 10 == 9:test()

PyTorch学习(九)--用CNN模型识别手写数字数据集MNIST相关推荐

  1. 应用训练MNIST的CNN模型识别手写数字图片完整实例(图片来自网上)

    1 思考训练模型如何进行应用 通过CNN训练的MNIST模型如何应用来识别手写数字图片(图片来自网上)? 这个问题困扰了我2天,网上找的很多代码都是训练模型和调用模型包含在一个.py文件中,这样子每一 ...

  2. Keras搭建CNN(手写数字识别Mnist)

    MNIST数据集是手写数字识别通用的数据集,其中的数据是以二进制的形式保存的,每个数字是由28*28的矩阵表示的. 我们使用卷积神经网络对这些手写数字进行识别,步骤大致为: 导入库和模块 我们导入Se ...

  3. [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98%+

    [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98.8%+ 我们在博文,使用CNN做Kaggle比赛手写数字识别准确率99%+,在此基础之 ...

  4. CNN网络实现手写数字(MNIST)识别 代码分析

    CNN网络实现手写数字(MNIST)识别 代码分析(自学用) Github代码源文件 本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别 #导入需要的包 import num ...

  5. 利用CNN进行手写数字识别

    资源下载地址:https://download.csdn.net/download/sheziqiong/85884967 资源下载地址:https://download.csdn.net/downl ...

  6. CNN之手写数字识别(Handwriting Recognition)

    CNN之手写数字识别(Handwriting Recognition) 目录 CNN之手写数字识别(Handwriting Recognition) 1.常用的包 2.常见概念 3.手写数字识别器实现 ...

  7. Tensorflow.js||使用 CNN 识别手写数字

    Tensorflow官方的tesorflow.js实操课程 链接为:link 使用 CNN 识别手写数字 文章目录 使用 CNN 识别手写数字 1. 简介 2. 设置操作 3. 加载数据 4. 定义模 ...

  8. 基于CNN的手写数字识别

    基于CNN的手写数字识别 文章目录 基于CNN的手写数字识别 零. 写在之前 壹. 聊聊CNN 01. 什么是CNN 02. 为什么要有CNN 03. CNN模型 3.1 卷积层 3.2 池化层 3. ...

  9. 深度学习导论(5)手写数字识别问题步骤

    深度学习导论(5)手写数字识别问题步骤 手写数字识别分类问题具体步骤(Training an handwritten digit classification) 加载数据 显示训练集中的图片 定义神经 ...

  10. CNN识别手写数字-莫烦python

    搭建一个 CNN识别手写数字 前面跟着莫烦python/tensorflow教程完成了神经网络识别手写数字的代码,这一part是cnn识别手写数字的 import tensorflow as tf f ...

最新文章

  1. Struts2学习总结二
  2. 使用java+TestNG进行接口回归测试 1
  3. element-ui表单验证:用户名、密码、电话、邮箱
  4. Linux系统扩硬盘,Linux系统硬盘扩容
  5. 【恋上数据结构】贪心(最优装载、零钱兑换、0-1背包)、分治(最大连续子序列和、大数乘法)
  6. UICollectionViewController
  7. Debian 中使用apt-get update 出现NO_PUBKEY 解决方法
  8. AD库文件(元件库+封装库+3D模型)
  9. win7计算机图标 灰色不可选,win7系统aero主题灰色不可选怎么办|win7 aero灰色的解决方法...
  10. C语言运算符优先级(超详细)
  11. 【彻底学会】多级编号
  12. 铲雪车 骑马修栅栏 (欧拉路径和欧拉回路)
  13. 深度学习半自动标注_时下流行的深度学习数据标注工具
  14. oracle灾备同步_【oracle灾备方案系列】基于DDS的Oracle复制容灾方案(三)
  15. computer-06 其它
  16. 使用视频监控摄像头的语音对讲功能,在视频平台,配置符合GB28181协议国标视频对讲
  17. 云原生|斯人若彩虹,遇到方知有【Python代码实现】
  18. 数据库的完全依赖,部分依赖和传递依赖
  19. 软考高级系统架构设计师系列论文十三:论软件测试方法和工具的选用
  20. php 工厂模式运用实例,php工厂模式的实例

热门文章

  1. 笔试——字符串算法题——寻找最大回文子串
  2. 计算机vb输入框函数,VB基本函数大全
  3. Mac上配置svn diff为kdiff3
  4. OPNsense用户手册-用户管理
  5. App Store 上架审核指北【翻译】
  6. c语言 long int最大数,long整型的最大值跟处理器位数有关
  7. python if语句怎么结束_【Python】IF 条件语句总结
  8. android自动计时器,Android实现定时器的几种方法
  9. CSS中的字体背景和盒子模型
  10. [白话解析] Flink的Watermark机制