一、总体介绍

1.1 什么是机器识别手写数字?

1.2 MNIST数据集是什么?

(1)该数据集包含60,000个用于训练的示例和10,000个用于测试的示例。
(2)数据集包含了0-9共10类手写数字图片,每张图片都做了尺寸归一化,都是28x28大小的灰度图。
(3)MNIST数据集包含四个部分:
训练集图像:train-images-idx3-ubyte.gz(9.9MB,包含60000个样本)
训练集标签:train-labels-idx1-ubyte.gz(29KB,包含60000个标签)
测试集图像:t10k-images-idx3-ubyte.gz(1.6MB,包含10000个样本)
测试集标签:t10k-labels-idx1-ubyte.gz(5KB,包含10000个标签)
下载地址MNIST

1.3 手写字体的识别流程

( 1)定义超参数;
(2〉构建transforms,主要是对图像做变换;
(3)下载、加载数据集MNIST;
(4)构建网络模型;
(5)定义训练方法;
(6)定义测试方法;
(7)开始训练模型,输出预测结果;

二、代码实现

2.1 导入相应的库

# 1 加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms

2.2 定义超参数

# 2 定义超参数
BATCH_SIZE = 16  # 每批处理的数据
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 是否用GPU还是CPU训练
EPOCHS = 10 # 训练数据集的轮次

2.3 构建pipeline,对图像做处理

# 3 构建pipeline,对图像做处理
pipeline = transforms.Compose([transforms.ToTensor(),# 将图片转换成tensortransforms.Normalize((0.1307,),(0.3081,)) # 正则化降低模型复杂度
])

2.4 下载、加载数据

# 4 下载、加载数据
from torch.utils.data import DataLoader# 下载数据集
train_set = datasets.MNIST("data", train=True, download=False, transform=pipeline)test_set = datasets.MNIST("data", train=False, download=False, transform=pipeline)# 加载数据
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)  # 顺序打乱shuffle=Truetest_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

2.5 查看一下数据集图片(可跳过)

## 插入代码,显示MNIST中的图片
with open("./data/MNIST/raw/train-images-idx3-ubyte","rb") as f:file = f.read()
imagel = [int(str(item).encode('ascii'),16) for item in file[16 : 16+784]]
print(imagel)
import cv2
import numpy as npimagel_np = np.array(imagel, dtype=np.uint8).reshape(28, 28, 1)print(imagel_np.shape)
cv2.imwrite("gigit.jpg", imagel_np)

2.6 构建网络模型

# 5 构建网络模型
class Digit(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 10, 5) # 1:灰度图片的通道, 10:输出通道, 5:kernel 5x5self.conv2 = nn.Conv2d(10, 20, 3) # 10:输入通道, 20:输出通道, 3:kernel 3x3self.fc1 = nn.Linear(20*10*10, 500) # 20*10*10:输入通道, 500:输出通道self.fc2 = nn.Linear(500, 10) # 500:输入通道, 10:输出通道#前向传播def forward(self, x):input_size = x.size(0)  # batch_sizex = self.conv1(x) # 输入:batch*1*28*28, 输出:batch*10*24*24 (28 - 5 + 1 = 24)x = F.relu(x)  # 激活函数,保持shape不变, 输出batch*10*24*24x = F.max_pool2d(x, 2, 2) # 输入:batch*10*24*24 输出:batch*10*12*12x = self.conv2(x) # 输入:batch*10*12*12, 输出:batch*20*10*10 (12 - 3 + 1 = 10)x = F.relu(x)x = x.view(input_size, -1) # 拉平, -1 自动计算维度,20*10*10 = 2000x = self.fc1(x) # 输入:batch*2000,输出:batch*500x = F.relu(x)x = self.fc2(x) # 输入:batch*500,输出:batch*10output = F.log_softmax(x, dim=1) # 计算分类后,每个数字的概率值return output

2.6 定义优化器(更新参数,是训练测试结果达到最优值)

model = Digit().to(DEVICE)optimizer = optim.Adam(model.parameters())

2.7 定义训练方法

# 7 定义训练方法
def train_model(model, device, train_loader, optimizer, epoch):# 模型训练model.train()for batch_index, (data, target) in enumerate(train_loader):# 部署到DEVICE上去data, target = data.to(device), target.to(device)# 梯度初始化为0optimizer.zero_grad()# 训练后的结果output = model(data)# 计算损失loss = F.cross_entropy(output, target) # 交叉熵用于分类比较多的情况# 反向传播loss.backward()# 参数优化optimizer.step()if batch_index % 3000 == 0:print("Train Epoch : {} \t Loss : {:.6f}".format(epoch, loss.item()))

2.8 定义测试方法

# 8 定义测试方法
def test_model(model, device, test_loader):# 模型验证model.eval()# 正确率correct = 0.0# 测试损失test_loss = 0.0with torch.no_grad(): # 不会计算梯度,也不会反向传播for data, target in test_loader:# 部署到device上data, target = data.to(device), target.to(device)# 测试数据output = model(data)# 计算测试损失test_loss += F.cross_entropy(output, target).item()# 找到概率值最大的下标pred = output.max(1, keepdim=True)[1] # 值,索引# pred = torch.max(output, dim=1)# pred = output.argmax(dim=1)#累计正确的值correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print("Test — Average loss : {:.4f}, Accuracy : {:.3f}\n".format(test_loss, 100.0 * correct / len(test_loader.dataset)))

2.9 调用 方法7 / 8

# 9 调用 方法7 / 8
for epoch in range(1, EPOCHS +1):train_model(model, DEVICE, train_loader, optimizer, epoch)test_model(model, DEVICE, test_loader)

2.10 结果

三、总结

本文观看b站的“唐国梁Tommy”up主的轻松学 PyTorch 手写字体识别 MNIST,讲解十分详细,推荐观看。

MNIST手写体数字识别数据集相关推荐

  1. Tensorflow 改进的MNIST手写体数字识别

    上篇简单的Tensorflow解决MNIST手写体数字识别可扩展性并不好.例如计算前向传播的函数需要将所有的变量都传入,当神经网络的结构变得复杂.参数更多时,程序的可读性变得非常差.而且这种方式会导致 ...

  2. Tensorflow解决MNIST手写体数字识别

    这里给出的代码是来自<Tensorflow实战Google深度学习框架>,以供参考和学习. 首先这个示例应用了几个基本的方法: 使用随机梯度下降(batch) 使用Relu激活函数去线性化 ...

  3. 基于MNIST手写体数字识别--含可直接使用代码【Python+Tensorflow+CNN+Keras】

    基于MNIST手写体数字识别--[Python+Tensorflow+CNN+Keras] 1.任务 2.数据集分析 2.1 数据集总体分析 2.2 单个图片样本可视化 3. 数据处理 4. 搭建神经 ...

  4. 全连神经网络的经典实战--MNIST手写体数字识别

    mnist数据集 MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:它也包含每一张图片对应的标签,告诉我们这个是数字几.比如,上面这四张图片的标签分别是5,0,4,1. 在本章中,我们 ...

  5. 计算机视觉:mnist手写体数字识别

    一.mnist数据描述 MNIST数据集是28×28像素的灰度手写数字图片,其中数字的范围从0到9 具体如下所示(参考自Tensorflow官方文档): 二.原理   受Hubel和Wiesel对猫视 ...

  6. 支持向量机(SVM)实现MNIST手写体数字识别

    一.SVM算法简述 支持向量机即Support Vector Machine,简称SVM.一听这个名字,就有眩晕的感觉.支持(Support).向量(Vector).机器(Machine),这三个毫无 ...

  7. 随机森林算法(RandomForest)实现MNIST手写体数字识别

    一.准备: 第三方库 sklearn 二.代码: # -*- coding: utf-8 -*- # @Time : 2018/8/21 9:35 # @Author : Barry # @File ...

  8. 基于TensorFlow的手写体数字识别

    目录 一.MNIST数据集介绍 二.原理 2.1.卷积神经网络简介( convolutional neural network 简称CNN) 2.1.1卷积运算过程 2.1.2滑动的步长 2.1.3卷 ...

  9. keras框架下的深度学习(一)手写体数字识别

    文章目录 前言 一.keras的介绍及其操作使用 二.手写题数字识别 1.介绍 2.对数据的预处理 3.搭建网络框架 4.编译 5.循环训练 6.测试训练的网络模 7.总代码 三.附:梯度下降算法 1 ...

最新文章

  1. 渗透知识-SQL注入
  2. linux下如何查看主机是否安装了ftp server
  3. java中逗号怎么加_Java中如何将字符串从右至左每三位加一逗号
  4. java非堆内存_java – 监视JVM的非堆内存使用情况
  5. jdbc mysql api_JDBC Api详解
  6. Java IO ---学习笔记(数据流)
  7. Retrofit2源码分析(一)
  8. udp linux 获取本机ip
  9. pwdx与netstat、lsof结合查找进程号是由哪个程序启动的
  10. PyTorch中文教程 | (4) 迁移学习教程
  11. iOS--在线搜索苹果 App Store 应用商店
  12. 全国各地网站备案的通过时间表
  13. 基础LSB算法的matlab实现
  14. 戴尔t420服务器重装系统教程,教你联想t420系统重装步骤
  15. matlab 矩阵分行标准化,matlab矩阵标准化
  16. 糯米网电子商务模式:上线当天销售额600多万元的缘由
  17. Android开发技巧——Camera拍照功能
  18. ubuntu18.04 安装flash。
  19. 什么是大数据4v 指的是哪四个
  20. zcmu——1601: 卡斯丁狗去挖矿(01背包-三维数组)

热门文章

  1. 香港大学计算机科学学制,香港大学CS 计算机科学专业解析
  2. 作为程序员,外包到底值不值得去呢
  3. 【go】mac下brew升级golang
  4. php 2038年,PHP 处理大于2038年以后的日期
  5. 中国工程院院士评选结果公布,阿里王坚当选
  6. python库吐血整理
  7. 报名倒计时 | 有道技术沙龙,聊聊明星语音背后的故事
  8. 好书收藏:读书知多少
  9. Java 17的这些新特性,Java迈入新时代
  10. GoogleMap获取地图中心点位置信息