MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下.

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.

不妨新建一个文件夹 – mnist, 将数据集下载到 mnist 以后, 解压即可:

图片是以字节的形式进行存储, 我们需要把它们读取到 NumPy array 中, 以便训练和测试算法.

import os
import struct
import numpy as npdef load_mnist(path, kind='train'):"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte'% kind)with open(labels_path, 'rb') as lbpath:magic, n = struct.unpack('>II',lbpath.read(8))labels = np.fromfile(lbpath,dtype=np.uint8)with open(images_path, 'rb') as imgpath:magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)return images, labels

load_mnist 函数返回两个数组, 第一个是一个 n x m 维的 NumPy array(images), 这里的 n 是样本数(行数), m 是特征数(列数). 训练数据集包含 60,000 个样本, 测试数据集包含 10,000 样本. 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示. 在这里, 我们将 28 x 28 的像素展开为一个一维的行向量, 这些行向量就是图片数组里的行(每行 784 个值, 或者说每行就是代表了一张图片). load_mnist 函数返回的第二个数组(labels) 包含了相应的目标变量, 也就是手写数字的类标签(整数 0-9).

第一次见的话, 可能会觉得我们读取图片的方式有点奇怪:

magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)

为了理解这两行代码, 我们先来看一下 MNIST 网站上对数据集的介绍:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

通过使用上面两行代码, 我们首先读入 magic number, 它是一个文件协议的描述, 也是在我们调用 fromfile 方法将字节读入 NumPy array 之前在文件缓冲中的 item 数(n). 作为参数值传入 struct.unpack>II 有两个部分:

  • >: 这是指大端(用来定义字节是如何存储的); 如果你还不知道什么是大端和小端, Endianness 是一个非常好的解释. (关于大小端, 更多内容可见<<深入理解计算机系统 – 2.1 节信息存储>>)
  • I: 这是指一个无符号整数.

通过执行下面的代码, 我们将会从刚刚解压 MNIST 数据集后的 mnist 目录下加载 60,000 个训练样本和 10,000 个测试样本.

为了了解 MNIST 中的图片看起来到底是个啥, 让我们来对它们进行可视化处理. 从 feature matrix 中将 784-像素值 的向量 reshape 为之前的 28*28 的形状, 然后通过 matplotlib 的 imshow 函数进行绘制:

import matplotlib.pyplot as pltfig, ax = plt.subplots(nrows=2,ncols=5,sharex=True,sharey=True, )ax = ax.flatten()
for i in range(10):img = X_train[y_train == i][0].reshape(28, 28)ax[i].imshow(img, cmap='Greys', interpolation='nearest')ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

我们现在应该可以看到一个 2*5 的图片, 里面分别是 0-9 单个数字的图片.

此外, 我们还可以绘制某一数字的多个样本图片, 来看一下这些手写样本到底有多不同:

fig, ax = plt.subplots(nrows=5,ncols=5,sharex=True,sharey=True, )ax = ax.flatten()
for i in range(25):img = X_train[y_train == 7][i].reshape(28, 28)ax[i].imshow(img, cmap='Greys', interpolation='nearest')ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

执行上面的代码后, 我们应该看到数字 7 的 25 个不同形态:

另外, 我们也可以选择将 MNIST 图片数据和标签保存为 CSV 文件, 这样就可以在不支持特殊的字节格式的程序中打开数据集. 但是, 有一点要说明, CSV 的文件格式将会占用更多的磁盘空间, 如下所示:

  • train_img.csv: 109.5 MB
  • train_labels.csv: 120 KB
  • test_img.csv: 18.3 MB
  • test_labels: 20 KB

如果我们打算保存这些 CSV 文件, 在将 MNIST 数据集加载入 NumPy array 以后, 我们应该执行下列代码:

np.savetxt('train_img.csv', X_train,fmt='%i', delimiter=',')
np.savetxt('train_labels.csv', y_train,fmt='%i', delimiter=',')
np.savetxt('test_img.csv', X_test,fmt='%i', delimiter=',')
np.savetxt('test_labels.csv', y_test,fmt='%i', delimiter=',')

一旦将数据集保存为 CSV 文件, 我们也可以用 NumPy 的 genfromtxt 函数重新将它们加载入程序中:

X_train = np.genfromtxt('train_img.csv',dtype=int, delimiter=',')
y_train = np.genfromtxt('train_labels.csv',dtype=int, delimiter=',')
X_test = np.genfromtxt('test_img.csv',dtype=int, delimiter=',')
y_test = np.genfromtxt('test_labels.csv',dtype=int, delimiter=',')

不过, 从 CSV 文件中加载 MNIST 数据将会显著发给更长的时间, 因此如果可能的话, 还是建议你维持数据集原有的字节格式.

参考:
- Book , Python Machine Learning.

详解 MNIST 数据集相关推荐

  1. Pytorch入门--详解Mnist手写字识别

    1 什么是Mnist?         Mnist是计算机视觉领域中最为基础的一个数据集. MNIST数据集(Mixed National Institute of Standards and Tec ...

  2. 【实际操作】DenseFusion复现过程详解-YCB-Video数据集

    DenseFusion系列代码全讲解目录:[DenseFusion系列目录]代码全讲解+可视化+计算评估指标_Panpanpan!的博客-CSDN博客 这些内容均为个人学习记录,欢迎大家提出错误一起讨 ...

  3. Netflix Prize数据集详解及数据集下载链接

    Netflix数据集包含了1999.12.31-2005.12.31期间匿名客户提供的超过一亿部电影平级.这个数据集大约给出了480189个用户和17770部电影评级.数据集中的详细信息如下图所示: ...

  4. tensorflow入门数据集:mnist详解

    文章目录 python处理二进制 mnist介绍 mnist显示 方法一:读取解压后的原始文件 方法二:使用TensorFlow封装代码读取 需求一:同时显示图片和标签,验证图片和标签一一对应 需求二 ...

  5. 基于pytorch的MNIST数据集的四层CNN,测试准确率99.77%

    基于pytorch的MNIST数据集的四层CNN,测试准确率99.77% MNIST数据集 环境配置 文件存储结构 代码 引入库 调用GPU 初始化变量 导入数据集并进行数据增强 导入测试集 加载测试 ...

  6. python数据集的预处理_关于Pytorch的MNIST数据集的预处理详解

    关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu ...

  7. 全卷积神经网路【U-net项目实战】LUNA 2016 数据集详解

    文章目录 1.LUNA 2016 数据集详解 2.mdh数据格式详解 3.python读取mdh的方法 4.annotations.csv坐标转换 5.LUNA16数据集肺结节显示 1.LUNA 20 ...

  8. LUNA 2016 数据集详解

    LUNA 2016 数据集详解 LUNA16数据集的由来 LUNA 2016 数据集来自2016年LUng Nodule Analysis比赛,这里是其官方网站. LUNA16数据集是最大公用肺结节数 ...

  9. ILSVRC2015_VID数据集详解

    数据集下载地址:http://bvisionweb1.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID.tar.gz 总说: 数据集包括3862 snippets用于训练,55 ...

最新文章

  1. Maven报错解决:Element 'dependency' cannot have character [children], because the type's content type is
  2. 学习笔记(四)——JavaScript(一)
  3. [转]jquery的一个模板引擎-zt
  4. 算法 --- 记一道面试dp算法题
  5. setTimeout详解
  6. antd request 通过jsessionid传参数_Umi-request源码阅读
  7. python比较三个数_python经典练习题(三)
  8. html玫瑰花效果代码,html5渲染3D玫瑰花情人节礼物js特效代码
  9. ./mysql-bin.index_MySQL 启动报错:File ./mysql-bin.index not found (Errcode: 13)
  10. 一文了解授信审批策略及流程
  11. 普通人如何月入10万
  12. python爬取qq好友网络状态_Python爬虫实战----爬取QQ空间好友说说并生成词云(超详细)...
  13. PHP 进程间通信——消息队列(msg_queue)
  14. 仿Android 5.0 侧滑菜单按钮动画 以及侧滑菜单联动
  15. html如何改成花体英文字体,花体英文在线转换
  16. USB3300速度调试
  17. 带宽与响应速度的关系
  18. 判断移动终端是安卓还是iOS
  19. 《西部世界》会成真吗? 人类如何避免被机器人干掉的未来?
  20. 网络安全[脚本小子] -- SSI注入

热门文章

  1. 【pynq-z2】初始配置
  2. Baiduman的经历
  3. kali命令行连接wifi
  4. python range函数返回的是什么,python中range函数用法是什么
  5. 神经网络 -- 百科
  6. Blazor_WASM之1:Blazor概述
  7. linux busybox路径,BusyBox构建根文件系统
  8. Ogre 材质与材质脚本
  9. 轻松访问Google Chrome浏览器的特殊页面
  10. 使用 Vue.js 构建 VS Code 扩展