原题 | DATA LOADING AND PROCESSING TUTORIAL

作者 | Sasank Chilamkurthy

原文 | https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

声明 | 翻译是出于交流学习的目的,欢迎转载,但请保留本文出处,请勿用作商业或者非法用途

简介

本文教程主要是介绍如何加载、预处理并对数据进行增强的方法。

首先需要确保安装以下几个 python 库:

  • scikit-image :处理图片数据
  • pandas :处理 csv 文件

导入模块代码如下:

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils# Ignore warnings
import warnings
warnings.filterwarnings("ignore")plt.ion()   # interactive mode

本次教程采用的是一个人脸姿势数据集,其图片如下所示:

每张人脸都是有 68 个人脸关键点,它是由 dlib 生成的,具体实现可以查看其官网介绍:

https://blog.dlib.net/2014/08/real-time-face-pose-estimation.html

数据集下载地址:

https://download.pytorch.org/tutorial/faces.zip

数据集中的 csv 文件的格式如下所示,图片名字和每个关键点的坐标 x, y

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

数据集下载解压缩后放到文件夹 data/faces 中,然后我们先快速打开 face_landmarks.csv 文件,查看文件内容,即标注信息,代码如下所示:

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

输出如下所示:

接着写一个辅助函数来显示人脸图片及其关键点,代码如下所示:

def show_landmarks(image, landmarks):"""Show image with landmarks"""plt.imshow(image)plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')plt.pause(0.001)  # pause a bit so that plots are updatedplt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),landmarks)
plt.show()

输出如下所示:

Dataset 类

torch.utils.data.Dataset 是表示一个数据集的抽象类,在自定义自己的数据集的时候需要继承 Dataset 类别,并重写下方这些方法:

  • len :调用 len(dataset) 时可以返回数据集的数量;
  • getitem:获取数据,可以实现索引访问,即 dataset[i] 可以访问第 i 个样本数据

接下来将给我们的人脸关键点数据集自定义一个类别,在 __init__ 方法中将读取数据集的信息,并在 __getitem__ 方法调用获取的数据集,这主要是基于内存的考虑,这种做法不需要将所有数据一次读取存储在内存中,可以在需要读取数据的时候才读取加载到内存里。

数据集的样本将用一个字典表示:{'image': image, 'landmarks': landmarks},另外还有一个可选参数 transform 用于预处理读取的样本数据,下一节将介绍这个 transform 的用处。

自定义函数的代码如下所示:

class FaceLandmarksDataset(Dataset):"""Face Landmarks dataset."""def __init__(self, csv_file, root_dir, transform=None):"""Args:csv_file (string): 带有标注信息的 csv 文件路径root_dir (string): 图片所在文件夹transform (callable, optional): 可选的用于预处理图片的方法"""self.landmarks_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):# 读取图片img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])image = io.imread(img_name)# 读取关键点并转换为 numpy 数组landmarks = self.landmarks_frame.iloc[idx, 1:]landmarks = np.array([landmarks])landmarks = landmarks.astype('float').reshape(-1, 2)sample = {'image': image, 'landmarks': landmarks}if self.transform:sample = self.transform(sample)return sample

接下来是一个简单的例子来使用上述我们自定义的数据集类,例子中将读取前 4 个样本并展示:

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')fig = plt.figure()
# 读取前 4 张图片并展示
for i in range(len(face_dataset)):sample = face_dataset[i]print(i, sample['image'].shape, sample['landmarks'].shape)ax = plt.subplot(1, 4, i + 1)plt.tight_layout()ax.set_title('Sample #{}'.format(i))ax.axis('off')show_landmarks(**sample)if i == 3:plt.show()break

输出结果如下所示:

Transforms

从上述例子输出的结构可以看到一个问题,图片的大小并不一致,但大多数神经网络都需要输入图片的大小固定。因此,接下来是给出一些预处理的代码,主要是下面三种预处理方法:

  • Rescale :调整图片大小
  • RandomCrop:随机裁剪图片,这是一种数据增强的方法
  • ToTensor:将 numpy 格式的图片转换为 pytorch 的数据格式 tensors,这里需要交换坐标。

这几种方法都将写成可调用的类,而不是简单的函数,这样就不需要每次都传递参数。因此,我们需要实现 __call__ 方法,以及有必要的话,__init__ 方法也是要实现的,然后就可以如下所示一样调用这些方法:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

Rescale 方法的实现代码如下:

class Rescale(object):"""将图片调整为给定的大小.Args:output_size (tuple or int): 期望输出的图片大小. 如果是 tuple 类型,输出图片大小就是给定的 output_size;如果是 int 类型,则图片最短边将匹配给的大小,然后调整最大边以保持相同的比例。"""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]# 判断给定大小的形式,tuple 还是 int 类型if isinstance(self.output_size, int):# int 类型,给定大小作为最短边,最大边长根据原来尺寸比例进行调整if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# 根据调整前后的尺寸比例,调整关键点的坐标位置,并且 x 对应 w,y 对应 hlandmarks = landmarks * [new_w / w, new_h / h]return {'image': img, 'landmarks': landmarks}

RandomCrop 的代码实现:

class RandomCrop(object):"""给定图片,随机裁剪其任意一个和给定大小一样大的区域.Args:output_size (tuple or int): 期望裁剪的图片大小。如果是 int,将得到一个正方形大小的图片."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]new_h, new_w = self.output_size# 随机选择裁剪区域的左上角,即起点,(left, top),范围是由原始大小-输出大小top = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h,left: left + new_w]# 调整关键点坐标,平移选择的裁剪起点landmarks = landmarks - [left, top]return {'image': image, 'landmarks': landmarks}

ToTensor 的方法实现:

class ToTensor(object):"""将 ndarrays 转换为 tensors."""def __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']# 调整坐标尺寸,numpy 的维度是 H x W x C,而 torch 的图片维度是 C X H X Wimage = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}

组合使用预处理方法

接下来就是介绍使用上述自定义的预处理方法的例子。

假设我们希望将图片的最短边长调整为 256,然后随机裁剪一个 224*224 大小的图片区域,也就是我们需要组合调用 RescaleRandomCrop 预处理方法。

torchvision.transforms.Compose 是一个可以实现组合调用欲处理方法的类,实现代码如下所示:

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),RandomCrop(224)])# 对图片数据调用上述 3 种形式预处理方法,即单独使用 Rescale,RandomCrop,组合使用 Rescale和 RandomCrop
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):transformed_sample = tsfrm(sample)ax = plt.subplot(1, 3, i + 1)plt.tight_layout()ax.set_title(type(tsfrm).__name__)show_landmarks(**transformed_sample)plt.show()

输出结构:

迭代整个数据集

现在我们已经定义好一个处理数据集的类,3种预处理数据的类,那么可以将它们整合在一起,实现加载并预处理数据的流程,流程如下所示:

  • 首先根据图片路径读取图片
  • 对图片都调用预处理的方法
  • 预处理方法也可以实现数据增强

实现的代码如下所示:

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/',transform=transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()]))for i in range(len(transformed_dataset)):sample = transformed_dataset[i]print(i, sample['image'].size(), sample['landmarks'].size())if i == 3:break

输出结果:

上述只是一个简单的处理过程,实际上处理和加载数据的时候,我们一般还对数据做以下的处理:

  • 将数据按给定大小分成一批一批数据
  • 打乱数据排列顺序
  • 采用 multiprocessing 来并行加载数据

torch.utils.data.DataLoader 是一个可以实现上述操作的迭代器。其需要的参数如下代码所示,其中一个参数 collate_fn 是用于指定如何对数据进行分批的操作,但也可以采用默认函数。

dataloader = DataLoader(transformed_dataset, batch_size=4,shuffle=True, num_workers=4)# 辅助函数,用于展示一个 batch 的数据
def show_landmarks_batch(sample_batched):"""Show image with landmarks for a batch of samples."""images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']batch_size = len(images_batch)im_size = images_batch.size(2)grid_border_size = 2grid = utils.make_grid(images_batch)plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,landmarks_batch[i, :, 1].numpy() + grid_border_size,s=10, marker='.', c='r')plt.title('Batch from dataloader')for i_batch, sample_batched in enumerate(dataloader):print(i_batch, sample_batched['image'].size(),sample_batched['landmarks'].size())# observe 4th batch and stop.if i_batch == 3:plt.figure()show_landmarks_batch(sample_batched)plt.axis('off')plt.ioff()plt.show()break

输出结果:

torchvision

最后介绍 torchvision 这个库,它提供了一些常见的数据集和预处理方法,采用这个库就可以不需要自定义类,它比较常用的方法是 ImageFolder ,它假定图片的保存路径如下所示:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

这里的 antsbees 等等都是类别标签,此外对 PIL.Image 的预处理方法,如 RandomHorizontalFlipScale 都包含在 torchvision 中,一个使用例子如下所示:

import torch
from torchvision import transforms, datasetsdata_transform = transforms.Compose([transforms.RandomSizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4, shuffle=True,num_workers=4)

小结

本教程主要介绍如何对自己的数据集自定义一个类来加载,以及预处理的方法,同时最后也介绍了 PyTorch 中的 torchvisiontorch.utils.data.DataLoader 方法。

本文的代码上传至 Github:

https://github.com/ccc013/DeepLearning_Notes/blob/master/Pytorch/pytorch_dataloader_tutorial.ipynb

另外,还有用 dlib 生成人脸关键点的代码:

https://github.com/ccc013/DeepLearning_Notes/blob/master/Pytorch/create_landmark_dataset.py

pytorch dataset自定义_PyTorch 系列 | 数据加载和预处理教程相关推荐

  1. PyTorch 系列 | 数据加载和预处理教程

    图片来源:Unsplash,作者:Damiano Baschiera 2019 年第 66 篇文章,总第 90 篇文章 本文大约 8000 字,建议收藏阅读 原题 | DATA LOADING AND ...

  2. pytorch dataset自定义_PyTorch | 数据加载及预处理教程

    原题 | DATA LOADING AND PROCESSING TUTORIAL 作者 | Sasank Chilamkurthy 译者 | kbsc13("算法猿的成长"公众号 ...

  3. PyTorch 编写自定义数据集,数据加载器和转换

    本文为 pytorch 官方教程https://pytorch.org/tutorials/beginner/data_loading_tutorial.html代码的注释 w3cschool 的翻译 ...

  4. PyTorch基础(四)-----数据加载和预处理

    前言 之前已经简单讲述了PyTorch的Tensor.Autograd.torch.nn和torch.optim包,通过这些我们已经可以简单的搭建一个网络模型,但这是不够的,我们还需要大量的数据,众所 ...

  5. 【自然语言处理入门系列】加载和预处理数据-以Cornell Movie-Dialogs Corpus数据集为例

    [自然语言处理入门系列]加载和预处理数据-以Cornell Movie-Dialogs Corpus数据集为例 Author: Yirong Chen from South China Univers ...

  6. detectron2使用自定义数据集及数据加载

    1.使用自定义数据集 数据集中列出了detectron2中内置支持的数据集.如果要使用自定义数据集,同时还重复使用detectron2的数据加载器,则需要: 1)注册您的数据集(即,告诉detectr ...

  7. Pytorch基础(三)数据集加载及预处理

    目录 下载数据集及显示样本 数据集类 建立数据集类及显示部分样本 数据变换 后记 python提供了许多工具简化数据加载,使代码更具可读性.经常用到的包有scikit-image.pandas等,本文 ...

  8. PyTorch—torch.utils.data.DataLoader 数据加载类

    文章目录 DataLoader(object)类: _DataLoaderIter(object)类 __next__函数 pin_memory_batch() _get_batch函数 _proce ...

  9. PyTorch 1.0 中文官方教程:数据加载和处理教程

    译者:yportne13 作者:Sasank Chilamkurthy 在解决机器学习问题的时候,人们花了大量精力准备数据.pytorch提供了许多工具来让载入数据更简单并尽量让你的代码的可读性更高. ...

最新文章

  1. 查看SQL SERVER数据库的连接数
  2. Bootstrap 导入js文件,浏览器找不到文件问题
  3. 存储基础:磁盘 IO 为什么总叫你对齐?
  4. webrtc之onicecandidate的 event handler的一点疑惑
  5. Math.random()取随机数一直为0
  6. lintcode:1-10题
  7. 基于DEV控件库的webservice打印.repx模板
  8. JavaScript网页特效5则
  9. 小米蓝牙键盘怎么连接_不是吐槽,是推荐!买了个小米旗下的蓝牙双模键盘。。。...
  10. C++三种创建对象的方法区别
  11. android go官方下载,GoFIT下载
  12. VMware下载,安装及创建虚拟机
  13. 原笔迹手写实现平滑和笔锋效果之:笔锋效果(三)[完结篇]
  14. 触屏笔和电容笔哪个好?非常值得入手的电容笔推荐
  15. python实现商品管理系统_商品管理系统(示例代码)
  16. Python自学难吗?Python课程主要学些什么内容?
  17. 光 颜色 波长 眼睛
  18. 运维监控系列(15)-Alertmanager添加163邮箱、钉钉、微信告警通知功能
  19. Redis笔记基础篇:6分钟看完Redis的八种数据类型
  20. swf to html5 movie maker,SWF to Video Converter Pro(Flash转换视频格式)

热门文章

  1. 简单密码(Caesar密码)--C++实现
  2. 虚幻4 游戏引擎(二):蓝图教学
  3. 基于echarts 24种数据可视化展示,填充数据就可用,动手能力强的还可以DIY(演示地址+下载地址)
  4. iOS虚拟支付被封,6个技巧帮你快速解决烦恼
  5. java开发微信公众号:微信公众号对接
  6. python 登陆网站图片验证,用python登录带弱图片验证码的网站
  7. python国际象棋ai程序_Python开发AI应用-国际象棋应用
  8. 北邮智能车仿真培训(五)—— 数据可视化工具的使用
  9. Ubuntu下PX4飞控开发环境搭建
  10. 智能风控中台设计与落地