图片来源:Unsplash,作者:Damiano Baschiera

2019 年第 66 篇文章,总第 90 篇文章

本文大约 8000 字,建议收藏阅读

原题 | DATA LOADING AND PROCESSING TUTORIAL

作者 | Sasank Chilamkurthy

译者 | kbsc13("算法猿的成长"公众号作者)

原文 | https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

声明 | 翻译是出于交流学习的目的,欢迎转载,但请保留本文出于,请勿用作商业或者非法用途

简介

本文教程主要是介绍如何加载、预处理并对数据进行增强的方法。

首先需要确保安装以下几个 python 库:

  • scikit-image :处理图片数据

  • pandas :处理 csv 文件

导入模块代码如下:

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils# Ignore warnings
import warnings
warnings.filterwarnings("ignore")plt.ion()   # interactive mode

本次教程采用的是一个人脸姿势数据集,其图片如下所示:

每张人脸都是有 68 个人脸关键点,它是由 dlib 生成的,具体实现可以查看其官网介绍:

https://blog.dlib.net/2014/08/real-time-face-pose-estimation.html

数据集下载地址:

https://download.pytorch.org/tutorial/faces.zip

数据集中的 csv 文件的格式如下所示,图片名字和每个关键点的坐标 x, y

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

数据集下载解压缩后放到文件夹 data/faces 中,然后我们先快速打开 face_landmarks.csv 文件,查看文件内容,即标注信息,代码如下所示:

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

输出如下所示:


接着写一个辅助函数来显示人脸图片及其关键点,代码如下所示:

def show_landmarks(image, landmarks):"""Show image with landmarks"""plt.imshow(image)plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')plt.pause(0.001)  # pause a bit so that plots are updatedplt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),landmarks)
plt.show()

输出如下所示:


Dataset 类

torch.utils.data.Dataset 是表示一个数据集的抽象类,在自定义自己的数据集的时候需要继承 Dataset 类别,并重写下方这些方法:

  • len :调用 len(dataset) 时可以返回数据集的数量;

  • getitem:获取数据,可以实现索引访问,即 dataset[i] 可以访问第 i 个样本数据

接下来将给我们的人脸关键点数据集自定义一个类别,在 __init__ 方法中将读取数据集的信息,并在 __getitem__ 方法调用获取的数据集,这主要是基于内存的考虑,这种做法不需要将所有数据一次读取存储在内存中,可以在需要读取数据的时候才读取加载到内存里。

数据集的样本将用一个字典表示:{'image': image, 'landmarks': landmarks},另外还有一个可选参数 transform 用于预处理读取的样本数据,下一节将介绍这个 transform 的用处。

自定义函数的代码如下所示:

class FaceLandmarksDataset(Dataset):"""Face Landmarks dataset."""def __init__(self, csv_file, root_dir, transform=None):"""Args:csv_file (string): 带有标注信息的 csv 文件路径root_dir (string): 图片所在文件夹transform (callable, optional): 可选的用于预处理图片的方法"""self.landmarks_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):# 读取图片img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])image = io.imread(img_name)# 读取关键点并转换为 numpy 数组landmarks = self.landmarks_frame.iloc[idx, 1:]landmarks = np.array([landmarks])landmarks = landmarks.astype('float').reshape(-1, 2)sample = {'image': image, 'landmarks': landmarks}if self.transform:sample = self.transform(sample)return sample

接下来是一个简单的例子来使用上述我们自定义的数据集类,例子中将读取前 4 个样本并展示:

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')fig = plt.figure()
# 读取前 4 张图片并展示
for i in range(len(face_dataset)):sample = face_dataset[i]print(i, sample['image'].shape, sample['landmarks'].shape)ax = plt.subplot(1, 4, i + 1)plt.tight_layout()ax.set_title('Sample #{}'.format(i))ax.axis('off')show_landmarks(**sample)if i == 3:plt.show()break

输出结果如下所示:


Transforms

从上述例子输出的结构可以看到一个问题,图片的大小并不一致,但大多数神经网络都需要输入图片的大小固定。因此,接下来是给出一些预处理的代码,主要是下面三种预处理方法:

  • Rescale :调整图片大小

  • RandomCrop:随机裁剪图片,这是一种数据增强的方法

  • ToTensor:将 numpy 格式的图片转换为 pytorch 的数据格式 tensors,这里需要交换坐标。

这几种方法都将写成可调用的类,而不是简单的函数,这样就不需要每次都传递参数。因此,我们需要实现 __call__ 方法,以及有必要的话,__init__ 方法也是要实现的,然后就可以如下所示一样调用这些方法:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

Rescale 方法的实现代码如下:

class Rescale(object):"""将图片调整为给定的大小.Args:output_size (tuple or int): 期望输出的图片大小. 如果是 tuple 类型,输出图片大小就是给定的 output_size;如果是 int 类型,则图片最短边将匹配给的大小,然后调整最大边以保持相同的比例。"""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]# 判断给定大小的形式,tuple 还是 int 类型if isinstance(self.output_size, int):# int 类型,给定大小作为最短边,最大边长根据原来尺寸比例进行调整if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# 根据调整前后的尺寸比例,调整关键点的坐标位置,并且 x 对应 w,y 对应 hlandmarks = landmarks * [new_w / w, new_h / h]return {'image': img, 'landmarks': landmarks}

RandomCrop 的代码实现:

class RandomCrop(object):"""给定图片,随机裁剪其任意一个和给定大小一样大的区域.Args:output_size (tuple or int): 期望裁剪的图片大小。如果是 int,将得到一个正方形大小的图片."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]new_h, new_w = self.output_size# 随机选择裁剪区域的左上角,即起点,(left, top),范围是由原始大小-输出大小top = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h,left: left + new_w]# 调整关键点坐标,平移选择的裁剪起点landmarks = landmarks - [left, top]return {'image': image, 'landmarks': landmarks}

ToTensor 的方法实现:

class ToTensor(object):"""将 ndarrays 转换为 tensors."""def __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']# 调整坐标尺寸,numpy 的维度是 H x W x C,而 torch 的图片维度是 C X H X Wimage = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}

组合使用预处理方法

接下来就是介绍使用上述自定义的预处理方法的例子。

假设我们希望将图片的最短边长调整为 256,然后随机裁剪一个 224*224 大小的图片区域,也就是我们需要组合调用 Rescale 和 RandomCrop 预处理方法。

torchvision.transforms.Compose 是一个可以实现组合调用欲处理方法的类,实现代码如下所示:

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),RandomCrop(224)])# 对图片数据调用上述 3 种形式预处理方法,即单独使用 Rescale,RandomCrop,组合使用 Rescale和 RandomCrop
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):transformed_sample = tsfrm(sample)ax = plt.subplot(1, 3, i + 1)plt.tight_layout()ax.set_title(type(tsfrm).__name__)show_landmarks(**transformed_sample)plt.show()

输出结构:


迭代整个数据集

现在我们已经定义好一个处理数据集的类,3种预处理数据的类,那么可以将它们整合在一起,实现加载并预处理数据的流程,流程如下所示:

  • 首先根据图片路径读取图片

  • 对图片都调用预处理的方法

  • 预处理方法也可以实现数据增强

实现的代码如下所示:

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/',transform=transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()]))for i in range(len(transformed_dataset)):sample = transformed_dataset[i]print(i, sample['image'].size(), sample['landmarks'].size())if i == 3:break

输出结果:


上述只是一个简单的处理过程,实际上处理和加载数据的时候,我们一般还对数据做以下的处理:

  • 将数据按给定大小分成一批一批数据

  • 打乱数据排列顺序

  • 采用 multiprocessing 来并行加载数据

torch.utils.data.DataLoader 是一个可以实现上述操作的迭代器。其需要的参数如下代码所示,其中一个参数 collate_fn 是用于指定如何对数据进行分批的操作,但也可以采用默认函数。

dataloader = DataLoader(transformed_dataset, batch_size=4,shuffle=True, num_workers=4)# 辅助函数,用于展示一个 batch 的数据
def show_landmarks_batch(sample_batched):"""Show image with landmarks for a batch of samples."""images_batch, landmarks_batch = \sample_batched['image'], sample_batched['landmarks']batch_size = len(images_batch)im_size = images_batch.size(2)grid_border_size = 2grid = utils.make_grid(images_batch)plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,landmarks_batch[i, :, 1].numpy() + grid_border_size,s=10, marker='.', c='r')plt.title('Batch from dataloader')for i_batch, sample_batched in enumerate(dataloader):print(i_batch, sample_batched['image'].size(),sample_batched['landmarks'].size())# observe 4th batch and stop.if i_batch == 3:plt.figure()show_landmarks_batch(sample_batched)plt.axis('off')plt.ioff()plt.show()break

输出结果:


torchvision

最后介绍 torchvision 这个库,它提供了一些常见的数据集和预处理方法,采用这个库就可以不需要自定义类,它比较常用的方法是 ImageFolder ,它假定图片的保存路径如下所示:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

这里的 antsbees 等等都是类别标签,此外对 PIL.Image 的预处理方法,如 RandomHorizontalFlip 、Scale 都包含在 torchvision 中,一个使用例子如下所示:

import torch
from torchvision import transforms, datasetsdata_transform = transforms.Compose([transforms.RandomSizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4, shuffle=True,num_workers=4)

小结

本教程主要介绍如何对自己的数据集自定义一个类来加载,以及预处理的方法,同时最后也介绍了 PyTorch 中的 torchvision ,torch.utils.data.DataLoader 方法。

本文的代码上传至 Github:

https://github.com/ccc013/DeepLearning_Notes/blob/master/Pytorch/pytorch_dataloader_tutorial.ipynb

另外,还有用 dlib 生成人脸关键点的代码:

https://github.com/ccc013/DeepLearning_Notes/blob/master/Pytorch/create_landmark_dataset.py

此外,也可以公众号后台回复“PyTorch”获取本次教程的数据集和代码。

留言时间


欢迎关注我的微信公众号--算法猿的成长,或者扫描下方的二维码,大家一起交流,学习和进步!


如果觉得不错,在看、转发就是对小编的一个支持!

PyTorch 系列 | 数据加载和预处理教程相关推荐

  1. pytorch dataset自定义_PyTorch 系列 | 数据加载和预处理教程

    原题 | DATA LOADING AND PROCESSING TUTORIAL 作者 | Sasank Chilamkurthy 原文 | https://pytorch.org/tutorial ...

  2. pytorch dataset自定义_PyTorch | 数据加载及预处理教程

    原题 | DATA LOADING AND PROCESSING TUTORIAL 作者 | Sasank Chilamkurthy 译者 | kbsc13("算法猿的成长"公众号 ...

  3. 【自然语言处理入门系列】加载和预处理数据-以Cornell Movie-Dialogs Corpus数据集为例

    [自然语言处理入门系列]加载和预处理数据-以Cornell Movie-Dialogs Corpus数据集为例 Author: Yirong Chen from South China Univers ...

  4. PyTorch:数据加载,数学原理,猫鱼分类,CNN,预训练,迁移学习

    1,数据加载 PyTorch开发了与数据交互的标准约定,所以能一致地处理数据,而不论处理图像.文本还是音频.与数据交互的两个主要约定是数据集(dataset)和数据加载器(dataloader).数据 ...

  5. PyTorch基础(四)-----数据加载和预处理

    前言 之前已经简单讲述了PyTorch的Tensor.Autograd.torch.nn和torch.optim包,通过这些我们已经可以简单的搭建一个网络模型,但这是不够的,我们还需要大量的数据,众所 ...

  6. PyTorch 1.0 中文官方教程:数据加载和处理教程

    译者:yportne13 作者:Sasank Chilamkurthy 在解决机器学习问题的时候,人们花了大量精力准备数据.pytorch提供了许多工具来让载入数据更简单并尽量让你的代码的可读性更高. ...

  7. matlab 读取csv_利用Pytorch进行数据加载1--CSV文件的读取和显示

    import os # 文件处理模块,用于处理文件和目录 import torch # pytorch的深度学习框架 import pandas as pd #人脸识别库 from skimage i ...

  8. pytorch学习(一)数据加载之前的预处理(UCSD数据集)

    最近在做有关视频异常检测方面的实验,需要用到UCSD数据集,pytorch自定义加载自己的数据集时需要将自己的数据的路径以及标签存放到txt文档中,方便后续的数据加载. 最后我会给出生成好的UCSD数 ...

  9. Pytorch基础(三)数据集加载及预处理

    目录 下载数据集及显示样本 数据集类 建立数据集类及显示部分样本 数据变换 后记 python提供了许多工具简化数据加载,使代码更具可读性.经常用到的包有scikit-image.pandas等,本文 ...

最新文章

  1. 爬取王垠的博客并生成pdf
  2. python图片-Python图片处理
  3. WebSocket | 为什么你前后端推送不会用?因为你少了WebSocket的帮忙
  4. CentOS添加明细路由
  5. 音视频开发基础(二)常用的直播协议
  6. Pytorch 手工复现交叉熵损失(Cross Entropy Loss)
  7. 日常记录,记下来自己的遇到的问题
  8. CentOS6.5安装python环境
  9. linux USR1亦通常被用来告知应用程序重载配置文件
  10. android studio 创建项目失败原因Failed to create
  11. 搭建国外海外多语言一元云购软件夺宝购商城网站
  12. NovelAI-WebUI安装教程
  13. 【译】2021年十大热门编程语言
  14. c语言回溯法解决倒桥本分数式,回溯法 经典题目 八皇后 桥本分数
  15. Uplift Model
  16. 什么是云服务举例说明_什么叫云服务举例说明(云服务器实例是什么)
  17. 微信web中IOS系统手机摇一摇功能实现及问题解决
  18. u盘插电脑上不显示怎么办?数据恢复还有希望吗
  19. kis商贸系列加密服务器,金蝶KIS商贸标准版系统登录
  20. xctf之warmup

热门文章

  1. 无线服务器密码让别人改了,wifi密码被改了怎么办_wifi密码被别人改了怎么办?-192路由网...
  2. mysql用户_MySQL用户权限管理详解
  3. 解决flask端口被占用的问题
  4. 【 Grey Hack 】综合工具 shellOs
  5. Exynos4412 Uboot 编译工具 —— 交叉工具链 arm-linux-gcc 的安装
  6. .net 获取网站根目录的方法
  7. 简单调试 Python 程序
  8. 关闭eslint检验;vue-cli3搭建的vue项目关闭eslint;脚手架3关闭eslint;
  9. uni-app微信小程序image引入图片;background-image背景图引入图片;小程序预览本地图片;小程序图片过大引入报错;获取本地图片的网络地址;
  10. [react] 你最喜欢React的哪一个特性(说一个就好)