(深度学习)构造属于你自己的Pytorch数据集

1.综述

2.实现原理

3.代码细节

4.详细代码

综述

Pytorch可以说是一个非常便利的深度学习库,它甚至在torchvision.datasets中拥有许多一步到位完成数据集下载、解析、读取的类——然鹅,这样也就养成了我们懒惰依赖的心理。当我们需要用到torchvision.datasets中不曾拥有的数据集时,我们可能就会不知所措。

这篇文章中,我将以CIFAR-10数据集为例(虽然有torchvision.datasets.CIFAR10了),摆脱对torchvision.datasets的依赖,构建一个自己的数据集。

在开始之前,首先你要有CIFAR-10数据集,直接去官网上下载可能较慢(再次感谢我国著名建筑师方斌新院士 ),可以在https://pan.baidu.com/s/1bGVGeeiw001qz-PUk7q1Uw(提取码:m35y)中下载python版本的数据集。

数据集解压后目录情况如下:

实现原理

首先,torch.utils.data.DataLoader不仅生成迭代数据非常方便,而且它也是经过优化的,效率十分之高(肯定比我们自己写一个要高多了),因此我们最好不要舍弃。

因此,我们的目标是根据CIFAR-10数据集构造一个Dataset的子类,使之能够作为torch.utils.data.DataLoader的参数,从而使数据集能被我们用于生成迭代数据进行训练:

cifar10 = MyCIFAR10.MyCIFAR10('./data/cifar-10-batches-py', train=True)
train_loader = torch.utils.data.DataLoader(dataset=cifar10, batch_size=batch_size, shuffle=True)

要构造Dataset的子类,就必须要实现两个方法:

  • _getitem_(self, index):根据index来返回数据集中标号为index的元素及其标签。
  • _len_(self):返回数据集的长度。

因此,实质上我们主要是要通过__init__初始化之时读取数据集,再实现这两个函数便轻而易举。

代码细节

  1. _init_:

    • root是存放解压后的数据集的根目录,根据上图我这里是'./data/cifar-10-batches-py'
    • X的类型是numpy数组,Y的类型是List;由于X作为数据要送入网络中,因此最后需要将其累加值从numpy数组转为Tensor。
    def __init__(self, root, train=True, transform=None, target_transform=None):super(MyCIFAR10, self).__init__()self.transform = transformself.target_transform = target_transformself.imgs = Noneself.labels = []# 根据CIFAR-10官网上下载的数据,训练集分为5个batch文件,每个里有10000张32*32的图片;测试集只有1个batch文件,里面有10000张32*32的图片train_lists = ['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5']test_lists = ['test_batch']# 根据train是否为True来选择测试集或训练集if train:lists = train_listselse:lists = test_lists# 读取数据集,构造类中的图像集和标签for list in lists:filename = os.path.join(root, list)with open(filename, 'rb') as f:  # 这里需要'rb' + 'latin1'才能读取datadict = pickle.load(f, encoding='latin1')X = datadict['data'].reshape(-1, 3, 32, 32)Y = datadict['labels']if self.imgs is None:self.imgs = np.vstack(X).reshape(-1, 3, 32, 32)else:self.imgs = np.vstack((self.imgs, X)).reshape(-1, 3, 32, 32)self.labels = self.labels + Yself.imgs = torch.from_numpy(self.imgs).type(torch.FloatTensor)     # 最后需要将numpy数组转为Tensor
    
  2. _getitem_:

    较为简单,直接给出:

    def __getitem__(self, index):img, label = self.imgs[index], self.labels[index]if self.transform is not None:img = self.transform(img)if self.target_transform is not None:label = self.target_transform(label)return img, label
    
  3. _len_:

    极其简单,直接给出:

    def __len__(self):return len(self.imgs)
    

详细代码

class MyCIFAR10(Dataset):"""根据CIFAR-10定义的个人数据集类继承自Dataset类,因此能够被torch.utils.data.DataLoader使用,从而更高效地在训练和测试中迭代"""def __init__(self, root, train=True, transform=None, target_transform=None):super(MyCIFAR10, self).__init__()self.transform = transformself.target_transform = target_transformself.imgs = Noneself.labels = []# 根据CIFAR-10官网上下载的数据,训练集分为5个batch文件,每个里有10000张32*32的图片;测试集只有1个batch文件,里面有10000张32*32的图片train_lists = ['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5']test_lists = ['test_batch']# 根据train是否为True来选择测试集或训练集if train:lists = train_listselse:lists = test_lists# 读取数据集,构造类中的图像集和标签for list in lists:filename = os.path.join(root, list)with open(filename, 'rb') as f:  # 这里需要'rb' + 'latin1'才能读取datadict = pickle.load(f, encoding='latin1')X = datadict['data'].reshape(-1, 3, 32, 32)Y = datadict['labels']if self.imgs is None:self.imgs = np.vstack(X).reshape(-1, 3, 32, 32)else:self.imgs = np.vstack((self.imgs, X)).reshape(-1, 3, 32, 32)self.labels = self.labels + Yself.imgs = torch.from_numpy(self.imgs).type(torch.FloatTensor)     # 最后需要将numpy数组转为Tensor# 继承的Dataset类需要实现两个方法之一:__getitem__(self, index)def __getitem__(self, index):img, label = self.imgs[index], self.labels[index]if self.transform is not None:img = self.transform(img)if self.target_transform is not None:label = self.target_transform(label)return img, label# 继承的Dataset类需要实现两个方法之一:__len__(self)def __len__(self):return len(self.imgs)

(深度学习)构造属于你自己的Pytorch数据集相关推荐

  1. 深度学习笔记其七:计算机视觉和PYTORCH

    深度学习笔记其七:计算机视觉和PYTORCH 1. 图像增广 1.1 常用的图像增广方法 1.1.1 翻转和裁剪 1.1.2 改变颜色 1.1.3 结合多种图像增广方法 1.2 使用图像增广进行训练 ...

  2. 基于深度学习的口罩识别与检测PyTorch实现

    基于深度学习的口罩识别与检测PyTorch实现 1. 设计思路 1.1 两阶段检测器:先检测人脸,然后将人脸进行分类,戴口罩与不戴口罩. 1.2 一阶段检测器:直接训练口罩检测器,训练样本为人脸的标注 ...

  3. 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)

    初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...

  4. Pytorch深度学习(五):加载数据集以及mini-batch的使用

    Pytorch深度学习(五):加载数据集以及mini-batch的使用 参考B站课程:<PyTorch深度学习实践>完结合集 传送门:<PyTorch深度学习实践>完结合集 一 ...

  5. 深度学习---从入门到放弃(一)pytorch基础

    深度学习-从入门到放弃(一)pytorch Tensor 类似于numpy的array,pandas的dataframe:在pytorch里的数据结构是tensor,即张量 tensor简单操作 1. ...

  6. 15个小时彻底搞懂NLP自然语言处理(2021最新版附赠课件笔记资料)【LP自然语言处理涉及到深度学习和神经网络的介绍、 Pytorch、 RNN自然语言处理】 笔记

    15个小时彻底搞懂NLP自然语言处理(2021最新版附赠课件笔记资料)[LP自然语言处理涉及到深度学习和神经网络的介绍. Pytorch. RNN自然语言处理] 笔记 教程与代码地址 P1 机器学习与 ...

  7. 【三维重建】【深度学习】windows10下NeRF代码Pytorch实现

    [三维重建][深度学习]windows10下NeRF代码Pytorch实现 提示:最近开始在[三维重建]方面进行研究,记录相关知识点,分享学习中遇到的问题已经解决的方法. 文章目录 [三维重建][深度 ...

  8. 深度学习报错 | THCudaCheck FAIL file=/pytorch/aten/src/THC/THCGeneral.cpp

    深度学习报错 | THCudaCheck FAIL file=/pytorch/aten/src/THC/THCGeneral.cpp 错误定位 解决历程 错误定位 近日在自己的服务器上跑别人的代码时 ...

  9. 【深度学习】弱/半监督学习解决医学数据集规模小、数据标注难问题

    [深度学习]弱/半监督学习解决医学数据集规模小.数据标注难问题 文章目录 1 概述 2 半监督学习 3 重新思考空洞卷积: 为弱监督和半监督语义分割设计的简捷方法 4 弱监督和半监督分割的训练和学习 ...

最新文章

  1. tidb 架构 ~Tidb学习系列(4)
  2. Paddle 环境中 使用LeNet在MNIST数据集实现图像分类
  3. 1131: 零起点学算法38——求阶乘和
  4. 文本编辑器中命令行参数的应用
  5. 爬虫_微信小程序社区教程(crawlspider)
  6. SD卡读写,首选项,共享首选项
  7. SAP Hybris Commerce启用customer coupon的前提条件
  8. PYQT4 Python GUI 编写与 打包.exe程序
  9. KnockoutJS-与服务端交互
  10. 红橙Darren视频笔记 缓存方案 缓存到数据库(数据库操作) 上
  11. oledb vc访问mdb数据库_一个通用数据库操作组件DBUtil(c#)、支持SqlServer、Oracle、Mysql、postgres、SQLITE...
  12. 书屋(三):《浪潮之巅》品各大百年公司兴衰历程
  13. 特殊的Excel填充序号技巧,总有一种你会遇到【特别实用,赶紧收藏】
  14. 数据科学的原理与技巧 四、数据清理
  15. 中债登——各功能快捷入口
  16. Qt error: C2039: “staticMetaObject”: 不是“QXXX”的成员
  17. 给图片加水印的简单方法,手机图片加水印也可以用
  18. Portainer容器可视化工具
  19. Android 音频源码分析——音量调节流程
  20. 使用HTTPie测试Web服务

热门文章

  1. 搜狐html5,手机搜狐率先发力Html5技术
  2. 关于mysql本地计算机上的MySQL服务启动后停止。某些服务在未由其他服务或程序使用时将自动停止问题
  3. HDFS成员的工作机制
  4. 【转】excel音标乱码
  5. 一个简单的音乐网站设计与实现(HTML+CSS)---爵士乐音乐 3页
  6. 分布式系统原理介绍_分布式系统的全面介绍
  7. 细数二十世纪最伟大的10大算法(Top10)
  8. Matlab2013a学习之男女的声音识别
  9. 《欢乐颂2》狗血的剧情才是生活该有的模样
  10. (纪中)2419. Grass Planting