本文以torch.utils.data中的Dataset类为例进行说明

Dataset的作用是构建自定义的数据集,以方便使用Dataloader进行加载

语法

我们自定义的数据集需要继承自torch.util.data.Dataset抽象类,并重写相应的两个方法:

  • len:返回数据集的大小。一般情况而言直接用 len(xxx) 进行实现即可
  • getitem:使得 dataset[i] 能够返回数据集中的第i个样本,相应的需要传入一个索引i

原抽象类中相应的定义如下:

def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedError

数据

假设我们在解决一个分类问题。那么,在训练集文件夹train中,我们可以这么给图片加上标签:

到时候就可以通过文件名的方式来判断某张图片对应的分类。

例子

我们构造一个FruitDataset来处理这些数据。首先实现init方法:

def __init__(self, root_dir, transform=None): self.root_dir = root_dir   self.transform = transform self.images = os.listdir(self.root_dir)

init方法一般会有两个基础的参数,一个是dir,用来表示数据集所在的目录;另一个则是transform,可以传入一些方法,以对图片进行处理(一般是进行数据增强)。
此外,在init方法中还会进行基础的数据读取,例如这里使用listdir来列出目录下的所有文件;而如果是表格形式的数据(如kaggle),那么则可以使用切片方法将标签与数据分离,方便后续的处理。

接下来是len方法,即返回数据集的长度。既然我们刚才已经读出了数据集目录下的所有文件,那么只要返回这个文件夹列表的长度即可:

def __len__(self):return len(self.images)

最后则是getitem方法。getitem方法返回的是一个字典,表示相应数据所蕴含的其他信息,有了其他信息一个数据才能变成一个样本。在这里,“其他信息”就是图片所对应的标签,即要返回一个{‘image’:img, ‘label’:label}。习惯上,我们会把这个字典记做sample。
img可以通过imread方法读取图片实际内容得到,而label可以通过处理文件名获得:

def __getitem__(self,index):# 通过路径与索引读图片image_index = self.images[index]img_path = os.path.join(self.root_dir, image_index)img = io.imread(img_path)# 通过文件名读标签label = img_path.split('\\')[-1].split('.')[0]# 组装成字典sample = {'image':img,'label':label}if self.transform:sample = self.transform(sample)return sample

注意这里的if self.transform也算是一种习惯上的用法,即如果传入了变换方法则进行变换后再返回。

Dataloader

我们通过dataloader来分析刚才构建的数据集。一般来说,训练集与测试集各会对应一个dataloader,这里为了演示方便起见就只拿我们刚才的训练集进行说明。
首先,实例化一个Dataset对象。在这里我们没有变换方法,则只需要传入数据所在的目录即可:

data = FruitDataset(r"data\train", transform=None)

dataset对象可以通过下标来访问其中的各个样本,比如:

print(data[0])

然后利用dataloader进行加载:

dataloader = DataLoader(data, batch_size=2, shuffle=True)

一般而言Dataloader需要传入三个参数:

  • dataset:传入Dataset对象,表示需要加载的数据集
  • batch_size:“批大小”,表示一次选取的一批中有几个样本。在这里bs为2,即每轮选取2个样本
  • shuffle:是否需要将数据打乱。一般来说只需要打乱训练集即可,测试集并不需要打乱

查看dataloader的长度。总共有10张图,一批有2张,因此有5批,长度为5:

# 5
print(len(dataloader))

最后迭代整个数据集:

for i_batch, batch_data in enumerate(dataloader):print(i_batch)print(batch_data)

i_batch就是batch的编号,0、1、2、3、4;batch_data就是我们在数据集中定义的sample,在这里两个两个一组出现。

完整代码

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset, DataLoader
from skimage import io
import osclass FruitDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir  self.transform = transform self.images = os.listdir(self.root_dir)def __len__(self):return len(self.images)def __getitem__(self,index):image_index = self.images[index]img_path = os.path.join(self.root_dir, image_index)img = io.imread(img_path)label = img_path.split('\\')[-1].split('.')[0]sample = {'image':img,'label':label}if self.transform:sample = self.transform(sample)return sample data = FruitDataset(r"data\train", transform=None)
print(data[0])
dataloader = DataLoader(dataset=data, batch_size=2, shuffle=True)
print(len(dataloader))
for i_batch, batch_data in enumerate(dataloader):print(i_batch)print(batch_data)

参考

https://blog.csdn.net/xuan_liu123/article/details/101145366

Pytorch Dataset、Dataloader的简单理解与使用相关推荐

  1. pytorch Dataset, DataLoader产生自定义的训练数据

    pytorch Dataset, DataLoader产生自定义的训练数据 目录 pytorch Dataset, DataLoader产生自定义的训练数据 1. torch.utils.data.D ...

  2. Pytorch nn.Fold()的简单理解与用法

    官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Fold.html 这个东西基本上就是绑定Unfold使用的.实际上,在没有overla ...

  3. Pytorch Tensor.unfold()的简单理解与用法

    unfold的作用就是手动实现的滑动窗口操作,也就是只有卷,没有积:不过相比于nn.functional中的unfold而言,其窗口的意味更浓,只能是一维的,也就是不存在类似2×2窗口的说法. ret ...

  4. Pytorch nn.BCEWithLogitsLoss()的简单理解与用法

    这个东西,本质上和nn.BCELoss()没有区别,只是在BCELoss上加了个logits函数(也就是sigmoid函数),例子如下: import torch import torch.nn as ...

  5. Pytorch之DataLoader Dataset、datasets、models、transforms的认识和学习

    文章目录 利用PyTorch框架来开发深度学习算法时几个基础的模块 Dataset & DataLoader 基础概念 自定义数据集 1 读取自定义数据集 1 自定义数据集 2 自定义数据集3 ...

  6. dataset__getitem___一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...

  7. 查看dataloader的大小_一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...

  8. 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    自上而下理解三者关系 首先我们看一下DataLoader.__next__的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据). c ...

  9. 从零开始构建基于textcnn的文本分类模型(上),word2vec向量训练,预训练词向量模型加载,pytorch Dataset、collete_fn、Dataloader转换数据集并行加载

    伴随着bert.transformer模型的提出,文本预训练模型应用于各项NLP任务.文本分类任务是最基础的NLP任务,本文回顾最先采用CNN用于文本分类之一的textcnn模型,意在巩固分词.词向量 ...

最新文章

  1. [转载]C# 二进制与十进制,十进制与十六进制相互转换
  2. 中文分词_中文分词及其应用
  3. python程序入门设计_程序设计入门—Python
  4. boost::math::find_location用法的测试程序
  5. 统计文件中有多少个单词amp;c语言实现
  6. 【CodeForces - 144B 】Meeting (暴力枚举,水题,计算几何)
  7. 三天两夜肝完这篇万字长文,终于拿下了 TCP/IP
  8. 流量一天一个台阶,谈映客直播服务端架构优化之路
  9. CF1109F Sasha and Algorithm of Silence's Sounds LCT、线段树
  10. 音乐机器人活动教案_幼儿园小班音乐教案小熊跳舞律动活动反思【幼儿教案】...
  11. 新型K4宏病毒代码分析报告
  12. mysql 身份证算年龄
  13. Php处理输入法表情,php开发中手机输入法自带的表情、emoji表情、微信表情不显示问题,以及过虑emoji表情方法!...
  14. OpenHarmony开源鸿蒙学习入门-应用开发之使用eTS语法示例项目讲解
  15. jquery gotop插件
  16. 关于母亲节的c语言程序设计教程课后答案,《我的母亲》习题及参考答案
  17. 英语语法浅述-动词、时态和语态
  18. 泰坦尼克号 机器学习_机器学习项目泰坦尼克号问题陈述
  19. Qt任务栏图标增加进度条
  20. mstar Android解锁,年轻人的新宠 当贝小投影C2解锁各种观影姿势

热门文章

  1. visio中公式太小_五金冲压模具中的凹模有哪些注意事项,值得一看
  2. python用字典统计单词出现次数_python - 如何使用字典理解来计算文档中每个单词的出现次数...
  3. 西门子rwd60参数设置调试手册_RWD60 RWD68 RWD62控制器调试指导说明
  4. 饭卡 01背包 DP
  5. 中文信息处理——分词评价程序(计算分词结果的准确率,召回率,F测度)
  6. 可交互绘图——鼠标移到点的上方会显示该点的标签[jupyter notebook]
  7. [转]awesome-tensorflow-chinese
  8. 在markdown (csdn)博客上输出 右下小标,右上小标。
  9. c/c++教程 - 1.9 指针 空指针 野指针 const修饰指针 指针常量 常量指针 指针和数组 指针和函数
  10. 天枰称重 (枚举法|进制转换逢十进一模版)