torch的DataLoader主要是用来装载数据,就是给定已知的数据集,把数据集装载进DataLoaer,然后送入深度学习网络进行训练。先看一下它的声明吧。(官方声明,pytorch 1.10.0文档,见参考资料1)

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

其中的参数如下:

  • dataset (Dataset) – dataset from which to load the data.(数据类型:DataSet,需要装载进DataLoader的原始数据集)

  • batch_size (intoptional) – how many samples per batch to load (default: 1).(数据类型:int,可选项,批的大小,默认为1)

  • shuffle (booloptional) – set to True to have the data reshuffled at every epoch (default: False).(数据类型:bool,可选项,每个循环是否需要重新打乱或洗牌)

  • sampler (Sampler or Iterableoptional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.(可选项,定义了从数据集获取样本的策略,可以是任何实现了__len__的迭代器类型(Iterable),如果使用了这个选项,shuffle不可再设置

  • batch_sampler (Sampler or Iterableoptional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_sizeshufflesampler, and drop_last.(数据类型Sampler 或者Iterable,可选项,就像sampler一样,但是它一次返回一批的索引,使用后,不可使用batch_sizeshufflesampler, drop_last选项

  • num_workers (intoptional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)(数据类型:int,可选项,确定使用多少个子进程来进行数据加载,0代表使用主进程加载。默认为0)

  • collate_fn (callableoptional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.(数据类型:可调用的类型,可选项,合并一个链表的样本来形成最小批的张量,当从映射类型数据集装载的时候使用)

  • pin_memory (booloptional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.(数据类型:bool,可选项,如果设置为True,dataloader 会在返回张量前将其拷贝至CUDA的pinned的内存区。如果你的数据类型是一个个性化的类型或者你的collate_fn返回了个性化的一批样本,请参看下面例子

  • drop_last (booloptional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)(数据类型:bool,可选项,若设置为True,如果数据集不能被批除尽,则舍弃最后一个不满足整批的样本,若设置为False,如果数据集不能被批除尽,则保留最后一个不满足整批的样本。)

  • timeout (numericoptional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)(数据类型:数字,可选项,如果是正值,此选项代表从一个进程收集批数据的超时时间,应该是非负值)

  • worker_init_fn (callableoptional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)(数据类型:可调用的类型,可选项,在进行随机种子之后,且装载数据之前,这个可调用类型在调用之后,它的输出作为每一个子进程的输入,子进程的ID为[0, num_workers - 1]

  • generator (torch.Generatoroptional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)(数据类型:torch.Generator,可选项,如果此选项不是None,RandomSampler 将使用RNG去生成随机的索引,在多进程中也会生成每个进程基准种子)

  • prefetch_factor (intoptionalkeyword-only arg) – Number of samples loaded in advance by each worker. 2 means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2)(数据类型:int,可选项,只是用关键字的参数,对每个进程来说,需要提前装载的样本数量。当此值为2时,对所有的进程来数,需要提前获取的样本数量为2*num_workers,默认值为2)

  • persistent_workers (booloptional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)(数据类型bool,可选项,如果设置为True,对每个进程来数,当一个数据集被使用一次后,此进程并不会被关闭,这样就会保持进程中的数据集实例是活的)

具体来看dataset的类型,

Dataset Types

The most important argument of DataLoader constructor is dataset, which indicates a dataset object to load data from. PyTorch supports two different types of datasets:

  • map-style datasets,

  • iterable-style datasets.

Map-style datasets

A map-style dataset is one that implements the __getitem__() and __len__() protocols, and represents a map from (possibly non-integral) indices/keys to data samples.

For example, such a dataset, when accessed with dataset[idx], could read the idx-th image and its corresponding label from a folder on the disk.

See Dataset for more details.

Iterable-style datasets

An iterable-style dataset is an instance of a subclass of IterableDataset that implements the __iter__() protocol, and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.

For example, such a dataset, when called iter(dataset), could return a stream of data reading from a database, a remote server, or even logs generated in real time.

See IterableDataset for more details.

NOTE

When using an IterableDataset with multi-process data loading. The same dataset object is replicated on each worker process, and thus the replicas must be configured differently to avoid duplicated data. See IterableDataset documentations for how to achieve this.

DataLoader中最重要的参数就是dataset,它决定了要装载的数据集。torch支持两种类型的数据集

(1)map-style 类型。

一个map-style类型是实现了__getitem__() 和__len__()协议的类,它代表了一个从索引/键值 到数据样本的映射。

例如,对于一个通过dataset[idx]访问的数据集,可以读到第idx个图片,并从磁盘的文件中取到对应的标签。

看第一个例子吧(见参考资料4)

import torch
from torch.utils.data import DataLoader
import numpy as np

class MyLoader(torch.utils.data.Dataset):

#父类是torch.utils.data.Dataset,也可以是object,对父类没有要求
    def __init__(self,data,label):
        self.data=data
        self.label=label
    def __getitem__(self,index):#迭代数据
        data=self.data[index]
        labels=self.label[index]
        return data,labels
    def __len__(self):#返回数据的总长度
        return len(self.data)
source_data=np.random.rand(10,20)
source_label=np.random.randint(0,2,(10,1))

torch_data=MyLoader(source_data,source_label)

for i,data in enumerate(torch_data):
    print('第{}个 Batch {}'.format(i,data))

针对图像分割,数据集是飞浆平台课程上的(参考资料3),数据集原图片,而标签是和原数据集大小一样的图片,如下图所示:

原图片:

而标签为

原图片和标签都是经过转换之后的。而list文件如下,共14个样本:

humanseg/aa645bc9cf23db7912a69309072cd9ab325f02cd.jpg visual/aa645bc9cf23db7912a69309072cd9ab325f02cd.png
humanseg/aa63d7e6db0d03137883772c246c6761fc201059.jpg visual/aa63d7e6db0d03137883772c246c6761fc201059.png
humanseg/aa6300f76981dcf8701534dd1d3b2ec19b3dee02.jpg visual/aa6300f76981dcf8701534dd1d3b2ec19b3dee02.png
humanseg/56173ddd1ccb419e1efdeb5f5cb242ab160142cb.jpg visual/56173ddd1ccb419e1efdeb5f5cb242ab160142cb.png
humanseg/aa6bd3eaf471bea1cca7467a95fe93e69b006797.jpg visual/aa6bd3eaf471bea1cca7467a95fe93e69b006797.png
humanseg/aa67b2d074e00942191c4bd2472e7f77538ec113.jpg visual/aa67b2d074e00942191c4bd2472e7f77538ec113.png
humanseg/aa6ff076c7360b8dabc30edd05ebafb65bba9343.jpg visual/aa6ff076c7360b8dabc30edd05ebafb65bba9343.png
humanseg/aa611a0cf92ace38bd2d3b0fe0bc50b5235eea7e.jpg visual/aa611a0cf92ace38bd2d3b0fe0bc50b5235eea7e.png
humanseg/aa65f5b4f85c37ce44dc48473150a16e652b6bc5.jpg visual/aa65f5b4f85c37ce44dc48473150a16e652b6bc5.png
humanseg/aa65c231dbce73de1527101bf35b975b2c2e9d5a.jpg visual/aa65c231dbce73de1527101bf35b975b2c2e9d5a.png
humanseg/aa6b34b24414bafa7fab8393239c793587513ce6.jpg visual/aa6b34b24414bafa7fab8393239c793587513ce6.png
humanseg/aa662fb7540312c51f6e6870c0542c8035495b14.jpg visual/aa662fb7540312c51f6e6870c0542c8035495b14.png
humanseg/aa6f23e6ac596962ee773e4eea0560fb0e4522ac.jpg visual/aa6f23e6ac596962ee773e4eea0560fb0e4522ac.png
humanseg/aa65dc40ae9713e4fe3e63b55a8fd10bd1320822.jpg visual/aa65dc40ae9713e4fe3e63b55a8fd10bd1320822.png

每一行中的第一个是原始文件,第二个是标签文件 ,开发平台linux平台,python 版本3.7.4,anaconda3,torch版本1.10.0+cpu

class Transform(object): #图片转换def __init__(self,size=256):self.size=sizedef __call__(self,input,label):input=cv2.resize(input,(self.size,self.size),interpolation=cv2.INTER_LINEAR)label=cv2.resize(label,(self.size,self.size),interpolation=cv2.INTER_NEAREST)return input,label
#map-style datasets
class MapDataLoader(torch.utils.data.Dataset):def __init__(self,image_folder,image_list_file,transform=True,shuffle=True):self.image_folder=image_folderself.image_list_file=image_list_fileself.transform=transformself.shuffle=shuffleself.data_list=self.read_list() #读取列表self.data_total=self.get_total() #获取所有的数据集,包括原数据和标签,放入列表def __getitem__(self,index):data=self.data_total[index][0]labels=self.data_total[index][1]return data,labelsdef read_list(self):data_list=[]with open(os.path.join(self.image_folder,self.image_list_file)) as infile:for line in infile:data_path=os.path.join(self.image_folder,line.split()[0])label_path=os.path.join(self.image_folder,line.split()[1])data_list.append((data_path,label_path))random.shuffle(data_list)return data_listdef get_total(self):total_list=[]for data_path,label_path in self.data_list:data=cv2.imread(data_path,cv2.IMREAD_COLOR)label=cv2.imread(label_path,cv2.IMREAD_GRAYSCALE)assert data.all!=None,"NoneType"print(data.shape,label.shape)data,label= self.preprocess(data,label)print('after:',data.shape,label.shape)total_list.append((data,label))random.shuffle(total_list)return total_listdef preprocess(self,data,label):h,w,c=data.shapeh_gt,w_gt=label.shapeassert h==h_gt,"Error"assert w==w_gt,"Error"if self.transform:data,label=self.transform(data,label)label=label[:,:,np.newaxis] #扩展一维return data,labeldef __len__(self):return len(self.data_total)
transform=Transform(256)
map_dataloader=MapDataLoader(image_folder='../data',image_list_file='list_linux.txt',transform=transform,shuffle=True)

输出结果如下:

(1000, 706, 3) (1000, 706)
after: (256, 256, 3) (256, 256, 1)
(664, 1000, 3) (664, 1000)
after: (256, 256, 3) (256, 256, 1)
(768, 484, 3) (768, 484)
after: (256, 256, 3) (256, 256, 1)
(1000, 666, 3) (1000, 666)
after: (256, 256, 3) (256, 256, 1)
(960, 717, 3) (960, 717)
after: (256, 256, 3) (256, 256, 1)
(940, 626, 3) (940, 626)
after: (256, 256, 3) (256, 256, 1)
(600, 900, 3) (600, 900)
after: (256, 256, 3) (256, 256, 1)
(565, 800, 3) (565, 800)
after: (256, 256, 3) (256, 256, 1)
(633, 940, 3) (633, 940)
after: (256, 256, 3) (256, 256, 1)
(825, 550, 3) (825, 550)
after: (256, 256, 3) (256, 256, 1)
(634, 950, 3) (634, 950)
after: (256, 256, 3) (256, 256, 1)
(939, 626, 3) (939, 626)
after: (256, 256, 3) (256, 256, 1)
(1000, 737, 3) (1000, 737)
after: (256, 256, 3) (256, 256, 1)
(676, 1000, 3) (676, 1000)
after: (256, 256, 3) (256, 256, 1)
datas=DataLoader(map_dataloader,batch_size=4,shuffle=True,drop_last=False,num_workers=4)#windows下num_workers 需要设置为0
num_epoch=2
for epoch in range(1,num_epoch+1):print(f'Epoch [{epoch}/{num_epoch}]')for index,(data,label) in enumerate(datas):print(f'Iter {index},data shape:{data.shape} Label shape:{label.shape}')

输出结果:

Epoch [1/2]
Iter 0,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 1,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 2,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 3,data shape:torch.Size([2, 256, 256, 3]) Label shape:torch.Size([2, 256, 256, 1])
Epoch [2/2]
Iter 0,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 1,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 2,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 3,data shape:torch.Size([2, 256, 256, 3]) Label shape:torch.Size([2, 256, 256, 1])

(2)Iterable-style 类型。

这种类型的父类是IterableDataset,并且实现了 __iter__() 协议,代表了在数据样本的迭代器。这种类型非常适合随机读取非常难或者不可能的情况,这种情况下批的大小取决于得到的数据。

例如有这样一个数据集,可以调用iter(dataset),可以从数据库、远程服务器或者实时产生的日志获取样本流。

注意。

当使用多进程加载Iterable-style 类型的DataLoader时,一份样本会被复制至所有的进程中,这份复制的样本将被差异配置以避免重复数据。请参阅IIterableDataset文档以如何实现它。

还是以图像分割为例。

#iter_style
class IterDataLoader(torch.utils.data.IterableDataset):#父类是torch.utils.data.IterableDatasetdef __init__(self,image_folder,image_list_file,transform=True,shuffle=True):self.image_folder=image_folderself.image_list_file=image_list_fileself.transform=transformself.shuffle=shuffleself.data_list=self.read_list()self.data_total=self.get_total()self.start=0self.end=len(self.data_total)def read_list(self):data_list=[]with open(os.path.join(self.image_folder,self.image_list_file)) as infile:for line in infile:data_path=os.path.join(self.image_folder,line.split()[0])label_path=os.path.join(self.image_folder,line.split()[1])data_list.append((data_path,label_path))random.shuffle(data_list)return data_listdef get_total(self):total_list=[]for data_path,label_path in self.data_list:data=cv2.imread(data_path,cv2.IMREAD_COLOR)
#             data=cv2.cvtColor(data,cv2.COLOR_BAYER_BG2RGB)label=cv2.imread(label_path,cv2.IMREAD_GRAYSCALE)assert data.all!=None,"NoneType"print(data.shape,label.shape)data,label= self.preprocess(data,label)print('after:',data.shape,label.shape)total_list.append((data,label))random.shuffle(total_list)return total_listdef preprocess(self,data,label):h,w,c=data.shapeh_gt,w_gt=label.shapeassert h==h_gt,"Error"assert w==w_gt,"Error"if self.transform:data,label=self.transform(data,label)label=label[:,:,np.newaxis]return data,labeldef __len__(self):return len(self.data_list)def __iter__(self):worker_info = torch.utils.data.get_worker_info()if worker_info is None:  # single-process data loading, return the full iteratorreturn iter(self.data_total) #单进程情况下返回所有的else:  # 多进程情况下per_worker = int(math.ceil(len(self.data_total)/ float(worker_info.num_workers))) #计算出每个进程需要装载样本的数量
#             print('per_worker:',per_worker)worker_id = worker_info.id
#             print('worker_id:{}\n'.format(worker_id))iter_start = self.start + worker_id * per_workeriter_end = min(iter_start + per_worker, self.end)
#             print('start{}:end{}\n'.format(iter_start,iter_end))return iter(self.data_total[iter_start:iter_end])#torch 读取图片
transform=Transform(256)
iter_dataloader=IterDataLoader(image_folder='../data',image_list_file='list_linux.txt',transform=transform,shuffle=True)
datas=DataLoader(iter_dataloader,batch_size=4,drop_last=False,num_workers=2)#window下num_workers需要设置为0,且不可以使用shuffle==True
num_epoch=2
for epoch in range(1,num_epoch+1):print(f'Epoch [{epoch}/{num_epoch}]')for index,(data,label) in enumerate(datas):print(f'Iter {index},data shape:{data.shape} Label shape:{label.shape}')

输出结果如下:

Epoch [1/2]
Iter 0,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 1,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 2,data shape:torch.Size([3, 256, 256, 3]) Label shape:torch.Size([3, 256, 256, 1])
Iter 3,data shape:torch.Size([3, 256, 256, 3]) Label shape:torch.Size([3, 256, 256, 1])
Epoch [2/2]
Iter 0,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 1,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 2,data shape:torch.Size([3, 256, 256, 3]) Label shape:torch.Size([3, 256, 256, 1])
Iter 3,data shape:torch.Size([3, 256, 256, 3]) Label shape:torch.Size([3, 256, 256, 1])

从输出结果来看,静态的数据还是用map-style类型比较合适。

参考资料:

1 torch.utils.data — PyTorch 1.10.1 documentation

2 pytorch/dataloader.py at master · pytorch/pytorch · GitHub

3飞桨PaddlePaddle-源于产业实践的开源深度学习平台

4Pytorch加载自己的数据集(使用DataLoader加载Dataset)_北国觅梦-CSDN博客

torch的DataLoader 浅析相关推荐

  1. 【Torch】Dataloader torch.utils.data.DataLoader全面详实概念理解

    目录 1.torch.utils.data.DataLoader概念介绍 2.torch.utils.data.DataLoader参数介绍 3 案例体会 DataLoader:[batch_size ...

  2. 从零学PyTorch:DataLoader构建高效的自定义数据集

    Torch中可以创建一个DataSet对象,并与dataloader一起使用,在训练模型时不断为模型提供数据Torch中DataLoader的参数如下 DataLoader(dataset, batc ...

  3. dataloader 源码_[莫烦 PyTorch 系列教程] 3.5 – 数据读取 (Data Loader)

    请教一个问题: 在遍历"for step, (batch_x, batch_y) in enumerate(loader): "这一步会报这样的错误: raceback (most ...

  4. Pytorch中Dataloader保存文件名

    转载自:https://gist.github.com/andrewjong/6b02ff237533b3b2c554701fb53d5c4d,本文只做个人记录学习使用,版权归原作者所有. impor ...

  5. DL框架之PyTorch:深度学习框架PyTorch的简介、安装、使用方法之详细攻略

    DL框架之PyTorch:PyTorch的简介.安装.使用方法之详细攻略 DL框架之PyTorch:深度学习框架PyTorch的简介.安装.使用方法之详细攻略 目录 PyTorch的简介 1.pyto ...

  6. MNIST手写数字识别

    进入到研究生阶段了,从头学一下Pytorch,在这个小破站上记录一下自己的学习过程. 本文使用的是Pytorch来做手写数字的识别. step0:先引入一些相关的包和库 import torch fr ...

  7. Training a classifier

    你已经学习了如何定义神经网络,计算损失和执行网络权重的更新. 现在你或许在思考. What about data? 通常当你需要处理图像,文本,音频,视频数据,你能够使用标准的python包将数据加载 ...

  8. pytorch默认初始化_PyTorch的初始化

    背景 在使用PyTorch深度学习框架的时候,不管是训练还是测试,代码中引入PyTorch的第一句总是: import torch 在Gemfield前述专栏文章里,我们已经得知,torch/csrc ...

  9. 3Dcnn 降假阳性模型调试(三)

    ps看清了上面问题的本质,我尝试来调整批次大小以及多进程数(也就是torch的DataLoader函数的参数)来看看怎么样可以加快运行速度.(同时会看下cup负载和gpu占用率).先根据自己的回忆大致 ...

最新文章

  1. 局域网共享的解决方法
  2. Javascript 页面模板化 ——大部分人没有使用过的方法
  3. Distinct Subsequences
  4. Java 基础复习实践 --- Hashcode Equals
  5. Linux 磁盘管理命令
  6. app信息:PackageInfo+ApplicationInfo+ActivityInfo/ServiceInfo/ProviderInfo+PermissionInfo+...
  7. LiveVideoStackCon深圳-编解码的三足鼎立
  8. LeetCode(集合)队列和栈的相互实现 golang
  9. mysql远程访问 linux_Linux中开启mysql远程访问功能
  10. 帆软按钮控件变查询_帆软报表(多sheet)自定义分页查询
  11. OpenHarmony开源项目
  12. 继续!面试继续!Netty dubbo的通信方式
  13. Lua笔记4 闭包、迭代器
  14. 那个抗血栓机器人_美国DJO抗血栓压力袜
  15. 定义struct出错指针不允许指向不完整类型_Go入门系列(三)复合数据类型
  16. ListenalbeFuture的使用总结
  17. java json 变量所有的属性
  18. linux文泉驿字体调用,Ubuntu 上安装文泉驿字体的脚本
  19. Tomcat服务器端口修改,tomcat服务器配置端口 tomcat端口号修改操作步骤
  20. Java——(九)IO流

热门文章

  1. 删除顽固文件夹cygwin的方法,挺折腾的
  2. $con=mysql_connect_php连接mysql mysql_connect()与mysqli_connect()的盲点
  3. gitlab 只能owner 上传_代码管理-gitlab使用方法建议
  4. logstash同步数据到es
  5. linux开机桌面出现网格,[转自linux联盟]openfoam 网格类编程
  6. java打印等腰三角形_Java 后台开发面试题分享三
  7. 8s pod 查看 的yaml_k8s之深入解剖Pod(三)
  8. 用java二分法计算a的n次幂_用二分法计算a的n次幂算法分析
  9. java 调用foxmail_Javamail简单使用案例
  10. java多线程编程_Java多线程编程