(深度学习)构造属于你自己的Pytorch数据集
(深度学习)构造属于你自己的Pytorch数据集
1.综述
2.实现原理
3.代码细节
4.详细代码
综述
Pytorch可以说是一个非常便利的深度学习库,它甚至在torchvision.datasets
中拥有许多一步到位完成数据集下载、解析、读取的类——然鹅,这样也就养成了我们懒惰依赖的心理。当我们需要用到torchvision.datasets
中不曾拥有的数据集时,我们可能就会不知所措。
这篇文章中,我将以CIFAR-10数据集为例(虽然有torchvision.datasets.CIFAR10
了),摆脱对torchvision.datasets
的依赖,构建一个自己的数据集。
在开始之前,首先你要有CIFAR-10数据集,直接去官网上下载可能较慢(再次感谢我国著名建筑师方斌新院士 ),可以在https://pan.baidu.com/s/1bGVGeeiw001qz-PUk7q1Uw(提取码:m35y)中下载python版本的数据集。
数据集解压后目录情况如下:
实现原理
首先,torch.utils.data.DataLoader
不仅生成迭代数据非常方便,而且它也是经过优化的,效率十分之高(肯定比我们自己写一个要高多了),因此我们最好不要舍弃。
因此,我们的目标是根据CIFAR-10数据集构造一个Dataset的子类,使之能够作为torch.utils.data.DataLoader
的参数,从而使数据集能被我们用于生成迭代数据进行训练:
cifar10 = MyCIFAR10.MyCIFAR10('./data/cifar-10-batches-py', train=True)
train_loader = torch.utils.data.DataLoader(dataset=cifar10, batch_size=batch_size, shuffle=True)
要构造Dataset的子类,就必须要实现两个方法:
- _getitem_(self, index):根据index来返回数据集中标号为index的元素及其标签。
- _len_(self):返回数据集的长度。
因此,实质上我们主要是要通过__init__初始化之时读取数据集,再实现这两个函数便轻而易举。
代码细节
_init_:
- root是存放解压后的数据集的根目录,根据上图我这里是
'./data/cifar-10-batches-py'
。 - X的类型是numpy数组,Y的类型是List;由于X作为数据要送入网络中,因此最后需要将其累加值从numpy数组转为Tensor。
def __init__(self, root, train=True, transform=None, target_transform=None):super(MyCIFAR10, self).__init__()self.transform = transformself.target_transform = target_transformself.imgs = Noneself.labels = []# 根据CIFAR-10官网上下载的数据,训练集分为5个batch文件,每个里有10000张32*32的图片;测试集只有1个batch文件,里面有10000张32*32的图片train_lists = ['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5']test_lists = ['test_batch']# 根据train是否为True来选择测试集或训练集if train:lists = train_listselse:lists = test_lists# 读取数据集,构造类中的图像集和标签for list in lists:filename = os.path.join(root, list)with open(filename, 'rb') as f: # 这里需要'rb' + 'latin1'才能读取datadict = pickle.load(f, encoding='latin1')X = datadict['data'].reshape(-1, 3, 32, 32)Y = datadict['labels']if self.imgs is None:self.imgs = np.vstack(X).reshape(-1, 3, 32, 32)else:self.imgs = np.vstack((self.imgs, X)).reshape(-1, 3, 32, 32)self.labels = self.labels + Yself.imgs = torch.from_numpy(self.imgs).type(torch.FloatTensor) # 最后需要将numpy数组转为Tensor
- root是存放解压后的数据集的根目录,根据上图我这里是
_getitem_:
较为简单,直接给出:
def __getitem__(self, index):img, label = self.imgs[index], self.labels[index]if self.transform is not None:img = self.transform(img)if self.target_transform is not None:label = self.target_transform(label)return img, label
_len_:
极其简单,直接给出:
def __len__(self):return len(self.imgs)
详细代码
class MyCIFAR10(Dataset):"""根据CIFAR-10定义的个人数据集类继承自Dataset类,因此能够被torch.utils.data.DataLoader使用,从而更高效地在训练和测试中迭代"""def __init__(self, root, train=True, transform=None, target_transform=None):super(MyCIFAR10, self).__init__()self.transform = transformself.target_transform = target_transformself.imgs = Noneself.labels = []# 根据CIFAR-10官网上下载的数据,训练集分为5个batch文件,每个里有10000张32*32的图片;测试集只有1个batch文件,里面有10000张32*32的图片train_lists = ['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5']test_lists = ['test_batch']# 根据train是否为True来选择测试集或训练集if train:lists = train_listselse:lists = test_lists# 读取数据集,构造类中的图像集和标签for list in lists:filename = os.path.join(root, list)with open(filename, 'rb') as f: # 这里需要'rb' + 'latin1'才能读取datadict = pickle.load(f, encoding='latin1')X = datadict['data'].reshape(-1, 3, 32, 32)Y = datadict['labels']if self.imgs is None:self.imgs = np.vstack(X).reshape(-1, 3, 32, 32)else:self.imgs = np.vstack((self.imgs, X)).reshape(-1, 3, 32, 32)self.labels = self.labels + Yself.imgs = torch.from_numpy(self.imgs).type(torch.FloatTensor) # 最后需要将numpy数组转为Tensor# 继承的Dataset类需要实现两个方法之一:__getitem__(self, index)def __getitem__(self, index):img, label = self.imgs[index], self.labels[index]if self.transform is not None:img = self.transform(img)if self.target_transform is not None:label = self.target_transform(label)return img, label# 继承的Dataset类需要实现两个方法之一:__len__(self)def __len__(self):return len(self.imgs)
(深度学习)构造属于你自己的Pytorch数据集相关推荐
- 深度学习笔记其七:计算机视觉和PYTORCH
深度学习笔记其七:计算机视觉和PYTORCH 1. 图像增广 1.1 常用的图像增广方法 1.1.1 翻转和裁剪 1.1.2 改变颜色 1.1.3 结合多种图像增广方法 1.2 使用图像增广进行训练 ...
- 基于深度学习的口罩识别与检测PyTorch实现
基于深度学习的口罩识别与检测PyTorch实现 1. 设计思路 1.1 两阶段检测器:先检测人脸,然后将人脸进行分类,戴口罩与不戴口罩. 1.2 一阶段检测器:直接训练口罩检测器,训练样本为人脸的标注 ...
- 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)
初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...
- Pytorch深度学习(五):加载数据集以及mini-batch的使用
Pytorch深度学习(五):加载数据集以及mini-batch的使用 参考B站课程:<PyTorch深度学习实践>完结合集 传送门:<PyTorch深度学习实践>完结合集 一 ...
- 深度学习---从入门到放弃(一)pytorch基础
深度学习-从入门到放弃(一)pytorch Tensor 类似于numpy的array,pandas的dataframe:在pytorch里的数据结构是tensor,即张量 tensor简单操作 1. ...
- 15个小时彻底搞懂NLP自然语言处理(2021最新版附赠课件笔记资料)【LP自然语言处理涉及到深度学习和神经网络的介绍、 Pytorch、 RNN自然语言处理】 笔记
15个小时彻底搞懂NLP自然语言处理(2021最新版附赠课件笔记资料)[LP自然语言处理涉及到深度学习和神经网络的介绍. Pytorch. RNN自然语言处理] 笔记 教程与代码地址 P1 机器学习与 ...
- 【三维重建】【深度学习】windows10下NeRF代码Pytorch实现
[三维重建][深度学习]windows10下NeRF代码Pytorch实现 提示:最近开始在[三维重建]方面进行研究,记录相关知识点,分享学习中遇到的问题已经解决的方法. 文章目录 [三维重建][深度 ...
- 深度学习报错 | THCudaCheck FAIL file=/pytorch/aten/src/THC/THCGeneral.cpp
深度学习报错 | THCudaCheck FAIL file=/pytorch/aten/src/THC/THCGeneral.cpp 错误定位 解决历程 错误定位 近日在自己的服务器上跑别人的代码时 ...
- 【深度学习】弱/半监督学习解决医学数据集规模小、数据标注难问题
[深度学习]弱/半监督学习解决医学数据集规模小.数据标注难问题 文章目录 1 概述 2 半监督学习 3 重新思考空洞卷积: 为弱监督和半监督语义分割设计的简捷方法 4 弱监督和半监督分割的训练和学习 ...
最新文章
- tidb 架构 ~Tidb学习系列(4)
- Paddle 环境中 使用LeNet在MNIST数据集实现图像分类
- 1131: 零起点学算法38——求阶乘和
- 文本编辑器中命令行参数的应用
- 爬虫_微信小程序社区教程(crawlspider)
- SD卡读写,首选项,共享首选项
- SAP Hybris Commerce启用customer coupon的前提条件
- PYQT4 Python GUI 编写与 打包.exe程序
- KnockoutJS-与服务端交互
- 红橙Darren视频笔记 缓存方案 缓存到数据库(数据库操作) 上
- oledb vc访问mdb数据库_一个通用数据库操作组件DBUtil(c#)、支持SqlServer、Oracle、Mysql、postgres、SQLITE...
- 书屋(三):《浪潮之巅》品各大百年公司兴衰历程
- 特殊的Excel填充序号技巧,总有一种你会遇到【特别实用,赶紧收藏】
- 数据科学的原理与技巧 四、数据清理
- 中债登——各功能快捷入口
- Qt error: C2039: “staticMetaObject”: 不是“QXXX”的成员
- 给图片加水印的简单方法,手机图片加水印也可以用
- Portainer容器可视化工具
- Android 音频源码分析——音量调节流程
- 使用HTTPie测试Web服务