dataset__getitem___【小白学PyTorch】3.浅谈Dataset和Dataloader
文章目录:
1 Dataset基类
2 构建Dataset子类
2.1 __Init__
2.2 __getitem__
3 dataloader
1 Dataset基类
PyTorch 读取其他的数据,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。在看很多PyTorch的代码的时候,也会经常看到dataset这个东西的存在。Dataset类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。
先看一下源码:
这里有一个__getitem__
函数,__getitem__
函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。之后会举例子来讲解这个逻辑。
其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,这是触发去读取图片这些操作的是DataLoader里的__iter__(self)
(后面再讲)。
2 构建Dataset子类
下面我们构建一下Dataset的子类,叫他MyDataset类:
import torch from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset): def __init__(self): self.data = torch.tensor([[1,2,3],[2,3,4],[3,4,5],[4,5,6]]) self.label = torch.LongTensor([1,1,0,0])
def __getitem__(self,index): return self.data[index],self.label[index]
def __len__(self): return len(self.data)
2.1 Init
- 初始化中,一般是把数据直接保存在这个类的属性中。像是
self.data,self.label
2.2 getitem
- index是一个索引,这个索引的取值范围是要根据
__len__
这个返回值确定的,在上面的例子中,__len__
的返回值是4,所以这个index会在0,1,2,3这个范围内。
3 dataloader
从上文中,我们知道了MyDataset这个类中的__getitem__
的返回值,应该是某一个样本的数据和标签(如果是测试集的dataset,那么就只返回数据),在梯度下降的过程中,一般是需要将多个数据组成batch,这个需要我们自己来组合吗?不需要的,所以PyTorch中存在DataLoader这个迭代器(这个名词用的准不准确有待考究)。
继续上面的代码,我们接着写代码:
mydataloader = DataLoader(dataset=mydataset, batch_size=1)
我们现在创建了一个DataLoader的实例,并且把之前实例化的mydataset作为参数输入进去,并且还输入了batch_size这个参数,现在我们使用的batch_size是1.下面来用for循环来遍历这个dataloader:
for i,(data,label) in enumerate(mydataloader): print(data,label)
输出结果是:
意料之中的结果,总共输出了4个batch,每个batch都是只有1个样本(数据+标签),值得注意的是,这个输出过程是顺序的。
我们稍微修改一下上面的DataLoader的参数:
mydataloader = DataLoader(dataset=mydataset, batch_size=2, shuffle=True)
for i,(data,label) in enumerate(mydataloader): print(data,label)
结果是:
可以看到每一个batch内出现了2个样本。假如我们再运行一遍上面的代码,得到:
两次结果不同,这是因为shuffle=True
,dataset中的index不再是按照顺序从0到3了,而是乱序,可能是[0,1,2,3],也可能是[2,3,1,0]。
【个人感想】
Dataloader和Dataset两个类是非常方便的,因为这个可以快速的做出来batch数据,修改batch_size和乱序都非常地方便。有下面两个希望注意的地方:
- 一般标签值应该是Long整数的,所以标签的tensor可以用
torch.LongTensor(数据)
或者用.long()
来转化成Long整数的形式。 - 如果要使用PyTorch的GPU训练的话,一般是先判断cuda是否可用,然后把数据标签都用
to()
放到GPU显存上进行GPU加速。
device = 'cuda' if torch.cuda.is_available() else 'cpu'for i,(data,label) in enumerate(mydataloader): data = data.to(device) label = label.to(device) print(data,label)
看一下输出:
- END -
往期精彩回顾
适合初学者入门人工智能的路线及资料下载
机器学习及深度学习笔记等资料打印
机器学习在线手册
深度学习笔记专辑
《统计学习方法》的代码复现专辑
AI基础下载
机器学习的数学基础专辑
获取一折本站知识星球优惠券,复制链接直接打开:
https://t.zsxq.com/662nyZF
本站qq群1003271085。
加入微信群请扫码进群(如果是博士或者准备读博士请说明):
dataset__getitem___【小白学PyTorch】3.浅谈Dataset和Dataloader相关推荐
- 【小白学PyTorch】3.浅谈Dataset和Dataloader
文章目录: 1 Dataset基类 2 构建Dataset子类 2.1 __Init__ 2.2 __getitem__ 3 dataloader 1 Dataset基类 PyTorch 读取其他的数 ...
- pytorch默认初始化_小白学PyTorch | 9 tensor数据结构与存储结构
[机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 8 实战之MNIST小试牛刀 小白学PyTorch | 7 最新版本torchvision.transforms常用API翻 ...
- 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization
<<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...
- 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则...
[机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 20 TF2的eager模式与求导 扩展之Tensorflow2.0 | ...
- 【小白学PyTorch】扩展之Tensorflow2.0 | 20 TF2的eager模式与求导
[机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 19 TF2模型的存储与载入 扩展之Tensorflow2.0 | 18 ...
- 【小白学PyTorch】18.TF2构建自定义模型
[机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 17 TFrec文件的创建与读取 扩展之Tensorflow2.0 | 1 ...
- 【小白学PyTorch】17.TFrec文件的创建与读取
[机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 小白学PyTorch | 16 TF2读取图片的方法 小白学PyTorch | 15 TF2实现一个简单的服装分 ...
- 【小白学PyTorch】16.TF2读取图片的方法
<<小白学PyTorch>> 扩展之tensorflow2.0 | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 tensorboardX可视化教程 ...
- 【小白学PyTorch】15.TF2实现一个简单的服装分类任务
<<小白学PyTorch>> 小白学PyTorch | 14 tensorboardX可视化教程 小白学PyTorch | 13 EfficientNet详解及PyTorch实 ...
最新文章
- java二叉树镜像_给定一个二叉树,检查它是否是镜像对称的。
- win10 中redis client提示 ERR Client sent AUTH,but no password is set
- oracle+trunkc,Oracle常用备份与恢复操作
- python批量上传 服务器_Python Tornado批量上传图片并显示功能
- ActionScript for Multiplayer Games and Virtual Worlds 下载。
- Redis基本使用及百亿数据量中的使用技巧分享
- 数据挖掘—BP神经网络(Java实现)
- oracle database 11g 如何正确卸载
- C#实现局域网UDP广播
- nginx启用https访问
- 【方案分享】华为MateBook X Pro上市数字传播方案.pptx(附下载链接)
- Ceres Solver: 高效的非线性优化库(二)实战篇
- android手机无法开机自动启动,手机无法开机怎么刷机?安卓手机救砖教程
- 高效管理之团队梯度建设
- java模拟新浪微博_用java程序模拟登陆新浪微博
- javascript实现中国地图
- 素数----南阳OJ
- 笔记本驱动图标消失怎么办
- 上网日志留存_日志留存系统
- java面试题(一)java面试题集合
热门文章
- 细述 Java垃圾回收机制→How Java Garbage Collection Works?
- 如何快速定位不小心暴露到全局的变量
- [詹兴致矩阵论习题参考解答]习题3.7
- arm qt mysql插件_Ubuntu下编译ARM平台Qt的MySQL插件
- 信息学奥赛一本通 1010:计算分数的浮点数值 | OpenJudge NOI 1.3 05
- 图论 —— 生成树 —— 最小瓶颈生成树
- 最大正方形(洛谷-P1387)
- 二进制分类(信息学奥赛一本通-T1412)
- 小鱼比可爱(洛谷-P1428)
- 20 CO配置-控制-产品成本控制-产品成本计划编制-定义成本核算变式