-柚子皮-

什么是Datasets?

在输入流水线中,准备数据的代码是这么写的

data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)

datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

为什么要定义Datasets?

PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们可以让数据变成mini-batch,且在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。

Datasets就是构建这个类的实例的参数之一。

DataLoader的使用参考[PyTorch:数据读取2 - Dataloader]。

数据集划分

1 建议使用sklearn.preprocessing.model_selection

ds_train, ds_eval = model_selection.train_test_split(dataset, test_size=0.2, shuffle=args.if_shuffle_data)

​​​​​​​2 train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

Note: dataloader应该是不能进行划分的。

[Pytorch划分数据集的方法]

-柚子皮-

自定义Datasets

框架

dataset必须继承自torch.utils.data.Dataset。内部要实现两个函数:一个是__lent__用来获取整个数据集的大小,一个是__getitem__用来从数据集中得到一个数据片段item

import torch.utils.data as data
class CustomDataset(data.Dataset):  # 继承data.Dataset
    """Custom data.Dataset compatible with data.DataLoader."""

def __init__(self, filename, data_info, oth_params):
        """Reads source and target sequences from txt files."""
        # # # Initialize file path or list of file names.
        self.file = open(filename, 'r')
        pass
        # # # 或者从外部数据结构data_info中读取数据
        self.all_texts = data_info['all_texts']
        self.all_labels = data_info['all_labels']
        self.vocab = data_info['vocab']

def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        # # # 从文件读取
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform或者word2id什么的).
        # 3. Return a data pair(source and target) (e.g. image and label).
        pass
        # # # 或者直接读取
        item_info = {
            "text": self.all_texts[index],
            "label": self.all_labels[index]
        }
        return item_info

def __len__(self):
        # You should change 0 to the total size of your dataset.
        # return 0
        return len(self.all_texts)

小示例

从文件中读取数据定稿Dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, filepath=None,dataLen=None):
        self.file = filepath
        self.dataLen = dataLen
        
    def __getitem__(self, index):
        A,B,path,hop= linecache.getline(self.file, index+1).split('\t')
        return A,B,path.split(' '),int(hop)

def __len__(self):
        return self.dataLen

随机mock一个分类数据

class Dataset(data.Dataset):
    """Custom data.Dataset compatible with data.DataLoader."""

def __init__(self, df, lang: Lang):
        inputs_dim = vars(Config)['inputs_dim']
        self.x = torch.randint(0, 5, (5, inputs_dim), dtype=torch.float)

self.label = torch.tensor([0, 0, 1, 1, 0, 1, 0, 1, 0, 1], dtype=torch.float)

self.src_word2id = lang.word2id
        self.trg_word2id = lang.word2id
        # self.mem_word2id = mem_word2id

def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        x = self.x[index]
        label = self.label[index]

item_info = {
            "x": x,
            "label": label
        }
        return item_info

官方MNIST的例子

(代码被缩减,只留下了重要的部分):

class MNIST(data.Dataset):def __init__(self, root, train=True, transform=None, target_transform=None, download=False):self.root = rootself.transform = transformself.target_transform = target_transformself.train = train  # training set or test setif download:self.download()if not self._check_exists():raise RuntimeError('Dataset not found.' +' You can use download=True to download it')if self.train:self.train_data, self.train_labels = torch.load(os.path.join(root, self.processed_folder, self.training_file))else:self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))def __getitem__(self, index):if self.train:img, target = self.train_data[index], self.train_labels[index]else:img, target = self.test_data[index], self.test_labels[index]# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self):if self.train:return 60000else:return 10000

from: -柚子皮-

ref: [pytorch学习笔记(六):自定义Datasets]

PyTorch:数据读取1 - Datasets及数据集划分相关推荐

  1. PyTorch框架学习八——PyTorch数据读取机制(简述)

    PyTorch框架学习八--PyTorch数据读取机制(简述) 一.数据 二.DataLoader与Dataset 1.torch.utils.data.DataLoader 2.torch.util ...

  2. 使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络

    使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络 本文例程部分主要参考官方文档. JAX简介 JAX 的前身是 Autograd ,也就是说 JAX 是 Autograd 升级版本 ...

  3. PyTorch系列 (二): pytorch数据读取自制数据集并

    PyTorch系列 (二): pytorch数据读取 PyTorch 1: How to use data in pytorch Posted by WangW on February 1, 2019 ...

  4. Pytorch数据读取(Dataset, DataLoader, DataLoaderIter)

    Pytorch的数据读取主要包含三个类: Dataset DataLoader DataLoaderIter 这三者是一个依次封装的关系: 1.被装进2., 2.被装进3. Dataset类 Pyto ...

  5. pytorch数据读取之Dataset与DataLoader

    1. 先前处理数据集的代码经常比较混乱并且难以维护 2. 数据集处理代码应该和训练代码解耦合,从而达到模块化和更好的可读性 因此,pytorch提出了两个数据处理类:DataLoader与Datase ...

  6. Pytorch数据读取加速方法

    1. 方法一:使用prefetcher class data_prefetcher():def __init__(self, loader):self.loader = iter(loader)sel ...

  7. 系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

    Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html? Pytorch中文文档:https://pytorch-cn.readthedocs ...

  8. 图像数据读取及数据扩增方法

    Datawhale干货 作者:王程伟,Datawhale成员 本文为干货知识+竞赛实践系列分享,旨在理论与实践结合,从学习到项目实践.(零基础入门系列:数据挖掘/cv/nlp/金融风控/推荐系统等,持 ...

  9. 十分钟搞懂Pytorch如何读取MNIST数据集

    前言 本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧- 正文 在阅读教程书籍<深度学习入门之Pytorch>时,文中是如此加载MNIST手写数字训练集的: ...

  10. 深度之眼Pytorch打卡(九):Pytorch数据预处理——预处理过程与数据标准化(transforms过程、Normalize原理、常用数据集均值标准差与数据集均值标准差计算)

    前言   前段时间因为一些事情没有时间或者心情学习,现在两个多月过去了,事情结束了,心态也调整好了,所以又来接着学习Pytorch.这篇笔记主要是关于数据预处理过程.数据集标准化与数据集均值标准差计算 ...

最新文章

  1. VC++ 从View类获取各种指针编程实例
  2. android 自动表单提交数据,Android 使用三种方式获取网页(通过Post,Get进行表单的提交)...
  3. shiro单点登录原理_SSO单点登录三种情况的实现方式详解
  4. 提高篇 第五部分 动态规划 第4章 状态压缩类动态规划
  5. onvif协议服务器端口,通过onvif协议接入海康、大华NVR步骤
  6. [转载] python-TypeError: Object of type ‘Decimal‘ is not JSON serializable 报错
  7. Rabbitmq+Nginx+keepalived高可用热备
  8. [UE4] 虚幻4学习---UE4中的字符串转换
  9. DiskGeniux无损分区
  10. 0712CF解题报告
  11. 支付宝 手机h5支付
  12. 容斥原理解决某个区间[1,n]闭区间与m互质数数量问题
  13. APP创意IDEA记录
  14. 基于opencv的手势识别(HSV)控制鼠标
  15. 玩转星际争霸局部战斗 —— QMIX
  16. 直律云所——让法律变得简单
  17. iPhone4 SIM失败?无效SIM?有效解决
  18. 小米手机系统更新没有数据连接到服务器吗,小米手机无服务怎么解决
  19. 屋面房顶白色外壁降温用凉凉胶隔热面漆 隔热性能十分优异
  20. linux下通过Python代码实现获取硬件接口信息

热门文章

  1. 【转】Windows和Ubuntu双系统,修复UEFI引导的两种办法
  2. 操作系统--文件管理
  3. React的性能优化 - 代码拆分之lazy的使用方法
  4. asp.net—单例模式
  5. ARM处理器的9种模式详解
  6. 读《突然就走到了西藏》 | 保持呼吸,继续向前
  7. 收集WebDriver的执行命令和参数信息
  8. 读《遇见未知的自己》有感
  9. JAVA 调用Web Service的方法(转)
  10. 下载MSN2009享受SkyDrive免费25G网络硬盘