Pytorch Dataset、Dataloader的简单理解与使用
本文以torch.utils.data中的Dataset类为例进行说明
Dataset的作用是构建自定义的数据集,以方便使用Dataloader进行加载
语法
我们自定义的数据集需要继承自torch.util.data.Dataset抽象类,并重写相应的两个方法:
- len:返回数据集的大小。一般情况而言直接用 len(xxx) 进行实现即可
- getitem:使得 dataset[i] 能够返回数据集中的第i个样本,相应的需要传入一个索引i
原抽象类中相应的定义如下:
def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedError
数据
假设我们在解决一个分类问题。那么,在训练集文件夹train中,我们可以这么给图片加上标签:
到时候就可以通过文件名的方式来判断某张图片对应的分类。
例子
我们构造一个FruitDataset来处理这些数据。首先实现init方法:
def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.images = os.listdir(self.root_dir)
init方法一般会有两个基础的参数,一个是dir,用来表示数据集所在的目录;另一个则是transform,可以传入一些方法,以对图片进行处理(一般是进行数据增强)。
此外,在init方法中还会进行基础的数据读取,例如这里使用listdir来列出目录下的所有文件;而如果是表格形式的数据(如kaggle),那么则可以使用切片方法将标签与数据分离,方便后续的处理。
接下来是len方法,即返回数据集的长度。既然我们刚才已经读出了数据集目录下的所有文件,那么只要返回这个文件夹列表的长度即可:
def __len__(self):return len(self.images)
最后则是getitem方法。getitem方法返回的是一个字典,表示相应数据所蕴含的其他信息,有了其他信息一个数据才能变成一个样本。在这里,“其他信息”就是图片所对应的标签,即要返回一个{‘image’:img, ‘label’:label}。习惯上,我们会把这个字典记做sample。
img可以通过imread方法读取图片实际内容得到,而label可以通过处理文件名获得:
def __getitem__(self,index):# 通过路径与索引读图片image_index = self.images[index]img_path = os.path.join(self.root_dir, image_index)img = io.imread(img_path)# 通过文件名读标签label = img_path.split('\\')[-1].split('.')[0]# 组装成字典sample = {'image':img,'label':label}if self.transform:sample = self.transform(sample)return sample
注意这里的if self.transform也算是一种习惯上的用法,即如果传入了变换方法则进行变换后再返回。
Dataloader
我们通过dataloader来分析刚才构建的数据集。一般来说,训练集与测试集各会对应一个dataloader,这里为了演示方便起见就只拿我们刚才的训练集进行说明。
首先,实例化一个Dataset对象。在这里我们没有变换方法,则只需要传入数据所在的目录即可:
data = FruitDataset(r"data\train", transform=None)
dataset对象可以通过下标来访问其中的各个样本,比如:
print(data[0])
然后利用dataloader进行加载:
dataloader = DataLoader(data, batch_size=2, shuffle=True)
一般而言Dataloader需要传入三个参数:
- dataset:传入Dataset对象,表示需要加载的数据集
- batch_size:“批大小”,表示一次选取的一批中有几个样本。在这里bs为2,即每轮选取2个样本
- shuffle:是否需要将数据打乱。一般来说只需要打乱训练集即可,测试集并不需要打乱
查看dataloader的长度。总共有10张图,一批有2张,因此有5批,长度为5:
# 5
print(len(dataloader))
最后迭代整个数据集:
for i_batch, batch_data in enumerate(dataloader):print(i_batch)print(batch_data)
i_batch就是batch的编号,0、1、2、3、4;batch_data就是我们在数据集中定义的sample,在这里两个两个一组出现。
完整代码
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset, DataLoader
from skimage import io
import osclass FruitDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.images = os.listdir(self.root_dir)def __len__(self):return len(self.images)def __getitem__(self,index):image_index = self.images[index]img_path = os.path.join(self.root_dir, image_index)img = io.imread(img_path)label = img_path.split('\\')[-1].split('.')[0]sample = {'image':img,'label':label}if self.transform:sample = self.transform(sample)return sample data = FruitDataset(r"data\train", transform=None)
print(data[0])
dataloader = DataLoader(dataset=data, batch_size=2, shuffle=True)
print(len(dataloader))
for i_batch, batch_data in enumerate(dataloader):print(i_batch)print(batch_data)
参考
https://blog.csdn.net/xuan_liu123/article/details/101145366
Pytorch Dataset、Dataloader的简单理解与使用相关推荐
- pytorch Dataset, DataLoader产生自定义的训练数据
pytorch Dataset, DataLoader产生自定义的训练数据 目录 pytorch Dataset, DataLoader产生自定义的训练数据 1. torch.utils.data.D ...
- Pytorch nn.Fold()的简单理解与用法
官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Fold.html 这个东西基本上就是绑定Unfold使用的.实际上,在没有overla ...
- Pytorch Tensor.unfold()的简单理解与用法
unfold的作用就是手动实现的滑动窗口操作,也就是只有卷,没有积:不过相比于nn.functional中的unfold而言,其窗口的意味更浓,只能是一维的,也就是不存在类似2×2窗口的说法. ret ...
- Pytorch nn.BCEWithLogitsLoss()的简单理解与用法
这个东西,本质上和nn.BCELoss()没有区别,只是在BCELoss上加了个logits函数(也就是sigmoid函数),例子如下: import torch import torch.nn as ...
- Pytorch之DataLoader Dataset、datasets、models、transforms的认识和学习
文章目录 利用PyTorch框架来开发深度学习算法时几个基础的模块 Dataset & DataLoader 基础概念 自定义数据集 1 读取自定义数据集 1 自定义数据集 2 自定义数据集3 ...
- dataset__getitem___一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...
- 查看dataloader的大小_一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...
- 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
自上而下理解三者关系 首先我们看一下DataLoader.__next__的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据). c ...
- 从零开始构建基于textcnn的文本分类模型(上),word2vec向量训练,预训练词向量模型加载,pytorch Dataset、collete_fn、Dataloader转换数据集并行加载
伴随着bert.transformer模型的提出,文本预训练模型应用于各项NLP任务.文本分类任务是最基础的NLP任务,本文回顾最先采用CNN用于文本分类之一的textcnn模型,意在巩固分词.词向量 ...
最新文章
- [转载]C# 二进制与十进制,十进制与十六进制相互转换
- 中文分词_中文分词及其应用
- python程序入门设计_程序设计入门—Python
- boost::math::find_location用法的测试程序
- 统计文件中有多少个单词amp;c语言实现
- 【CodeForces - 144B 】Meeting (暴力枚举,水题,计算几何)
- 三天两夜肝完这篇万字长文,终于拿下了 TCP/IP
- 流量一天一个台阶,谈映客直播服务端架构优化之路
- CF1109F Sasha and Algorithm of Silence's Sounds LCT、线段树
- 音乐机器人活动教案_幼儿园小班音乐教案小熊跳舞律动活动反思【幼儿教案】...
- 新型K4宏病毒代码分析报告
- mysql 身份证算年龄
- Php处理输入法表情,php开发中手机输入法自带的表情、emoji表情、微信表情不显示问题,以及过虑emoji表情方法!...
- OpenHarmony开源鸿蒙学习入门-应用开发之使用eTS语法示例项目讲解
- jquery gotop插件
- 关于母亲节的c语言程序设计教程课后答案,《我的母亲》习题及参考答案
- 英语语法浅述-动词、时态和语态
- 泰坦尼克号 机器学习_机器学习项目泰坦尼克号问题陈述
- Qt任务栏图标增加进度条
- mstar Android解锁,年轻人的新宠 当贝小投影C2解锁各种观影姿势
热门文章
- visio中公式太小_五金冲压模具中的凹模有哪些注意事项,值得一看
- python用字典统计单词出现次数_python - 如何使用字典理解来计算文档中每个单词的出现次数...
- 西门子rwd60参数设置调试手册_RWD60 RWD68 RWD62控制器调试指导说明
- 饭卡 01背包 DP
- 中文信息处理——分词评价程序(计算分词结果的准确率,召回率,F测度)
- 可交互绘图——鼠标移到点的上方会显示该点的标签[jupyter notebook]
- [转]awesome-tensorflow-chinese
- 在markdown (csdn)博客上输出 右下小标,右上小标。
- c/c++教程 - 1.9 指针 空指针 野指针 const修饰指针 指针常量 常量指针 指针和数组 指针和函数
- 天枰称重 (枚举法|进制转换逢十进一模版)