[Pytorch] Sampler, DataLoader和数据batch的形成
目录
1. 简介
2. 整体流程
3. Sampler和BatchSampler
3.1 Sampler
3.2 BatchSampler
4. DataLoader
4.1 DataLoader
4.2 _DataLoaderIter
1. 简介
本文将简介pytorch采样器Sampler和数据加载器DataLoader,并解释在读取数据时每个batch形成的过程,附上部分源码解读。
了解这些能帮助我们更好地研究采样(sample)方法和模型训练。希望阅读后能让各位对数据批次产生的过程更加清晰。
让我们开始吧。
2. 整体流程
简要来说在pytorch中,Sampler负责决定读取数据时的先后顺序,DataLoader负责装载数据并根据Sampler提供的顺序安排数据,具体过程绘图和描述如下。
初始化DataLoader的时候需指定数据集Dataset(包括数据和标签),Sampler可选,没有Sampler时会根据是否打乱数据顺序(shuffle)分别采用顺序采样器(sequential sampler)和随机采样器(random sampler)。
第①步,Sampler首先根据Dataset的大小n形成一个可迭代的序号列表[0~n-1]。
第②步,BatchSampler根据DataLoader的batch_size参数将Sampler提供的序列划分成多个batch大小的可迭代序列组,drop_last参数决定是否保留最后一组。
第③步,兵分两路的Sampler(BatchSampler)和Dataset合二为一,在迭代读取DataLoader时,用BatchSampler中一个batch的编号查找Dataset中对应的数据和标签,读出一个batch数据。
举个例子。
假如数据集D={X,Y},其中数据X为[野兔在野外.png,野猫在野外.png,野猫在家.png,野狗在家.png,野狗在野外.png],标签Y为[0,1,1,2,2]
第①步,初始的序号列表为[0, 1, 2, 3, 4],使用RandomSampler采样,不重复(replacement==FALSE),得到了采样后的序号列表[3, 2, 1, 0, 4]
第②步:输入batch_size为2,drop_last为FALSE,所以用BatchSampler批次采样,形成列表[[3, 2], [1, 0], [4]];若drop_last为TRUE,则列表变为[[3, 2], [1, 0]]
第③步:迭代读取数据,根据序号从Dataset里找到相应数据和标签,如第一个batch为:
[[野狗在家.png, 野猫在家.png], [2, 1]]
以上就是形成一个batch数据的整个流程,下文将从代码角度深入介绍各个Class中的重要参数和函数。我是用较旧的pytorch版本(0.4.1.post2),也自己对照了一下1.7.0版本的代码。其中BatchSampler类基本一致,Sampler类去掉了__len__()方法,总的来说采样改动不大;DataLoader类主要是针对多线程做了很多优化,具体代码中也补充了大量注释,整体基础仍然是本文提到的几个方法。
3. Sampler和BatchSampler
3.1 Sampler
知乎上一篇文章对pytorch Sampler进行了很详细的讲解:一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
简要来说,Sampler类__init__()方法用于初始化采样算法,__iter__()方法用torch的random、multinomial方法实现随机和基于权重的采样并返回可迭代对象,__len__()是返回采样长度。
3.2 BatchSampler
参数:
sampler(Sampler类):输入的sampler
batch_size(int类):设定的批次大小
drop_last(bool类):是否弃掉不足batch_size大小的最后一个批次
重要函数:
__init__初始化各项参数
def __init__(self, sampler, batch_size, drop_last):# ...self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_last
__iter__循环读取sampler生成的序号列表,采样够batch_size大小后,返回batch,下一次清空batch继续采集。
def __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:# 通过yield返回,下一个iter时清空batch继续采集yield batchbatch = []# 如果不需drop最后一组返回最后一组if len(batch) > 0 and not self.drop_last:yield batch
__len__返回batch数量,如果drop最后一个,则序列长度对batch_size取整,否则加上一
def __len__(self):if self.drop_last:return len(self.sampler) // self.batch_sizeelse:return (len(self.sampler) + self.batch_size - 1) // self.batch_size
4. DataLoader
4.1 DataLoader
重要参数:
dataset(Dataset类):Dataset类型的输入数据,由数据和标签组成
batch_size(int类):同BatchSampler
shuffle(bool类):是否打乱数据顺序
sampler(Sampler类):同BatchSampler
batch_sampler(BatchSampler类)
drop_last(bool类):同BatchSampler
重要函数:
__init__中对参数关系中的互斥情况进行了排除,指定sampler并通过batch_sampler分出batch,
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None):# ...# 互斥关系,指定了batch_sampler时,batch_size,shuffle,sampler和drop_last无效if batch_sampler is not None:if batch_size > 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')self.batch_size = Noneself.drop_last = None# 互斥关系,指定了sampler时,shuffle无效if sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if self.num_workers < 0:raise ValueError('num_workers option cannot be negative; ''use num_workers=0 to disable multiprocessing.')# 此处可以看出,shuffle与否其实还是靠sampler类型实现的# 当不指定sampler时,不shuffle就是顺序采样,shuffle就是随机采样if batch_sampler is None:if sampler is None:if shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)# 用batch_sampler对sampler产生的序列划分batchbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.sampler = samplerself.batch_sampler = batch_samplerself.__initialized = True
DataLoader的__iter__是在_DataLoaderIter类中实现的,该类也是整个迭代方法的核心
def __iter__(self):return _DataLoaderIter(self)
4.2 _DataLoaderIter
__init__初始化并指定了sampler_iter,即batch_sampler
def __init__(self, loader):self.dataset = loader.datasetself.collate_fn = loader.collate_fnself.batch_sampler = loader.batch_samplerself.num_workers = loader.num_workersself.pin_memory = loader.pin_memory and torch.cuda.is_available()self.timeout = loader.timeoutself.done_event = threading.Event()self.sample_iter = iter(self.batch_sampler)# ...
_get_batch读取数据,加入了连接超时的判断
def _get_batch(self):# 连接超时if self.timeout > 0:try:return self.data_queue.get(timeout=self.timeout)except queue.Empty:raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))else:return self.data_queue.get()
_DataLoaderIter在每次调用时会执行__next__方法返回下一个batch
def __next__(self):if self.num_workers == 0: # same-process loadingindices = next(self.sample_iter) # may raise StopIterationbatch = self.collate_fn([self.dataset[i] for i in indices])if self.pin_memory:batch = pin_memory_batch(batch)return batch# check if the next sample has already been generatedif self.rcvd_idx in self.reorder_dict:batch = self.reorder_dict.pop(self.rcvd_idx)return self._process_next_batch(batch)if self.batches_outstanding == 0:self._shutdown_workers()raise StopIterationwhile True:assert (not self.shutdown and self.batches_outstanding > 0)idx, batch = self._get_batch()self.batches_outstanding -= 1if idx != self.rcvd_idx:# store out-of-order samplesself.reorder_dict[idx] = batchcontinuereturn self._process_next_batch(batch)# 调用时执行__next__
next = __next__ # Python 2 compatibility
欢迎交流和指正。
[Pytorch] Sampler, DataLoader和数据batch的形成相关推荐
- dataset__getitem___一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...
- pytorch Dataset, DataLoader产生自定义的训练数据
pytorch Dataset, DataLoader产生自定义的训练数据 目录 pytorch Dataset, DataLoader产生自定义的训练数据 1. torch.utils.data.D ...
- 查看dataloader的大小_一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...
- 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
自上而下理解三者关系 首先我们看一下DataLoader.__next__的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据). c ...
- PyTorch手把手自定义Dataloader读取数据
PyTorch手把手自定义Dataloader读取数据 https://zhuanlan.zhihu.com/p/35698470 pytorch之dataloader深入剖析 https://www ...
- PyTorch 之 DataLoader
DataLoader DataLoader 是 PyTorch 中读取数据的一个重要接口,该接口定义在 dataloader.py 文件中,该接口的目的: 将自定义的 Dataset 根据 batch ...
- PyTorch主要组成模块 | 数据读入 | 数据预处理 | 模型构建 | 模型初始化 | 损失函数 | 优化器 | 训练与评估
文章目录 一.深度学习任务框架 二.数据读入 三.数据预处理模块-transforms 1.数据预处理transforms模块机制 2.二十二种transforms数据预处理方法 1.裁剪 2. 翻转 ...
- Pytorch之Dataloader参数collate_fn研究
前言 之前看了不到pytorch代码,对Dataloader的大部分参数都比较了解,今天看代码时,发现了一个参数collate_fn ,之前论文代码没怎么见过,也就自动忽略了,今天既然遇到了,就突然来 ...
- pytorch 定义torch类型数据_PyTorch 使用TorchText进行文本分类
本教程演示如何在 torchtext 中使用文本分类数据集,包括 - AG_NEWS, - SogouNews, - DBpedia, - YelpReviewPolarity, - YelpRevi ...
最新文章
- 初始化CISCO路由器和交换机密码
- 卷积神经网络的实际意义
- u6系统服务器启动不了,u6链接不到服务器
- 浅谈auto与decltype函数的区别
- vue-codemirror基本用法:实现搜索功能、代码折叠功能、获取编辑器值及时验证
- java linkedlist 更新_Java填坑系列之LinkedList
- 计算机算法设计与分析 大学生电影节观影问题
- 万亿美元软件浪潮来临,开发者是核心!
- 从这6个方面,帮你轻松管理Chrome中保存的密码!
- MySql学习笔记【二、库相关操作】
- 磁盘和文件系统的管理
- 黑马vue实战项目-(六)商品列表组件的开发
- 群辉linux系统,[教程] 群晖VMM虚拟机安装Linux系统无法成功启动桌面的解决办法...
- 洛谷P1000 超级玛丽游戏c语言基础
- 小米6自动重启android,小米6充电重启怎么办?小米6充电自动重启解决方法介绍...
- mcafee杀毒软件编写规则时通配符使用方法
- 项目管理绝版秘籍——IT项目管理全套127个表格文档
- 英语流利说l4u1p2_L4-U1-P2-3 Vocabulary : Science 英语流利说 懂你英语
- Nginx 配置SSL 证书 cannot load certificate No such file or directory
- 云梦四时歌如何在电脑上玩 云梦四时歌模拟器教程
热门文章
- LTH7五脚芯片的完整方案图
- Oracle Primavera P6 单机版SQLite的使用(Professional)
- 人工智能中噪声数据的产生与处理方法详解
- Mac Iterm2各种用法和配置
- 【机器学习实验】用Python进行机器学习实验
- grpc stream的应用场景
- 关于OLSR协议中的MPR机制的阅读与理解
- 电脑显示没有被指定在上运行_.dll没有被指定在windows上运行怎么办
- html自动序号函数代码,自定义自动编号函数
- ESP32 Tensorflow Lite (二)TensorFlow Lite Hello World