目录

MNIST是什么?

tf.keras.datasets.mnist

tf.keras.datasets.mnist.load_data()读取的是什么?

load_data()函数的大体原理

将读取的mnist数据集中的数据转为浮点数并归一化


TensorFlow 2.9的零零碎碎(二)-TensorFlow 2.9的零零碎碎(六)都是围绕使用TensorFlow 2.9在MNIST数据集上训练和评价模型来展开。

Python环境3.8。

代码调试都用的PyCharm。

MNIST是什么?

MNIST是手写数字数据集,由6万张训练图片和1万张测试图片构成的,每张图片都是28*28大小(如下图),这些图片是采集的不同的人手写从0到9的数字。

tf.keras.datasets.mnist

tf.keras.datasets.mnist的定义在keras.datasets.mnist模块中。

代码很简单

import tensorflow as tfmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

tf.keras.datasets.mnist.load_data()读取的是什么?

我们已经知道了mnist的实现是在keras.datasets.mnist,主要就是load_data()函数,用于读取mnist数据集,load_data()函数的源码如下

@keras_export('keras.datasets.mnist.load_data')
def load_data(path='mnist.npz'):"""Loads the MNIST dataset.This is a dataset of 60,000 28x28 grayscale images of the 10 digits,along with a test set of 10,000 images.More info can be found at the[MNIST homepage](http://yann.lecun.com/exdb/mnist/).Args:path: path where to cache the dataset locally(relative to `~/.keras/datasets`).Returns:Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.**x_train**: uint8 NumPy array of grayscale image data with shapes`(60000, 28, 28)`, containing the training data. Pixel values rangefrom 0 to 255.**y_train**: uint8 NumPy array of digit labels (integers in range 0-9)with shape `(60000,)` for the training data.**x_test**: uint8 NumPy array of grayscale image data with shapes(10000, 28, 28), containing the test data. Pixel values rangefrom 0 to 255.**y_test**: uint8 NumPy array of digit labels (integers in range 0-9)with shape `(10000,)` for the test data.Example:```python(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()assert x_train.shape == (60000, 28, 28)assert x_test.shape == (10000, 28, 28)assert y_train.shape == (60000,)assert y_test.shape == (10000,)```License:Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,which is a derivative work from original NIST datasets.MNIST dataset is made available under the terms of the[Creative Commons Attribution-Share Alike 3.0 license.](https://creativecommons.org/licenses/by-sa/3.0/)"""origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'path = get_file(path,origin=origin_folder + 'mnist.npz',file_hash='731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')with np.load(path, allow_pickle=True) as f:  # pylint: disable=unexpected-keyword-argx_train, y_train = f['x_train'], f['y_train']x_test, y_test = f['x_test'], f['y_test']return (x_train, y_train), (x_test, y_test)

TensorFlow和Keras其实在代码注释上做得都非常好,看代码注释就能看出一个函数是什么以及怎么用

load_data()函数的大体原理

mnist.load_data()函数访问上面的网址,下载mnist数据集的文件,保存为mnist.npz,路径在xxx\.keras\datasets\mnist.npz

npy

The .npy format is the standard binary file format in NumPy for persisting a single arbitrary NumPy array on disk. The format stores all of the shape and dtype information necessary to reconstruct the array correctly even on another machine with a different architecture. The format is designed to be as simple as possible while achieving its limited goals.

也就是将numpy生成的数组保存为二进制格式数据。

npz

The .npz format is the standard format for persisting multiple NumPy arrays on disk. A .npz file is a zip file containing multiple .npy files, one for each array.

也就是将多个数组保存到一个文件,且保存为二进制格式。

1个npz中可以有多个npy

用np.load读取这个文件,np就是numpy,这个文件里包含4个数组:

x_train、y_train、x_test、y_test

读取之后将这4个数组返回(x_train, y_train), (x_test, y_test)

所以这也就是为什么我写代码的时候总是要写(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train数据一览

y_train数据一览

x_test数据一览

y_test数据一览

将读取的mnist数据集中的数据转为浮点数并归一化

读取的mnist数据集中的数据的取值范围都是0-255

归一化的目的就是使得预处理的数据被限定在一定的范围内(比如[0,1]或者[-1,1]),从而消除奇异样本数据导致的不良影响。归一化的目的就是使得预处理的数据被限定在一定的范围内(比如[0,1]或者[-1,1]),从而消除奇异样本数据的存在会引起训练时间增大,同时也可能导致无法收敛,因此,当存在奇异样本数据时,在进行训练之前需要对预处理数据进行归一化;反之,不存在奇异样本数据时,则可以不进行归一化。奇异样本数据导致的不良影响。

x_train和y_train都是numpy数组,且为整形,直接用“/=”会报错:“numpy.core._exceptions.UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64') to dtype('uint8') with casting rule 'same_kind'”

这是Numpy内部机制导致的

x_train /= 255.0
y_train /= 255.0

正确写法

x_train = x_train.astype(np.float)
x_train /= 255.0
y_train = y_train.astype(np.float)
y_train /= 255.0

或者

x_train, y_train = x_train / 255.0, y_train / 255.0

得到的结果如下

TensorFlow 2.9的零零碎碎(二)-读取MNIST数据集相关推荐

  1. TensorFlow读取MNIST数据集错误的问题

    TensorFlow读取mnist数据集错误的问题 运行程序出现"URLError"错误的问题 可能是服务器或路径的原因,可以自行下载数据集后,将数据集放到代码所在的文件夹下,并将 ...

  2. 读取mnist数据集方法大全(train-images-idx3-ubyte.gz,train-labels.idx1-ubyte等)(python读取gzip文件)

    文章目录 gzip包 keras读取mnist数据集 本地读取mnist数据集 下载数据集 解压读取 方法一 方法二 gzip包读取 读取bytes数据 注:import导入的包如果未安装使用pip安 ...

  3. python 读取 MNIST 数据集,并解析为图片文件

    python 读取 MNIST 数据集,并解析为图片文件 MNIST 是 Yann LeCun 收集创建的手写数字识别数据集,训练集有 60,000 张图片,测试集有 10,000 张图片.数据集链接 ...

  4. Tensorflow 笔记 XIII——“百无聊赖”:深挖 mnist 数据集与 fashion-mnist 数据集的读取原理,经典数据的读取你真的懂了吗?

    文章目录 数据集简介 Mnist 出门右转 Fashion-Mnist 数据集制作需求来源 写给专业的机器学习研究者 获取数据 类别标注 读取原理 原理获取 TRAINING SET LABEL FI ...

  5. 深度学习之利用TensorFlow实现简单的卷积神经网络(MNIST数据集)

    卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深度学习 ...

  6. MNIST手写数字数据集格式,如何读取MNIST数据集?

    数据集下载地址:http://yann.lecun.com/exdb/mnist/ TRAINING SET LABEL FILE (train-labels-idx1-ubyte):[offset] ...

  7. Python读取MNIST数据集

    MNIST数据集下载地址:http://yann.lecun.com/exdb/mnist/ 读取MINST数据集第一张图像并显示 # coding=utf-8 import numpy as np ...

  8. 十分钟搞懂Pytorch如何读取MNIST数据集

    前言 本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧- 正文 在阅读教程书籍<深度学习入门之Pytorch>时,文中是如此加载MNIST手写数字训练集的: ...

  9. 用Numpy读取MNIST数据集(附已经读取完成的mat文件)

    MNIST是常用的手写字符数据集 可以在杨立昆的网站下载此数据集:http://yann.lecun.com/exdb/mnist/ 该数据集的存储方式比较特殊,是用二进制格式存储的,以下是网站对数据 ...

  10. 利用numpy读取mnist数据集

    读取并分析如下四个文件 'train-images-idx3-ubyte' 'train-labels-idx1-ubyte' 't10k-images-idx3-ubyte' 't10k-label ...

最新文章

  1. shell编程系列22--shell操作数据库实战之shell脚本与MySQL数据库交互(增删改查)
  2. mos管电路_【鼎阳硬件智库原创︱电源】 MOS管驱动电路的设计
  3. UA SIE545 优化理论基础1 凸分析3 凸集与凸包
  4. 前端学习(3170):react-hello-react之实现底部功能
  5. Dart基础学习02--变量及内置类型
  6. 大数据学习(07)--MapReduce
  7. [转载]堆排序(HeapSort) Java实现
  8. 论文阅读:A Progressive Architecture With Knowledge Review Network for Salient Object Detection
  9. 输入aAZut,输出bBAvu
  10. VS2015编译VS2013工程文件出错
  11. springmuvc如何设置jsp的input跳转_小程序有链接吗?如何获取小程序的链接?
  12. UE4 遮挡剔除文档
  13. Python chardet模块
  14. JavaGUI版聊天室
  15. 关于自动拼接地图算法
  16. 趣图:程序员的鄙视链/图
  17. 为了革命 保护眼睛 !—— 眼科专家配置的色调
  18. 计算机临时桌面是怎么回事,如何解决电脑开机后桌面空白问题?
  19. lazada发货_lazada怎么发货?
  20. Error in created hook: “SyntaxError: Unexpected token u in JSON at position 0“

热门文章

  1. 视频压缩中IPB帧概念
  2. 边境的悍匪—机器学习实战:第七章 集成学习和随机森林
  3. 计算机高级筛选操作步骤,【EXCLE表格中根据特定的条件进行高级筛选】计算机excel高级筛选步骤...
  4. 网站使用微信网页授权,qq登录
  5. 暴风影音下载|暴风影音播放器下载
  6. 解决Vscode使用LeetCode报错Failed to test the solution. Please open the output channel for details.
  7. vmware linux虚拟机中添加硬盘
  8. LOI2504 [HAOI2006]聪明的猴子
  9. ibatis的isequal_ibatIS中的isNotNull、isEqual、isEmpty
  10. c语言char10是什么意思,c语言char是什么意思