文章目录:

  • 1 Dataset基类

  • 2 构建Dataset子类

    • 2.1 __Init__

    • 2.2 __getitem__

  • 3 dataloader

1 Dataset基类

PyTorch 读取其他的数据,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。在看很多PyTorch的代码的时候,也会经常看到dataset这个东西的存在。Dataset类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。

先看一下源码:

这里有一个__getitem__函数,__getitem__函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。之后会举例子来讲解这个逻辑

其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,这是触发去读取图片这些操作的是DataLoader里的__iter__(self)(后面再讲)。

2 构建Dataset子类

下面我们构建一下Dataset的子类,叫他MyDataset类:

import torch from torch.utils.data import Dataset,DataLoader

class MyDataset(Dataset):    def __init__(self):        self.data = torch.tensor([[1,2,3],[2,3,4],[3,4,5],[4,5,6]])        self.label = torch.LongTensor([1,1,0,0])

    def __getitem__(self,index):        return self.data[index],self.label[index]

    def __len__(self):        return len(self.data)

2.1 Init

  • 初始化中,一般是把数据直接保存在这个类的属性中。像是self.data,self.label

2.2 getitem

  • index是一个索引,这个索引的取值范围是要根据__len__这个返回值确定的,在上面的例子中,__len__的返回值是4,所以这个index会在0,1,2,3这个范围内。

3 dataloader

从上文中,我们知道了MyDataset这个类中的__getitem__的返回值,应该是某一个样本的数据和标签(如果是测试集的dataset,那么就只返回数据),在梯度下降的过程中,一般是需要将多个数据组成batch,这个需要我们自己来组合吗?不需要的,所以PyTorch中存在DataLoader这个迭代器(这个名词用的准不准确有待考究)。

继续上面的代码,我们接着写代码:

mydataloader = DataLoader(dataset=mydataset,                          batch_size=1)

我们现在创建了一个DataLoader的实例,并且把之前实例化的mydataset作为参数输入进去,并且还输入了batch_size这个参数,现在我们使用的batch_size是1.下面来用for循环来遍历这个dataloader:

for i,(data,label) in enumerate(mydataloader):    print(data,label)

输出结果是:

意料之中的结果,总共输出了4个batch,每个batch都是只有1个样本(数据+标签),值得注意的是,这个输出过程是顺序的

我们稍微修改一下上面的DataLoader的参数:

mydataloader = DataLoader(dataset=mydataset,                          batch_size=2,                          shuffle=True)

for i,(data,label) in enumerate(mydataloader):    print(data,label)

结果是:

可以看到每一个batch内出现了2个样本。假如我们再运行一遍上面的代码,得到:

两次结果不同,这是因为shuffle=True,dataset中的index不再是按照顺序从0到3了,而是乱序,可能是[0,1,2,3],也可能是[2,3,1,0]。

【个人感想】

Dataloader和Dataset两个类是非常方便的,因为这个可以快速的做出来batch数据,修改batch_size和乱序都非常地方便。有下面两个希望注意的地方:

  1. 一般标签值应该是Long整数的,所以标签的tensor可以用torch.LongTensor(数据)或者用.long()来转化成Long整数的形式。
  2. 如果要使用PyTorch的GPU训练的话,一般是先判断cuda是否可用,然后把数据标签都用to()放到GPU显存上进行GPU加速。
device = 'cuda' if torch.cuda.is_available() else 'cpu'for i,(data,label) in enumerate(mydataloader):    data = data.to(device)    label = label.to(device)    print(data,label)

看一下输出:

- END -

往期精彩回顾

适合初学者入门人工智能的路线及资料下载

机器学习及深度学习笔记等资料打印

机器学习在线手册

深度学习笔记专辑

《统计学习方法》的代码复现专辑

AI基础下载

机器学习的数学基础专辑

获取一折本站知识星球优惠券,复制链接直接打开:

https://t.zsxq.com/662nyZF

本站qq群1003271085。

加入微信群请扫码进群(如果是博士或者准备读博士请说明):

dataset__getitem___【小白学PyTorch】3.浅谈Dataset和Dataloader相关推荐

  1. 【小白学PyTorch】3.浅谈Dataset和Dataloader

    文章目录: 1 Dataset基类 2 构建Dataset子类 2.1 __Init__ 2.2 __getitem__ 3 dataloader 1 Dataset基类 PyTorch 读取其他的数 ...

  2. pytorch默认初始化_小白学PyTorch | 9 tensor数据结构与存储结构

    [机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 8 实战之MNIST小试牛刀 小白学PyTorch | 7 最新版本torchvision.transforms常用API翻 ...

  3. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization

    <<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...

  4. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则...

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 20 TF2的eager模式与求导 扩展之Tensorflow2.0 | ...

  5. 【小白学PyTorch】扩展之Tensorflow2.0 | 20 TF2的eager模式与求导

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 19 TF2模型的存储与载入 扩展之Tensorflow2.0 | 18 ...

  6. 【小白学PyTorch】18.TF2构建自定义模型

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 17 TFrec文件的创建与读取 扩展之Tensorflow2.0 | 1 ...

  7. 【小白学PyTorch】17.TFrec文件的创建与读取

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 小白学PyTorch | 16 TF2读取图片的方法 小白学PyTorch | 15 TF2实现一个简单的服装分 ...

  8. 【小白学PyTorch】16.TF2读取图片的方法

    <<小白学PyTorch>> 扩展之tensorflow2.0 | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 tensorboardX可视化教程 ...

  9. 【小白学PyTorch】15.TF2实现一个简单的服装分类任务

    <<小白学PyTorch>> 小白学PyTorch | 14 tensorboardX可视化教程 小白学PyTorch | 13 EfficientNet详解及PyTorch实 ...

最新文章

  1. java二叉树镜像_给定一个二叉树,检查它是否是镜像对称的。
  2. win10 中redis client提示 ERR Client sent AUTH,but no password is set
  3. oracle+trunkc,Oracle常用备份与恢复操作
  4. python批量上传 服务器_Python Tornado批量上传图片并显示功能
  5. ActionScript for Multiplayer Games and Virtual Worlds 下载。
  6. Redis基本使用及百亿数据量中的使用技巧分享
  7. 数据挖掘—BP神经网络(Java实现)
  8. oracle database 11g 如何正确卸载
  9. C#实现局域网UDP广播
  10. nginx启用https访问
  11. 【方案分享】华为MateBook X Pro上市数字传播方案.pptx(附下载链接)
  12. Ceres Solver: 高效的非线性优化库(二)实战篇
  13. android手机无法开机自动启动,手机无法开机怎么刷机?安卓手机救砖教程
  14. 高效管理之团队梯度建设
  15. java模拟新浪微博_用java程序模拟登陆新浪微博
  16. javascript实现中国地图
  17. 素数----南阳OJ
  18. 笔记本驱动图标消失怎么办
  19. 上网日志留存_日志留存系统
  20. java面试题(一)java面试题集合

热门文章

  1. 细述 Java垃圾回收机制→How Java Garbage Collection Works?
  2. 如何快速定位不小心暴露到全局的变量
  3. [詹兴致矩阵论习题参考解答]习题3.7
  4. arm qt mysql插件_Ubuntu下编译ARM平台Qt的MySQL插件
  5. 信息学奥赛一本通 1010:计算分数的浮点数值 | OpenJudge NOI 1.3 05
  6. 图论 —— 生成树 —— 最小瓶颈生成树
  7. 最大正方形(洛谷-P1387)
  8. 二进制分类(信息学奥赛一本通-T1412)
  9. 小鱼比可爱(洛谷-P1428)
  10. 20 CO配置-控制-产品成本控制-产品成本计划编制-定义成本核算变式