实现一个定制的 Dataset 类

Dataset 类是 PyTorch 图像数据集中最为重要的一个类,也是 PyTorch 中所有数据集加载类中应该继承的父类。其中,父类的两个私有成员函数必须被重载。

  • getitem(self, index) # 支持数据集索引的函数
  • len(self) # 返回数据集的大小

Datasets 的框架:

class CustomDataset(data.Dataset): # 需要继承 data.Datasetdef __init__(self):# TODO# Initialize file path or list of file names.passdef __getitem__(self, index):# TODO# 1. 从文件中读取指定 index 的数据(例:使用 numpy.fromfile, PIL.Image.open)# 2. 预处理读取的数据(例:torchvision.Transform)# 3. 返回数据对(例:图像和对应标签)passdef __len__(self):# TODO# You should change 0 to the total size of your dataset.return 0

举例:

class MyDataset(Dataset):"""root: 图像存放地址根路径augment:是否需要图像增强"""def __init__(self, root, augment=None):# 这个 list 存放所有图像的地址self.image_files = np.array([x.path for x in os.scandir(root)if x.name.endswith(".jpg") or x.name.endswith(".png") or x.name.endswith(".JPG")])self.augment = augmentdef __getitem__(self, index):if self.augment:image = open_image(self.image_files[index])   # 这里的 open_image 是读取图像的函数,可以用 PIL 或者 OpenCV 等库进行读取image = self.augment(image)   # 这里对图像进行了数据增强return to_tensor(image)       # PyTorch 中得到的图像必须是 tensorelse:image = open_image(self.image_files[index])return to_tensor(image)

下面是官方 MNIST 的例子:

class MNIST(data.Dataset):"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.Args:root (string): Root directory of dataset where ``processed/training.pt``and  ``processed/test.pt`` exist.train (bool, optional): If True, creates dataset from ``training.pt``,otherwise from ``test.pt``.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it."""urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz','http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz','http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz','http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',]raw_folder = 'raw'processed_folder = 'processed'training_file = 'training.pt'test_file = 'test.pt'classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four','5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']class_to_idx = {_class: i for i, _class in enumerate(classes)}@propertydef targets(self):if self.train:return self.train_labelselse:return self.test_labelsdef __init__(self, root, train=True, transform=None, target_transform=None, download=False):self.root = os.path.expanduser(root)self.transform = transformself.target_transform = target_transformself.train = train  # training set or test setif download:self.download()if not self._check_exists():raise RuntimeError('Dataset not found.' +' You can use download=True to download it')if self.train:self.train_data, self.train_labels = torch.load(os.path.join(self.root, self.processed_folder, self.training_file))else:self.test_data, self.test_labels = torch.load(os.path.join(self.root, self.processed_folder, self.test_file))def __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""if self.train:img, target = self.train_data[index], self.train_labels[index]else:img, target = self.test_data[index], self.test_labels[index]# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self):if self.train:return len(self.train_data)else:return len(self.test_data)def _check_exists(self):return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))def download(self):"""Download the MNIST data if it doesn't exist in processed_folder already."""from six.moves import urllibimport gzipif self._check_exists():return# download filestry:os.makedirs(os.path.join(self.root, self.raw_folder))os.makedirs(os.path.join(self.root, self.processed_folder))except OSError as e:if e.errno == errno.EEXIST:passelse:raisefor url in self.urls:print('Downloading ' + url)data = urllib.request.urlopen(url)filename = url.rpartition('/')[2]file_path = os.path.join(self.root, self.raw_folder, filename)with open(file_path, 'wb') as f:f.write(data.read())with open(file_path.replace('.gz', ''), 'wb') as out_f, \gzip.GzipFile(file_path) as zip_f:out_f.write(zip_f.read())os.unlink(file_path)# process and save as torch filesprint('Processing...')training_set = (read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')))test_set = (read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')))with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:torch.save(training_set, f)with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:torch.save(test_set, f)print('Done!')def __repr__(self):fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())tmp = 'train' if self.train is True else 'test'fmt_str += '    Split: {}\n'.format(tmp)fmt_str += '    Root Location: {}\n'.format(self.root)tmp = '    Transforms (if any): 'fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))tmp = '    Target Transforms (if any): 'fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))return fmt_str

转载于:https://www.cnblogs.com/xxxxxxxxx/p/11429051.html

PyTorch 之 Datasets相关推荐

  1. pytorch torchvision.datasets.ImageFolder

    API CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<f ...

  2. pytorch torchvision.datasets

    torchvision 库是服务于pytorch深度学习框架的,用来生成图片,视频数据集,和一些流行的模型类和预训练模型. torchvision.datasets 所有数据集都是 torch.uti ...

  3. PyTorch - torchvision - datasets

    所有的数据集都是torch.utils.data.Dataset的子类,即它们实现了__getitem__和__len__方法. 因此,它们都可以传递给torch.utils.data.DataLoa ...

  4. 【已补蓝奏云链接】PyTorch中MNIST数据集(附datasets.MNIST离线包)下载慢/安装慢的解决方案

    一.问题背景 在学习MNIST数据集手写数字识别demo的时候,笔者碰到了一些问题,现记录如下: 1.必须先确保torchvision已经正确安.如何安装torchvision?请参考PyTorch/ ...

  5. 使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络

    使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络 本文例程部分主要参考官方文档. JAX简介 JAX 的前身是 Autograd ,也就是说 JAX 是 Autograd 升级版本 ...

  6. 毕设日志——在faster rcnn pytorch上训练KITTI数据集

    本次目标: 整理运行代码产生的内容 下载KITTI数据集和LSVH数据集 修改数据集样式为VOC2007 在新的数据集上训测 2019.4.13 一.准备工作 备份之前训练生成的文件models,ou ...

  7. Win环境下配置PyTorch深度学习环境

    目录 0.查看Nvidia驱动 1.下载torch和torchvision 2.安装torch和torchvison 3.YOLOv5环境配置 相较于tensorflow环境配置,PyTorch的配置 ...

  8. Anaconda中GPU版本Pytorch 的whl 安装方法【2023.1最新最详细】(附anaconda以及cudacudnn安装教程)

    本教程郑重承诺,不使用Pytorch官网,不使用清华源镜像资源,全部利用whl文件完成Pytorch 安装. 写在前面的一些艰辛心路历程: 因为毕设需要安装pytorch,学长轻飘飘地丢给我一句&qu ...

  9. pytorch的安装--命令

    --------------------------------------------------pytorch相关安装--------------------------------------- ...

最新文章

  1. 手机游戏深化、改革。
  2. Struts2的类型转换(下)
  3. LESSON 10.410.510.6 贝叶斯优化的基本流程BayesOpt vs HyperOpt vs Optuna batch基于BayesOpt实现高斯过程gp优化
  4. QT解析 JSON 格式的数据
  5. 开源python-打包发布
  6. catch的执行与try的匹配
  7. Redis 重写原理
  8. 面向全球用户的Teams app之Culture计量单位和禁忌篇
  9. Java面向对象(16)--单例(Singleton)设计模式
  10. Linux下架设邮件服务器全攻略(二)
  11. redis key/value 前面出现\xac\xed\x00\x05t\x00\x06 已解决
  12. Android 连接SQLite
  13. 海量数据挖掘MMDS week2: 频繁项集挖掘 Apriori算法的改进:非hash方法
  14. sessionStorage跨标签取值
  15. pytorch自带网络_一篇长文学懂 pytorch
  16. 第二季-专题13-NandFlash变硬盘
  17. Picnic Planning
  18. win10找回永久删除文件【图文教程】
  19. 猜拳java,猜拳小游戏(Java代码实现)
  20. 云大计算机初试最高分,【经验谈】初试总分360+,专业排名前五!云大社会工作专......

热门文章

  1. 查看Linux内核及发行商版本命令
  2. Tactai获美国科学基金会100万美元投资,致力于打造VR触觉体验
  3. C语言在VS2017环境下写俄罗斯方块的感悟
  4. 2017华南理工华为杯D bx回文
  5. div根据内容改变大小并且左右居中
  6. 桥牌笔记:3NT做庄路线
  7. fopen参数mode详解
  8. 关于 Hive 报 SemanticException 错误的问题
  9. 浅析Java内存模型--ClassLoader
  10. 优化算法-共轭梯度法