前言

MNIST数据集由250个不同的人手写而成,总共有7000张手写数据集。其中训练集有6000张,测试集有1000张。每张图片大小为28x28,或者说是由28x28个像素组成。这章打算用一个简单的模型进行手写字符识别。

MNIST

下载MNIST数据集的方式有很多,可以去MNIST官网下载,也可以用函数api下载
官网下载网页为:http://yann.lecun.com/exdb/mnist/,复制链接打开之后可以在网页中看到以下信息,下图圈起来的就是数据集。


本文采用的是通过pytorch的函数下载

from torchvision import datasets
# 下载训练集,测试集
traindataset = datasets.MNIST(root="./data/",train=True,download=False)
testdataset = datasets.MNIST(root="./data/",train=False,download=False)

接下来将数据集保存为图片。

# 查看手写数据集Mnist,保存图片集
import torchvision
from torchvision import datasets
import cv2
from tqdm import tqdm
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed def download_save_img(img_message,index,path,train=True):# 这里用的是opencv保存图片,img_message是一个tuple,#其中tuple[0]是图片类别,tuple[1]是PIL格式的图片,用opencv保存的需要转为numpy格式img = np.array(img_message[0])img_class = img_message[1]cv2.imwrite(path+str(img_class)+"_"+str("train" if train else "test")+str(index)+".jpg",img)results = []
traindataset = datasets.MNIST(root="./data/",train=True,download=False)
# 打印训练集数量
print(len(traindataset))
# 多线程下保存图片
with ThreadPoolExecutor(max_workers=None) as t:for index, img_message in enumerate(traindataset): # "./data/MNIST_ori_img/"results.append(t.submit(download_save_img, img_message, index, "./data/MNIST_ori_img/", train=True))for result in tqdm(as_completed(results),total=len(results),desc = "train"):pass
results = []
testdataset = datasets.MNIST(root="./data/",train=False,download=False)
print(len(testdataset))
with ThreadPoolExecutor(max_workers=None) as t:for index, img_message in enumerate(testdataset): # "./data/MNIST_ori_img/"results.append(t.submit(download_save_img, img_message, index, "./data/MNIST_ori_img/", train=False))for result in tqdm(as_completed(results),total=len(results),desc = "test"):pass

可以在文件所在目录下/data/ 查看手写数字原图,图片为黑白手写图集,具体可看下图

训练

构建一个由2个激活层和全连接层组成的模型
激活函数可以引入非线性因素,为什么要引入非线性因素呢,主要是因为我们所要解决的问题(识别手写字符)是一个非线性问题,引入非线性因素可以更有效地解决非线性问题。
全连接层主要作用是分类,将学到的“分布式特征表示”映射到样本标记空间

import torch
import torch.nn as nn
import numpy as npclass Cnn(nn.Module):def __init__(self, class_num):super(Cnn, self).__init__()self.flatten = nn.Flatten()self.relu = nn.ReLU()self.relu1 = nn.ReLU()self.linear_out = nn.Linear(784,class_num)def forward(self, x):x = self.flatten(x)x = self.relu(x)x = self.relu1(x)x = self.linear_out(x)return x
# 手写字符从0到9 总共有10个类别 实例化模型
cnn = Cnn(class_num = 10)
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torch
from torchvision import datasets
import torchvisiondevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 设置batchsize
bs = 8
# 设置多进程加载数据
nw = 4
# 设置训练迭代
epoches = 10
# 加载数据集
train_set = datasets.MNIST(root="./data",train=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))]),download=True)
train_loader = DataLoader(dataset=train_set,batch_size=bs, shuffle=True, num_workers=nw)
val_set = datasets.MNIST(root="./data",train=False,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))]),download=True)
val_loader = DataLoader(dataset=val_set,batch_size=bs, shuffle=True, num_workers=nw)# 训练 使用Adam作为优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.1)
# 损失函数使用交叉熵损失函数
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(epoches):train_loss = 0train_correct = 0val_loss = 0val_correct = 0for inputs, labels in tqdm(train_loader):cnn.to(device).train()inputs = inputs.to(device)labels = labels.to(device)outputs = cnn(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs,labels)loss.backward()optimizer.step()## loss计算train_loss += loss.item() * inputs.size(0)train_correct += torch.sum(preds == labels.data)for inputs, labels in tqdm(val_loader):cnn.to(device).eval()inputs = inputs.to(device)labels = labels.to(device)with torch.no_grad():outputs = cnn(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs,labels)## loss计算val_loss += loss.item() * inputs.size(0)val_correct += torch.sum(preds == labels.data)train_losses = train_loss / len(train_loader.dataset)train_acc = float(train_correct) / len(train_loader.dataset)valid_losses = val_loss / len(val_loader.dataset)valid_acc = float(val_correct) / len(val_loader.dataset)print("epoch: {},  train_loss is: {}, train_acc is: {}, val_loss: {}, val_acc: {}".format(epoch,train_losses,train_acc,valid_losses,valid_acc))

结果

不采用卷积神经网络的情况下,准确率在89%左右

手写数字识别_MNIST数据集相关推荐

  1. Pytorch实战1:LeNet手写数字识别 (MNIST数据集)

    版权说明:此文章为本人原创内容,转载请注明出处,谢谢合作! Pytorch实战1:LeNet手写数字识别 (MNIST数据集) 实验环境: Pytorch 0.4.0 torchvision 0.2. ...

  2. Python 手写数字识别 MNIST数据集下载失败

    目录 一.MNIST数据集下载失败 1 失败的解决办法(经验教训): 2 亲测有效的解决方法: 一.MNIST数据集下载失败 场景复现:想要pytorch+MINIST数据集来实现手写数字识别,首先就 ...

  3. 智科模式识别期末大课设:多种方法对数据集进行手写数字识别(数据集:MINIST)

    0结课作业内容 (1)程序编写及报告. 请大家下载70000个样本的MNIST数据集("手写体数字70000.zip",28*28像素),60000个用于训练,10000个用于测试 ...

  4. 使用Pytorch实现手写数字识别(Mnist数据集)

    目标 知道如何使用Pytorch完成神经网络的构建 知道Pytorch中激活函数的使用方法 知道Pytorch中torchvision.transforms中常见图形处理函数的使用 知道如何训练模型和 ...

  5. 手写数字识别MNIST数据集下载百度网盘链接快速下载

    介绍 MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片. 下载 官方链接:http://ya ...

  6. 手写数字识别的数据集讲解

    CLASS torchvision.datasets.MNIST(root: str, train: bool = True, transform: Optional[Callable] = None ...

  7. python实现lenet_吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集...

    importtensorflow as tf tf.reset_default_graph()#配置神经网络的参数 INPUT_NODE = 784OUTPUT_NODE= 10IMAGE_SIZE= ...

  8. 吴裕雄 python 神经网络——TensorFlow实现AlexNet模型处理手写数字识别MNIST数据集...

    import tensorflow as tf# 输入数据 from tensorflow.examples.tutorials.mnist import input_datamnist = inpu ...

  9. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

最新文章

  1. React 项目----内联样式style的使用 (12)
  2. NHibernate应用二:第一个NHibernate程序
  3. vsftpd登录报530
  4. 起步,停车——走好你的IT运维管理之路
  5. 生成二维码的 jQuery 插件:jquery.qrcode.js
  6. 码云怎么创建公开的仓库_使用码云或GitHub搭建简单的个人网站(补充hexo搭建博客)...
  7. Mysql库及表的基本概念、增删查改操作以及表的约束、多表联查
  8. 黄章爆料魅族16s/16s Plus更多信息 无线充电已做到24W
  9. pca各个向量之间的相关度_PCA算法原理及实现
  10. 保险行业持续扩展,巨杉数据库再次中标人保财险
  11. linux下常用vim命令
  12. 香农编码,哈夫曼编码与费诺编码的比较
  13. 基于android的希腊字母读音手机软件,希腊字母表app下载
  14. Windows---diskpart命令的使用
  15. matlab单位阶跃响应与单位脉冲响应,python 已知响应函数求单位阶跃响应或脉冲响应...
  16. 启动服务提示端口已存在的处理方法
  17. K - 链表的有序集合_Java
  18. R语言股票市场指数:ARMA-GARCH模型和对数收益率数据探索性分析
  19. 取带runas的一些优秀小工具介绍
  20. 微信小程序之流星雨个人页

热门文章

  1. Unity 判断目标是否在左边或右边
  2. MFC--DDV与DDX对比
  3. linux红帽安装qq,Linux如何安装QQ软件_Centos_redhat_ubuntu
  4. 我的围棋二十年――业余菜鸟的成长故事
  5. python聚宽量化_聚宽量化交易Portfolio与Context对象学习笔记
  6. linux中uniq c命令详解,linux uniq 命令整理
  7. 【代码随想录二刷】day 25 | 216.组合总和III 17.电话号码的字母组合
  8. 《白话统计》学习笔记之方差分析与变异分解
  9. PHP断点调试技术(Xdebug)-李明-专题视频课程
  10. Swift初始化(Initialization)