前言

PyTorch作为一款流行深度学习框架其热度大有超越TensorFlow的感觉。根据此前的统计,目前TensorFlow虽然仍然占据着工业界,但PyTorch在视觉和NLP领域的顶级会议上已呈一统之势。

这篇文章笔者将和大家聚焦于PyTorch的自定义数据读取pipeline模板和相关trciks以及如何优化数据读取的pipeline等。我们从PyTorch的数据对象类Dataset开始。Dataset在PyTorch中的模块位于utils.data下。

from torch.utils.data import Dataset

本文将围绕Dataset对象分别从原始模板、torchvision的transforms模块、使用pandas来辅助读取、torch内置数据划分功能和DataLoader来展开阐述。

Dataset原始模板

PyTorch官方为我们提供了自定义数据读取的标准化代码代码模块,作为一个读取框架,我们这里称之为原始模板。其代码结构如下:

from torch.utils.data import Dataset
class CustomDataset(Dataset):def __init__(self, ...):# stuffdef __getitem__(self, index):# stuffreturn (img, label)def __len__(self):# return examples sizereturn count

根据这个标准化的代码模板,我们只需要根据自己的数据读取任务,分别往__init__()、__getitem__()和__len__()三个方法里添加读取逻辑即可。作为PyTorch范式下的数据读取以及为了后续的data loader,三个方法缺一不可。其中:

  • __init__()函数用于初始化数据读取逻辑,比如读取包含标签和图片地址的csv文件、定义transform组合等。

  • __getitem__()函数用来返回数据和标签。目的上是为了能够被后续的dataloader所调用。

  • __len__()函数则用于返回样本数量。

现在我们往这个框架里填几行代码来形成一个简单的数字案例。创建一个从1到100的数字例子:

from torch.utils.data import Dataset
class CustomDataset(Dataset):def __init__(self):self.samples = list(range(1, 101))def __len__(self):return len(self.samples)def __getitem__(self, idx):return self.samples[idx]if __name__ == '__main__':dataset = CustomDataset()print(len(dataset))print(dataset[50])print(dataset[1:100])

添加torchvision.transforms

然后我们来看如何从内存中读取数据以及如何在读取过程中嵌入torchvision中的transforms功能。torchvision是一个独立于torch的关于数据、模型和一些图像增强操作的辅助库。主要包括datasets默认数据集模块、models经典模型模块、transforms图像增强模块以及utils模块等。在使用torch读取数据的时候,一般会搭配上transforms模块对数据进行一些处理和增强工作。

添加了tranforms之后的读取模块可以改写为:

from torch.utils.data import Dataset
from torchvision import transforms as Tclass CustomDataset(Dataset):def __init__(self, ...):# stuff...# compose the transforms methodsself.transform = T.Compose([T.CenterCrop(100),T.ToTensor()])def __getitem__(self, index):# stuff...data = # Some data read from a file or image# execute the transformdata = self.transform(data)return (img, label)def __len__(self):# return examples sizereturn countif __name__ == '__main__':# Call the datasetcustom_dataset = CustomDataset(...)

可以看到,我们使用了Compose方法来把各种数据处理方法聚合到一起进行定义数据转换方法。通常作为初始化方法放在__init__()函数下。我们以猫狗图像数据为例进行说明。

定义数据读取方法如下:

class DogCat(Dataset):    def __init__(self, root, transforms=None, train=True, val=False):"""get images and execute transforms."""self.val = valimgs = [os.path.join(root, img) for img in os.listdir(root)]# train: Cats_Dogs/trainset/cat.1.jpg# val: Cats_Dogs/valset/cat.10004.jpgimgs = sorted(imgs, key=lambda x: x.split('.')[-2])self.imgs = imgs         if transforms is None:# normalize      normalize = T.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])# trainset and valset have different data transform # trainset need data augmentation but valset don't.# valsetif self.val:self.transforms = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),normalize])# trainsetelse:self.transforms = T.Compose([T.Resize(256),T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),normalize])def __getitem__(self, index):"""return data and label"""img_path = self.imgs[index]label = 1 if 'dog' in img_path.split('/')[-1] else 0data = Image.open(img_path)data = self.transforms(data)return data, labeldef __len__(self):"""return images size."""return len(self.imgs)if __name__ == "__main__":train_dataset = DogCat('./Cats_Dogs/trainset/', train=True)print(len(train_dataset))print(train_dataset[0])

因为这个数据集已经分好了训练集和验证集,所以在读取和transforms的时候需要进行区分。运行示例如下:

与pandas一起使用

很多时候数据的目录地址和标签都是通过csv文件给出的。如下所示:

此时在数据读取的pipeline中我们需要在__init__()方法中利用pandas把csv文件中包含的图片地址和标签融合进去。相应的数据读取pipeline模板可以改写为:

class CustomDatasetFromCSV(Dataset):def __init__(self, csv_path):"""Args:csv_path (string): path to csv filetransform: pytorch transforms for transforms and tensor conversion"""# Transformsself.to_tensor = transforms.ToTensor()# Read the csv fileself.data_info = pd.read_csv(csv_path, header=None)# First column contains the image pathsself.image_arr = np.asarray(self.data_info.iloc[:, 0])# Second column is the labelsself.label_arr = np.asarray(self.data_info.iloc[:, 1])# Calculate lenself.data_len = len(self.data_info.index)def __getitem__(self, index):# Get image name from the pandas dfsingle_image_name = self.image_arr[index]# Open imageimg_as_img = Image.open(single_image_name)# Transform image to tensorimg_as_tensor = self.to_tensor(img_as_img)# Get label of the image based on the cropped pandas columnsingle_image_label = self.label_arr[index]return (img_as_tensor, single_image_label)def __len__(self):return self.data_lenif __name__ == "__main__":# Call datasetdataset =  CustomDatasetFromCSV('./labels.csv')

以mnist_label.csv文件为示例:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import pandas as pdclass CustomDatasetFromCSV(Dataset):def __init__(self, csv_path):"""Args:csv_path (string): path to csv file            transform: pytorch transforms for transforms and tensor conversion"""# Transformsself.to_tensor = T.ToTensor()# Read the csv fileself.data_info = pd.read_csv(csv_path, header=None)# First column contains the image pathsself.image_arr = np.asarray(self.data_info.iloc[:, 0])# Second column is the labelsself.label_arr = np.asarray(self.data_info.iloc[:, 1])# Third column is for an operation indicatorself.operation_arr = np.asarray(self.data_info.iloc[:, 2])# Calculate lenself.data_len = len(self.data_info.index)def __getitem__(self, index):# Get image name from the pandas dfsingle_image_name = self.image_arr[index]# Open imageimg_as_img = Image.open(single_image_name)# Check if there is an operationsome_operation = self.operation_arr[index]# If there is an operationif some_operation:# Do some operation on image# ...# ...pass# Transform image to tensorimg_as_tensor = self.to_tensor(img_as_img)# Get label of the image based on the cropped pandas columnsingle_image_label = self.label_arr[index]return (img_as_tensor, single_image_label)def __len__(self):return self.data_lenif __name__ == "__main__":transform = T.Compose([T.ToTensor()])dataset = CustomDatasetFromCSV('./mnist_labels.csv')print(len(dataset))print(dataset[5])

运行示例如下:

训练集验证集划分

一般来说,为了模型训练的稳定,我们需要对数据划分训练集和验证集。torch的Dataset对象也提供了random_split函数作为数据划分工具,且划分结果可直接供后续的DataLoader使用。

以kaggle的花朵数据为例:

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import random_splittransform = T.Compose([T.Resize((224, 224)),T.RandomHorizontalFlip(),T.ToTensor()])dataset = ImageFolder('./flowers_photos', transform=transform)
print(dataset.class_to_idx)trainset, valset = random_split(dataset, [int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):img, label = img.numpy(), label.numpy()print(img, label)valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):img, label = img.numpy(), label.numpy()print(img.shape, label)

这里使用了ImageFolder模块,可以直接读取各标签对应的文件夹,部分运行示例如下:

使用DataLoader

dataset方法写好之后,我们还需要使用DataLoader将其逐个喂给模型。上一节的数据划分我们已经用到了DataLoader函数。从本质上来讲,DataLoader只是调用了__getitem__()方法并按批次返回数据和标签。使用方法如下:

from torch.utils.data import DataLoader
from torchvision import transforms as Tif __name__ == "__main__":# Define transformstransformations = T.Compose([T.ToTensor()])# Define custom datasetdataset = CustomDatasetFromCSV('./labels.csv')# Define data loaderdata_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)for images, labels in data_loader:# Feed the data to the model

以上就是PyTorch读取数据的Pipeline主要方法和流程。基于Dataset对象的基本框架不变,具体细节可自定义化调整。

本文原创首发于公众号【机器学习实验室】,开创了【深度学习60讲】、【机器学习算法手推30讲】和【深度学习100问】三大系列文章。

一个算法工程师的成长之路


长按二维码.关注机器学习实验室

机器学习实验室的近期文章:

  • 机器学习公式推导和算法手写之XGBoost

  • 机器学习公式推导和算法手写之马尔科夫链蒙特卡洛

  • 如何部署一个轻量级深度学习项目?

  • 基于C++的PyTorch模型部署

  • PyTorch数据Pipeline标准化代码模板

  • 算法工程师的一天

参考文献

【1】https://pytorch.org/docs/stable/data.html

【2】https://towardsdatascience.com/building-efficient-custom-datasets-in-pytorch-2563b946fd9f

【3】https://github.com/utkuozbulak/pytorch-custom-dataset-examples

夕小瑶的卖萌屋

_

关注&星标小夕,带你解锁AI秘籍

订阅号主页下方「撩一下」有惊喜哦

PyTorch数据Pipeline标准化代码模板相关推荐

  1. 深度之眼Pytorch打卡(九):Pytorch数据预处理——预处理过程与数据标准化(transforms过程、Normalize原理、常用数据集均值标准差与数据集均值标准差计算)

    前言   前段时间因为一些事情没有时间或者心情学习,现在两个多月过去了,事情结束了,心态也调整好了,所以又来接着学习Pytorch.这篇笔记主要是关于数据预处理过程.数据集标准化与数据集均值标准差计算 ...

  2. PyTorch数据加载器

    We'll be covering the PyTorch DataLoader in this tutorial. Large datasets are indispensable in the w ...

  3. PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差

    PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差 1.数据归一化处理:transforms.Normalize 1.1 理解torchvision 1 ...

  4. (第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

    前言:在深度学习中,数据的预处理是第一步,pytorch提供了非常规范的处理接口,本文将针对处理过程中的一些问题来进行说明,本文所针对的主要数据是图像数据集. 本文的案例来源于车道线语义分割,采用的数 ...

  5. pytorch数据增广albumentations

    pytorch数据增广albumentations 图像增强库官方英文介绍 安装 pip install albumentations 支持的目标检测bbox格式 pascal_voc [x_min, ...

  6. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 • scikit-image:用于图像的IO和变换 • pandas:用于更容易地进行 ...

  7. PyTorch 数据并行处理

    PyTorch 数据并行处理 可选择:数据并行处理(文末有完整代码下载) 本文将学习如何用 DataParallel 来使用多 GPU. 通过 PyTorch 使用多个 GPU 非常简单.可以将模型放 ...

  8. 数据的标准化和标准化方法

    数据的标准化(normalization)是将数据按比例缩放,使之落入一个小的特定区间.在某些比较和评价的指标处理中经常会用到,去除数据的单位限制,将其转化为无量纲的纯数值,便于不同单位或量纲的指标能 ...

  9. 吴裕雄 python 机器学习——数据预处理标准化StandardScaler模型

    from sklearn.preprocessing import StandardScaler#数据预处理标准化StandardScaler模型 def test_StandardScaler(): ...

最新文章

  1. 如何在JavaScript中反转字符串?
  2. [转载]poj 计算几何题全集(转)
  3. mysql怎么显示、查询现有数据库列表?(show databases;)怎么删除现有数据库?(drop database <库名>)
  4. [云炬创业基础笔记]第九章企业的法律形态测试5
  5. eq linux_音乐家和音乐爱好者的开放硬件 | Linux 中国
  6. go语言web编程,初学点滴记录1
  7. 用geoda软件进行空间自相关分析示例
  8. JAVA计算机毕业设计随心淘网管理系统源码+系统+mysql数据库+lw文档
  9. iOS 应用内付费(IAP)开发步骤
  10. 智能水杯设计方案_智能水杯的设计与营销
  11. 35枚不同风格的设计师个人网站欣赏
  12. ​A* 算法简介 from Red Blob Games​(译文)
  13. android逆向知乎,Android逆向之路---为什么从后台切换回app又显示广告了
  14. 服务器2016自动备份怎么取消,wps中ppt的制作怎样取消掉定时自动备份
  15. 如何做一个基于微信共享会议室预约小程序系统毕业设计毕设作品
  16. 无穷项和求极限(夹逼准则)
  17. Learn more study less 读后感
  18. Unity的Handles类
  19. 面试题9:AJAX是什么【JS】
  20. 【可视化大屏】屏幕多分辨率适配方案

热门文章

  1. var与dynamic区别
  2. VC++中忽略所有默认库纯Win32 API编译及链接 - 计算机软件编程 - Wangye's Space
  3. Oracle定时器(Job)各时间段写法汇总
  4. Postgre体系结构图
  5. 陆奇给工程师们的5个建议
  6. 这种扯淡的嵌入式项目,尽量不要碰
  7. TQ210——时钟系统
  8. wordvba编程代码大全_这几本基础编程书籍一定要看
  9. 【Pytorch神经网络理论篇】 18 循环神经网络结构:LSTM结构+双向RNN结构
  10. 指令系统——数据存放、指令寻址(详解)