1 基本流程

首先熟知,pytorch 的数据加载到模型的操作顺序是这样的:

  • 创建一个 Dataset 对象
  • 创建一个 DataLoader 对象
  • 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):for img, label in dataloader:....

2 参数介绍

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,shuffle: bool = False, sampler: Optional[Sampler[int]] = None,batch_sampler: Optional[Sampler[Sequence[int]]] = None,num_workers: int = 0, collate_fn: _collate_fn_t = None,pin_memory: bool = False, drop_last: bool = False,timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,multiprocessing_context=None, generator=None,*, prefetch_factor: int = 2,persistent_workers: bool = False):
  • dataset(Dataset): 传入的数据集

  • batch_size(int, optional): 每个batch有多少个样本

  • shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序

  • sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

  • batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

  • num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

  • collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

  • pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

  • drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
    如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

  • timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

  • worker_init_fn (callable, optional): 用户定义的每个worker初始化的时候需要执行的函数。如果不是None, 则会以worker id[0, num_workers - 1]的每个子进程调用(在sedding后,数据加载前)

3 关于worker_init_fn的问题

如果程序一开始指定了各种seed,为了得到确定的结果。如

    random.seed(args.seed)np.random.seed(args.seed)torch.manual_seed(args.seed)torch.cuda.manual_seed(args.seed)

此时,在生成Dataloader时,如果要设置num_workers的数量大于0,使用多进程。则需要传入

    def worker_init_fn(worker_id):random.seed(args.seed + worker_id)

4 关于pin_memory的问题

pin_memory就是锁页内存,创建DataLoader时,设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。

主机中的内存,有两种存在方式,一是锁页,二是不锁页,锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。而显卡中的显存全部是锁页内存!

当计算机的内存充足的时候,可以设置pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False。因为pin_memory与电脑硬件性能有关,pytorch开发者不能确保每一个炼丹玩家都有高端设备,因此pin_memory默认为False。

参考资料

  1. https://blog.csdn.net/g11d111/article/details/81504637
  2. https://blog.csdn.net/qq_40612314/article/details/114435334

Pytorch中DataLoader类相关推荐

  1. Pytorch中Dataloader踩坑:RuntimeError: DataLoader worker (pid(s) 6700, 10620) exited unexpectedly

    Pytorch中Dataloader踩坑 环境: 问题背景: 观察报错信息进行分析 根据分析进行修改尝试 总结 环境: 系统:windows10 Pytorch版本:1.5.1+cu101 问题背景: ...

  2. pytorch中DataLoader的num_workers参数详解与设置大小建议

    Q:在给Dataloader设置worker数量(num_worker)时,到底设置多少合适?这个worker到底怎么工作的? train_loader = torch.utils.data.Data ...

  3. pytorch中Dataloader()中的num_workers设置问题

    pytorch中Dataloader()中的num_workers设置问题: 如果num_workers的值大于0,要在运行的部分放进__main__()函数里,才不会有错: import numpy ...

  4. pytorch中dataloader的num_workers参数

    结论速递 在Windows系统中,num_workers参数建议设为0,在Linux系统则不需担心. 1 问题描述 在之前的任务超大图上的节点表征学习中,使用PyG库用于数据加载的DataLoader ...

  5. pytorch中DataLoader的num_workers

    Question 一直很迷, 在给Dataloader设置worker数量(num_worker)时,到底设置多少合适?这个worker到底怎么工作的? 如果将num_worker设为0(也是默认值) ...

  6. Pytorch中Dataloader保存文件名

    转载自:https://gist.github.com/andrewjong/6b02ff237533b3b2c554701fb53d5c4d,本文只做个人记录学习使用,版权归原作者所有. impor ...

  7. python怎么设置随机数种子_Pytorch在dataloader类中设置shuffle的随机数种子方式

    如题:Pytorch在DataLoader类中设置shuffle的随机数种子方式 虽然实验结果差别不大,但是有时候也悬殊两个百分点 想要复现实验结果 发现用到随机数的地方就是DataLoader类中封 ...

  8. Pytorch中的dataset类——创建适应任意模型的数据集接口

    作为一个2年多的不资深keraser和tfer,被boss要求全员换成pytorch.不得不说,pytorch还是真香的.之前用keras,总会发现多GPU使用的情况下不太好,对计算资源的利用率不太高 ...

  9. pytorch中调整学习率的lr_scheduler机制

    pytorch中调整学习率的lr_scheduler机制 </h1><div class="clear"></div><div class ...

最新文章

  1. 简单ajax类, 比较小, 只用ajax功能时, 可以考虑它
  2. Vijos P1449 字符串还原【密码】
  3. C++11 统一初始化(Uniform Initialization)
  4. 我同事狠心用 Python 3 ,刚开始就直接崩溃!你们试试......
  5. (转)Hibernate关联映射——一对多(多对一)
  6. matlab转dsp软件,matlab/simulink程序代写 DSP程序开发
  7. Android:Android SDK的下载与安装
  8. Java后端避坑——如何使用注解忽略掉JavaBean的属性值
  9. velocity include
  10. summernote 字体名字不显示_觉得 Windows 10 显示字体不好看吗? 教你轻松更换成 Mac 字体版本。...
  11. 《Perl语言入门》学习笔记
  12. python os模块
  13. 什么是PXE及PXE启动
  14. 最小外接矩形--最大内接矩形
  15. Flume之HDFS Sink 的参数解析及异常处理
  16. 增长黑客理论(AARRR)模型
  17. 改cpp[1] Vscode Hex Editor,在vscode中查看内存
  18. 如何使用Stack Overflow ?
  19. 【读书笔记】《曾国藩的正面与侧面(一)》
  20. PMP杂谈--快速记忆ITTO

热门文章

  1. Luogu1574 超级数
  2. 利用WINDOWS活动目录提供LDAP的方案
  3. java中 快捷键输入System.out.println();
  4. java枚举的例子_Java枚举例子
  5. java 网络传输中发送byte[]和接收到的不一致_为什么JAVA对象需要实现序列化?
  6. bool类型头文件_[C++基础入门] 2、数据类型
  7. esxi 7.0 封装瑞昱网卡驱动_小科普 | 无线网卡怎么选?
  8. php课程实验总结报告_PHP课程总结20161125
  9. switch日版有中文吗_任天堂switch国行和日版的区别
  10. c语言错误指导,c语言编程指导.pdf