【深度学习】基于PyTorch深度学习框架的序列图像数据装载器
作者 | Harsh Maheshwari
编译 | VK
来源 | Towards Data Science
如今,深度学习和机器学习算法正在统治世界。PyTorch是最常用的深度学习框架之一,用于实现各种深度学习算法。另一方面,基于学习的方法本质上需要一些带注释的训练数据集,这些数据集可以被模型用来提取输入数据和标签之间的关系。为了给神经网络提供数据,我们定义了一个数据加载器。
在这个博客中,我们将看到如何在PyTorch框架中为不同的数据集编写一个数据加载器。
图像数据集的数据加载器
我们将致力于狗与猫的图像分类问题。我们需要对给定的图像进行分类,数据集可以从这里下载:https://www.kaggle.com/c/dogs-vs-cats。训练数据集总共包含25000个图像。因为这是一个分类问题,所以dog的标签是“0”,cat的标签是“1”。
让我们从导入所有必需的库开始。
import os
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch.nn as nn
PyTorch框架的dataset类被定义为一个类,其基本结构如下
class data(Dataset):def __init__(self, param1, param2):# 函数在此处初始化def __len__(self):# 函数返回数据的长度def __getitem__(self, index):# 一次提供一个项目
这个类的最终目的是使用函数
__getitem__
每次提供一个数据点。这是通过使用内部传递给函数的索引完成的,使用Dataloader中定义的sampler函数(将在接下来的博客中讨论)。初始化数据集的对象时,会调用函数
__init__
。在这里,你可以传递多个参数,这些参数对于编写__getitem__
非常有用。函数用于返回数据集的总长度。在此基础上,将生成索引,然后将其提供给
getitem
。
dog vs cat数据集的格式如下-:
data/- dog_1.jpg- dog_2.jpg.........- cat_1.jpg- cat_2.jpg.........
现在我们已经了解了编写数据加载器所需的组件,让我们深入研究一下我们的用例。
class data(Dataset): def __init__(self, path, transform):self.files = os.listdir(path)self.transform = transformself.path = path def __len__(self):return len(self.files) def __getitem__(self, index):filename = self.files[index]input = Image.open(os.path.join(self.path, filename))label = 0 if filename.find("dog")>=0 else 1img_as_tensor = self.transform(input)return img_as_tensor, labeltransformations = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])path = "./data"
train_dataset = data(path, transformations)
dataloader = DataLoader(train_dataset, batch_size=Train_Batch_Size, shuffle=True)
首先让我们了解函数
__init__
。类数据用两个参数path和transform初始化,这两个参数作为参数传递给__init__
。当我们声明这个类的一个对象时,它会在内部调用__init__
。由于使用了len来返回整个数据集的长度,所以我使用len(self.files)来返回相同的长度。
函数getitem是最关键的,它加载图像,然后调整其大小,然后将其转换为张量。这里需要注意的一点是,提供给神经网络的数据应该总是标准化的。我们使用transforms.ToTensor处理规范化。最后,
getitem
返回两个结果,image作为张量,label作为对应的数据点。
在初始化类数据之后,我们使用DataLoader函数自动将整个数据批处理成一个定义的批大小。因此,如果你的原始数据点大小是(3,224,224)(你从__getitem__
获得),那么dataloader的每个项都将具有大小(batch_size,3,224,224),即它会自动对数据点的batch_size数进行采样。
这在我们的例子中是可能的,因为图像的大小是恒定的,所以DataLoader函数能够自动创建批处理。然而,在自然语言处理这样的情况下,当大小不是常数时,我们需要编写自己的批处理函数。
序列数据集的数据加载器
现在让我们来处理序列数据集,即句子、时间序列、音频等。这里的__getitem__
将不再提供相同大小的数据点。例如,考虑情绪分类的任务(在这里解释),那么一句话可以是“The flight service was very good”,另一句话可以是“I did not get my baggage on the belt, pathetic service.”在这里,两句话的长度是不同的。
为了解决这个问题,让我们先回答三个问题。
什么是batch?-批处理是指将多个数据点的张量合并成一个张量
为什么我们需要分批处理?批处理可以用于加快计算速度,因为批处理可以同时处理多个数据点,而不是一次只处理一个数据点。
如何进行batch化?因为我们在这里合并多个张量,所以张量的每个维度的大小都需要相同。由于输出的数据点大小不一,我们手中就有一个问题。
我们现在主要要解决batch化问题。
为了便于我们在这里讨论,我们将使用IMDB数据集,它是一个评论数据集。因为我们在这里处理的是句子,所以处理数据集的方法会有所不同。
因为神经网络只懂数字,不懂单词,所以我们必须把每个单词转换成一个数字。为了做到这一点,我们必须构建一个词汇表,如下代码所述。
import os
import gensim
from collections import Counter
import jsontrain_path = "./aclImdb/train"
test_path = "./aclImdb/test"# simple函数从目录读取数据并返回数据和标签
# 你可以为其他数据集制作自己的读取器。
def reader(path):pos_path = os.path.join(path, "pos")neg_path = os.path.join(path, "neg")data = []label = []for file in os.listdir(pos_path):f = open(os.path.join(pos_path, file))data.append(f.read())label.append(1)for file in os.listdir(neg_path):f = open(os.path.join(neg_path, file))data.append(f.read())label.append(0)# print(data[:1])return data, labeldef build_vocab(data, min_word_count = 5):counter = Counter()for line in data:l = gensim.utils.simple_preprocess(line)counter.update(l)# 初始化一个字典或查找表word2id = {}word2id['<pad>'] = 0word2id['<unk>'] = 1# 只包括那些在字典中出现超过min次的单词。words = [word for word, count in counter.items() if count>min_word_count]for i, word in enumerate(words):word2id[word] = i+2with open("word2id.json", 'w') as f:json.dump(word2id, f)return word2iddata, label = reader(train_path)
word2id = build_vocab(data)
print("Dictionary Formed and saved. The length of dictionary is-: ", len(word2id))
函数读取器用于读取整个数据,它返回所有句子的列表,标签“0”表示消极评论,“1”表示积极评论。
函数build_vocab将数据和最小字数作为输入,并将每个字的映射(称为“word2id”)作为输出,映射到一个唯一的数字。对于每个向前的未知单词,对应的数字将是1。
继续为序列数据集编写数据集类。我们的目标是在给定索引的情况下,一次输出一个item。
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import gensimclass Dataset_seq(Dataset):def __init__(self, word2id, train_path):self.word2id = word2idself.train_path = train_path# 读取数据和标签self.data, self.label = reader(train_path)def __getitem__(self, index):# 返回seq和标签seq = self.preprocess(self.data[index])label = self.label[index]return seq, labeldef __len__(self):return(len(self.data))def preprocess(self, text):# 用于将line转换为token,然后使用word2id将其转换为相应的数字值line = gensim.utils.simple_preprocess(text)seq = []for word in line:if word in self.word2id:seq.append(self.word2id[word])else:seq.append(self.word2id['<unk>'])# 将list转换成张量seq = torch.from_numpy(np.array(seq))return seq
由于上面已经讨论了不同函数的功能,我将简要地回顾一下。
函数
__init__
采用word2id映射和train路径。然后,init调用reader获取与句子对应的数据和标签。函数
__len__
返回整个数据集的长度,即self.data。函数preprocess将输入句子转换成数字张量,其中每个数字对应于句子中的单词。
函数
getitem
用于在索引的帮助下输出一个经过处理的数据点。
下面的代码定义了collate_fn。
train_dataset = Dataset_seq(word2id, train_path)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,collate_fn=collate_fn)
def collate_fn(data):''' 我们应该构建一个自定义的collate_fn,而不是使用默认的collate_fn,因为每个句子的大小不同,并且默认不支持合并序列。Args:data: 元组列表 (training sequence, label)Return:padded_seq - 填充序列,形状 (batch_size, padded_length)length - 每个序列的原始长度(没有填充), 形状(batch_size)label - 张量形状 (batch_size)'''data.sort(key=lambda x: len(x[0]), reverse=True)sequences, label = zip(*data)length = [len(seq) for seq in sequences]padded_seq = torch.zeros(len(sequences), max(length)).long()for i, seq in enumerate(sequences):end = length[i]padded_seq[i,:end] = seqreturn padded_seq, torch.from_numpy(np.array(length)), torch.from_numpy(np.array(label))
这里需要注意的一点是,在一个元组列表中,每个元组可以有不同的大小,但在张量中,所有维度的大小都必须相同才能合并它们。
collate_fn自动获得一个名为data的输入,这是一个长度等于batch size的元组列表。每个元组包含数字张量及其相应的标签。
为了简单起见,我们将它们分别称为sequence和label。所以最终我们必须以这样一种方式转换每个序列,使它们的大小保持不变。
为了实现这一点,我们执行零填充,如上面的代码所示。由于对整个数据集统一使用零填充,因此模型了解到它没有多大用处,它只是表示浪费值。
我们肯定已经找到了解决办法,但问题是,这是一个最佳的解决办法吗?如果所有序列的原始大小都有很大的差异,或者换言之有很大的差异,那么我们最终会浪费大量的GPU内存,而这些内存是零填充的,这最终是没有用的。必须有一个更好的方法来最小化零填充的要求!
这个问题的解决请关注后续文章!
往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》课件合集
本站qq群851320808,加入微信群请扫码:
【深度学习】基于PyTorch深度学习框架的序列图像数据装载器相关推荐
- 基于PyTorch深度学习无人机遥感影像目标检测、地物分类及语义分割
随着无人机自动化能力的逐步升级,它被广泛的应用于多种领域,如航拍.农业.植保.灾难评估.救援.测绘.电力巡检等.但同时由于无人机飞行高度低.获取目标类型多.以及环境复杂等因素使得对无人机获取的数据处理 ...
- 深度学习必备书籍——《Python深度学习 基于Pytorch》
作为一名机器学习|深度学习的博主,想和大家分享几本深度学习的书籍,让大家更快的入手深度学习,成为AI达人!今天给大家介绍的是:<Python深度学习 基于Pytorch> 文章目录 一.背 ...
- Pytext简介:facebook的基于PyTorch的NLP框架
自然语言处理(NLP)在现代深度学习生态中越来越常见.从流行的深度学习框架到云端API的支持,例如Google云.Azure.AWS或Bluemix,NLP是深度学习平台不可或缺的部分.尽管已经取得了 ...
- 基于PyTorch重写sklearn,《现代大数据算法》
HyperLearn是一个基于PyTorch重写的机器学习工具包Scikit Learn,它的一些模块速度更快.需要内存更少,效率提高了一倍. 专为大数据而设计,HyperLearn可以使用50%以下 ...
- [深度学习工具]基于PyTorch的NLP框架Flair
一个非常简单的框架,用于最先进的NLP.由Zalando Research开发. Flair简介: 一个功能强大的NLP库.Flair允许您将最先进的自然语言处理(NLP)模型应用于您的文本,例如命名 ...
- 基于PyTorch深度学习遥感影像地物分类与目标检测、分割及遥感影像问题深度学习优化
我国高分辨率对地观测系统重大专项已全面启动,高空间.高光谱.高时间分辨率和宽地面覆盖于一体的全球天空地一体化立体对地观测网逐步形成,将成为保障国家安全的基础性和战略性资源.未来10年全球每天获取的观测 ...
- 【深度学习】PyTorch深度学习技术生态
PyTorch Author:louwill Machine Learning Lab 随着近几年的大力发展,PyTorch逐渐成为主流的深度学习框架.相应的PyTorch技术生态也逐渐丰富和完善.本 ...
- 论文学习|基于少镜头学习的毛果杨群体叶片性状分析
关于原文 Few-Shot Learning Enables Population-Scale Analysis of Leaf Traits in Populus trichocarpa 基于少镜头 ...
- 基于PyTorch的GAN框架TorchGAN:用架构级API轻松定制GAN项目
机器之心报道 参与:刘晓坤 TorchGAN 是基于 PyTorch 的 GAN 设计开发框架.该框架旨在为流行的 GAN 提供构造模块,且允许为前沿研究进行定制化. 使用 TorchGAN 的模块化 ...
最新文章
- 速度前瞻运动控制c语言程序_整合实时运动控制及多颗相机连接,大幅提升光学影像检测速度...
- 【下】安全HTTPS-全面详解对称加密,非对称加密,数字签名,数字证书和HTTPS
- asp.net 连接 Access 的几种方法
- ant编译重设property的值
- 用 JavaScript 的方式理解递归
- php新闻列表排序,javascript 新闻列表排序简单封装
- 工信部下架37款侵害用户权益APP 114票务网等在列
- 大数据之-Hadoop3.x_Hadoop_HDFS_总结---大数据之hadoop3.x工作笔记0080
- linux退出python环境_Linux中的python虚拟环境
- php 输出tab_php实现读取和写入tab分割的文件
- C# 中的常用正则表达式汇总
- 现代控制理论——非线性系统的lyapunov
- 使用Linux常见问题及其解决办法
- Dynamics CRM开发学习-插件01
- 循环结构:while和do...while循环语句
- 华为Atlas200dk使用第四步------配置CANNtoolkit环境
- 机器学习_决策树与信息熵
- 机器学习常用的五种预测结果评价
- c++11之特性之std::function(书:深入应用c++11)
- go语言环境安装之插件
热门文章
- Python快速入门(1)
- ajax 延迟显示加载中提示
- 关于C/C++的trigraphs和Digraphs
- B计划 第四周(开学第一周)
- Nginx + PHP CGI的fix_pathinfo安全漏洞
- python 线程中出现执行错乱_多处理会导致Python崩溃,并在调用fork()时在另一个线程中出现错误...
- Java实现插值查找算法 Insert search
- 毕业论文 | 便携式环境烟雾监测器(源码、电路图)
- 人工智能 | SLAM与Visual Odometry技术综述(浙江大学智能系统和控制研究所)
- 跳一跳python源码_使用Python实现跳一跳自动跳跃功能