目录

介绍

加载数据集

迭代和可视化数据集

自定义数据集

__init__

__len__

__getitem__

使用 DataLoaders 准备训练数据

遍历 DataLoader

参考


介绍

理想情况下,我们希望我们的数据集代码与我们的模型训练代码分离,以获得更好的可读性和模块化。 PyTorch 提供了两个数据模块:torch.utils.data.DataLoader 和 torch.utils.data.Dataset,

允许使用预加载的数据集以及自己的数据。 Dataset 存储样本及其对应的标签,DataLoader 在 Dataset 周围包装了一个可迭代对象,可以轻松访问样本。

PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集是 torch.utils.data.Dataset 的子类,并实现了特定于特定数据的功能。 它们可用于对模型进行原型设计和基准测试。 可以在此处找到它们: Image Datasets, Text Datasets, and Audio Datasets.

加载数据集

Fashion-MNIST数据集 由 60,000 个训练示例和 10,000 个测试示例组成。 每个示例都包含 28×28 灰度图像和来自 10 个类别之一的相关标签。我们通过加载此数据集演示加载数据集的操作。

我们使用以下参数加载 FashionMNIST 数据集:

  • root 是存储训练/测试数据的路径,
  • train 指定训练或测试数据集,
  • download=True 如果数据在根目录下不可用,则从 Internet 下载数据。
  • transform 和 target_transform 指定特征和标签转换

注:transforms.ToTensor()函数的作用是将原始的PILImage格式或者numpy.array格式的数据格式化为可被pytorch快速处理的张量类型。
输入模式为(L、LA、P、I、F、RGB、YCbCr、RGBA、CMYK、1)的PIL Image 或 numpy.ndarray (形状为H x W x C)数据范围是[0, 255] 到一个 Torch.FloatTensor,其形状 (C x H x W) 在 [0.0, 1.0] 范围内。

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)out:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

迭代和可视化数据集

我们可以像列表一样手动索引数据集:training_data[index]。

我们使用 matplotlib 来可视化我们训练数据中的一些样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

自定义数据集

自定义数据集类必须实现三个函数:__init__、__len__ 和 __getitem__。 以FashionMNIST为例; FashionMNIST 图像存储在目录 img_dir 中,它们的标签分别存储在 CSV 文件 annotations_file 中。

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

__init__

__init__ 函数在实例化 Dataset 对象时运行一次。 始化化图像、注释文件和两种转换的目录。

labels.csv 文件如下所示:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

__len__

__len__ 函数返回数据集中的样本数。

def __len__(self):return len(self.img_labels)

__getitem__

__getitem__ 函数从给定索引 idx 的数据集中加载并返回一个样本。 根据索引识别图像在磁盘上的位置,使用 read_image函数 将其转换为张量,从 self.img_labels 中的 csv 数据中检索相应的标签,调用它们的转换函数(如果适用),并返回张量图像 和元组中的相应标签。

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

使用 DataLoaders 准备训练数据

检索数据集的特征(Dataset)并一次标记一个样本。 在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个 epoch 重新洗牌以减少模型过拟合,并使用 Python 的多通道处理(multiprocessing)来加速数据检索。

DataLoader 是一个迭代器,它通过一个简单的 API 为我们实现这种功能。

from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

遍历 DataLoader

我们已经将该数据集加载到 DataLoader 中,并且可以根据需要遍历数据集。 下面的每次迭代都会返回一批 train_features 和 train_labels(分别包含 batch_size=64 个特征和标签)。 因为我们指定了 shuffle=True,所以在我们遍历所有批次之后,数据会被打乱(为了更精准地控制数据加载顺序,请查看 Samplers)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")out:
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 2

官方文档:

torch.squeeze

python 函数 iter()

python 函数 next()

不想看官方文档,就看这个:(11条消息) next()函数___泡泡茶壶的博客-CSDN博客_next()

参考

(10条消息) pytorch数据处理之 transforms.ToTensor()解释_菜根檀的博客-CSDN博客_totensor()

[pytorch学习笔记] 3.Datasets Dataloaders相关推荐

  1. pytorch学习笔记(2):在MNIST上实现一个CNN

    参考文档:https://mp.weixin.qq.com/s/1TtPWYqVkj2Gaa-3QrEG1A 这篇文章是在一个大家经常见到的数据集 MNIST 上实现一个简单的 CNN.我们会基于上一 ...

  2. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  3. Pytorch学习笔记7——自定义数据集

    Pytorch学习笔记7--自定义数据集 1.读取数据 首先继承自torch.utils.data.Dataset 重写len与getitem train就用train数据集,test就用test数据 ...

  4. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  5. 莫烦pytorch学习笔记5

    莫烦pytorch学习笔记5 1 自编码器 2代码实现 1 自编码器 自编码,又称自编码器(autoencoder),是神经网络的一种,经过训练后能尝试将输入复制到输出.自编码器(autoencode ...

  6. # PyTorch学习笔记(15)--神经网络模型训练实战

    PyTorch学习笔记(15)–神经网络模型训练实战     本博文是PyTorch的学习笔记,第15次内容记录,主要是以一个实际的例子来分享神经网络模型的训练和测试的完整过程. 目录 PyTorch ...

  7. PyTorch学习笔记(13)--现有网络模型的使用及修改

    PyTorch学习笔记(13)–现有网络模型的使用及修改     本博文是PyTorch的学习笔记,第13次内容记录,主要介绍如何使用现有的神经网络模型,如何修改现有的网络模型. 目录 PyTorch ...

  8. PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call

    您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...

  9. pytorch学习笔记(二):gradien

    pytorch学习笔记(二):gradient 2017年01月21日 11:15:45 阅读数:17030

最新文章

  1. C#获得文件版本信息及只读文件的删除
  2. hashcode的作用_看似简单的hashCode和equals面试题,竟然有这么多坑!
  3. mysql page header_MySQL系列:innodb源码分析之page结构解析
  4. Mybatis之SqlSession简析
  5. 正式环境docker部署hyperf_使用docker搭建hyperf环境连接mysql
  6. MVC源码解析 - 配置注册 / 动态注册 HttpModule
  7. 加密软件漏洞评测系统_惠州上线软件产品登记测试企业
  8. linux svn服务的维护,Linux服务器搭建svn环境方法详解_网站服务器运行维护,Linux,svn...
  9. allure 测试报告本地打开_自动化测试报告太丑?信息实用的Allure Report测试报告拯救你...
  10. Android n multi-window多窗口支持
  11. PingFang SC 字体
  12. php mud游戏源码,mud手游源码,mud安卓端源码,谁与争锋mud源码:关于MUD纯文字游戏架设(回答得好加分100)(开源mud游戏框架)-南开游戏网...
  13. Imbalance data——数据不平衡问题
  14. 2021软科大学排名爬虫程序
  15. c# 计算圆锥的体积_用C#如何编写程序计算球,圆柱和圆锥的表面积和体积?
  16. ods转html android,SAS--output delivery system--ods html
  17. 单机版斗地主游戏源代码,纯JS编写的斗地主单机版小游戏源代码
  18. 宾夕法尼亚大学在线计算机硕士,宾夕法尼亚大学计算机与信息科学研究生录取条件有哪些?...
  19. 【Python高级语法】——生成器(generator)
  20. 高房楼噪音测试软件,噪音测试房制作

热门文章

  1. SpringBoot返回date日期格式化,解决返回为TIMESTAMP时间戳格式或8小时时间差
  2. [openstack][keystone]架构分析
  3. css零到一基础教程009:CSS HSL 颜色
  4. 证券市场基础知识(二)——股票、债券、基金
  5. 【SpringBoot】application配置文件及注入
  6. IP数据库的比较和选择
  7. 2021 河北取证比武决赛个人赛 题解 入侵溯源
  8. 深度学习框架PyTorch入门与实践:第九章 AI诗人:用RNN写诗
  9. 单元测试的重要性【转自”至简李云“博客】
  10. 如何合理规划每日时间