迭代器

理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键。

在 DatasetSampler 和 DataLoader 这三个类中都会用到 python 抽象类的魔法方法,包括__len__(self)__getitem__(self) 和 __iter__(self)

  • __len__(self): 定义当被 len() 函数调用时的行为,一般返回迭代器中元素的个数
  • __getitem__(self): 定义获取容器中指定元素时的行为,相当于 self[key] ,即允许类对象拥有索引操作
  • __iter__(self): 定义当迭代容器中的元素时的行为

迭代的意思类似于循环,每一次重复的过程被称为一次迭代的过程,而每一次迭代得到的结果会被用来作为下一次迭代的初始值。提供迭代方法的容器称为迭代器,通常接触的迭代器有序列(列表、元组和字符串)还有字典,这些数据结构都支持迭代操作。

实现迭代器的魔法方法有两个:__iter__(self) 和 __next__(self)

一个容器如果是迭代器,那就必须实现 __iter__(self) 魔法方法,这个方法实际上是返回是一个迭代器(通常是迭代器本身)。接下来重点要实现的是 __next__(self) 魔法方法,因为它决定了迭代的规则。

class Fibs:def __init__(self, n=20):self.a = 0self.b = 1self.n = ndef __iter__(self):return selfdef __next__(self):self.a, self.b = self.b, self.a + self.bif self.a > self.n:raise StopIterationreturn self.afibs = Fibs()
for each in fibs:print(each)# 输出
# 1 1 2 3 5 8 13

一般来说,迭代器满足以下几种特性:

  • 迭代器是⼀个对象
  • 迭代器可以被 next() 函数调⽤,并返回⼀个值
  • 迭代器可以被 iter() 函数调⽤,并返回一个迭代器(可以是自身)
  • 连续被 next() 调⽤时依次返回⼀系列的值
  • 如果到了迭代的末尾,则抛出 StopIteration 异常
  • 迭代器也可以没有末尾,只要被 next() 调⽤,就⼀定会返回⼀个值
  • Python 中, next() 内置函数调⽤的是对象的 next() ⽅法
  • Python 中, iter() 内置函数调⽤的是对象的 iter() ⽅法
  • ⼀个实现了迭代器协议的的对象可以被 for 语句循环迭代直到终⽌

了解了什么是迭代器后,我们就可以开始解读 torch.utils.data 模块

对于 torch.utils.data 而言,重点是其 DatasetSamplerDataLoader 模块,辅以 collatefetchpin_memory 等组件对特定功能予以支持。

1 Dataset

Dataset 负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。

Dataset 共有 Map-style datasets 和 Iterable-style datasets 两种:

1.1 Map-style dataset

torch.utils.data.Dataset

它是一种通过实现 __getitem__() 和 __len()__ 来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。访问时,这样的数据集用 dataset[idx] 访问 idx 对应的数据。

通常我们使用 Map-style 类型的 dataset 居多,其数据接口定义如下:

class Dataset(Generic[T_co]):# Generic is an Abstract base class for generic types.def __getitem__(self, index) -> T_co:raise NotImplementedErrordef __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':return ConcatDataset([self, other])

PyTorch 中所有定义的 Dataset 都是其子类。

对于一般计算机视觉任务,我们通常会在其中进行一些 resize, crop, flip 等预处理的操作

值得一提的是,PyTorch 源码中并没有提供默认的 __len__() 方法实现,原因是 return NotImplemented 或者 raise NotImplementedError() 之类的默认实现都会存在各自的问题,这点在其源码中也有注释加以体现。

1.2 Iterable-style dataset

torch.utils.data.IterableDataset

它是一种实现 __iter__() 来获取数据的 Dataset,这种类型的数据集特别适用于以下情况:随机读取代价很大甚至不大可能,且 batch size 取决于获取的数据。其接口定义如下:

class IterableDataset(Dataset[T_co]):def __iter__(self) -> Iterator[T_co]:raise NotImplementedErrordef __add__(self, other: Dataset[T_co]):return ChainDataset([self, other])

特别地,当 DataLoader 的 num_workers > 0 时, 每个 worker 都将具有数据对象的不同样本。因此需要独立地对每个副本进行配置,以防止每个 worker 产生的数据不重复。同时,数据加载顺序完全由用户定义的可迭代样式控制。这允许更容易地实现块读取和动态批次大小(例如,通过每次产生一个批次的样本)

1.3 其他 Dataset

除了 Map-style dataset 和 Iterable-style dataset 以外,PyTorch 也在此基础上提供了其他类型的 Dataset 子类

  • torch.utils.data.ConcatDataset: 用于连接多个 ConcatDataset 数据集
  • torch.utils.data.ChainDataset : 用于连接多个 IterableDataset 数据集,在 IterableDataset 的 __add__() 方法中被调用
  • torch.utils.data.Subset: 用于获取指定一个索引序列对应的子数据集
class Subset(Dataset[T_co]):dataset: Dataset[T_co]indices: Sequence[int]def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:self.dataset = datasetself.indices = indicesdef __getitem__(self, idx):return self.dataset[self.indices[idx]]def __len__(self):return len(self.indices)
  • torch.utils.data.TensorDataset: 用于获取封装成 tensor 的数据集,每一个样本都通过索引张量来获得。
class TensorDataset(Dataset):def __init__(self, *tensor):assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)self.tensors = tensorsdef __getitem__(self, index):return tuple(tensor[index] for tensor in tensorsdef __len__(self):return self.tensors[0].size(0)

2 Sampler

torch.utils.data.Sampler 负责提供一种遍历数据集所有元素索引的方式。可支持用户自定义,也可以用 PyTorch 提供的,基类接口定义如下:

lass Sampler(Generic[T_co]):r"""Base class for all Samplers.Every Sampler subclass has to provide an :meth:`__iter__` method, providing away to iterate over indices of dataset elements, and a :meth:`__len__` methodthat returns the length of the returned iterators... note:: The :meth:`__len__` method isn't strictly required by:class:`~torch.utils.data.DataLoader`, but is expected in anycalculation involving the length of a :class:`~torch.utils.data.DataLoader`."""def __init__(self, data_source: Optional[Sized]) -> None:passdef __iter__(self) -> Iterator[T_co]:raise NotImplementedError

特别地__len__() 方法不是必要的,但是当 DataLoader 需要计算 len() 的时候必须定义,这点在其源码中也有注释加以体现。

同样,PyTorch 也在此基础上提供了其他类型的 Sampler 子类

  • torch.utils.data.SequentialSampler : 顺序采样样本,始终按照同一个顺序
  • torch.utils.data.RandomSampler: 可指定有无放回地,进行随机采样样本元素
  • torch.utils.data.SubsetRandomSampler: 无放回地按照给定的索引列表采样样本元素
  • torch.utils.data.WeightedRandomSampler: 按照给定的概率来采样样本。样本元素来自 [0,…,len(weights)-1] , 给定概率(权重)
  • torch.utils.data.BatchSampler: 在一个batch中封装一个其他的采样器, 返回一个 batch 大小的 index 索引
  • torch.utils.data.DistributedSample: 将数据加载限制为数据集子集的采样器。与 torch.nn.parallel.DistributedDataParallel 结合使用。 在这种情况下,每个进程都可以将 DistributedSampler 实例作为 DataLoader 采样器传递

3 DataLoader

torch.utils.data.DataLoader 是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和 Iterable-style Dataset,支持单进程/多进程,还可以设置 loading order, batch size, pin memory 等加载参数。其接口定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, *, prefetch_factor=2,persistent_workers=False)

对于每个参数的含义,以下给出一个表格进行对应介绍:

attribute meaning default value type
dataset 加载数据的数据集   Dataset
batch_size 每个 batch 加载多少个样本 1 int
shuffle 设置为 True 时,调用 RandomSampler 进行随机索引 False bool
sampler 定义从数据集中提取样本的策略
如果指定了, shuffle 参数必须为 False,(否则会和 RandomSampler 互斥)
None Sampler, Iterable
batch_sampler 和 sampler 类似,但是一般传入 BatchSampler,每次返回一个 batch 大小的索引
其和 batch_size, shuffle 等参数是互斥的
None Sampler, Iterable
num_workers 要用于数据加载的子进程数,0 表示将在主进程中加载数据 0 int
collate_fn 在将 Map-style datase t 取出的数据整合成 batch 时使用,合并样本列表以形成一个 batch None callable
pin_memory 如果为 True,则 DataLoader 在将张量返回之前将其复制到 CUDA 固定的内存中 False bool
drop_last 设置为 True 删除最后一个不完整的批次,如果该数据集大小不能被该批次大小整除。
如果 False 并且数据集的大小不能被批次大小整除,那么最后一批将较小
False bool
timeout 如果为正,则为从 worker 收集 batch 的超时值,应始终为非负数
超过这个时间还没读取到数据的话就会报错
0 numeric
worker_init_fn 如果不为 None,它将会被每个 worker 子进程调用,
以 worker id ([0, num_workers - 1] 内的整形) 为输入
None callable
prefetch_factor 每个 worker 提前加载 的 sample 数量 2 int
persistent_workers 如果为 True,dataloader 将不会终止 worker 进程,直到 dataset 迭代完成 False bool

从参数定义中,我们可以看到 DataLoader 主要支持以下几个功能

  • 支持加载 map-style 和 iterable-style 的 dataset,主要涉及到的参数是 dataset
  • 自定义数据加载顺序,主要涉及到的参数有 shufflesamplerbatch_samplercollate_fn
  • 自动把数据整理成batch序列,主要涉及到的参数有 batch_sizebatch_samplercollate_fndrop_last
  • 单进程和多进程的数据加载,主要涉及到的参数有 num_workersworker_init_fn
  • 自动进行锁页内存读取 (memory pinning),主要涉及到的参数 pin_memory
  • 支持数据预加载,主要涉及的参数 prefetch_factor

3.1 三者关系 (Dataset, Sampler, Dataloader)

通过以上介绍的三者工作内容不难推出其内在关系:

  1. 设置 Dataset,将数据 data source 包装成 Dataset 类,暴露提取接口。
  2. 设置 Sampler,决定采样方式。我们是能从 Dataset 中提取元素了,还是需要设置 Sampler 告诉程序提取 Dataset 的策略。
  3. 将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置 shufflebatch_size 等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。

总结来说,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。

pytorch源码解析2——数据处理torch.utils.data相关推荐

  1. PyTorch源码解析--torchvision.transforms(数据预处理、数据增强)

    PyTorch框架中有一个很常用的包:torchvision torchvision主要由3个子包构成:torchvision.datasets.torchvision.models.torchvis ...

  2. yolov3之pytorch源码解析_springmvc源码架构解析之view

    说在前面 前期回顾 sharding-jdbc源码解析 更新完毕 spring源码解析 更新完毕 spring-mvc源码解析 更新完毕 spring-tx源码解析 更新完毕 spring-boot源 ...

  3. SSD PyTorch源码解析

    0. 引言 0.1 代码来源 代码来源:https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Detection/SSD ...

  4. pytorch源码解析:Python层 pytorchmodule源码

    尝试使用了pytorch,相比其他深度学习框架,pytorch显得简洁易懂.花时间读了部分源码,主要结合简单例子带着问题阅读,不涉及源码中C拓展库的实现. 一个简单例子 实现单层softmax二分类, ...

  5. 【PyTorch】torch.utils.data.Dataset 介绍与实战

    文章目录 一.前言 二.torch.utils.data.Dataset 是什么 1. 干什么用的? 2. 长什么样子? 三.通过继承 torch.utils.data.Dataset 定义自己的数据 ...

  6. MyBatis源码- SqlSession门面模式 selectList 源码解析

    文章目录 Pre 工程概览 pom.xml mybatis-config.xml UserMapper 测试类 selectList 源码解析 附 SQL log4j.properties app.p ...

  7. PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

    目录 0 前言 1 Dataset 1.1 Map-style dataset 1.2 Iterable-style dataset 1.3 其他 dataset 2 Sampler 3 DataLo ...

  8. pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)

    写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...

  9. [源码解析] PyTorch 流水线并行实现 (1)--基础知识

    [源码解析] PyTorch 流水线并行实现 (1)–基础知识 文章目录 [源码解析] PyTorch 流水线并行实现 (1)--基础知识 0x00 摘要 0x01 历史 1.1 GPipe 1.2 ...

最新文章

  1. Tensorflow安装与测试
  2. 【科普】国内外高质量数据科学竞赛平台有哪些?
  3. openSUSE中启用apache mod_rewrite
  4. 配置sql server 2000以允许远程访问
  5. jQuery向未来的元素添加事件处理程序(绑定事件)
  6. 施一公谈自己35岁和53岁的区别
  7. 交付铁三角的故事之兵戎相见
  8. 字节跳动斩获支付牌照欲建金融帝国,技术实力配得上野心吗?
  9. Dalsa线扫相机SDK下载和安装
  10. 重启报错_AFAB折旧计提报错:科目xxxxx要求一个成本会计分配 及重启问题
  11. 关于代理。谢谢方志朋
  12. MSP430FR5994LannchPad开发笔记之三:MSP430的IO复用以及如何去获取IO复用功能
  13. 微型计算机常用显示器,专业显示器只买某卓?那是你没见识过这款专业显示器的厉害...
  14. 基于神经网络的指纹识别,指纹比对技术何时出现
  15. 反客为主:巧妙用grldr冒名顶替ntldr引导XP/Ubuntu
  16. Mermaid知识点总结3 - Flowchart 2
  17. 新媒体运营教程:短视频剧本创作技巧
  18. windows提权常用系统漏洞与对应的补丁编号
  19. 干电池电量采集_干电池电量的检测方法,干电池的常用保存方法
  20. 面向对象编程实验课随笔(承继下的构造函数和析构函数)

热门文章

  1. SAP License:SAP顾问是如何炼成的——SAP到底是什么?
  2. 风控报表大全(全面触及)
  3. 《如何搭建小微企业风控模型》第八节 反欺诈策略 节选
  4. pl/sql实现打印九九乘法表
  5. jquery click点击事件重复执行多次
  6. Java中怎么控制线程訪问资源的数量
  7. 有道词典Mac版崩溃信息
  8. python全栈开发_day20_加密模块和excel操作模块以及xml
  9. 获取浏览器语言的解决方案
  10. oracle sqlLoader 批量导入工具使用说明