pytorch入门强化教程——数据加载和处理
PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性。
1.下载安装包
- scikit-image:用于图像的IO和变换
- pandas:用于更容易地进行csv解析
2.下载数据集
从此处下载数据集, 数据存于“data / faces /”的目录中。这个数据集实际上是imagenet数据集标注为face的图片当中在 dlib 面部检测 (dlib’s pose estimation) 表现良好的图片。我们要处理的是一个面部姿态的数据集。也就是按如下方式标注的人脸:
2.1 数据集注释
数据集是按如下规则打包成的csv文件:
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
代码:
from __future__ import print_function, division
import os
import torch
import pandas as pd #用于更容易地进行csv解析
from skimage import io, transform #用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils# 忽略警告
import warnings
warnings.filterwarnings("ignore")plt.ion() # interactive mode# 读取数据集
# 将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量。读取数据代码如下:
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')'''
n = 65 #对应csv文件中的第67行
#img_name = landmarks_frame.iloc[0, 0] 结合face_landmarks.csv看一下第0行从哪开始
#第0列在数据集中对应着图片名称
img_name = landmarks_frame.iloc[n, 0]
#[n, 1:]中的1:代表着第一列及之后的列(即各标记点的横纵坐标,每张图都有68个标记点)
landmarks = landmarks_frame.iloc[n, 1:].values
#重新塑形为(68,2)的数组,正好存放各标记点
landmarks = landmarks.astype('float').reshape(-1, 2)print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
#前四行
print('First 4 Landmarks: {}'.format(landmarks[:4]))
'''
#写一个简单的函数来展示一张图片和它对应的标注点作为例子。
def show_landmarks(image, landmarks):"""显示带有地标的图片"""plt.imshow(image)plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')plt.pause(0.001) # pause a bit so that plots are updated
'''
plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),landmarks)
plt.show()
'''#建立数据集类,继承Dataset
'''
数据集类
torch.utils.data.Dataset是表示数据集的抽象类,因此自定义数据集应继承Dataset并覆盖以下方法 * __len__ 实现 len(dataset) 返还数据集的尺寸。
* __getitem__用来获取一些索引数据,例如 dataset[i] 中的(i)。建立数据集类
为面部数据集创建一个数据集类。我们将在 __init__中读取csv的文件内容,在 __getitem__中读取图片。这么做是为了节省内存空间。
只有在需要用到图片的时候才读取它而不是一开始就把图片全部存进内存里。
我们的数据样本将按这样一个字典{'image': image, 'landmarks': landmarks}组织。
我们的数据集类将添加一个可选参数transform 以方便对样本进行预处理。下一节我们会看到什么时候需要用到transform参数。
__init__方法如下图所示:
'''
class FaceLandmarksDataset(Dataset):"""面部标记数据集."""def __init__(self, csv_file, root_dir, transform=None):"""csv_file(string):带注释的csv文件的路径。root_dir(string):包含所有图像的目录。transform(callable, optional):一个样本上的可用的可选变换"""self.landmarks_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])image = io.imread(img_name)landmarks = self.landmarks_frame.iloc[idx, 1:]landmarks = np.array([landmarks])landmarks = landmarks.astype('float').reshape(-1, 2)sample = {'image': image, 'landmarks': landmarks}#网上搜一下transform的使用,传入参数主要是函数名func和按行还是按列标记axis,即用哪个函数对数据进行操作#所以默认值设为None,方便判断是否指定了对数据操作的函数if self.transform:print(self.transform)sample = self.transform(sample)return sample#数据可视化
#实例化这个类并遍历数据样本。我们将会打印出前四个例子的尺寸并展示标注的特征点。
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')
'''
fig = plt.figure()for i in range(len(face_dataset)):sample = face_dataset[i]print('图像序号:',i, '图像尺寸:',sample['image'].shape, '标记点的尺寸',sample['landmarks'].shape)ax = plt.subplot(1, 4, i + 1)plt.tight_layout()ax.set_title('Sample #{}'.format(i))ax.axis('off')show_landmarks(**sample)if i == 3:plt.show()break
'''#数据变换
'''
通过上面的例子我们会发现图片并不是同样的尺寸。绝大多数神经网络都假定图片的尺寸相同。因此我们需要做一些预处理。
让我们创建三个转换: * Rescale:缩放图片 * RandomCrop:对图片进行随机裁剪。这是一种数据增强操作
* ToTensor:把numpy格式图片转为torch格式图片 (我们需要交换坐标轴).
我们会把它们写成可调用的类的形式而不是简单的函数,这样就不需要每次调用时传递一遍参数。
我们只需要实现__call__方法,必要的时候实现 __init__方法。我们可以这样调用这些转换:
tsfm = Transform(params)
transformed_sample = tsfm(sample)
观察下面这些转换是如何应用在图像和标签上的。
'''
class Rescale(object):"""将样本中的图像重新缩放到给定大小。.Args:output_size(tuple或int):所需的输出大小。 如果是元组,则输出为与output_size匹配。 如果是int,则匹配较小的图像边缘到output_size保持纵横比相同。"""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]# 如果是int也就是是一个整型数字if isinstance(self.output_size, int):if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# h and w are swapped for landmarks because for images,# x and y axes are axis 1 and 0 respectivelylandmarks = landmarks * [new_w / w, new_h / h]return {'image': img, 'landmarks': landmarks}class RandomCrop(object):"""随机裁剪样本中的图像.Args:output_size(tuple或int):所需的输出大小。 如果是int,方形裁剪是。"""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)#元组else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h,left: left + new_w]landmarks = landmarks - [left, top]return {'image': image, 'landmarks': landmarks}class ToTensor(object):"""将样本中的ndarrays转换为Tensors."""def __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']# 交换颜色轴因为# numpy包的图片是: H * W * C# torch包的图片是: C * H * W#transpose(z,x,y)-> (x,y,z),012分别指代从左到右的位置索引,numpy中图片012分别对应HWC,#调用transpose((2, 0, 1)之后变成CHWimage = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}#组合转换
'''
接下来我们把这些转换应用到一个例子上。
我们想要把图像的短边调整为256,然后随机裁剪(randomcrop)为224大小的正方形。
也就是说,我们打算组合一个Rescale和 RandomCrop的变换。
我们可以调用一个简单的类 torchvision.transforms.Compose来实现这一操作。具体实现如下:
'''
'''
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),RandomCrop(224)])# 在样本上应用上述的每个变换。
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):transformed_sample = tsfrm(sample)ax = plt.subplot(1, 3, i + 1)plt.tight_layout()ax.set_title(type(tsfrm).__name__)show_landmarks(**transformed_sample)plt.show()
'''#迭代数据集
'''
让我们把这些整合起来以创建一个带组合转换的数据集。
总结一下,每次这个数据集被采样时: * 及时地从文件中读取图片 * 对读取的图片应用转换 * 由于其中一步操作是随机的 (randomcrop) , 数据被增强了
我们可以像之前那样使用for i in range循环来对所有创建的数据集执行同样的操作。
'''transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/',transform=transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()]))
'''
for i in range(len(transformed_dataset)):sample = transformed_dataset[i]print(i, sample['image'].size(), sample['landmarks'].size())if i == 3:break
''''''
但是,对所有数据集简单的使用for循环牺牲了许多功能,尤其是: * 批量处理数据 * 打乱数据 * 使用多线程multiprocessingworker 并行加载数据。
torch.utils.data.DataLoader是一个提供上述所有这些功能的迭代器。下面使用的参数必须是清楚的。
一个值得关注的参数是collate_fn, 可以通过它来决定如何对数据进行批处理。但是绝大多数情况下默认值就能运行良好。
'''
dataloader = DataLoader(transformed_dataset, batch_size=4,shuffle=True, num_workers=0)# 辅助功能:显示批次
def show_landmarks_batch(sample_batched):"""Show image with landmarks for a batch of samples."""images_batch, landmarks_batch = \sample_batched['image'], sample_batched['landmarks']batch_size = len(images_batch)im_size = images_batch.size(2)grid_border_size = 2grid = utils.make_grid(images_batch)plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,landmarks_batch[i, :, 1].numpy() + grid_border_size,s=10, marker='.', c='r')plt.title('Batch from dataloader')for i_batch, sample_batched in enumerate(dataloader):print(i_batch, sample_batched['image'].size(),sample_batched['landmarks'].size())# 观察第4批次并停止。if i_batch == 3:plt.figure()show_landmarks_batch(sample_batched)plt.axis('off')plt.ioff()plt.show()break
torchvision:
在这篇教程中我们学习了如何构造和使用数据集类(datasets),转换(transforms)和数据加载器(dataloader)。torchvision
包提供了 常用的数据集类(datasets)和转换(transforms)。你可能不需要自己构造这些类。torchvision
中还有一个更常用的数据集类ImageFolder
。 它假定了数据集是以如下方式构造的:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
其中'ants’,bees’等是分类标签。在PIL.Image
中你也可以使用类似的转换(transforms)例如RandomHorizontalFlip
,Scale
。利 用这些你可以按如下的方式创建一个数据加载器(dataloader) :
import torch
from torchvision import transforms, datasetsdata_transform = transforms.Compose([transforms.RandomSizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4, shuffle=True,num_workers=4)
通过这篇的学习也让我对数据集有了一个浅浅的认知吧算是。
pytorch入门强化教程——数据加载和处理相关推荐
- pytorch dataset自定义_PyTorch | 数据加载及预处理教程
原题 | DATA LOADING AND PROCESSING TUTORIAL 作者 | Sasank Chilamkurthy 译者 | kbsc13("算法猿的成长"公众号 ...
- 从numpy里加载_PyTorch强化:01.PyTorch 数据加载和处理
PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解析 from __futur ...
- pytorch入门(二):数据加载和处理
pytorch入门(二):数据加载和处理 小引 数据加载 引包 数据集 编写辅助函数 显示图像及其特征点 定义数据集类 数据处理 组合变换 遍历数据集 其他注意事项 本章对应pytorch官方文档链接 ...
- Pytorch官方教程练习之数据加载和处理
PyTorch提供了许多工具来简化数据加载,使代码更具可读性. 数据及加载和处理步骤如下 1.熟悉数据集. 拿到数据集后首先了解数据集的信息和结构. 本次的数据集为:imagenet数据集标注为fac ...
- PyTorch强化:01.PyTorch 数据加载和处理
PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解析 from __futur ...
- 项目管理工具dhtmlxGantt甘特图入门教程(八):数据加载(三)
这篇文章给大家讲解如何利用 dhtmlxGantt正确保存和显示任务的结束日期,本节将给你一个明确答案. 点击获DhtmlxGantt官方正式版 首先,让我们考虑一下在处理任务日期时可能会遇到的两种情 ...
- PyTorch基础(四)-----数据加载和预处理
前言 之前已经简单讲述了PyTorch的Tensor.Autograd.torch.nn和torch.optim包,通过这些我们已经可以简单的搭建一个网络模型,但这是不够的,我们还需要大量的数据,众所 ...
- PyTorch数据加载器
We'll be covering the PyTorch DataLoader in this tutorial. Large datasets are indispensable in the w ...
- pytorch 数据加载和处理
# PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. from __future__ import print_function, division import os impor ...
最新文章
- mongodb常用语句以及SpringBoot中使用mongodb
- R统计笔记(四):中括号与双中括号的差异
- 多个Main函数的应用程序
- boost::hana::less_equal用法的测试程序
- 俄罗斯将封杀LinkedIn 推动个人数据本地化
- Git笔记(25) 选择修订版本
- 2017.3.13 反素数ant 失败总结
- 【转】QDockWidget 停靠窗口和工具栏
- td中bug处理过程_特斯拉的致命BUG,埃安LX的L3能解开吗?
- 浅谈Java的Nio以及报Connection refused: no further information异常原因?
- 拓端tecdat|R语言生态学模拟对广义线性混合模型GLMM进行功率(功效、效能、效力)分析power analysis环境监测数据
- P-6002-10PK,P-6002-2PK脂质研究工具解析
- windows ghost备份
- python第三方库 invalid requirement_Python - 生成 requirement.txt 文件
- python:文件写入出现ASII编码
- 如何禁用C-State功能?关闭intel CPU的C-State省电模式方法
- 四大网络抓包神器,总有一款适合你......
- 苹果手机进水屏幕乱跳怎么办
- phpwindexp.php,phpwind 5.0.1 Sql注射漏洞利用程序脚本安全 -电脑资料
- ESP32 关于HTTPS的使用
热门文章
- 人名中间的小圆点的实现方式
- API调用,淘宝天猫、1688、京东、拼多多商品页面APP端原数据获取
- could not create folder “sftp://xxx.xxx.xxx.xxx/.../venv“. (Permission denied)
- bzoj 4372 烁烁的游戏 - 点分治 - 线段树
- 姿态角速度和机体角速度,横摆角速度(Yaw Rate)估算
- 线上线下一体化趋势下,零售品牌如何利用线上营销为营收赋能?
- 谷歌浏览器突然打不开
- My $650,100 Lunch with Warren Buffett
- S905L(P211)盒子刷android tv以及刷emuelec 4.4/4.5的向导/方法
- 苹果用Android发文,安卓和苹果怎么传文件 安卓和苹果传文件详细教程