重点在第二部分的构建数据通道和第三部分的加载数据集

Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。

Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。

而DataLoader定义了按batch加载数据集的方法,它是一个实现了__iter__方法的可迭代对象,每次迭代输出一个batch的数据。

DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。

在绝大部分情况下,用户只需实现Dataset的__len__方法和__getitem__方法,就可以轻松构建自己的数据集,并用默认数据管道进行加载。

一,Dataset和DataLoader概述

1,获取一个batch数据的步骤

让我们考虑一下从一个数据集中获取一个batch的数据需要哪些步骤。

(假定数据集的特征和标签分别表示为张量X和Y,数据集可以表示为(X,Y), 假定batch大小为m)

1,首先我们要确定数据集的长度n。

结果类似:n = 1000。

2,然后我们从0到n-1的范围中抽样出m个数(batch大小)。

假定m=4, 拿到的结果是一个列表,类似:indices = [1,4,8,9]

3,接着我们从数据集中去取这m个数对应下标的元素。

拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]

4,最后我们将结果整理成两个张量作为输出。

拿到的结果是两个张量,类似batch = (features,labels),

其中 features = torch.stack([X[1],X[4],X[8],X[9]])

labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])

2,Dataset和DataLoader的功能分工

上述第1个步骤确定数据集的长度是由 Dataset的__len__ 方法实现的。

第2个步骤从0到n-1的范围中抽样出m个数的方法是由 DataLoader的 sampler和 batch_sampler参数指定的。

sampler参数指定单个元素抽样方法,一般无需用户设置,程序默认在DataLoader的参数shuffle=True时采用随机抽样,shuffle=False时采用顺序抽样。

batch_sampler参数将多个抽样的元素整理成一个列表,一般无需用户设置,默认方法在DataLoader的参数drop_last=True时会丢弃数据集最后一个长度不能被batch大小整除的批次,在drop_last=False时保留最后一个批次。

第3个步骤的核心逻辑根据下标取数据集中的元素 是由 Dataset的 __getitem__方法实现的。

第4个步骤的逻辑由DataLoader的参数collate_fn指定。一般情况下也无需用户设置。

3,Dataset和DataLoader的主要接口

伪代码,实际应用意义不大

import torch
class Dataset(object):def __init__(self):passdef __len__(self):raise NotImplementedErrordef __getitem__(self,index):raise NotImplementedErrorclass DataLoader(object):def __init__(self,dataset,batch_size,collate_fn,shuffle = True,drop_last = False):self.dataset = datasetself.sampler =torch.utils.data.RandomSampler if shuffle else \torch.utils.data.SequentialSamplerself.batch_sampler = torch.utils.data.BatchSamplerself.sample_iter = self.batch_sampler(self.sampler(range(len(dataset))),batch_size = batch_size,drop_last = drop_last)def __next__(self):indices = next(self.sample_iter)batch = self.collate_fn([self.dataset[i] for i in indices])return batch

二,使用Dataset创建数据集

Dataset创建数据集常用的方法有:

使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。

使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。

继承 torch.utils.data.Dataset 创建自定义数据集。

此外,还可以通过

torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。

调用Dataset的加法运算符(+)将多个数据集合并成一个数据集。

1,根据Tensor创建数据集

  1. 头文件:
import numpy as np
import torch
from torch.utils.data import TensorDataset,Dataset,DataLoader,random_split
  1. 根据Tensor创建数据集
from sklearn import datasets
iris = datasets.load_iris()
ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))
  1. 分割成训练集和预测集
n_train = int(len(ds_iris)*0.8)
n_valid = len(ds_iris) - n_train
ds_train,ds_valid = random_split(ds_iris,[n_train,n_valid])
  1. 使用DataLoader加载数据集
dl_train,dl_valid = DataLoader(ds_train,batch_size = 8),DataLoader(ds_valid,batch_size = 8)#查看数据集
for features,labels in dl_train:print(features,labels)break
  1. 演示加法运算符(+)的合并作用
ds_data = ds_train + ds_validprint('len(ds_train) = ',len(ds_train))
print('len(ds_valid) = ',len(ds_valid))
print('len(ds_train+ds_valid) = ',len(ds_data))print(type(ds_data))

2,根据图片目录创建图片数据集

  1. 头文件:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
  1. 图片加载:
from PIL import Image
img = Image.open('./data/cat.jpeg')
  1. 随机数值翻转
transforms.RandomVerticalFlip()(img)
  1. 随机旋转
transforms.RandomRotation(45)(img)
  1. 定义图片增强操作
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(), #随机水平翻转transforms.RandomVerticalFlip(), #随机垂直翻转transforms.RandomRotation(45),  #随机在45度角度内旋转transforms.ToTensor() #转换成张量]
) transform_valid = transforms.Compose([transforms.ToTensor()]
)
  1. 根据图片目录创建数据集
ds_train = datasets.ImageFolder("./data/cifar2/train/",transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("./data/cifar2/test/",transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())print(ds_train.class_to_idx)
  1. 使用DataLoader加载数据集
#注意:windows用户要把num_workers去掉,容易报错
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True,num_workers=3)
dl_valid = DataLoader(ds_valid,batch_size = 50,shuffle = True,num_workers=3)for features,labels in dl_train:print(features.shape)print(labels.shape)break

三,使用DataLoader加载数据集

DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。

DataLoader的函数签名

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,
)

一般情况下,我们仅仅会配置 dataset, batch_size, shuffle, num_workers, drop_last这五个参数,其他参数使用默认值即可。
dataset : 数据集
batch_size: 批次大小
shuffle: 是否乱序
sampler: 样本采样函数,一般无需设置。
batch_sampler: 批次采样函数,一般无需设置。
num_workers: 使用多进程读取数据,设置的进程数。
collate_fn: 整理一个批次数据的函数。
pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
timeout: 加载一个数据批次的最长等待时间,一般无需设置。
worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。

#构建输入数据管道
ds = TensorDataset(torch.arange(1,50))
dl = DataLoader(ds,batch_size = 10,shuffle= True,num_workers=2,drop_last = True)
#迭代数据
for batch, in dl:print(batch)

Dataset和DataLoader构建数据通道相关推荐

  1. 速成pytorch学习——6天Dataset和DataLoader

    Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道. Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素. ...

  2. 5 torch.utils.data (Dataset,TensorDataset,DataLoader)

    文章目录 一.DataLoader(数据预处理) 1.DataLoader :(构建可迭代的数据装载器) 2.输出:DataLoader 的输出包含:数据和标签 二.TensorDataset(数据预 ...

  3. Pytorch的Dataset和DataLoader

    还是拿来自学用 多谢多谢  对于torch本人无人讨论 拿来主义的 多谢理解 0,Dataset和DataLoader功能简介 Pytorch通常使用Dataset和DataLoader这两个工具类来 ...

  4. 《Pytorch学习指南》- Dataset和Dataloader用法详解

    目录 前言 DataSet DataLoader 数据构建 1. 创建Dataset 类 :sparkles: 2. 读取数据 :ambulance: 3. 返回数据 :zap: 读取数据 :art: ...

  5. 编写transformers的自定义pytorch训练循环(Dataset和DataLoader解析和实例代码)

    文章目录 一.Dataset和DataLoader加载数据集 1.torch.utils.data 2. 加载数据流程 3. Dataset 4. dataloader类及其参数 5. dataloa ...

  6. Pytorch自定义Dataset和DataLoader去除不存在和空的数据

    Pytorch自定义Dataset和DataLoader去除不存在和空的数据 [源码GitHub地址]:https://github.com/PanJinquan/pytorch-learning-t ...

  7. (第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

    前言:在深度学习中,数据的预处理是第一步,pytorch提供了非常规范的处理接口,本文将针对处理过程中的一些问题来进行说明,本文所针对的主要数据是图像数据集. 本文的案例来源于车道线语义分割,采用的数 ...

  8. PyTorch 入门实战(三)——Dataset和DataLoader

    承接上一篇:PyTorch 入门实战(二)--Variable 对于Dataset,博主也有着自己的理解: 关于Pytorch中dataset的迭代问题(这就是为什么我们要使用dataloader的原 ...

  9. 【小白学PyTorch】3.浅谈Dataset和Dataloader

    文章目录: 1 Dataset基类 2 构建Dataset子类 2.1 __Init__ 2.2 __getitem__ 3 dataloader 1 Dataset基类 PyTorch 读取其他的数 ...

最新文章

  1. host文件修改后无法保存的问题
  2. 推荐一款代码神器,代码量至少省一半!
  3. app 崩溃测试 (转:CSDN 我去热饭)
  4. drf3 Serializers 序列化组件
  5. confirm弹框修改按钮确认取消为是否
  6. 前端:JS/17/前篇总结(JS程序的基本语法,变量),数据类型-变量的类型(数值型,字符型,布尔型,未定义型,空型),数据类型转换,typeof()判断数据类型,从字符串提取整数或浮点数的函数
  7. c语言程序设计上机考试题,C语言程序设计上机考试题目汇编..doc
  8. “编程能力差,90% 输在了数学上!”CTO:多数程序员都是瞎努力!
  9. read -p 命令--shell 脚本
  10. Ubuntu配置及美化篇
  11. 数电基础-数字电路芯片种类
  12. 最新苹果CMS对接千月版本-畅视影视(V9.3开源)已搭建测试版
  13. 单词风暴2009免费分享版
  14. mysql创建索引视图_mysql中创建视图、索引
  15. 六轴机器人光机_六轴机器人主要用到哪些传感器?
  16. CSS布局的三种方式
  17. Oracle安装步骤(记录)
  18. PyQt5  PyQt5-tools 安装
  19. 罗雪娟(Luo Xuejuan)
  20. 今年95后很狂阿里P7晒出工资单:狠补了两眼泪汪汪,真香...

热门文章

  1. Linux平台上SQLite数据库教程(一)——终端使用篇
  2. OpenCV基础知识 图像
  3. 【java图文趣味版】数组元素的访问与遍历
  4. shell编程题(三)
  5. Leecode 69. x 的平方根
  6. access、strtol函数的使用(后者为C库函数)
  7. iOS 应用内跳转到appstore里下载
  8. HDU 2204 Eddy's爱好(容斥原理)
  9. Leetcode: LRU Cache
  10. mysql 清空表的两种方法