Pytorch 中的数据类型 torch.utils.data.DataLoader 参数详解
DataLoader是PyTorch中的一种数据类型,它定义了如何读取数据方式。详情也可参考本博主的另一篇关于torch.utils.data.DataLoader
(https://blog.csdn.net/qq_36653505/article/details/83351808)的讨论。
在PyTorch中训练模型经常要使用它,那么该数据结构长什么样子,如何生成这样的数据类型?
下面就研究一下:
先看看 dataloader.py
源码是怎么写的(VS中按F12跳转到该脚本)
__init__
(构造函数)中的几个重要的属性:
1、dataset:(数据类型 dataset)
输入的数据类型。看名字感觉就像是数据库,C#里面也有dataset类,理论上应该还有下一级的datatable。这应当是原始数据的输入。PyTorch内也有这种数据结构。这里先不管,估计和C#的类似,这里只需要知道是输入数据类型是dataset就可以了。
2、batch_size:(数据类型 int)
每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。
3、shuffle:(数据类型 bool)
洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。
4、collate_fn:(数据类型 callable,没见过的类型)
将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)
5、batch_sampler:(数据类型 Sampler)
批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。
6、sampler:(数据类型 Sampler)
采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。
7、num_workers:(数据类型 Int)
工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。
8、pin_memory:(数据类型 bool)
内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。
9、drop_last:(数据类型 bool)
丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
10、timeout:(数据类型 numeric)
超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。
11、worker_init_fn(数据类型 callable,没见过的类型)
子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据
从DataLoader类的属性定义中可以看出,这个类的作用就是实现数据以什么方式输入到什么网络中。
代码一般是这么写的:
定义学习集 DataLoader
train_data = torch.utils.data.DataLoader(各种设置…)
将数据喂入神经网络进行训练
for i, (input, target) in enumerate(train_data):
循环代码行…
如果全部采用默认设置输入数据,数据就是一行一行按顺序输入到神经网络。如果对数据的输入有特殊要求。
比如:想打乱一下数据的排序,可以设置 shuffle(洗牌)为True;
比如:想数据是一捆的输入,可以设置 batch_size 的数目;
比如:想随机抽取的模式输入,可以设置 sampler 或 batch_sampler。如何定义抽样规则,可以看sampler.py脚本。这里不是重点;
比如:像多线程输入,可以设置 num_workers 的数目;
其他的就不太懂了,以后实际应用时碰到特殊要求再研究吧。
DataLoader类中还有3个函数:
def __setattr__(self, attr, val):if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):raise ValueError('{} attribute should not be set after {} is ''initialized'.format(attr, self.__class__.__name__))super(DataLoader, self).__setattr__(attr, val)def __iter__(self):return _DataLoaderIter(self)def __len__(self):return len(self.batch_sampler)
关键是第二个函数,
_DataLoaderIter 又是一个类,被一起写在DataLoader.py文件中。
主要是用来处理各种设置如何运作的,这里就不管那么多啦。
最后,如果要导入自己各种古灵精怪的数据,就要看看 DataSet 又是如何操作的。
torch.utils.data主要包括以下三个类:
- class torch.utils.data.Dataset
其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder. - class torch.utils.data.sampler.Sampler(data_source)
参数: data_source (Dataset) – dataset to sample from
作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度. - class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
参考
https://blog.csdn.net/rogerfang/article/details/82291464
Pytorch 中的数据类型 torch.utils.data.DataLoader 参数详解相关推荐
- 【Torch】Dataloader torch.utils.data.DataLoader全面详实概念理解
目录 1.torch.utils.data.DataLoader概念介绍 2.torch.utils.data.DataLoader参数介绍 3 案例体会 DataLoader:[batch_size ...
- PyTorch—torch.utils.data.DataLoader 数据加载类
文章目录 DataLoader(object)类: _DataLoaderIter(object)类 __next__函数 pin_memory_batch() _get_batch函数 _proce ...
- 深度学习之“制作自定义数据”--torch.utils.data.DataLoader重写构造方法。
深度学习之"制作自定义数据"–torch.utils.data.DataLoader重写构造方法. 前言: 本文讲述重写torch.utils.data.DataLoader类 ...
- 2021.08.24学习内容torch.utils.data.DataLoader以及CUDA与GPU的关系
pytorch数据加载: ①totchvision 的包,含有支持加载类似Imagenet,CIFAR10,MNIST 等公共数据集的数据加载模块 torchvision.datasets impor ...
- torch.utils.data.DataLoader 详解
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...
- 阅读源码-理解torch.utils.data、torch.utils.data.Dataset、torch.utils.data.DataLoader的工作方式
文章目录 目标 Dataset DataLoader 应用 Dataset DataLoader 测试 知识点 Python splitlines()方法 python filter()函数 暂时先写 ...
- torch.utils.data.DataLoader()的使用
数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集.在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据.直至把所有的数据都抛出.就是做一个数据的初始化. 官网上 ...
- linux top VIRT RES SHR SWAP DATA内存参数详解
Linux top VIRT RES SHR SWAP DATA内存参数详解 其实很早之前就想开博客,写一写码农几年自己积攒下来的知识与见解.看过很多文章有过很多感触,有些收获很值得梳理一下认真思考反 ...
- torch.utils.data.DataLoader()到底是什么作用?
就是数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集.在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据.直至把所有的数据都抛出.就是做一个数据的初始化. 生 ...
最新文章
- iOS开发8:使用Tool Bar切换视图
- 01-Vue博客后台管理页面框架搭建
- [C# 设计模式] Adapter - 适配器模式(两种)
- UA STAT687 线性模型理论I 线性模型概述
- ST新一代烧写工具 STM32CubeProgrammer
- 苹果maccmsv10和redis memcached缓存的若干问题解决!
- CCF202112-2 序列查询新解
- Python安装pycryptodome密码库
- 推荐几个代码自动生成器,神器!!!
- ad建集成库_AD16创建集成库的步骤
- 爱好-超级IP:超级IP
- Multi-task中的多任务loss平衡问题
- LTE-PCC SCC
- 厦门大学计算机保研学校,厦门大学计算机科学系(专业学位)计算机技术保研
- R^2(可决系数)为负分析
- abaqus python实例_abaqus Python实例-操作excel文件
- OSPF特殊区域TOTAL STUB配置实验
- 苹果系统 虚拟机_大连win10远程双系统重装电脑维修7苹果笔记本安装做虚拟机服务mac8...
- STM32单片机-低功耗设置
- [转]iPhone 港版和美版,有锁版和无锁版的区别?
热门文章
- 测试一个显示器有拖影的软件,让“瑕疵”原形毕露,教你检测游戏显示器!
- 郑州大学计算机系王院长,我院成功承办河南省第十二届ACM大学生程序设计竞赛...
- flex-gow 的用法
- 《秒懂EXCEL》重点复习笔记01
- (Python)LeetCode1386:安排电影院座位
- 大数据笔记--SparkSQL(第一篇)
- 【Python019--函数与过程】
- mac安装quicklook命令
- 文章向大家介绍安卓逆向,解决app抓包抓不到的问题,主要包括安卓逆向,解决app抓包抓不到的问题使用实例、应用技巧
- PaddleDetection目标检测之水果检测(下)(yolov3_mobilenet_v1)