PyTorch 之 Datasets
实现一个定制的 Dataset 类
Dataset 类是 PyTorch 图像数据集中最为重要的一个类,也是 PyTorch 中所有数据集加载类中应该继承的父类。其中,父类的两个私有成员函数必须被重载。
- getitem(self, index) # 支持数据集索引的函数
- len(self) # 返回数据集的大小
Datasets 的框架:
class CustomDataset(data.Dataset): # 需要继承 data.Datasetdef __init__(self):# TODO# Initialize file path or list of file names.passdef __getitem__(self, index):# TODO# 1. 从文件中读取指定 index 的数据(例:使用 numpy.fromfile, PIL.Image.open)# 2. 预处理读取的数据(例:torchvision.Transform)# 3. 返回数据对(例:图像和对应标签)passdef __len__(self):# TODO# You should change 0 to the total size of your dataset.return 0
举例:
class MyDataset(Dataset):"""root: 图像存放地址根路径augment:是否需要图像增强"""def __init__(self, root, augment=None):# 这个 list 存放所有图像的地址self.image_files = np.array([x.path for x in os.scandir(root)if x.name.endswith(".jpg") or x.name.endswith(".png") or x.name.endswith(".JPG")])self.augment = augmentdef __getitem__(self, index):if self.augment:image = open_image(self.image_files[index]) # 这里的 open_image 是读取图像的函数,可以用 PIL 或者 OpenCV 等库进行读取image = self.augment(image) # 这里对图像进行了数据增强return to_tensor(image) # PyTorch 中得到的图像必须是 tensorelse:image = open_image(self.image_files[index])return to_tensor(image)
下面是官方 MNIST 的例子:
class MNIST(data.Dataset):"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.Args:root (string): Root directory of dataset where ``processed/training.pt``and ``processed/test.pt`` exist.train (bool, optional): If True, creates dataset from ``training.pt``,otherwise from ``test.pt``.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it."""urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz','http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz','http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz','http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',]raw_folder = 'raw'processed_folder = 'processed'training_file = 'training.pt'test_file = 'test.pt'classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four','5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']class_to_idx = {_class: i for i, _class in enumerate(classes)}@propertydef targets(self):if self.train:return self.train_labelselse:return self.test_labelsdef __init__(self, root, train=True, transform=None, target_transform=None, download=False):self.root = os.path.expanduser(root)self.transform = transformself.target_transform = target_transformself.train = train # training set or test setif download:self.download()if not self._check_exists():raise RuntimeError('Dataset not found.' +' You can use download=True to download it')if self.train:self.train_data, self.train_labels = torch.load(os.path.join(self.root, self.processed_folder, self.training_file))else:self.test_data, self.test_labels = torch.load(os.path.join(self.root, self.processed_folder, self.test_file))def __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""if self.train:img, target = self.train_data[index], self.train_labels[index]else:img, target = self.test_data[index], self.test_labels[index]# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self):if self.train:return len(self.train_data)else:return len(self.test_data)def _check_exists(self):return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))def download(self):"""Download the MNIST data if it doesn't exist in processed_folder already."""from six.moves import urllibimport gzipif self._check_exists():return# download filestry:os.makedirs(os.path.join(self.root, self.raw_folder))os.makedirs(os.path.join(self.root, self.processed_folder))except OSError as e:if e.errno == errno.EEXIST:passelse:raisefor url in self.urls:print('Downloading ' + url)data = urllib.request.urlopen(url)filename = url.rpartition('/')[2]file_path = os.path.join(self.root, self.raw_folder, filename)with open(file_path, 'wb') as f:f.write(data.read())with open(file_path.replace('.gz', ''), 'wb') as out_f, \gzip.GzipFile(file_path) as zip_f:out_f.write(zip_f.read())os.unlink(file_path)# process and save as torch filesprint('Processing...')training_set = (read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')))test_set = (read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')))with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:torch.save(training_set, f)with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:torch.save(test_set, f)print('Done!')def __repr__(self):fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())tmp = 'train' if self.train is True else 'test'fmt_str += ' Split: {}\n'.format(tmp)fmt_str += ' Root Location: {}\n'.format(self.root)tmp = ' Transforms (if any): 'fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))tmp = ' Target Transforms (if any): 'fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))return fmt_str
转载于:https://www.cnblogs.com/xxxxxxxxx/p/11429051.html
PyTorch 之 Datasets相关推荐
- pytorch torchvision.datasets.ImageFolder
API CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<f ...
- pytorch torchvision.datasets
torchvision 库是服务于pytorch深度学习框架的,用来生成图片,视频数据集,和一些流行的模型类和预训练模型. torchvision.datasets 所有数据集都是 torch.uti ...
- PyTorch - torchvision - datasets
所有的数据集都是torch.utils.data.Dataset的子类,即它们实现了__getitem__和__len__方法. 因此,它们都可以传递给torch.utils.data.DataLoa ...
- 【已补蓝奏云链接】PyTorch中MNIST数据集(附datasets.MNIST离线包)下载慢/安装慢的解决方案
一.问题背景 在学习MNIST数据集手写数字识别demo的时候,笔者碰到了一些问题,现记录如下: 1.必须先确保torchvision已经正确安.如何安装torchvision?请参考PyTorch/ ...
- 使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络
使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络 本文例程部分主要参考官方文档. JAX简介 JAX 的前身是 Autograd ,也就是说 JAX 是 Autograd 升级版本 ...
- 毕设日志——在faster rcnn pytorch上训练KITTI数据集
本次目标: 整理运行代码产生的内容 下载KITTI数据集和LSVH数据集 修改数据集样式为VOC2007 在新的数据集上训测 2019.4.13 一.准备工作 备份之前训练生成的文件models,ou ...
- Win环境下配置PyTorch深度学习环境
目录 0.查看Nvidia驱动 1.下载torch和torchvision 2.安装torch和torchvison 3.YOLOv5环境配置 相较于tensorflow环境配置,PyTorch的配置 ...
- Anaconda中GPU版本Pytorch 的whl 安装方法【2023.1最新最详细】(附anaconda以及cudacudnn安装教程)
本教程郑重承诺,不使用Pytorch官网,不使用清华源镜像资源,全部利用whl文件完成Pytorch 安装. 写在前面的一些艰辛心路历程: 因为毕设需要安装pytorch,学长轻飘飘地丢给我一句&qu ...
- pytorch的安装--命令
--------------------------------------------------pytorch相关安装--------------------------------------- ...
最新文章
- 手机游戏深化、改革。
- Struts2的类型转换(下)
- LESSON 10.410.510.6 贝叶斯优化的基本流程BayesOpt vs HyperOpt vs Optuna batch基于BayesOpt实现高斯过程gp优化
- QT解析 JSON 格式的数据
- 开源python-打包发布
- catch的执行与try的匹配
- Redis 重写原理
- 面向全球用户的Teams app之Culture计量单位和禁忌篇
- Java面向对象(16)--单例(Singleton)设计模式
- Linux下架设邮件服务器全攻略(二)
- redis key/value 前面出现\xac\xed\x00\x05t\x00\x06 已解决
- Android 连接SQLite
- 海量数据挖掘MMDS week2: 频繁项集挖掘 Apriori算法的改进:非hash方法
- sessionStorage跨标签取值
- pytorch自带网络_一篇长文学懂 pytorch
- 第二季-专题13-NandFlash变硬盘
- Picnic Planning
- win10找回永久删除文件【图文教程】
- 猜拳java,猜拳小游戏(Java代码实现)
- 云大计算机初试最高分,【经验谈】初试总分360+,专业排名前五!云大社会工作专......