目录

1. 简介

2. 整体流程

3. Sampler和BatchSampler

3.1 Sampler

3.2 BatchSampler

4. DataLoader

4.1 DataLoader

4.2 _DataLoaderIter


1. 简介

本文将简介pytorch采样器Sampler数据加载器DataLoader,并解释在读取数据时每个batch形成的过程,附上部分源码解读。

了解这些能帮助我们更好地研究采样(sample)方法和模型训练。希望阅读后能让各位对数据批次产生的过程更加清晰。

让我们开始吧。

2. 整体流程

简要来说在pytorch中,Sampler负责决定读取数据时的先后顺序,DataLoader负责装载数据并根据Sampler提供的顺序安排数据,具体过程绘图和描述如下。

初始化DataLoader的时候需指定数据集Dataset(包括数据和标签),Sampler可选,没有Sampler时会根据是否打乱数据顺序(shuffle)分别采用顺序采样器(sequential sampler)和随机采样器(random sampler)。

第①步,Sampler首先根据Dataset的大小n形成一个可迭代的序号列表[0~n-1]。

第②步,BatchSampler根据DataLoader的batch_size参数将Sampler提供的序列划分成多个batch大小的可迭代序列组,drop_last参数决定是否保留最后一组。

第③步,兵分两路的Sampler(BatchSampler)和Dataset合二为一,在迭代读取DataLoader时,用BatchSampler中一个batch的编号查找Dataset中对应的数据和标签,读出一个batch数据。

数据批次的形成

举个例子。

假如数据集D={X,Y},其中数据X为[野兔在野外.png,野猫在野外.png,野猫在家.png,野狗在家.png,野狗在野外.png],标签Y为[0,1,1,2,2]

第①步,初始的序号列表为[0, 1, 2, 3, 4],使用RandomSampler采样,不重复(replacement==FALSE),得到了采样后的序号列表[3, 2, 1, 0, 4]

第②步:输入batch_size为2,drop_last为FALSE,所以用BatchSampler批次采样,形成列表[[3, 2], [1, 0], [4]];若drop_last为TRUE,则列表变为[[3, 2], [1, 0]]

第③步:迭代读取数据,根据序号从Dataset里找到相应数据和标签,如第一个batch为:

[[野狗在家.png, 野猫在家.png], [2, 1]]

以上就是形成一个batch数据的整个流程,下文将从代码角度深入介绍各个Class中的重要参数和函数。我是用较旧的pytorch版本(0.4.1.post2),也自己对照了一下1.7.0版本的代码。其中BatchSampler类基本一致,Sampler类去掉了__len__()方法,总的来说采样改动不大;DataLoader类主要是针对多线程做了很多优化,具体代码中也补充了大量注释,整体基础仍然是本文提到的几个方法。

3. Sampler和BatchSampler

3.1 Sampler

知乎上一篇文章对pytorch Sampler进行了很详细的讲解:一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

简要来说,Sampler类__init__()方法用于初始化采样算法,__iter__()方法用torch的random、multinomial方法实现随机和基于权重的采样并返回可迭代对象,__len__()是返回采样长度。

3.2 BatchSampler

参数:

sampler(Sampler类):输入的sampler

batch_size(int类):设定的批次大小

drop_last(bool类):是否弃掉不足batch_size大小的最后一个批次

重要函数:

__init__初始化各项参数

def __init__(self, sampler, batch_size, drop_last):# ...self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_last

__iter__循环读取sampler生成的序号列表,采样够batch_size大小后,返回batch,下一次清空batch继续采集。

def __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:# 通过yield返回,下一个iter时清空batch继续采集yield batchbatch = []# 如果不需drop最后一组返回最后一组if len(batch) > 0 and not self.drop_last:yield batch

__len__返回batch数量,如果drop最后一个,则序列长度对batch_size取整,否则加上一

def __len__(self):if self.drop_last:return len(self.sampler) // self.batch_sizeelse:return (len(self.sampler) + self.batch_size - 1) // self.batch_size

4. DataLoader

4.1 DataLoader

重要参数:

dataset(Dataset类):Dataset类型的输入数据,由数据和标签组成

batch_size(int类):同BatchSampler

shuffle(bool类):是否打乱数据顺序

sampler(Sampler类):同BatchSampler

batch_sampler(BatchSampler类)

drop_last(bool类):同BatchSampler

重要函数:

__init__中对参数关系中的互斥情况进行了排除,指定sampler并通过batch_sampler分出batch,

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None):# ...# 互斥关系,指定了batch_sampler时,batch_size,shuffle,sampler和drop_last无效if batch_sampler is not None:if batch_size > 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')self.batch_size = Noneself.drop_last = None# 互斥关系,指定了sampler时,shuffle无效if sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if self.num_workers < 0:raise ValueError('num_workers option cannot be negative; ''use num_workers=0 to disable multiprocessing.')# 此处可以看出,shuffle与否其实还是靠sampler类型实现的# 当不指定sampler时,不shuffle就是顺序采样,shuffle就是随机采样if batch_sampler is None:if sampler is None:if shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)# 用batch_sampler对sampler产生的序列划分batchbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.sampler = samplerself.batch_sampler = batch_samplerself.__initialized = True

DataLoader的__iter__是在_DataLoaderIter类中实现的,该类也是整个迭代方法的核心

def __iter__(self):return _DataLoaderIter(self)

4.2 _DataLoaderIter

__init__初始化并指定了sampler_iter,即batch_sampler

def __init__(self, loader):self.dataset = loader.datasetself.collate_fn = loader.collate_fnself.batch_sampler = loader.batch_samplerself.num_workers = loader.num_workersself.pin_memory = loader.pin_memory and torch.cuda.is_available()self.timeout = loader.timeoutself.done_event = threading.Event()self.sample_iter = iter(self.batch_sampler)# ...

_get_batch读取数据,加入了连接超时的判断

def _get_batch(self):# 连接超时if self.timeout > 0:try:return self.data_queue.get(timeout=self.timeout)except queue.Empty:raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))else:return self.data_queue.get()

_DataLoaderIter在每次调用时会执行__next__方法返回下一个batch

def __next__(self):if self.num_workers == 0:  # same-process loadingindices = next(self.sample_iter)  # may raise StopIterationbatch = self.collate_fn([self.dataset[i] for i in indices])if self.pin_memory:batch = pin_memory_batch(batch)return batch# check if the next sample has already been generatedif self.rcvd_idx in self.reorder_dict:batch = self.reorder_dict.pop(self.rcvd_idx)return self._process_next_batch(batch)if self.batches_outstanding == 0:self._shutdown_workers()raise StopIterationwhile True:assert (not self.shutdown and self.batches_outstanding > 0)idx, batch = self._get_batch()self.batches_outstanding -= 1if idx != self.rcvd_idx:# store out-of-order samplesself.reorder_dict[idx] = batchcontinuereturn self._process_next_batch(batch)# 调用时执行__next__
next = __next__  # Python 2 compatibility

欢迎交流和指正。

[Pytorch] Sampler, DataLoader和数据batch的形成相关推荐

  1. dataset__getitem___一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...

  2. pytorch Dataset, DataLoader产生自定义的训练数据

    pytorch Dataset, DataLoader产生自定义的训练数据 目录 pytorch Dataset, DataLoader产生自定义的训练数据 1. torch.utils.data.D ...

  3. 查看dataloader的大小_一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...

  4. 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    自上而下理解三者关系 首先我们看一下DataLoader.__next__的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据). c ...

  5. PyTorch手把手自定义Dataloader读取数据

    PyTorch手把手自定义Dataloader读取数据 https://zhuanlan.zhihu.com/p/35698470 pytorch之dataloader深入剖析 https://www ...

  6. PyTorch 之 DataLoader

    DataLoader DataLoader 是 PyTorch 中读取数据的一个重要接口,该接口定义在 dataloader.py 文件中,该接口的目的: 将自定义的 Dataset 根据 batch ...

  7. PyTorch主要组成模块 | 数据读入 | 数据预处理 | 模型构建 | 模型初始化 | 损失函数 | 优化器 | 训练与评估

    文章目录 一.深度学习任务框架 二.数据读入 三.数据预处理模块-transforms 1.数据预处理transforms模块机制 2.二十二种transforms数据预处理方法 1.裁剪 2. 翻转 ...

  8. Pytorch之Dataloader参数collate_fn研究

    前言 之前看了不到pytorch代码,对Dataloader的大部分参数都比较了解,今天看代码时,发现了一个参数collate_fn ,之前论文代码没怎么见过,也就自动忽略了,今天既然遇到了,就突然来 ...

  9. pytorch 定义torch类型数据_PyTorch 使用TorchText进行文本分类

    本教程演示如何在 torchtext 中使用文本分类数据集,包括 - AG_NEWS, - SogouNews, - DBpedia, - YelpReviewPolarity, - YelpRevi ...

最新文章

  1. 初始化CISCO路由器和交换机密码
  2. 卷积神经网络的实际意义
  3. u6系统服务器启动不了,u6链接不到服务器
  4. 浅谈auto与decltype函数的区别
  5. vue-codemirror基本用法:实现搜索功能、代码折叠功能、获取编辑器值及时验证
  6. java linkedlist 更新_Java填坑系列之LinkedList
  7. 计算机算法设计与分析 大学生电影节观影问题
  8. 万亿美元软件浪潮来临,开发者是核心!
  9. 从这6个方面,帮你轻松管理Chrome中保存的密码!
  10. MySql学习笔记【二、库相关操作】
  11. 磁盘和文件系统的管理
  12. 黑马vue实战项目-(六)商品列表组件的开发
  13. 群辉linux系统,[教程] 群晖VMM虚拟机安装Linux系统无法成功启动桌面的解决办法...
  14. 洛谷P1000 超级玛丽游戏c语言基础
  15. 小米6自动重启android,小米6充电重启怎么办?小米6充电自动重启解决方法介绍...
  16. mcafee杀毒软件编写规则时通配符使用方法
  17. 项目管理绝版秘籍——IT项目管理全套127个表格文档
  18. 英语流利说l4u1p2_L4-U1-P2-3 Vocabulary : Science 英语流利说 懂你英语
  19. Nginx 配置SSL 证书 cannot load certificate No such file or directory
  20. 云梦四时歌如何在电脑上玩 云梦四时歌模拟器教程

热门文章

  1. LTH7五脚芯片的完整方案图
  2. Oracle Primavera P6 单机版SQLite的使用(Professional)
  3. 人工智能中噪声数据的产生与处理方法详解
  4. Mac Iterm2各种用法和配置
  5. 【机器学习实验】用Python进行机器学习实验
  6. grpc stream的应用场景
  7. 关于OLSR协议中的MPR机制的阅读与理解
  8. 电脑显示没有被指定在上运行_.dll没有被指定在windows上运行怎么办
  9. html自动序号函数代码,自定义自动编号函数
  10. ESP32 Tensorflow Lite (二)TensorFlow Lite Hello World