pytorch的数据加载:torch.utils.data.DataLoader
源码解析
python中iterator和iterable的区别:前者表示迭代器,后者表示是否可迭代;迭代器不用知道一共多少元素,需要next往下查找,可以用for结构的就是iterator;而list/truple/map/dict都可迭代,有固定的大小,不是迭代器。

只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle、如何采样(sampler)等封装成一个Batch Size大小的Tensor,用于后面的训练。

yield介绍:相当于一个生成器,有点类似于return

下面是DataLoader的源码:

class DataLoader(object):r"""Data loader. Combines a dataset and a sampler, and provides an iterable overthe given dataset.The :class:`~torch.utils.data.DataLoader` supports both map-style anditerable-style datasets with single- or multi-process loading, customizingloading order and optional automatic batching (collation) and memory pinning.See :py:mod:`torch.utils.data` documentation page for more details.Arguments:dataset (Dataset): dataset from which to load the data.batch_size (int, optional): how many samples per batch to load(default: ``1``).shuffle (bool, optional): set to ``True`` to have the data reshuffledat every epoch (default: ``False``).sampler (Sampler, optional): defines the strategy to draw samples fromthe dataset. If specified, :attr:`shuffle` must be ``False``.batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch ofindices at a time. Mutually exclusive with :attr:`batch_size`,:attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.num_workers (int, optional): how many subprocesses to use for dataloading. ``0`` means that the data will be loaded in the main process.(default: ``0``)collate_fn (callable, optional): merges a list of samples to form amini-batch of Tensor(s).  Used when using batched loading from amap-style dataset.pin_memory (bool, optional): If ``True``, the data loader will copy Tensorsinto CUDA pinned memory before returning them.  If your data elementsare a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,see the example below.drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,if the dataset size is not divisible by the batch size. If ``False`` andthe size of dataset is not divisible by the batch size, then the last batchwill be smaller. (default: ``False``)timeout (numeric, optional): if positive, the timeout value for collecting a batchfrom workers. Should always be non-negative. (default: ``0``)worker_init_fn (callable, optional): If not ``None``, this will be called on eachworker subprocess with the worker id (an int in ``[0, num_workers - 1]``) asinput, after seeding and before data loading. (default: ``None``).. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`cannot be an unpicklable object, e.g., a lambda function. See:ref:`multiprocessing-best-practices` on more details relatedto multiprocessing in PyTorch... note:: ``len(dataloader)`` heuristic is based on the length of the sampler used.When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,``len(dataset)`` (if implemented) is returned instead, regardlessof multi-process loading configurations, because PyTorch trustuser :attr:`dataset` code in correctly handling multi-processloading to avoid duplicate data. See `Dataset Types`_ for moredetails on these two types of datasets and how:class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_."""__initialized = Falsedef __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, multiprocessing_context=None):torch._C._log_api_usage_once("python.data_loader")if num_workers < 0:raise ValueError('num_workers option should be non-negative; ''use num_workers=0 to disable multiprocessing.')if timeout < 0:raise ValueError('timeout option should be non-negative')self.dataset = datasetself.num_workers = num_workersself.pin_memory = pin_memoryself.timeout = timeoutself.worker_init_fn = worker_init_fnself.multiprocessing_context = multiprocessing_context# Arg-check dataset related before checking samplers because we want to# tell users that iterable-style datasets are incompatible with custom# samplers first, so that they don't learn that this combo doesn't work# after spending time fixing the custom sampler errors.if isinstance(dataset, IterableDataset):self._dataset_kind = _DatasetKind.Iterable# NOTE [ Custom Samplers and `IterableDataset` ]## `IterableDataset` does not support custom `batch_sampler` or# `sampler` since the key is irrelevant (unless we support# generator-style dataset one day...).## For `sampler`, we always create a dummy sampler. This is an# infinite sampler even when the dataset may have an implemented# finite `__len__` because in multi-process data loading, naive# settings will return duplicated data (which may be desired), and# thus using a sampler with length matching that of dataset will# cause data lost (you may have duplicates of the first couple# batches, but never see anything afterwards). Therefore,# `Iterabledataset` always uses an infinite sampler, an instance of# `_InfiniteConstantSampler` defined above.## A custom `batch_sampler` essentially only controls the batch size.# However, it is unclear how useful it would be since an iterable-style# dataset can handle that within itself. Moreover, it is pointless# in multi-process data loading as the assignment order of batches# to workers is an implementation detail so users can not control# how to batchify each worker's iterable. Thus, we disable this# option. If this turns out to be useful in future, we can re-enable# this, and support custom samplers that specify the assignments to# specific workers.if shuffle is not False:raise ValueError("DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))elif sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))elif batch_sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""batch_sampler option, but got batch_sampler={}".format(batch_sampler))else:self._dataset_kind = _DatasetKind.Mapif sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if batch_sampler is not None:# auto_collation with custom batch_samplerif batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')batch_size = Nonedrop_last = Falseelif batch_size is None:# no auto_collationif shuffle or drop_last:raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with ''shuffle, and drop_last')if sampler is None:  # give default samplersif self._dataset_kind == _DatasetKind.Iterable:# See NOTE [ Custom Samplers and IterableDataset ]sampler = _InfiniteConstantSampler()else:  # map-styleif shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.batch_size = batch_sizeself.drop_last = drop_lastself.sampler = samplerself.batch_sampler = batch_samplerif collate_fn is None:if self._auto_collation:collate_fn = _utils.collate.default_collateelse:collate_fn = _utils.collate.default_convertself.collate_fn = collate_fnself.__initialized = Trueself._IterableDataset_len_called = None  # See NOTE [ IterableDataset and __len__ ]@propertydef multiprocessing_context(self):return self.__multiprocessing_context@multiprocessing_context.setterdef multiprocessing_context(self, multiprocessing_context):if multiprocessing_context is not None:if self.num_workers > 0:if not multiprocessing._supports_context:raise ValueError('multiprocessing_context relies on Python >= 3.4, with ''support for different start methods')if isinstance(multiprocessing_context, string_classes):valid_start_methods = multiprocessing.get_all_start_methods()if multiprocessing_context not in valid_start_methods:raise ValueError(('multiprocessing_context option ''should specify a valid start method in {}, but got ''multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))multiprocessing_context = multiprocessing.get_context(multiprocessing_context)if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):raise ValueError(('multiprocessing_context option should be a valid context ''object or a string specifying the start method, but got ''multiprocessing_context={}').format(multiprocessing_context))else:raise ValueError(('multiprocessing_context can only be used with ''multi-process loading (num_workers > 0), but got ''num_workers={}').format(self.num_workers))self.__multiprocessing_context = multiprocessing_contextdef __setattr__(self, attr, val):if self.__initialized and attr in ('batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset'):raise ValueError('{} attribute should not be set after {} is ''initialized'.format(attr, self.__class__.__name__))super(DataLoader, self).__setattr__(attr, val)def __iter__(self):if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:return _MultiProcessingDataLoaderIter(self)@propertydef _auto_collation(self):return self.batch_sampler is not None@propertydef _index_sampler(self):# The actual sampler used for generating indices for `_DatasetFetcher`# (see _utils/fetch.py) to read data at each time. This would be# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.# We can't change `.sampler` and `.batch_sampler` attributes for BC# reasons.if self._auto_collation:return self.batch_samplerelse:return self.samplerdef __len__(self):if self._dataset_kind == _DatasetKind.Iterable:# NOTE [ IterableDataset and __len__ ]## For `IterableDataset`, `__len__` could be inaccurate when one naively# does multi-processing data loading, since the samples will be duplicated.# However, no real use case should be actually using that behavior, so# it should count as a user error. We should generally trust user# code to do the proper thing (e.g., configure each replica differently# in `__iter__`), and give us the correct `__len__` if they choose to# implement it (this will still throw if the dataset does not implement# a `__len__`).## To provide a further warning, we track if `__len__` was called on the# `DataLoader`, save the returned value in `self._len_called`, and warn# if the iterator ends up yielding more than this number of samples.length = self._IterableDataset_len_called = len(self.dataset)return lengthelse:return len(self._index_sampler)

pytorch学习笔记之dataload相关推荐

  1. PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call

    您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...

  2. pytorch学习笔记(二):gradien

    pytorch学习笔记(二):gradient 2017年01月21日 11:15:45 阅读数:17030

  3. PyTorch学习笔记(二)——回归

    PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...

  4. pytorch 学习笔记目录

    1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...

  5. pytorch学习笔记 torchnn.ModuleList

    1 nn.ModuleList原理 nn.ModuleList,它是一个储存不同 module,并自动将每个 module 的 parameters 添加到网络之中的容器. 你可以把任意 nn.Mod ...

  6. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  7. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

  8. 深度学习入门之PyTorch学习笔记:深度学习框架

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 2.1 深度学习框架介绍 2.1.1 TensorFlow 2.1.2 Caffe 2.1.3 Theano 2.1.4 ...

  9. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

最新文章

  1. signature=680da11b802226668317d65ae7c38eb7,encryption with designated verifiers
  2. Spark 分布式调试工具
  3. android 模拟器识别,一种基于符号的识别Android应用运行在模拟器中的方法与流程...
  4. jquery.cookies使用
  5. 编写测试:VC下获取文件大小的4种方法
  6. 如何将IntelliJ项目添加到GitHub
  7. 机器学习算法总结--朴素贝叶斯
  8. 相机裁剪旋转_怎么旋转视频画面角度
  9. 软考网络管理员存储容量计算相关问题
  10. P1807 最长路 (SPFA写法)
  11. js初级——复习html+css
  12. Bootstrap3 响应式表格
  13. python版本控制git_实验一:Git代码版本管理
  14. 常用Keytool 命令
  15. mybatis 查询
  16. java爬取html过快,需要输入验证码
  17. 抖音-相关分析和理解
  18. matlab斜抛运动不用公式,分享斜抛运动中算末速度的公式
  19. 苹果中国官网新增蚂蚁花呗 24 期分期免息服务
  20. Java实现邮件抓取(亲测126,163,新浪都可以抓取到)

热门文章

  1. 拨打电话通过蓝牙接通
  2. html按钮按下效果_如何用HTML实现简单按钮样式
  3. 如何批量将文件名修改为小写?
  4. 重装win7系统、制作win7U盘启动盘【官方纯净版,不依靠任何第三方插件】
  5. 计算两个并联电阻的总电阻
  6. 4数据结构与算法(C/C++实现)视频教程-王桂林-专题视频课程
  7. Just as…,as+倒装句
  8. openedx搭建(汉化版)
  9. ESP32开发 解决VS Code 中 make menuconfig 乱码问题
  10. 视频去除水印、台标和logo,并重新在生成新的视频