1. torch.utils.data.DataLoader类:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
  • 作用:是加载数据的核心,返回可迭代的数据。
  • PyTorch中数据读取的一个重要接口torch.utils.data.DataLoader,该接口定义在dataloader.py脚本中,只要是用PyTorch来训练模型基本都会用到该接口。
  • 该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成V;ariable即可作为模型的输入,因此该接口有点承上启下的作用,比较重要。

  • 参数:
* dataset (Dataset): 加载数据的数据集
* batch_size (int, optional): 每批加载多少个样本
* shuffle (bool, optional): 设置为“真”时,在每个epoch对数据打乱.(默认:False)
* sampler (Sampler, optional): 定义从数据集中提取样本的策略,返回一个样本
* batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time 返回一批样本. 与atch_size, shuffle, sampler和 drop_last互斥.
* num_workers (int, optional): 用于加载数据的子进程数。0表示数据将在主进程中加载​​。(默认:0)
* collate_fn (callable, optional): 合并样本列表以形成一个 mini-batch.  # callable可调用对象
* pin_memory (bool, optional): 如果为 True, 数据加载器会将张量复制到 CUDA 固定内存中,然后再返回它们.
* drop_last (bool, optional): 设定为 True 如果数据集大小不能被批量大小整除的时候, 将丢掉最后一个不完整的batch,(默认:False).
* timeout (numeric, optional): 如果为正值,则为从工作人员收集批次的超时值。应始终是非负的。(默认:0)
* worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None).# num_workers  为0的话表示:数据导入在主进程中进行;其他大于0的数表示:通过多个进程来导入数据,可以加快数据导入速度。

self.num_workers等于0的情况,也就是不采用多进程进行数据读取。先通过indices = next(self.sample_iter)获取长度为batch size的列表:indices,这个列表的每个值表示一个batch中每个数据的index,每执行一次next操作都会读取一批长度为batch size的indices列表。             然后通过self.collate_fn函数将batch size个tuple(每个tuple长度为2,其中第一个值是数据,Tensor类型,第二个值是标签,int类型)封装成一个list,这个list长度为2,两个值都是Tensor,一个是batch size个数据组成的FloatTensor,另一个是batch size个标签组成的LongTensor。所以简单讲self.collate_fn函数就是将batch size个分散的Tensor封装成一个Tensor。

self.num_workers语句是针对多进程或单进程的情况进行初始化,如果不是设置为多进程读取数据,那么就不需要这些初始化操作,后面会介绍单进程数据读取。

通过multiprocessing.SimpleQueue()类创建了一个简单的队列对象。multiprocessing.Process类就是构造进程的类,这里根据设定的进程数来启动,然后赋值给self.workers。接下来的一个for循环就通过调用start方法依次启动self.workers中的进程。

如果设置为多进程读取数据,那么就会采用队列的方式来读,如果不是采用多进程来读取数据,那就采用普通方式来读

2. DataLoader类源代码:

先看看__init__中的几个重要的输入,也就是参数,参数上面已经解释过了。

在__init__中,RandomSampler类表示随机采样且不重复,所以起到的就是shuffle的作用

BatchSampler类则是把batch size个RandomSampler类对象封装成一个,这样就实现了随机选取一个batch的目的。这两个采样类都;是定义在sampler.py脚本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上这些都是初始化的时候进行的。当代码运行到要从torch.utils.data.DataLoader类生成的对象中取数据的时候,比如: 
train_data=torch.utils.data.DataLoader(...) 
for i, (input, target) in enumerate(train_data): 
... 
就会调用DataLoader类的__iter__方法,__iter__方法就一行代码:return DataLoaderIter(self),输入正是DataLoader类的属性。因此当调用__iter__方法的时候就牵扯到另外一个类:DataLoaderIter,接下来介绍。

class DataLoader(object):
"""Data loader. Combines a dataset and a sampler, and providessingle- or multi-process iterators over the dataset.Arguments:dataset (Dataset): dataset from which to load the data.batch_size (int, optional): how many samples per batch to load(default: 1).shuffle (bool, optional): set to ``True`` to have the data reshuffledat every epoch (default: False).sampler (Sampler, optional): defines the strategy to draw samples fromthe dataset. If specified, ``shuffle`` must be False.batch_sampler (Sampler, optional): like sampler, but returns a batch ofindices at a time. Mutually exclusive with batch_size, shuffle,sampler, and drop_last.num_workers (int, optional): how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)collate_fn (callable, optional): merges a list of samples to form a mini-batch.pin_memory (bool, optional): If ``True``, the data loader will copy tensorsinto CUDA pinned memory before returning them.drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,if the dataset size is not divisible by the batch size. If ``False`` andthe size of dataset is not divisible by the batch size, then the last batchwill be smaller. (default: False)timeout (numeric, optional): if positive, the timeout value for collecting a batchfrom workers. Should always be non-negative. (default: 0)worker_init_fn (callable, optional): If not None, this will be called on eachworker subprocess with the worker id as input, after seeding and before dataloading. (default: None)
"""def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None):self.dataset = datasetself.batch_size = batch_sizeself.num_workers = num_workersself.collate_fn = collate_fnself.pin_memory = pin_memoryself.drop_last = drop_lastself.timeout = timeoutself.worker_init_fn = worker_init_fnif timeout < 0:raise ValueError('timeout option should be non-negative')if batch_sampler is not None:if batch_size > 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler is mutually exclusive with ''batch_size, shuffle, sampler, and drop_last')if sampler is not None and shuffle:raise ValueError('sampler is mutually exclusive with shuffle')if self.num_workers < 0:raise ValueError('num_workers cannot be negative; ''use num_workers=0 to disable multiprocessing.')if batch_sampler is None:if sampler is None:if shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)batch_sampler = BatchSampler(sampler, batch_size, drop_last)self.sampler = samplerself.batch_sampler = batch_samplerdef __iter__(self):return DataLoaderIter(self)def __len__(self):return len(self.batch_sampler)

原博客:https://blog.csdn.net/u014380165/article/details/79058479

PyTorch源码解读之torch.utils.data.DataLoader相关推荐

  1. PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

    目录 0 前言 1 Dataset 1.1 Map-style dataset 1.2 Iterable-style dataset 1.3 其他 dataset 2 Sampler 3 DataLo ...

  2. PyTorch 源码解读之 torch.serialization torch.hub

    作者 | 123456 来源 | OpenMMLab 编辑 | 极市平台 导读 本文解读基于PyTorch 1.7版本,对torch.serialization.torch.save和torch.hu ...

  3. 【Torch】Dataloader torch.utils.data.DataLoader全面详实概念理解

    目录 1.torch.utils.data.DataLoader概念介绍 2.torch.utils.data.DataLoader参数介绍 3 案例体会 DataLoader:[batch_size ...

  4. 阅读源码-理解torch.utils.data、torch.utils.data.Dataset、torch.utils.data.DataLoader的工作方式

    文章目录 目标 Dataset DataLoader 应用 Dataset DataLoader 测试 知识点 Python splitlines()方法 python filter()函数 暂时先写 ...

  5. Pytorch源码解读——DataLoader模块

    torch/utils/data/_utils/dataloader.py 通常在使用pytorch训练神经网络时,DataLoader模块是整个网络训练过程中的基础前提且尤为重要,其主要作用是根据传 ...

  6. PyTorch 源码解读之分布式训练了解一下?

    来源丨商汤学术   编辑丨极市平台 本文由浅入深讲解 torch.distributed 这一并行计算包的概念,实现细节和应用方式,并带大家快速入门 PyTorch 分布式训练. 0 前言 由于大规模 ...

  7. Pytorch 中的数据类型 torch.utils.data.DataLoader 参数详解

    DataLoader是PyTorch中的一种数据类型,它定义了如何读取数据方式.详情也可参考本博主的另一篇关于torch.utils.data.DataLoader(https://blog.csdn ...

  8. PyTorch 源码解读之 BN SyncBN:BN 与 多卡同步 BN 详解

    目录 1. BatchNorm 原理 2. BatchNorm 的 PyTorch 实现 2.1 _NormBase 类 2.1.1 初始化 2.1.2 模拟 BN forward 2.1.3 run ...

  9. 分布式训练PyTorch 源码解读

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:商汤 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 0 前 ...

  10. PyTorch—torch.utils.data.DataLoader 数据加载类

    文章目录 DataLoader(object)类: _DataLoaderIter(object)类 __next__函数 pin_memory_batch() _get_batch函数 _proce ...

最新文章

  1. python不知道错在哪里怎么办_python怎么处理错误和异常
  2. 大数据测试环境服务器硬件推荐配置_服务器托管和服务器租用的区别
  3. kafka 集群_10分钟搭建单机Kafka集群
  4. 【IOS学习基础】OC类的相关
  5. 机器学习之决策树熵信息增量求解算法实现
  6. Java方法中的参数太多,第3部分:构建器模式
  7. windows故障转移群集和mysql_Windows 2016 无域故障转移群集部署方法 超详细图文教程...
  8. Windows OS上安装运行Apache Kafka教程
  9. 什么是php的ast结构,什么是AST?Vue源码中AST语法树的解析
  10. u盘ios刻录_win10 iso刻录到u盘操作教程
  11. JavaScript红宝书、犀牛书(2本)简介
  12. 使用 VLD 检测内存泄漏
  13. 一些可以参考的文档集合5
  14. 高中数学知识点总结:函数零点经典例题解题技巧与方法总结
  15. PLC温室大棚自动控制系统
  16. NOIP模拟赛 czy的后宫5
  17. 什么是商业智能(BI),就看这篇文章足够了
  18. Res2Net: A New Multi-scale Backbone Architecture
  19. 反甩锅成功后思考——RST 报文
  20. js获取PC设备信息,js获取手机设备信息,最全

热门文章

  1. python之sklearn
  2. 七月算法机器学习 8 信息论、最大熵模型与EM算法
  3. Python3入门机器学习经典算法与应用 第3章 Numpy数组的合并与分割
  4. Atitit 命令行执行springboot程序 目录 1.1. 执行spel表达式,调用app main,获取context 1 1.2. 直接在Application main函数内执行 1
  5. 目录 1. 常见mime类型 1 1.1. 2.1.1. Type application 2 2.1.2. Type audio 22.1.3. Type image 32.1.4. Type t
  6. Atitit. Attilax软件研发and开发之道 1. 基本语言 3 2. 标准化库api 3 3. Ied与代码编写 调试 3 4. ui 3 5. 通讯 3 6. 第三方库 3 7. 数据
  7. Atiitt 提升复用性之道 项目成本之道 Atitit 代码复用的理解attilax总结 1. 复用分类 1 1.1. 类库侧重代码重用,框架侧重设计重用 2 2. 文档与索引体系 2 3
  8. Atitit.prototype-base class-based  基于“类” vs 基于“原型”
  9. paip.提升用户体验-----c++ gcc 命令在notepad++扩展中的配置..
  10. bbs与BLOG与SNS在区别