学习pytorch: 数据加载和处理
简介
结合官方tutorials和源码以及部分博客写出此文。
pytorch
的数据加载和处理相对容易的多,常见的两种形式的导入:
- 一种是整个数据集都在一个文件夹下,内部再另附一个label文件,说明每个文件夹的状态,如这个数据库。这种存放数据的方式可能更适合在非分类问题上得到应用。
- 一种则是更适合使用在分类问题上,即把不同种类的数据分为不同的文件夹存放起来。其形式如下:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
本文首先结合官方turorials介绍第一种方法,以了解其数据加载的原理;然后以代码形式简单介绍第二种方法。其中第二种方法和第一种方法的原理相同,其差别在于第二种方法运用了trochvision
中提供的已写好的工具ImageFolder
,因此实现起来更为简单。
第一种
Dataset class
torch.utils.data.Dataset
是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:
__len__
: 覆写这个方法使得len(dataset)
可以返回整个数据集的大小__getitem__
: 覆写这个方法使得dataset[i]
可以返回数据集中第i
个样本- 不覆写这两个方法会直接返回错误,其源码如下:
def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedError
这里我随便从网上下载了20张图像,10张小猫,10张小狗。为了省事儿(只是想验证下继承Dataset
类是否好用),我没有给数据集增加标签文件,而是直接把1-10号定义为小猫,11-20号定义为小狗,这样会给__len__
和__getitem__
减小麻烦,其目录结构如下:
建立的自定义类如下:
from torch.utils.data import DataLoader, Dataset
from skimage import io, transform
import matplotlib.pyplot as plt
import os
import torch
from torchvision import transforms
import numpy as np class AnimalData(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.transform = transformdef __len__(self):return 20def __getitem__(self, idx):filenames = os.listdir(self.root_dir)filename = filenames[idx]img = io.imread(os.path.join(self.root_dir, filename))# print filename[:-5]if (int(filename[:-5]) > 10):lable = np.array([0])else:lable = np.array([1])sample = {'image': img, 'lable':lable}if self.transform:sample = self.transform(sample)return sample
Transforms & Compose transforms
可以注意到上一节中AnimalData
类中__init__
中有个transform
参数,这也是这一节中要讲清楚的问题。
从网上随便下载的图片必然大小不一,而cnn
的结构却要求输入图像要有固定的大小;numpy
中的图像通道定义为H, W, C
,而pytorch
中的通道定义为C, H, W
; pytorch
中输入数据需要将numpy array
改为tensor
类型;输入数据往往需要归一化,等等。
基于以上考虑,我们可以自定义一些Callable
的类,然后作为trasform
参数传递给上一节定义的dataset
类。为了更加方便,torchvision.transforms.Compose
提供了Compose类,可以一次性将我们自定义的callable
类传递给dataset
类,直接得到转换后的数据。
这里我直接copy
了教程上的三个类:Rescale
, RandomCrop
, ToTensor
,稍作改动,适应我的数据库。
class Rescale(object):"""Rescale the image in a sample to a given size.Args:output_size (tuple or int): Desired output size. If tuple, output ismatched to output_size. If int, smaller of image edges is matchedto output_size keeping aspect ratio the same."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size = output_sizedef __call__(self, sample):image, lable = sample['image'], sample['lable']h, w = image.shape[:2]if isinstance(self.output_size, int):if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# h and w are swapped for lable because for images,# x and y axes are axis 1 and 0 respectively# lable = lable * [new_w / w, new_h / h]return {'image': img, 'lable': lable}class RandomCrop(object):"""Crop randomly the image in a sample.Args:output_size (tuple or int): Desired output size. If int, square cropis made."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, lable = sample['image'], sample['lable']h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h,left: left + new_w]# lable = lable - [left, top]return {'image': image, 'lable': lable}class ToTensor(object):"""Convert ndarrays in sample to Tensors."""def __call__(self, sample):image, lable = sample['image'], sample['lable']# print lable# swap color axis because# numpy image: H x W x C# torch image: C X H X Wimage = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'lable': torch.from_numpy(lable)}
定义好callable
类之后,通过torchvision.transforms.Compose
将上述三个类结合在一起,传递给AnimalData
类中的transform
参数即可。
trsm = transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()])
data = AnimalData('./all', transform=trsm)
Iterating through the dataset
上一节中得到data
实例之后可以通过for
循环来一个一个读取数据,现在这是效率低下的。torch.utils.data.DadaLoader
类解决了上述问题。其主要有如下特点:
- Batching the data
- Shuffling the data
- Load the data in parallel using
multiprocessing
workers.
实现起来也很简单:
dataloader = DataLoader(data, batch_size=4, shuffle=True, num_workers=4)
for i_batch, bach_data in enumerate(dataloader):print i_batch print bach_data['image'].size()print bach_data['lable']
第二种
torchvision
pytorch
几乎将上述所有工作都封装起来供我们使用,其中一个工具就是torchvision.datasets.ImageFolder
,用于加载用户自定义的数据,要求我们的数据要有如下结构:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
torchvision.transforms
中也封装了各种各样的数据处理的工具,如Resize
, ToTensor
等等功能供我们使用。
修改我下载的数据库结构如下:
加载数据代码如下:
from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt train_data = datasets.ImageFolder('./data1', transform=transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor()
]))train_loader = torch.utils.data.DataLoader(train_data,batch_size=4,shuffle=True,)print len(train_loader)
for i_batch, img in enumerate(train_loader):if i_batch == 0:print(img[1])fig = plt.figure()grid = utils.make_grid(img[0])plt.imshow(grid.numpy().transpose((1, 2, 0)))plt.show()break
结果图:
附录
最后欣赏一段torchvision
源码
# vision/torchvision/datasets/folder.pyimport torch.utils.data as datafrom PIL import Image
import os
import os.pathIMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']def is_image_file(filename):"""Checks if a file is an image.Args:filename (string): path to a fileReturns:bool: True if the filename ends with a known image extension"""filename_lower = filename.lower()return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)def find_classes(dir):classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]classes.sort()class_to_idx = {classes[i]: i for i in range(len(classes))}return classes, class_to_idxdef make_dataset(dir, class_to_idx):images = []dir = os.path.expanduser(dir)for target in sorted(os.listdir(dir)):d = os.path.join(dir, target)if not os.path.isdir(d):continuefor root, _, fnames in sorted(os.walk(d)):for fname in sorted(fnames):if is_image_file(fname):path = os.path.join(root, fname)item = (path, class_to_idx[target])images.append(item)return imagesdef pil_loader(path):# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)with open(path, 'rb') as f:img = Image.open(f)return img.convert('RGB')def accimage_loader(path):import accimagetry:return accimage.Image(path)except IOError:# Potentially a decoding problem, fall back to PIL.Imagereturn pil_loader(path)def default_loader(path):from torchvision import get_image_backendif get_image_backend() == 'accimage':return accimage_loader(path)else:return pil_loader(path)class ImageFolder(data.Dataset):"""A generic data loader where the images are arranged in this way: ::root/dog/xxx.pngroot/dog/xxy.pngroot/dog/xxz.pngroot/cat/123.pngroot/cat/nsdf3.pngroot/cat/asd932_.pngArgs:root (string): Root directory path.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.loader (callable, optional): A function to load an image given its path.Attributes:classes (list): List of the class names.class_to_idx (dict): Dict with items (class_name, class_index).imgs (list): List of (image path, class_index) tuples"""def __init__(self, root, transform=None, target_transform=None,loader=default_loader):classes, class_to_idx = find_classes(root)imgs = make_dataset(root, class_to_idx)if len(imgs) == 0:raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n""Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))self.root = rootself.imgs = imgsself.classes = classesself.class_to_idx = class_to_idxself.transform = transformself.target_transform = target_transformself.loader = loaderdef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is class_index of the target class."""path, target = self.imgs[index]img = self.loader(path)if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self):return len(self.imgs)def __repr__(self):fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())fmt_str += ' Root Location: {}\n'.format(self.root)tmp = ' Transforms (if any): 'fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))tmp = ' Target Transforms (if any): 'fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))return fmt_str
参考
[1]. Data Loading and Processing Tutorial
[2]. github: pytorch/torch/utils/data/dataset.py
[3]. github: vision/torchvision/datasets/folder.py
[4]. csdn
学习pytorch: 数据加载和处理相关推荐
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 • scikit-image:用于图像的IO和变换 • pandas:用于更容易地进行 ...
- PyTorch数据加载器
We'll be covering the PyTorch DataLoader in this tutorial. Large datasets are indispensable in the w ...
- 从numpy里加载_PyTorch强化:01.PyTorch 数据加载和处理
PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解析 from __futur ...
- PyTorch强化:01.PyTorch 数据加载和处理
PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解析 from __futur ...
- pytorch 数据加载和处理
# PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. from __future__ import print_function, division import os impor ...
- pytorch数据加载、模型保存及加载
主要涉及的Pytorch官方示例下图红框部分的一些翻译及备注. 1.数据加载及处理 该部分主要是用于进行数据集加载及数据预处理说明,使用的数据集为:人脸+标注坐标.demo程序需要pandas(读 ...
- Datawhale7月组队学习task1数据加载及探索性数据分析
Datawhale7月task1数据加载及探索性数据分析 有幸了解到了Datawhale这样一个开源组织,欣然报名了2021年7月的组队学习的动手学数据分析系列课程 本系列目标:完成kaggl ...
- pytorch数据加载时报错OSError: [Errno 22] Invalid argument
将数据加载器的num_workers删去,即由 train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, ...
- pytorch 数据加载为python提供了cpu全速跑跑其他任务的能力
from torch.utils.data import Dataset from torch.utils.data import DataLoader def task(text):print(te ...
- pytorch 数据加载性能对比
传统方式需要10s,dat方式需要0.6s import os import time import torch import random from common.coco_dataset impo ...
最新文章
- CoordinatorLayout 和 AppBarLayout 实现折叠效果 底部带阴影的处理方法
- linux那个11命令是什么意思,11个有用的Linux命令
- Python tcp编程,TCP服务端和TCP客户端代码编写
- 借呗利息为什么比银行信用贷款高很多?
- [vue] vue要做权限管理该怎么做?如果控制到按钮级别的权限怎么做?
- iOS开发隐藏tabBar的问题
- 批处理脚本手动双击可以执行,但计划任务中执行失败
- Git 基础(六)—— 工作区和暂存区与 Git 对修改的管理
- 大理大学日常作业计算机基础知识,大理学院成人高等教育大学计算机基础课程作业.doc...
- 学计算机和电脑办公的区别,自学编程和计算机科班出身的差别在哪?
- RabbitMQ消费端自定义监听(九)
- Flexbox弹性布局,更优雅的布局
- ORACLE lag,lead
- Axure常见的几种原型图
- Ubuntu下载软件包网速过慢的解决方法
- 算数平均数、中位数、众数和几何平均数
- linux无线网卡驱动编写,博通无线网卡驱动linux版
- 最好用的9个php开发工具推荐
- 无效的m3u8怎么办_关于大视频video播放的问题以及解决方案(m3u8的播放)
- Oracle数据库:oracle内连接inner join on,多表查询各种自链接、内连接、外连接的练习示例
热门文章
- Eclipse maven工程 Missing artifact com.sun:tools:jar:1.7.0:system 解决方法
- 深入理解计算机系统(1.3)------操作系统的抽象概念
- 【转】飞鸽端口号被占用时的解决方法
- vue项目text-overflow:ellipsis;在生产环境上不显示...的问题
- [GRE] GRE协议介绍
- 所有锁的unlock要放到try{}finally{}里,不然发生异常返回就丢了unlock了
- c#DataGridView响应键盘事件
- 程序员求职之道(《程序员面试笔试宝典》)之民间的企业排名的可信度到底有多大?...
- AS 3.0 socket 通信,比较基础比较全【转载】
- application跟消息相关的东东