-柚子皮-

nlp中的dataloader的使用

torch.utils.data.DataLoader中的参数:

  • dataset (Dataset) – dataset from which to load the data.

  • batch_size (int, optional) – how many samples per batch to load (default: 1).
  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified, shuffle must be False.
  • batch_sampler (Sampler, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  • collate_fn (callable*, *optional) – merges a list of samples to form a mini-batch.
  • pin_memory (bool, optional) – If True, the data loader will copy tensors into CUDA pinned memory before returning them.
  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

返回值

返回值是一个实现了__iter__的对象,可以使用for循环进行迭代,或者转换成迭代器取第一条batch数据查看。

for循环进行迭代时返回的每条数据就是(batch_size,*)大小的。

常用操作

使用示例

self.data_loader = torch.utils.data.DataLoader(
            dataset=self.dataset, collate_fn=self.collate_fn,
            batch_size=batch_size, shuffle=if_shuffle, num_workers=args.num_workers)

batch数目

batch_num = len(train_dataloader)

获取dataset中的第一条数据

train_data_loader.dataset.__getitem__(0)

获取dataloader中batch中的第一条数据

def get_one_data(item_dict, i):
    return {k: v[i] for k, v in item_dict.items()}

print(get_one_data(next(iter(train_data_loader)), 1))

或者

for index, item_dict in enumerate(train_data_loader):
        print(get_one_data(item_dict, 1))
        exit()

自定义dataloader

Dataloader的处理逻辑是先通过Dataset类里面的 __getitem__ 函数获取单个数据,然后组合成batch,再使用collate_fn所指定的函数对这个batch做一些操作(比如每个batch中实际lengths,padding,cuda之类的)。

自定义collate_fn

因为dataloader是有batch_size参数的,我们可以通过自定义collate_fn=myfunction来设计数据收集的方式,意思是已经通过上面的Dataset类中的__getitem__函数采样了batch_size数据,以一个包的形式传递给collate_fn所指定的函数。

示例1:通过collate_fn进行解包

def myfunction(data):A,B,path,hop=zip(*data)print('A:',A," B:",B," path:",path," hop:",hop)raise Exception('utils collate_fun 147')return A,B,path,hop
for index,item in enumerate(dataloaders['train'])A,B,path.hop=item

Note: 需要在外面对dataloaders进行for调用,后再断点或者exit(),否则不会真正执行collate_fn,这样就不会print了。
示例2:nlp任务中,经常在collate_fn指定的函数里面做padding,将同一个batch中不一样长的句子padding成一样长。

def myfunction(data):src, tgt, original_src, original_tgt = zip(*data)src_len = [len(s) for s in src]src_pad = torch.zeros(len(src), max(src_len)).long()for i, s in enumerate(src):end = src_len[i]src_pad[i, :end] = torch.LongTensor(s[end-1::-1])tgt_len = [len(s) for s in tgt]tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long()for i, s in enumerate(tgt):end = tgt_len[i]tgt_pad[i, :end] = torch.LongTensor(s)[:end]return src_pad, tgt_pad, \torch.LongTensor(src_len), torch.LongTensor(tgt_len), \original_src, original_tgt

一些问题

[为什么pytorch DataLoader在numpy数组和列表上的行为有所不同?]

1 import问题

使用torch.utils.data.DataLoader时,pycharm中无法直接点击进入代码。

2 num_workers设置过大问题

num_workers如果设置过大,资源不够,会出错:Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)
[运行tensorflow程序bug:Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)]

[python模块导入及属性:import]

[https://github.com/pytorch/pytorch/issues/41794]

from: -柚子皮-

ref: [https://www.jianshu.com/p/8ea7fba72673]

PyTorch:数据读取2 - Dataloader相关推荐

  1. Pytorch数据读取(Dataset, DataLoader, DataLoaderIter)

    Pytorch的数据读取主要包含三个类: Dataset DataLoader DataLoaderIter 这三者是一个依次封装的关系: 1.被装进2., 2.被装进3. Dataset类 Pyto ...

  2. 系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

    Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html? Pytorch中文文档:https://pytorch-cn.readthedocs ...

  3. PyTorch框架学习八——PyTorch数据读取机制(简述)

    PyTorch框架学习八--PyTorch数据读取机制(简述) 一.数据 二.DataLoader与Dataset 1.torch.utils.data.DataLoader 2.torch.util ...

  4. 使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络

    使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络 本文例程部分主要参考官方文档. JAX简介 JAX 的前身是 Autograd ,也就是说 JAX 是 Autograd 升级版本 ...

  5. PyTorch系列 (二): pytorch数据读取自制数据集并

    PyTorch系列 (二): pytorch数据读取 PyTorch 1: How to use data in pytorch Posted by WangW on February 1, 2019 ...

  6. 数据读取机制Dataloader和Dataset和Transforms

    人民币二分类模型 数据-模型-损失函数-优化器-迭代训练 数据收集 img label 数据划分 train valid test 数据读取 Dataloader [sampler-生成索引 data ...

  7. pytorch数据读取之Dataset与DataLoader

    1. 先前处理数据集的代码经常比较混乱并且难以维护 2. 数据集处理代码应该和训练代码解耦合,从而达到模块化和更好的可读性 因此,pytorch提出了两个数据处理类:DataLoader与Datase ...

  8. Pytorch数据读取加速方法

    1. 方法一:使用prefetcher class data_prefetcher():def __init__(self, loader):self.loader = iter(loader)sel ...

  9. 深度之眼Pytorch打卡(九):Pytorch数据预处理——预处理过程与数据标准化(transforms过程、Normalize原理、常用数据集均值标准差与数据集均值标准差计算)

    前言   前段时间因为一些事情没有时间或者心情学习,现在两个多月过去了,事情结束了,心态也调整好了,所以又来接着学习Pytorch.这篇笔记主要是关于数据预处理过程.数据集标准化与数据集均值标准差计算 ...

  10. 图像数据读取及数据扩增方法

    Datawhale干货 作者:王程伟,Datawhale成员 本文为干货知识+竞赛实践系列分享,旨在理论与实践结合,从学习到项目实践.(零基础入门系列:数据挖掘/cv/nlp/金融风控/推荐系统等,持 ...

最新文章

  1. python爬虫Scrapy框架之增量式爬虫
  2. mysql8 安装_mysql 8.x 安装向导
  3. 【JAVA中级篇】线程池
  4. 【ClickHouse】查看数据库容量和表大小的方法(system.parts各种操作方法)
  5. rn代码与android,RN与原生通讯(安卓篇)
  6. 接收流信息---字符串
  7. OpenLayers 3加载矢量地图源
  8. 成功解决python.exe 无法找到入口 无法定位程序输入点
  9. 东营网站服务器部署,联通东营服务器dns地址
  10. 微生物组-扩增子16S分析和可视化(2022.10)
  11. 用java做考试管理系统,考试管理系统的开发实现(Java+Web)
  12. python 微信群发_用python写一个微信群发工具(基于itchat库)
  13. P2000 拯救世界(生成函数裸题+NTT高精)
  14. linux vim命令翻页,详解Vim编辑器翻页控制命令
  15. 最新智云全能API接口查询PHP源码V1.1
  16. [译] APT分析报告:01.Linux系统下针对性的APT攻击概述
  17. db2还原备份文件详细教程
  18. 详解幂律分布,以及用于重尾分布的Python库powerlaw的使用
  19. Vue3+Quasar实现ins风格图片墙
  20. win10、win7 脚本导证书到系统中

热门文章

  1. server 2012 IIS 启用.NET 4.5
  2. HDU3595_GG and MM
  3. [转]div中放flash运行30秒钟后自动隐藏效果
  4. Android:Toolbar的图标尺寸问题
  5. Git使用- 基本命令
  6. net 进阶学习 WebApi (2)
  7. [模板]01分数规划
  8. (转)AIX的Dump文件学习笔记
  9. [Linux] 常用Linux命令
  10. Windows CE 5.0待机界面定制之一 - Taskbar的位置