转自:https://blog.csdn.net/simple_the_best/article/details/75267863

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下.

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)

Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)

Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)

Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.

不妨新建一个文件夹 – mnist, 将数据集下载到 mnist 以后, 解压即可:

图片是以字节的形式进行存储, 我们需要把它们读取到 NumPy array 中, 以便训练和测试算法.

import os

import struct

import numpy as np

def load_mnist(path, kind='train'):

"""Load MNIST data from `path`"""

labels_path = os.path.join(path,

'%s-labels-idx1-ubyte'

% kind)

images_path = os.path.join(path,

'%s-images-idx3-ubyte'

% kind)

with open(labels_path, 'rb') as lbpath:

magic, n = struct.unpack('>II',

lbpath.read(8))

labels = np.fromfile(lbpath,

dtype=np.uint8)

with open(images_path, 'rb') as imgpath:

magic, num, rows, cols = struct.unpack('>IIII',

imgpath.read(16))

images = np.fromfile(imgpath,

dtype=np.uint8).reshape(len(labels), 784)

return images, labels

load_mnist 函数返回两个数组, 第一个是一个 n x m 维的 NumPy array(images), 这里的 n 是样本数(行数), m 是特征数(列数). 训练数据集包含 60,000 个样本, 测试数据集包含 10,000 样本. 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示. 在这里, 我们将 28 x 28 的像素展开为一个一维的行向量, 这些行向量就是图片数组里的行(每行 784 个值, 或者说每行就是代表了一张图片). load_mnist 函数返回的第二个数组(labels) 包含了相应的目标变量, 也就是手写数字的类标签(整数 0-9).

第一次见的话, 可能会觉得我们读取图片的方式有点奇怪:

magic, n = struct.unpack('>II', lbpath.read(8))

labels = np.fromfile(lbpath, dtype=np.uint8)

为了理解这两行代码, 我们先来看一下 MNIST 网站上对数据集的介绍:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):

[offset] [type] [value] [description]

0000 32 bit integer 0x00000801(2049) magic number (MSB first)

0004 32 bit integer 60000 number of items

0008 unsigned byte ?? label

0009 unsigned byte ?? label

........

xxxx unsigned byte ?? label

The labels values are 0 to 9.

通过使用上面两行代码, 我们首先读入 magic number, 它是一个文件协议的描述, 也是在我们调用 fromfile 方法将字节读入 NumPy array 之前在文件缓冲中的 item 数(n). 作为参数值传入 struct.unpack 的 >II 有两个部分:

>: 这是指大端(用来定义字节是如何存储的); 如果你还不知道什么是大端和小端, Endianness 是一个非常好的解释. (关于大小端, 更多内容可见<>)

I: 这是指一个无符号整数.

通过执行下面的代码, 我们将会从刚刚解压 MNIST 数据集后的 mnist 目录下加载 60,000 个训练样本和 10,000 个测试样本.

为了了解 MNIST 中的图片看起来到底是个啥, 让我们来对它们进行可视化处理. 从 feature matrix 中将 784-像素值 的向量 reshape 为之前的 28*28 的形状, 然后通过 matplotlib 的 imshow 函数进行绘制:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(

nrows=2,

ncols=5,

sharex=True,

sharey=True, )

ax = ax.flatten()

for i in range(10):

img = X_train[y_train == i][0].reshape(28, 28)

ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])

ax[0].set_yticks([])

plt.tight_layout()

plt.show()

我们现在应该可以看到一个 2*5 的图片, 里面分别是 0-9 单个数字的图片.

此外, 我们还可以绘制某一数字的多个样本图片, 来看一下这些手写样本到底有多不同:

fig, ax = plt.subplots(

nrows=5,

ncols=5,

sharex=True,

sharey=True, )

ax = ax.flatten()

for i in range(25):

img = X_train[y_train == 7][i].reshape(28, 28)

ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])

ax[0].set_yticks([])

plt.tight_layout()

plt.show()

执行上面的代码后, 我们应该看到数字 7 的 25 个不同形态:

另外, 我们也可以选择将 MNIST 图片数据和标签保存为 CSV 文件, 这样就可以在不支持特殊的字节格式的程序中打开数据集. 但是, 有一点要说明, CSV 的文件格式将会占用更多的磁盘空间, 如下所示:

train_img.csv: 109.5 MB

train_labels.csv: 120 KB

test_img.csv: 18.3 MB

test_labels: 20 KB

如果我们打算保存这些 CSV 文件, 在将 MNIST 数据集加载入 NumPy array 以后, 我们应该执行下列代码:

np.savetxt('train_img.csv', X_train,

fmt='%i', delimiter=',')

np.savetxt('train_labels.csv', y_train,

fmt='%i', delimiter=',')

np.savetxt('test_img.csv', X_test,

fmt='%i', delimiter=',')

np.savetxt('test_labels.csv', y_test,

fmt='%i', delimiter=',')

一旦将数据集保存为 CSV 文件, 我们也可以用 NumPy 的 genfromtxt 函数重新将它们加载入程序中:

X_train = np.genfromtxt('train_img.csv',

dtype=int, delimiter=',')

y_train = np.genfromtxt('train_labels.csv',

dtype=int, delimiter=',')

X_test = np.genfromtxt('test_img.csv',

dtype=int, delimiter=',')

y_test = np.genfromtxt('test_labels.csv',

dtype=int, delimiter=',')

不过, 从 CSV 文件中加载 MNIST 数据将会显著发给更长的时间, 因此如果可能的话, 还是建议你维持数据集原有的字节格式.

怎么改mnist数据的标签_详解 MNIST 数据集相关推荐

  1. 怎么改mnist数据的标签_【Pytorch】多个数据集联合读取

    深度学习好比炼丹,框架就是丹炉,网络结构及算法就是单方,而数据集则是原材料.现在世面上很多炼丹手册都是针对单一数据集进行炼丹,有了这些手册我们就能够很容易进行炼丹,但为了练好丹,我们常常收集各种公开的 ...

  2. mysql从挂了数据怎么恢复_详解MySQL误操作后怎样进行数据恢复

    一.开启binlog. 首先查看binlog是否开启 mysql> show variables like "log_bin"; +---------------+----- ...

  3. vue 修改模板{{}}标签_详解Vue 动态添加模板的几种方法

    以下方法只适用于 Vue1.0 版本,推荐系数由高到低排列. 通常我们会在组件里的 template 属性定义模板,或者是在*.vue文件里的 template 标签里写模板.但是有时候会需要动态生成 ...

  4. java获取mysql数据定时执行_详解SpringBoot 创建定时任务(配合数据库动态执行)...

    序言:创建定时任务非常简单,主要有两种创建方式:一.基于注解(@Scheduled) 二.基于接口(SchedulingConfigurer). 前者相信大家都很熟悉,但是实际使用中我们往往想从数据库 ...

  5. 数据存储方式_详解西门子S7-200PLC的数据区

    (一)数字量输入和输出映象区 1.输入映象寄存器(数字量输入映象区)(I) 数字量输入映象区是S7-200CPU为输入端信号状态开辟的一个存储区.输入映像寄存器的标识符为I,在每个扫描周期的开始,CP ...

  6. JSON数据构造及解析详解

    JSON数据构造及解析详解 1.JSON格式数据长啥样? 2.JSON简介 JSON(Javascript Object Notation)是一种轻量级的数据交换格式,易于阅读和编写,也易于机器解析和 ...

  7. docker导入MySQL文件_Docker容器中Mysql数据的导入/导出详解

    前言 Mysql数据的导入导出我们都知道一个mysqldump命令就能够解决,但如果是运行在docker环境下的mysql呢? 解决办法其实还是用mysqldump命令,但是我们需要进入docker的 ...

  8. python接入excel_使用python将excel数据导入数据库过程详解

    因为需要对数据处理,将excel数据导入到数据库,记录一下过程. 使用到的库:xlrd 和 pymysql (如果需要写到excel可以使用xlwt) 直接丢代码,使用python3,注释比较清楚. ...

  9. html标签非成对,深入document.write()与HTML4.01的非成对标签的详解

    深入document.write()与HTML4.01的非成对标签的详解 (一)HTML4.01中的非成对标签: 注释标签: 严格来讲不算HTML标签的:文档声明标签 设置页面元信息的:标签 设置网页 ...

  10. Echarts数据可视化series-bar柱形图详解,开发全解+完美注释

    全栈工程师开发手册 (作者:栾鹏) Echarts数据可视化开发代码注释全解 Echarts数据可视化开发参数配置全解 6大公共组件详解(点击进入): title详解. tooltip详解.toolb ...

最新文章

  1. qt试用1(Eclipse+cdt+Qt)
  2. qml demo分析(threading-线程任务)
  3. 怎样将英文html文件转换成中文乱码,解决html导出pdf中文乱码问题的正确姿势
  4. Windows下安装tensorflow-gpu/cpu教程
  5. Python教程:Python内置数据结构之双向队列!
  6. 测试点解析:1049 数列的片段和_12行代码AC
  7. 游戏ai 行为树_游戏AI –行为树简介
  8. ERROR 1130 (HY000): Host ‘192.168.3.238‘ is not allowed to connect to this MySQL server
  9. Java经典设计模式(2):七大结构型模式(附实例和详解)
  10. 2019年PAT甲级冬季考试真题及参考答案
  11. BigGAN高保真自然图像合成的大规模GAN训练
  12. 巴斯大学计算机世界专业排名,2019上海软科世界一流学科排名计算机科学与工程专业排名巴斯大学排名第301-400...
  13. ISO/OSI七层网络参考模型、TCP/IP四层网络模型和教学五层网络模型
  14. opencv 骨架提取/图片细化 代码
  15. android pcm文件大小_Android中的PCM设备
  16. 人工智能的软件研发管理系统
  17. 国产“芯”时代 盘点国内十大IC卡制卡企业
  18. Altium designer 备注手册
  19. 盘点2022年电视行业:科技与美学的战场三星缔造“生活方式”的全新价值
  20. 微信小程序中通过两点经纬度计算距离

热门文章

  1. FP-growth算法,fpgrowth算法详解
  2. linux下安装fortran90教程,linux 安装fortran 90 --zz
  3. 机器人学基础(一):空间描述与坐标变换
  4. Linux 系统设置静态ip地址
  5. qq linux五笔输入法,qq五笔输入法
  6. matlab工作区导入多个文件,MATLAB可以直接把Excel文件中的数据导入工作区中
  7. Oracle下载及安装超详细教程
  8. 淘宝店铺装修旺铺基础版全屏轮播代码效果1920PX海报
  9. 天敏盒子系统停止服务器,天敏网络机顶盒今天怎么停服了?
  10. 使用JAVA基础语法做一个简易的发票管理系统