处理数据样本的代码可能会变得混乱且难以维护; 理想情况下,我们想要数据集代码与模型训练代码解耦,以获得更好的可读性和模块化。PyTorch 域库提供了许多预加载的数据(例如 FashionMNIST)。这些数据集可以子类化torch.utils.data.Dataset并实现特定于特定数据的功能。 它们可用于对模型进行原型设计和基准测试。

ETL是用来描述将数据从来源端经过抽取、转换、加载至目的端的过程。在机器学习中处理数据集的流程为:

  1. 提取:从数据源提取数据。

  2. 转换:将我们的数据转换为张量形式。

  3. 加载:将我们的数据放入对象以使其易于访问。

一、加载数据集

PyTorch 提供了两个数据原语: 分别是
        torch.utils.data.Dataset和torch.utils.data.DataLoader
可以在预加载的数据集或者自己的数据集上使用。其中
Dataset表示存储样本及其对应的标签,用于表示数据集的抽象类。
DataLoader包裹一个可迭代的迭代器, 这使得 Dataset便于访问样品。包装数据集并提供对基础数据的访问。

说明
torch 顶级PyTorch软件包和张量库。
torch.nn 一个子包,其中包含用于构建神经网络的模块和可扩展类。
torch.optim 一个子包,其中包含SGD和Adam之类的标准优化操作。
torch.nn.functional 一个功能接口,其中包含用于构建神经网络的典型操作,例如损失函数和卷积。
torchvision 一个软件包,提供对流行的数据集,模型体系结构和计算机视觉图像转换的访问。
torchvision.transforms 一个接口,其中包含用于图像处理的常见转换。

首先导入训练模型必需的PyTorch库:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

导入预加载的数据集:
以下代码是演示如何从 TorchVision 加载 Fashion-MNIST 数据集的示例。 Fashion-MNIST 是 Zalando 文章图像的数据集,由 60,000 个训练示例和 10,000 个测试示例组成。 每个示例都包含 28×28 灰度图像和来自 10 个类别之一的相关标签。

from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

下载路径root=“data” 可以改成windows的一下路径,比如D://pytorch//data,就会把FashionMNIST 数据集下载到这个路径下。

其中代码中有以下参数:(使用torchvision获取FashionMNIST数据集的实例)

root 存储训练/测试数据的路径
train 如果数据集是训练集,则train=True
download=True 如果数据不可用,则从 Internet 下载数据

transformtarget_transform

指定特征和标签转换

由于希望将图像转换为张量,因此使用了内置的transforms.ToTensor()转换,若该数据集用于训练,则将其命名为training_data,若该数据集用于测试,则将其命名为test_set。当第一次运行此代码时,Fashion-MNIST数据集将在本地下载。后续将在下载数据之前检查数据。从ETL的角度来看,在创建数据集时已经完成了提取,并使用了Torchvision进行了转换:

二、迭代和可视化数据集

可以像一个列表一样手动索引 Datasetstraining_data[index]可以用 matplotlib可视化训练数据中的一些样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(10, 10))#设置整个画布的大小
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()#随机获得一个训练集中的样本索引值,item()将张量转换为标量img, label = training_data[sample_idx]#得到当前索引值下对应的图像数据和标签figure.add_subplot(rows, cols, i)#在当前画布下创建一个3*3的视图,指定格子(索引号)中创建一个Axesplt.title(labels_map[label])#根据标签获取字典labels_map中对应的名字plt.axis("off")#关闭坐标轴 plt.imshow(img.squeeze(), cmap="gray")#cimg.squeeze()将图像中为1的维度删掉,map="gray"显示灰度图像,
plt.show()#将plt.imshow()处理后的函数显示出来。

其中figure语法及操作:【Python】 【绘图】plt.figure()的使用_欧阳小俊的博客-CSDN博客_plt.figure

pytorch中的randint()方法:

torch.rand()、torch.randn()、torch.randint()、torch.randperm()用法_-CSDN博客_torch.random

Figure的add_subplot()方法 :

Matplotlib学习手册A006_Figure的add_subplot()方法_Python草堂的博客-CSDN_add_subplot()

 pytorch中squeeze()和unsqueeze():

pytorch中squeeze()和unsqueeze()_ying______的博客-CSDN博客_img.squeeze()

运行结果如下:

另外,还可以对数据集进行一些其他的操作:

(1)查看训练集中有多少张图片,可以使用Python len()函数检查数据集的长度:

print(len(training_data))
print(len(test_data))
#60000
#10000

(2)假设要查看每个图像的标签。 可以这样完成:第一个图像是9,接下来的两个是零。 这些值编码实际的类名称或标签。

print(training_data.targets)
#tensor([9, 0, 0,  ..., 3, 0, 5])

(3)要查看数据集中每个标签有多少个,可以使用PyTorch  bincount()函数,如下所示:

print(training_data.targets.bincount())
#tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000])

三、创建自定义的数据集

为了使用PyTorch创建自定义数据集,torch.utils.data.Dataset方法可以通过创建扩展Dataset类功能的子类来创建自定义数据集。完成操作后,新子类便可以传递给PyTorch DataLoader对象。Dataset的所有子类都必须覆盖提供数据集大小的__len__和支持从0到len(self)互斥的整数索引的__getitem__

class MyDataset(torch.utils.data.Dataset):#需要继承torch.utils.data.Datasetdef __init__(self):#对继承自父类的属性进行初始化super(MyDataset,self).__init__()#1、初始化一些参数和函数,方便在__getitem__函数中调用。#2、制作__getitem__函数所要用到的图片和对应标签的list。#也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。passdef __getitem__(self, index):#1、根据list从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。#2、预处理数据(例如torchvision.Transform)。#3、返回数据对(例如图像和标签)。#这里需要注意的是,这步所处理的是index所对应的一个样本。passdef __len__(self):#返回数据集大小return len()

假设有一个保存为npy格式的numpy数据集,现在需要将其变为pytorch的数据集,并能够被数据加载器DataLoader所加载,首先自定义 Dataset 类必须实现三个函数:
                                            __init__ 、 __len__ 和 __getitem__
否则报错。然后实例化这个类,得到train_data,最后将train_data放入DataLoader数据加载器,完成。若使用torchvision软件包内置的fashion-MNIST数据集类在后台进行此操作,因此不必在项目中执行操作。之前数据集Fashion-MNIST 的实现是 FashionMNIST 图像数据存储在目录 img_dir中,并且它们的标签分别存储在 CSV 文件 annotations_file中接下来的部分,将分解每个函数中具体过程

1、 __init__(初始化)

__init__ 函数在实例化 Dataset 对象时运行一次。 初始化包含图像、注释文件和两个转换的目录。

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

其中labels.csv 文件内容如下所示:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9

2、__len__(获取图像)

__len__ 函数返回数据集中的样本数。

    def __len__(self):return len(self.img_labels)

3、__getitem__(数据集数量)

__getitem__ 函数从给定索引处的数据集中加载并返回样本idx。基于索引,它识别图像在磁盘上的位置,使用 read_image将图像数据转​​换为张量检索来自 csv 数据的相应标签 self.img_labels,调用它们的转换函数(如果适用),最后返回张量图像和元组中的相应标签。

传入参数index为下标,返回数据集中对应下标的数据组(数据和标签)

执行步骤
获取img名,拼接路径read_image
获取label名
transform
target_transform
返回’image’和’label’的dict

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

4、完整实例(Dataset+DataLoader)

创建自己的数据集需要继承父类torch.utils.data.Dataset,同时需要重载两个私有成员函数:def __len__(self)和def __getitem__(self, index) 。 def __len__(self)应该返回数据集的大小;def __getitem__(self, index)接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。制作这个list通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。具体流程是:

 创建自己的数据集:

Pytorch学习(三)定义自己的数据集及加载训练_cdy艳0917的博客-CSDN博客_pytorch 数据集

(1)收集一组图片作为自己的数据集,然后创建一个txt文件储存图片对应的label。相应的txt文件如下:

(2)创建自己的数据集类,首先继承上面的dataset类。然后在__init__()方法中得到图像的路径,然后将图像路径组成一个数组,这样在__getitim__()中就可以直接读取。再经过处理,就可以将自己的数据集输入到神经网络里了。最后查看一下所获得的data_loader:

import torch
import torchvision
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader#路径是自己电脑里所对应的路径
datapath = r'D:\BIANCHENG\deepstudy\data\traindata'
txtpath = r'D:\BIANCHENG\deepstudy\data\lable.txt'class MyDataset(Dataset):def __init__(self,txtpath):#创建一个list用来储存图片和标签信息imgs = []#打开第一步创建的txt文件,按行读取,将结果以元组方式保存在imgs里datainfo = open(txtpath,'r')for line in datainfo:line = line.strip('\n')#同时去掉左右两边的空格words = line.split()#以空格为分割进行切片imgs.append((words[0],words[1]))#将图片的名称和对应标签存入列表imgs[]中self.imgs = imgs#返回数据集大小def __len__(self):return len(self.imgs)#打开index对应图片进行预处理后return回处理后的图片和标签def __getitem__(self, index):#按照索引读取每个元素的具体内容pic,label = self.imgs[index]#根据索引得到对应图片的图像名称和标签pic = Image.open(datapath+'\\'+pic)#打开当前数据的存储路径,读取当前索引下图片名称对应的图像数据,赋值给picpic = transforms.ToTensor()(pic)#对原始图像数据进行张量变换return pic,label#return回处理后的图片和标签
#实例化对象
data = MyDataset(txtpath)
#将数据集导入DataLoader,进行shuffle以及选取batch_size
data_loader = DataLoader(data,batch_size=1,shuffle=True,num_workers=0)
#Windows里num_works只能为0,其他值会报错
for pics,label in data_loader:print(pics,label)

查看加载之后的数据(部分):

tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.1569],[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.1569],[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.1569],...,[0.8510, 1.0000, 1.0000,  ..., 0.9412, 0.9412, 0.1569],[0.8510, 1.0000, 1.0000,  ..., 0.9412, 0.9412, 0.1569],[0.8510, 1.0000, 1.0000,  ..., 0.9412, 0.9412, 0.1569]],

结果显示,data_loader这个迭代器里存储的是每1个一组(batch_size)分批次读取的图片像素信息以及对应的标签信息,也就是后续要导入到神经网络里的数据。如果batch_size设置为 2 可能会报错,原因是因为输入dataloader的图片大小不一致。需要在dataset的__getitem__方法中加一行resize。

   pic=pic.resize((224, 224))#设置每一张图片的大小

 python strip()函数 去空格\n\r\t函数的用法:

Python strip() 函数 去空格 \n \r \t 函数的用法_王图思睿的博客-CSDN博客_strip('\n')

5、完整实例(ImageFolder+DataLoader)

在pytorch中提供了torchvision.datasets.ImageFolder训练自己的图像。ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造如下:

ImageFolder的函数说明如下:

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

它的主要参数如下:

root 在root指定的路径下寻找图片
transform 对loader读取图片的返回对象进行转换操作(ToTensor等)
target_transform 对label的转换
loader 给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

实现代码如下:

transform = transforms.ToTensor()
root = r'E:\Python\DeepLearning\Datasets\mymnist\train'
# 使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
train_data = torchvision.datasets.ImageFolder(root, transform=transform)
train_iter = torch.utils.data.DataLoader(train_data, batch_size=256, shuffle=True, num_workers=0)test_data = torchvision.datasets.ImageFolder(root, transform=transform)
test_iter = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=True, num_workers=0)

6、编写自定义数据集,数据加载器和转换

PyTorch 入门学习(六)————编写自定义数据集,数据加载器和转换_夏天的欢的博客-CSDN博客_编写数据集

四、使用 DataLoaders 准备数据

Dataset检索一次数据集的特征并标记一个样本。 在训练模型时,通常希望 以“小批量”的形式传递样本,在每个 epoch 重新洗牌以减少模型过拟合,并使用 Python 的 multiprocessing加快数据检索。 DataLoader是一个可迭代的,它在一个简单的 API 中抽象了这种复杂性。接下来为训练集创建一个DataLoader包装器, 由数据加载器包装(加载到其中)的train_set使我们可以访问基础数据。代码如下:

from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

该数据加载器的批量大小为64,故一次处理一批64张图像和64个相应的标签的数据。

其中数据加载器加载出来的数据,已经由之前的三个函数把numpy数据类型转化为了tensor类型。

training_set作为参数传递。 现在利用加载程序来完成该任务:

  1. batch_size(可以分批次读取,在示例中为64)
  2. shuffleshuffle=True,数据被打乱,对数据进行洗牌,打乱数据集内数据分布的顺序)
  3. num_workers(可以并行加载数据(加快载入数据的效率,默认为0,表示将使用主进程)

五、遍历 DataLoader

(1)将该数据集加载到 DataLoader并且可以根据需要遍历数据集。 下面的每次迭代都会返回一批 train_featurestrain_labels(包含 batch_size=64特征和标签)。 因为指定 shuffle=True,在遍历所有批次之后,数据被打乱(为了更细粒度的控制 数据加载顺序)。

import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()#h删掉第一行数据中的一维维度
label = train_labels[0]#获取标签数据中的信息
plt.imshow(img, cmap="gray")#以灰度图像显示
plt.show()
print(f"Label: {label}")#打印处对应的标签值
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 2

要访问训练集中的单个元素,首先将train_set对象传递给Python的iter()内置函数,该函数会返回一个代表数据流的对象。对于数据流,可以使用Python内置的next()函数来获取数据流中的下一个数据元素。

训练集中检索的每个样本都包含作为张量的图像数据和相应的作为张量的标签。故在图像上调用squeeze()函数,删除尺寸为1的维度。

因为shuffle = True,所以每次调用next时批次将不同。如果 shuffle = False,则在第一次调用next时将返回训练集中的第一个样本。

(2)要绘制一批图像,可以使用torchvision.utils.make_grid()函数创建一个可以如下绘制的网格:

PyTorch - 15 - PyTorch数据集和数据加载器 - 深度学习和AI的训练集探索_许喜远-CSDN博客

import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import torchvision
import numpy as np
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)batch = next(iter(train_dataloader))#获取数据流中的数据元素。
images, labels = batch
print('types:', type(images), type(labels))
print('shapes:', images.shape, labels.shape)
#images[0].shape
grid = torchvision.utils.make_grid(images, nrow=10)#组成图像的网络,其实就是将多张图片组合成一张图片。
plt.figure(figsize=(15,15))#画布大小
plt.imshow(grid.permute(1,2,0))#permute()可以对某个张量的任意维度进行调换。把grid的第一个维度放到最后面。
print('labels:', labels)
plt.show()

由于batch_size = 64,所以处理的是一批64张图像和64个相应的标签。这就是为什么变量名称上使用复数形式的原因。类型是我们期望的张量。但是,形状与在单个样品中看到的形状不同。没有一个标量值作为标签,而是拥有一个带有64个值的rank-1张量。张量中包含图像数据的每个维度的大小由以下每个值定义:

(批量大小,颜色通道数,图像高度,图像宽度)

代码输出为:

types: <class 'torch.Tensor'> <class 'torch.Tensor'>
shapes: torch.Size([64, 1, 28, 28]) torch.Size([64])
labels: tensor([0, 2, 6, 2, 9, 8, 3, 0, 5, 2, 4, 2, 4, 4, 6, 6, 8, 3, 3, 7, 5, 4, 0, 7,5, 0, 0, 0, 4, 0, 7, 4, 8, 1, 8, 5, 2, 1, 9, 2, 5, 7, 7, 7, 4, 9, 9, 5,2, 1, 7, 0, 7, 9, 4, 5, 9, 8, 8, 6, 4, 7, 4, 3])

PyTorch基础-自定义数据集和数据加载器(2)相关推荐

  1. Pytorch基础(三)数据集加载及预处理

    目录 下载数据集及显示样本 数据集类 建立数据集类及显示部分样本 数据变换 后记 python提供了许多工具简化数据加载,使代码更具可读性.经常用到的包有scikit-image.pandas等,本文 ...

  2. Android自定义简单的图片加载器(ImageLoader)

    废话不多述,首先来说明下 为什么要用图片加载器 呢,就是为了避免图片重复从网络加载.也就是在第一次从网络加载之后就把图片缓存在本地,下次用的时候直接从本地查找,有的话就直接用,没有再从网络加载. 加载 ...

  3. 爬虫训练场基础铺垫,BT加载器,分页,列表组,卡片,下拉菜单一文掌握

    爬虫训练场基础铺垫目录 Bootstrap 5 加载器 Bootstrap 5 分页组件 Bootstrap 5 列表组 Bootstrap 5 卡片 Bootstrap 5 下拉列表 本篇博客为大家 ...

  4. Android Loader(加载器)详解

    Loader(加载器)简介 Android 3.0 中引入了加载器,支持轻松在 Activity 或Fragment中异步加载数据. 加载器具有以下特征: (1)可用于每个 Activity 和 Fr ...

  5. Android 之Loader(加载器)

    介绍 Android 3.0 中引入了加载器,支持轻松在 Activity 或片段中异步加载数据. 加载器具有以下特征: 可用于每个 Activity 和 Fragment. 支持异步加载数据. 监控 ...

  6. 花里胡哨免杀《剪切板加载器》

    前言 最近研究内存加载器魔怔了,我们知道所有内存加载器原理都一个样:申请可执行内存->shellcode写入内存->执行该内存 申请内存还是比较好说的,去win开发手册搜搜就能找到很多申请 ...

  7. pytorch自定义数据集和数据加载器

    假设有一个保存为npy格式的numpy数据集,现在需要将其变为pytorch的数据集,并能够被数据加载器DataLoader所加载 首先自定义一个数据集类,继承torch.utils.data.Dat ...

  8. pytorch dataset自定义_PyTorch 系列 | 数据加载和预处理教程

    原题 | DATA LOADING AND PROCESSING TUTORIAL 作者 | Sasank Chilamkurthy 原文 | https://pytorch.org/tutorial ...

  9. PyTorch 编写自定义数据集,数据加载器和转换

    本文为 pytorch 官方教程https://pytorch.org/tutorials/beginner/data_loading_tutorial.html代码的注释 w3cschool 的翻译 ...

最新文章

  1. 干货丨 简述迁移学习在深度学习中的应用
  2. 对象间的联动——观察者模式
  3. 搭建hexo博客并部署到github上
  4. leetcode:剑指offer----数组中重复的数字
  5. FPGA不可综合语句
  6. SQL语句取得最大件数(MSSQL ORACLE Postgre,top rownum,limit)
  7. AM335X 3款核心板比较
  8. 【离散数学】集合论 第四章 函数与集合(6) 三歧性定理、两集合基数判等定理(基数的比较)、Cantor定理
  9. 2、那智机器人时序基板的TBEX1、TBEX2连接
  10. 前端POST请求下载文件
  11. 光学成像基础-荧光滤色片
  12. unity打开设置虚拟键的界面
  13. 【计算机毕业设计】018母婴商城系统
  14. 微信小程序踩坑记——ColorUI组件的使用
  15. 移动流量转赠给好友_移动的号怎么赠送流量给好友?
  16. Python:快速去除PDF水印
  17. 招商银行信用卡中心大数据
  18. java map扩容机制_Java HashMap的原理、扩容机制、以及性能思考
  19. oracle假如存在才删除该字段,Oracle删除表、字段之前判断表、字段是否存在
  20. Java基础语法(汉罗塔)

热门文章

  1. MT【193】三面角的正余弦定理
  2. 爬取简单静态网站——汽车之家二手车
  3. 复杂网络基础——《链接》
  4. 爬有道在线翻译(已完善)
  5. DMIPS, TOPS, FLOPS, FLOPs, GMACs, FMA
  6. 计算机毕业设计Java印染公司信息管理系统(系统+程序+mysql数据库+Lw文档)
  7. vs2008编译QT开源项目--太阳神三国杀源码分析(二) 客户端添加武将
  8. 计算机病毒是一种能破坏计算机运行的,计算机病毒是一种能破坏计算机运行的()。...
  9. JavaScript学习笔记|数据类型——Object类型、for in循环
  10. 配置完hadoop后调用HDFS的API进行统计英语单词数量