前言

本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧…

正文

在阅读教程书籍《深度学习入门之Pytorch》时,文中是如此加载MNIST手写数字训练集的:

train_dataset = datasets.MNIST(root='./MNIST',train=True,transform=data_tf,download=True)

解释一下参数

datasets.MNIST是Pytorch的内置函数torchvision.datasets.MNIST,通过这个可以导入数据集。

train=True 代表我们读入的数据作为训练集(如果为true则从training.pt创建数据集,否则从test.pt创建数据集)

transform则是读入我们自己定义的数据预处理操作

download=True则是当我们的根目录(root)下没有数据集时,便自动下载。

如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:

其中我们所需要的文件主要在raw文件夹下

train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes) t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)

接下来,书中是如此加载数据集的

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=5,shuffle=True)

由于DataLoader为Pytorch内部封装好的函数,所以对于它的调用方法需要自行去查阅。

我在最开始疑惑的点:传入的根目录在下载好数据集后,为MNIST下两个文件夹,而processed和raw文件夹下还有诸多文件,所以到底是如何读入数据的呢?所以我决定将数据集下载后,通过读取本地的MINIST数据集并进行装载。

首先,自定义数据类来继承和重写Dataset抽象类

class DealDataset(Dataset):"""读取数据、初始化数据"""def __init__(self, folder, data_name, label_name,transform=None):(train_set, train_labels) = self.load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式self.train_set = train_setself.train_labels = train_labelsself.transform = transformdef __getitem__(self, index):img, target = self.train_set[index], int(self.train_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.train_set)'''load_data也是我们自定义的函数,用途:读取数据集中的数据 ( 图片数据+标签label'''def load_data(self,data_folder, data_name, label_name):with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)return (x_train, y_train)

接下来,调用我们自定义的数据类来加载数据集

trainDataset = DealDataset('./MNIST/MNIST/raw', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())# 训练数据和测试数据的装载
train_loader = torch.utils.data.DataLoader(dataset=trainDataset,batch_size=10, # 一个批次可以认为是一个包,每个包中含有10张图片shuffle=False,
)

通过这种方式便可以大概了解了读取数据集的过程。

接下来,我们来验证以下我们数据是否正确加载

# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

p.s.:其实这里是用cv2.imshow来展示图片,但是我的代码是在jupyter notebook上写的,所以只能通过plt来代替加载。

数据加载成功~

深入探索

可以看到,在load_data函数中

y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

这个offset=8又是为啥呢?
我们进入MNIST数据集的官方页面进行查看

通过文档介绍,可以看到
offset的0000-0003是 magic number,所以跳过不读,
offset的0004-0007是items数目
接下来这些代表的就是标签

同理对于

x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train)

根据刚才的分析方法,也可以明白为什么offset=16了

完整代码

1.直接使用pytorch自带的mnist数据集加载

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as pltdata_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]
)train_dataset = datasets.MNIST(root='./coding/learning/lrdata/MNIST',train=True,transform=data_tf,download=True)train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=5,shuffle=True)
# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

p.s.:记得自己修改root根目录。

2.使用自定义的数据类加载本地MNIST数据集

import numpy as np
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
import gzip
import os
import torchvision
import cv2
import matplotlib.pyplot as pltclass DealDataset(Dataset):"""读取数据、初始化数据"""def __init__(self, folder, data_name, label_name,transform=None):(train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式self.train_set = train_setself.train_labels = train_labelsself.transform = transformdef __getitem__(self, index):img, target = self.train_set[index], int(self.train_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.train_set)def load_data(data_folder, data_name, label_name):with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)return (x_train, y_train)trainDataset = DealDataset('./coding/learning/lrdata/MNIST/MNIST/raw', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())# 训练数据和测试数据的装载
train_loader = torch.utils.data.DataLoader(dataset=trainDataset,batch_size=10, # 一个批次可以认为是一个包,每个包中含有10张图片shuffle=False,
)# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

参考

1.《深度学习入门之Pytorch》- 廖星宇
2.使用Pytorch进行读取本地的MINIST数据集并进行装载
3.顺藤摸瓜-mnist数据集的补充

十分钟搞懂Pytorch如何读取MNIST数据集相关推荐

  1. html网页和cgi程序编程,十分钟搞懂什么是CGI

    原文:CGI Made Really Easy,在翻译的过程中,我增加了一些我在学习过程中找到的更合适的资料,和自己的一些理解.不能算是严格的翻译文章,应该算是我的看这篇文章的过程的随笔吧. CGI真 ...

  2. python数据分析建模-十分钟搞懂“Python数据分析”

    原标题:十分钟搞懂"Python数据分析" 引言:本文重点是用十分钟的时间帮读者建立Python数据分析的逻辑框架.其次,讲解"如何通过Python 函数或代码和统计学知 ...

  3. pearsonr() python_十分钟搞懂“Python数据分析”

    引言:本文重点是用十分钟的时间帮读者建立Python数据分析的逻辑框架.其次,讲解"如何通过Python 函数或代码和统计学知识来实现数据分析". 本次介绍的建模框架图分为六大版块 ...

  4. 十分钟搞懂JSON(JSON对象---JSON字符串---对象 之间的区别)

    好记性不如烂笔头,相信我,看了之后你会彻底搞懂JSON 前言:前天被JSON对象,JSON字符串,JAVA对象搞混了,不知道各自代表的意思,我就查了资料,总结为一篇博文. 另外我想List<Us ...

  5. 十分钟搞懂基-2 FFT原理及编程思想

    0.写在最前 写本文的目的一是为了帮人理清FFT算法思路,二是有几个疑问(在5总结部分提到)希望得到解答.看懂本文的基础:至少听说过.简单了解过傅里叶变换.离散傅里叶变换(DFT).基于时间抽取的基2 ...

  6. python中cgi到底是什么_十分钟搞懂什么是CGI(转)

    原文:CGI Made Really Easy,在翻译的过程中,我增加了一些我在学习过程中找到的更合适的资料,和自己的一些理解.不能算是严格的翻译文章,应该算是我的看这篇文章的过程的随笔吧. CGI真 ...

  7. 十分钟搞懂什么是CGI

    原文:CGI Made Really Easy,在翻译的过程中,我增加了一些我在学习过程中找到的更合适的资料,和自己的一些理解.不能算是严格的翻译文章,应该算是我的看这篇文章的过程的随笔吧. CGI真 ...

  8. 十分钟搞懂手机号码一键登录

    手机号码一键登录是最近两三年出现的一种新型应用登录方式,比之前常用的短信验证码登录又方便了不少.登陆时,应用首先向用户展示带有本机号码掩码的授权登录页面,用户点击"同意授权"的按钮 ...

  9. 干货!十分钟搞懂消息队列的选型

    大家好,我是程序员史迪仔. 消息队列重要吗?有必要学吗?当然重要! 想必你在面试或者工作的过程中,被问过以下问题: (1)为什么你们项目要用消息队列? (2)用了消息队列后有什么好处? (3)消息队列 ...

最新文章

  1. Solr部署如何启动
  2. Facebook暂停中国工具类应用广告
  3. idea的2020.2版本
  4. tcp 发送数据长度比预设缓存大_一文秒懂 TCP/IP实际五层结构(下篇)
  5. 今日代码(200612)--数据录入(python+mysql)
  6. 28 PP配置-生产车间控制-工序-定义报工屏幕默认值
  7. destoon入门实例与常见问题汇总
  8. Winform窗体中发送HTTP请求 手工发送HTTP请求主要是调用 System.Net的HttpWebResponse方法
  9. 通过meta进行重定向
  10. 浅谈相对定位与绝对定位
  11. mysql文件扩展名查询_如何通过MySQL查询获取文件的文件扩展名?
  12. mysql 临时表循环_在游标循环中查询临时表可以,但是结束循环后就无法查询了。...
  13. ASP.NET CORE中使用SESSION
  14. 2021 ICCV TIMI-Net 抠图网络论文笔记
  15. Linux Deamon函数
  16. xocde8打印出:Presenting view controllers on detached view controllers is discouraged SettingViewContro
  17. Jvav语言(0.1)版
  18. qt 模拟鼠标滑轮_【游戏流体力学基础及Unity代码(四)】用欧拉方程模拟无粘性染料之公式推导...
  19. 用uniapp实现微信小程序的电子签名效果
  20. 文件的下载(2)——解决下载文件名的乱码问题

热门文章

  1. ROOT在Ubuntu18中安装出现的问题小解
  2. ORACLE去重总结
  3. TP-link wr886N路由器上网时快时慢的解决办法
  4. IPsec典型配置举例 -采用IKE方式建立保护IPv4 报文的IPsec隧道
  5. kodbox 可道云上传大文件(超过4G的)
  6. Xp远程桌面连接win7的方法
  7. 高瓴资本创始人张磊:美团点评有大格局价值观 我们长期看好
  8. windows 下使用gpb生成erlang 代码
  9. Coding in GPB vs XML
  10. 常见网络模型——BA无标度网络(使用轮盘赌算法)(python)