简介

结合官方tutorials和源码以及部分博客写出此文。

pytorch的数据加载和处理相对容易的多,常见的两种形式的导入:

  1. 一种是整个数据集都在一个文件夹下,内部再另附一个label文件,说明每个文件夹的状态,如这个数据库。这种存放数据的方式可能更适合在非分类问题上得到应用。
  2. 一种则是更适合使用在分类问题上,即把不同种类的数据分为不同的文件夹存放起来。其形式如下:

    root/ants/xxx.png
    root/ants/xxy.jpeg
    root/ants/xxz.png
    .
    .
    .
    root/bees/123.jpg
    root/bees/nsdf3.png
    root/bees/asd932_.png

本文首先结合官方turorials介绍第一种方法,以了解其数据加载的原理;然后以代码形式简单介绍第二种方法。其中第二种方法和第一种方法的原理相同,其差别在于第二种方法运用了trochvision中提供的已写好的工具ImageFolder,因此实现起来更为简单。

第一种

Dataset class

torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:

  1. __len__: 覆写这个方法使得len(dataset)可以返回整个数据集的大小
  2. __getitem__: 覆写这个方法使得dataset[i]可以返回数据集中第i个样本
  3. 不覆写这两个方法会直接返回错误,其源码如下:
    def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedError

这里我随便从网上下载了20张图像,10张小猫,10张小狗。为了省事儿(只是想验证下继承Dataset类是否好用),我没有给数据集增加标签文件,而是直接把1-10号定义为小猫,11-20号定义为小狗,这样会给__len____getitem__减小麻烦,其目录结构如下:

建立的自定义类如下:

from torch.utils.data import DataLoader, Dataset
from skimage import io, transform
import matplotlib.pyplot as plt
import os
import torch
from torchvision import transforms
import numpy as np class AnimalData(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.transform = transformdef __len__(self):return 20def __getitem__(self, idx):filenames = os.listdir(self.root_dir)filename = filenames[idx]img = io.imread(os.path.join(self.root_dir, filename))# print filename[:-5]if (int(filename[:-5]) > 10):lable = np.array([0])else:lable = np.array([1])sample = {'image': img, 'lable':lable}if self.transform:sample = self.transform(sample)return sample

Transforms & Compose transforms

可以注意到上一节中AnimalData类中__init__中有个transform参数,这也是这一节中要讲清楚的问题。
从网上随便下载的图片必然大小不一,而cnn的结构却要求输入图像要有固定的大小;numpy中的图像通道定义为H, W, C,而pytorch中的通道定义为C, H, W; pytorch中输入数据需要将numpy array改为tensor类型;输入数据往往需要归一化,等等。
基于以上考虑,我们可以自定义一些Callable的类,然后作为trasform参数传递给上一节定义的dataset类。为了更加方便,torchvision.transforms.Compose提供了Compose类,可以一次性将我们自定义的callable类传递给dataset类,直接得到转换后的数据。
这里我直接copy了教程上的三个类:Rescale, RandomCrop, ToTensor,稍作改动,适应我的数据库。

class Rescale(object):"""Rescale the image in a sample to a given size.Args:output_size (tuple or int): Desired output size. If tuple, output ismatched to output_size. If int, smaller of image edges is matchedto output_size keeping aspect ratio the same."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size = output_sizedef __call__(self, sample):image, lable = sample['image'], sample['lable']h, w = image.shape[:2]if isinstance(self.output_size, int):if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# h and w are swapped for lable because for images,# x and y axes are axis 1 and 0 respectively# lable = lable * [new_w / w, new_h / h]return {'image': img, 'lable': lable}class RandomCrop(object):"""Crop randomly the image in a sample.Args:output_size (tuple or int): Desired output size. If int, square cropis made."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, lable = sample['image'], sample['lable']h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h,left: left + new_w]# lable = lable - [left, top]return {'image': image, 'lable': lable}class ToTensor(object):"""Convert ndarrays in sample to Tensors."""def __call__(self, sample):image, lable = sample['image'], sample['lable']# print lable# swap color axis because# numpy image: H x W x C# torch image: C X H X Wimage = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'lable': torch.from_numpy(lable)}

定义好callable类之后,通过torchvision.transforms.Compose将上述三个类结合在一起,传递给AnimalData类中的transform参数即可。

trsm = transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()])
data = AnimalData('./all', transform=trsm)

Iterating through the dataset

上一节中得到data实例之后可以通过for循环来一个一个读取数据,现在这是效率低下的。torch.utils.data.DadaLoader类解决了上述问题。其主要有如下特点:

  • Batching the data
  • Shuffling the data
  • Load the data in parallel using multiprocessing workers.

实现起来也很简单:

dataloader = DataLoader(data, batch_size=4, shuffle=True, num_workers=4)
for i_batch, bach_data in enumerate(dataloader):print i_batch print bach_data['image'].size()print bach_data['lable']

第二种

torchvision

pytorch几乎将上述所有工作都封装起来供我们使用,其中一个工具就是torchvision.datasets.ImageFolder,用于加载用户自定义的数据,要求我们的数据要有如下结构:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

torchvision.transforms中也封装了各种各样的数据处理的工具,如Resize, ToTensor等等功能供我们使用。
修改我下载的数据库结构如下:

加载数据代码如下:

from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt train_data = datasets.ImageFolder('./data1', transform=transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor()
]))train_loader = torch.utils.data.DataLoader(train_data,batch_size=4,shuffle=True,)print len(train_loader)
for i_batch, img in enumerate(train_loader):if i_batch == 0:print(img[1])fig = plt.figure()grid = utils.make_grid(img[0])plt.imshow(grid.numpy().transpose((1, 2, 0)))plt.show()break

结果图:

附录

最后欣赏一段torchvision源码

# vision/torchvision/datasets/folder.pyimport torch.utils.data as datafrom PIL import Image
import os
import os.pathIMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']def is_image_file(filename):"""Checks if a file is an image.Args:filename (string): path to a fileReturns:bool: True if the filename ends with a known image extension"""filename_lower = filename.lower()return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)def find_classes(dir):classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]classes.sort()class_to_idx = {classes[i]: i for i in range(len(classes))}return classes, class_to_idxdef make_dataset(dir, class_to_idx):images = []dir = os.path.expanduser(dir)for target in sorted(os.listdir(dir)):d = os.path.join(dir, target)if not os.path.isdir(d):continuefor root, _, fnames in sorted(os.walk(d)):for fname in sorted(fnames):if is_image_file(fname):path = os.path.join(root, fname)item = (path, class_to_idx[target])images.append(item)return imagesdef pil_loader(path):# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)with open(path, 'rb') as f:img = Image.open(f)return img.convert('RGB')def accimage_loader(path):import accimagetry:return accimage.Image(path)except IOError:# Potentially a decoding problem, fall back to PIL.Imagereturn pil_loader(path)def default_loader(path):from torchvision import get_image_backendif get_image_backend() == 'accimage':return accimage_loader(path)else:return pil_loader(path)class ImageFolder(data.Dataset):"""A generic data loader where the images are arranged in this way: ::root/dog/xxx.pngroot/dog/xxy.pngroot/dog/xxz.pngroot/cat/123.pngroot/cat/nsdf3.pngroot/cat/asd932_.pngArgs:root (string): Root directory path.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.loader (callable, optional): A function to load an image given its path.Attributes:classes (list): List of the class names.class_to_idx (dict): Dict with items (class_name, class_index).imgs (list): List of (image path, class_index) tuples"""def __init__(self, root, transform=None, target_transform=None,loader=default_loader):classes, class_to_idx = find_classes(root)imgs = make_dataset(root, class_to_idx)if len(imgs) == 0:raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n""Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))self.root = rootself.imgs = imgsself.classes = classesself.class_to_idx = class_to_idxself.transform = transformself.target_transform = target_transformself.loader = loaderdef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is class_index of the target class."""path, target = self.imgs[index]img = self.loader(path)if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self):return len(self.imgs)def __repr__(self):fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())fmt_str += '    Root Location: {}\n'.format(self.root)tmp = '    Transforms (if any): 'fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))tmp = '    Target Transforms (if any): 'fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))return fmt_str

参考

[1]. Data Loading and Processing Tutorial
[2]. github: pytorch/torch/utils/data/dataset.py
[3]. github: vision/torchvision/datasets/folder.py
[4]. csdn

学习pytorch: 数据加载和处理相关推荐

  1. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 • scikit-image:用于图像的IO和变换 • pandas:用于更容易地进行 ...

  2. PyTorch数据加载器

    We'll be covering the PyTorch DataLoader in this tutorial. Large datasets are indispensable in the w ...

  3. 从numpy里加载_PyTorch强化:01.PyTorch 数据加载和处理

    PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解析 from __futur ...

  4. PyTorch强化:01.PyTorch 数据加载和处理

    PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解析 from __futur ...

  5. pytorch 数据加载和处理

    # PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. from __future__ import print_function, division import os impor ...

  6. pytorch数据加载、模型保存及加载

    主要涉及的Pytorch官方示例下图红框部分的一些翻译及备注. 1.数据加载及处理   该部分主要是用于进行数据集加载及数据预处理说明,使用的数据集为:人脸+标注坐标.demo程序需要pandas(读 ...

  7. Datawhale7月组队学习task1数据加载及探索性数据分析

    Datawhale7月task1数据加载及探索性数据分析 ​ 有幸了解到了Datawhale这样一个开源组织,欣然报名了2021年7月的组队学习的动手学数据分析系列课程 ​ 本系列目标:完成kaggl ...

  8. pytorch数据加载时报错OSError: [Errno 22] Invalid argument

    将数据加载器的num_workers删去,即由 train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, ...

  9. pytorch 数据加载为python提供了cpu全速跑跑其他任务的能力

    from torch.utils.data import Dataset from torch.utils.data import DataLoader def task(text):print(te ...

  10. pytorch 数据加载性能对比

    传统方式需要10s,dat方式需要0.6s import os import time import torch import random from common.coco_dataset impo ...

最新文章

  1. CoordinatorLayout 和 AppBarLayout 实现折叠效果 底部带阴影的处理方法
  2. linux那个11命令是什么意思,11个有用的Linux命令
  3. Python tcp编程,TCP服务端和TCP客户端代码编写
  4. 借呗利息为什么比银行信用贷款高很多?
  5. [vue] vue要做权限管理该怎么做?如果控制到按钮级别的权限怎么做?
  6. iOS开发隐藏tabBar的问题
  7. 批处理脚本手动双击可以执行,但计划任务中执行失败
  8. Git 基础(六)—— 工作区和暂存区与 Git 对修改的管理
  9. 大理大学日常作业计算机基础知识,大理学院成人高等教育大学计算机基础课程作业.doc...
  10. 学计算机和电脑办公的区别,自学编程和计算机科班出身的差别在哪?
  11. RabbitMQ消费端自定义监听(九)
  12. Flexbox弹性布局,更优雅的布局
  13. ORACLE lag,lead
  14. Axure常见的几种原型图
  15. Ubuntu下载软件包网速过慢的解决方法
  16. 算数平均数、中位数、众数和几何平均数
  17. linux无线网卡驱动编写,博通无线网卡驱动linux版
  18. 最好用的9个php开发工具推荐
  19. 无效的m3u8怎么办_关于大视频video播放的问题以及解决方案(m3u8的播放)
  20. Oracle数据库:oracle内连接inner join on,多表查询各种自链接、内连接、外连接的练习示例

热门文章

  1. Eclipse maven工程 Missing artifact com.sun:tools:jar:1.7.0:system 解决方法
  2. 深入理解计算机系统(1.3)------操作系统的抽象概念
  3. 【转】飞鸽端口号被占用时的解决方法
  4. vue项目text-overflow:ellipsis;在生产环境上不显示...的问题
  5. [GRE] GRE协议介绍
  6. 所有锁的unlock要放到try{}finally{}里,不然发生异常返回就丢了unlock了
  7. c#DataGridView响应键盘事件
  8. 程序员求职之道(《程序员面试笔试宝典》)之民间的企业排名的可信度到底有多大?...
  9. AS 3.0 socket 通信,比较基础比较全【转载】
  10. application跟消息相关的东东