将MNIST手写数字数据集导入NumPy数组(《深度学习入门:基于Python的理论与实现》实践笔记)

  • 一、下载MNIST数据集(使用urllib.request.urlretrieve()函数)
  • 二、打开下载得到的.gz压缩文件(使用gzip.open()函数)并导入NumPy数组(使用np.frombuffer()函数)
  • 三、完整实例(能直接运行):
    • 可能遇到的问题:

一、下载MNIST数据集(使用urllib.request.urlretrieve()函数)

  • os.path.exists(path)可以判断是否存在以path为地址的文件。
  • urllib.request.urlretrieve(url, filename)可以将网络地址为url的文件复制到本地地址为filename的文件。

例如:

# mnist数据集的4个文件
key_file = {'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz','test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'}
for _ in key_file.keys():# 如果当前地址中不存在这个文件就将这个文件下载if not os.path.exists(key_file[_]):urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/' + key_file[_], key_file[_])

ps:如果遇到HTTP Error 503的错误,是网络问题,多试几次就行。

二、打开下载得到的.gz压缩文件(使用gzip.open()函数)并导入NumPy数组(使用np.frombuffer()函数)

  • gzip.open(filename, mode)函数可以以mode的方式打开文件名为filename的.gz压缩文件。
  • numpy.frombuffer(buffer, dtype=None, offset=0)函数可以跳过buffer缓冲区最前面的offset个字节把buffer缓冲区的数据以dtype的格式读取转化为NumPy数组。

例如:

key_file = {'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz','test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'}
dataset = {}
with gzip.open(key_file[_], 'rb') as f:dataset[_] = np.frombuffer(f.read(), np.uint8, offset=16 if _ == 'train_img' or _ == 'test_img' else 8)

train_img和test_img的压缩包里,前16个字节是用于验证数据集是否完整的,不是图片数据,所以跳过这16个字节。而train_label和test_label的压缩包中,是前8个字节。所以这里用if条件判断后使用不同的offset值。

三、完整实例(能直接运行):

import urllib.request
import gzip
import numpy as np
import os
import pickledef load_mnist(normalize=True, flatten=True, one_hot_label=False):# 用dataset字典保存由4个文件读取得到的np数组dataset = {}# 若不存在pkl文件,下载文件导入numpy数组,并生成pkl文件if not os.path.exists('mnist.pkl'):# MNIST数据集的4个文件key_file = {'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz','test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'}# 下载文件并导入numpy数组for _ in key_file.keys():print('Downloading ' + key_file[_] + '...')urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/' + key_file[_], key_file[_])  # 下载文件print('Download finished!')# 用二进制只读方式打开.gz文件with gzip.open(key_file[_], 'rb') as f:# img文件前16个字节不是img数据,跳过读取;label文件前8个不是label数据,跳过读取dataset[_] = np.frombuffer(f.read(), np.uint8,offset=16 if _ == 'train_img' or _ == 'test_img' else 8)if _ == 'train_img' or _ == 'test_img':dataset[_] = dataset[_].reshape(-1, 1, 28, 28)# 生成mnist.pklprint('Creating pickle file ...')with open('mnist.pkl', 'wb') as f:pickle.dump(dataset, f, -1)print('Create finished!')# 若存在pkl文件,把pkl文件内容导入numpy数组else:with open('mnist.pkl', 'rb') as f:dataset = pickle.load(f)# 标准化处理if normalize:for _ in ('train_img', 'test_img'):dataset[_] = dataset[_].astype(np.float32) / 255.0# one_hot_label处理if one_hot_label:for _ in ('train_label', 'test_label'):t = np.zeros((dataset[_].size, 10))for idx, row in enumerate(t):row[dataset[_][idx]] = 1dataset[_] = t# 展平处理if flatten:for _ in ('train_img', 'test_img'):dataset[_] = dataset[_].reshape(-1, 784)# 返回np数组return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])if __name__ == '__main__':(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=False, one_hot_label=True)print(x_train.shape)print(t_train.shape)print(x_test.shape)print(t_test.shape)

运行结果:

ps:第一次运行因为要下载文件会比较慢,后面几次就很快的。

可能遇到的问题:

  • 如果遇到HTTP Error 503的错误,是网络问题,多试几次就行。
  • 如果遇到 No module named ‘…’ 的问题,在命令行使用pip install <这个缺少的模块的名称> 即可。
  • 如果遇到EOFError: Compressed file ended before the end-of-stream marker was reached的问题,是压缩文件被破坏或者不完整的原因,把下载到的.gz文件删除,重新运行程序即可。

本实例来自于,由[日]斋藤康毅所著的《深度学习入门:基于Python的理论与实现》。

将MNIST手写数字数据集导入NumPy数组(《深度学习入门:基于Python的理论与实现》实践笔记)相关推荐

  1. 深度学习入门-基于Python的理论入门与实现源代码加mnist数据集下载推荐

    深度学习入门-基于Python的理论入门与实现源代码加mnist数据集下载推荐 书籍封面 1-图灵网站下载 书里也说了,可以图灵网站下载https://www.ituring.com.cn/book/ ...

  2. MindSpore手写数字识别初体验,深度学习也没那么神秘嘛

    摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...

  3. 卷积神经网络(CNN)之MNIST手写数字数据集的实现

    MNIST数据集是一个非常经典的手写数字识别的数据集,本人很多文章都是拿这个数据集来做示例,MNIST的具体介绍与用法可以参阅: MNIST数据集手写数字识别(一)https://blog.csdn. ...

  4. matlab 对mnist手写数字数据集进行判决分析_人工智能TensorFlow(十四)MINIST手写数字识别...

    MNIST是一个简单的视觉计算数据集,它是像下面这样手写的数字图片: MNIST 每张图片还额外有一个标签记录了图片上数字是几,例如上面几张图的标签就是:5.0.4.1. MINIST数据 MINIS ...

  5. 用Python实现BP神经网络识别MNIST手写数字数据集(带GUI)

    概述 计算机神经网络则是人工智能中最为基础的也是较为重要的部分,它使用深度学习的方式模拟了人的神经元的工作,是一种全新的计算方法.本文的目标就是通过学习神经网络的相关知识,了解并掌握BP神经网络的实现 ...

  6. matlab 对mnist手写数字数据集进行判决分析_Python神经网络编程:手写数字的数据集MNIST...

    识别人的笔迹这个问题相对复杂,也非常模糊,因此这是一种检验人工智能的理想挑战.这不像进行大量数字相乘那样明确清晰. 让计算机准确区分图像中包含的内容,有时也称之为图像识别问题.科学家对这个问题进行了几 ...

  7. mnist手写数字数据集_mnist手写数据集(1. 加载与可视化)

    >>欢迎 点赞,留言,收藏加关注<< 1. 模型构建的步骤: 在构建AI模型时,一般有以下主要步骤:准备数据.数据预处理.划分数据集.配置模型.训练模型.评估优化.模型应用,如 ...

  8. MNIST手写数字数据集格式,如何读取MNIST数据集?

    数据集下载地址:http://yann.lecun.com/exdb/mnist/ TRAINING SET LABEL FILE (train-labels-idx1-ubyte):[offset] ...

  9. python sklearn.datasets.fetch_mldata MNIST手写数字数据集无法获取, 报错 Function fetch_mldata is deprecated 的解决办法

    解决方法: 直接从GitHub下载MNIST数据集 参考文章: scikit-learn使用fetch_mldata无法下载MNIST数据集问题解决方法 https://blog.csdn.net/c ...

最新文章

  1. 给 COLA 做减法:应用架构中的“弯弯绕设计”
  2. [翻译]应用程序池和应用程序域的区别
  3. java如何脱离ide运行_如何脱离IDE使用自己的jar包?
  4. python dataframe遍历_在pandas中遍历DataFrame行的实现方法
  5. uniapp手写_【uniapp 开发】手绘签名组件
  6. html网页制作中的问题,网页制作中注意应用HTML标签的问题
  7. java如何实现游戏暂停和恢复_Android:游戏循环暂停/恢复问题
  8. 小米wifi怎么创建虚拟服务器,小米路由器玩法:一键安装LLMP 建自己的网站
  9. 计算机辅助翻译入门第十章课后答案,计算机辅助翻译入门
  10. 三因子两水平doe_温故而知新 | DOE实验设计学习系列之(三):多因子DOE的魅力 (附视频)...
  11. C#开发串口调试助手的详细教程
  12. 浑浑噩噩10年,入坑软件测试,6年干到测试leader,非科班的我也能当程序员!
  13. L44. 通配符匹配
  14. 《大型网站技术架构-核心原理与案例分析》(李智慧 著)第2章-大型网站架构模式
  15. CentOS7非桌面版关闭休眠和设置关闭盖子不休眠(server)
  16. word默认文字环绕方式是什么_在Word 2010文档中设置图片文字环绕方式
  17. 论文阅读:机器学习模型可解释性方法、应用与安全研究综述
  18. python中元组拆包_Python 元组拆包示例(Tuple Unpacking)
  19. MongoDB 启动参数
  20. ParameterAttribute的区别

热门文章

  1. java 金额转换 元转分 分转元
  2. 代数方程与差分方程模型(二):原子弹爆炸的能量估计
  3. 三菱a系列motion软体_三菱M70第四轴追加参数设定一览表
  4. die查壳工具 使用教程
  5. IDM下载百度网盘文件,获取百度网盘文件url地址,破解
  6. 男女RatingBar
  7. 传导干扰测试(0.15~30MHz)
  8. 最大数字字符串(leetCode179)
  9. 机器学习处理数据为什么把连续性特征离散化、离散值、无监督、有监督用处
  10. D.引水工程 【最小生成树+超级源点】