@Author:Runsen

对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。

之前使用 torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集

Torchvision 中的数据集

MNIST

MNIST 是一个由标准化和中心裁剪的手写图像组成的数据集。它有超过 60,000 张训练图像和 10,000 张测试图像。这是用于学习和实验目的最常用的数据集之一。要加载和使用数据集,使用以下语法导入:torchvision.datasets.MNIST()

Fashion MNIST

Fashion MNIST数据集类似于MNIST,但该数据集包含T恤、裤子、包包等服装项目,而不是手写数字,训练和测试样本数分别为60,000和10,000。要加载和使用数据集,使用以下语法导入:torchvision.datasets.FashionMNIST()

CIFAR

CIFAR数据集有两个版本,CIFAR10和CIFAR100。CIFAR10 由 10 个不同标签的图像组成,而 CIFAR100 有 100 个不同的类。这些包括常见的图像,如卡车、青蛙、船、汽车、鹿等。

torchvision.datasets.CIFAR10()
torchvision.datasets.CIFAR100()

COCO

COCO数据集包含超过 100,000 个日常对象,如人、瓶子、文具、书籍等。这个图像数据集广泛用于对象检测和图像字幕应用。下面是可以加载 COCO 的位置​​:torchvision.datasets.CocoCaptions()

EMNIST

EMNIST数据集是 MNIST 数据集的高级版本。它由包括数字和字母的图像组成。如果您正在处理基于从图像中识别文本的问题,EMNIST是一个不错的选择。下面是可以加载 EMNIST的位置​​::torchvision.datasets.EMNIST()

IMAGE-NET

ImageNet 是用于训练高端神经网络的旗舰数据集之一。它由分布在 10,000 个类别中的超过 120 万张图像组成。通常,这个数据集加载在高端硬件系统上,因为单独的 CPU 无法处理这么大的数据集。下面是加载 ImageNet 数据集的类:torchvision.datasets.ImageNet()

Torchtext 中的数据集

IMDB

IMDB是一个用于情感分类的数据集,其中包含一组 25,000 条高度极端的电影评论用于训练,另外 25,000 条用于测试。使用以下类加载这些数据torchtext:torchtext.datasets.IMDB()

WikiText2

WikiText2语言建模数据集是一个超过 1 亿个标记的集合。它是从维基百科中提取的,并保留了标点符号和实际的字母大小写。它广泛用于涉及长期依赖的应用程序。可以从torchtext以下位置加载此数据:torchtext.datasets.WikiText2()

除了上述两个流行的数据集,torchtext库中还有更多可用的数据集,例如 SST、TREC、SNLI、MultiNLI、WikiText-2、WikiText103、PennTreebank、Multi30k 等。

深入查看 MNIST 数据集

MNIST 是最受欢迎的数据集之一。现在我们将看到 PyTorch 如何从 pytorch/vision 存储库加载 MNIST 数据集。让我们首先下载数据集并将其加载到名为 的变量中data_train

from torchvision.datasets import MNIST# Download MNIST
data_train = MNIST('~/mnist_data', train=True, download=True)import matplotlib.pyplot as pltrandom_image = data_train[0][0]
random_image_label = data_train[0][1]# Print the Image using Matplotlib
plt.imshow(random_image)
print("The label of the image is:", random_image_label)

DataLoader加载MNIST

下面我们使用DataLoader该类加载数据集,如下所示。

import torch
from torchvision import transformsdata_train = torch.utils.data.DataLoader(MNIST('~/mnist_data', train=True, download=True, transform = transforms.Compose([transforms.ToTensor()])),batch_size=64,shuffle=True)for batch_idx, samples in enumerate(data_train):print(batch_idx, samples)

CUDA加载

我们可以启用 GPU 来更快地训练我们的模型。现在让我们使用CUDA加载数据时可以使用的(GPU 支持 PyTorch)的配置。

device = "cuda" if torch.cuda.is_available() else "cpu"
kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('/files/', train=True, download=True),batch_size=batch_size_train, **kwargs)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('files/', train=False, download=True),batch_size=batch_size, **kwargs)

ImageFolder

ImageFolder是一个通用数据加载器类torchvision,可帮助加载自己的图像数据集。处理一个分类问题并构建一个神经网络来识别给定的图像是apple还是orange。要在 PyTorch 中执行此操作,第一步是在默认文件夹结构中排列图像,如下所示:

root
├── orange
│   ├── orange_image1.png
│   └── orange_image1.png
├── apple
│   └── apple_image1.png
│   └── apple_image2.png
│   └── apple_image3.png

可以使用ImageLoader该类加载所有这些图像。

torchvision.datasets.ImageFolder(root, transform)

transforms

PyTorch 转换定义了简单的图像转换技术,可将整个数据集转换为独特的格式。

如果是一个包含不同分辨率的不同汽车图片的数据集,在训练时,我们训练数据集中的所有图像都应该具有相同的分辨率大小。如果我们手动将所有图像转换为所需的输入大小,则很耗时,因此我们可以使用transforms;使用几行 PyTorch 代码,我们数据集中的所有图像都可以转换为所需的输入大小和分辨率。

现在让我们加载 CIFAR10torchvision.datasets并应用以下转换:

  • 将所有图像调整为 32×32
  • 对图像应用中心裁剪变换
  • 将裁剪后的图像转换为张量
  • 标准化图像
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as nptransform = transforms.Compose([# resize 32×32transforms.Resize(32),# center-crop裁剪变换transforms.CenterCrop(32),# to-tensortransforms.ToTensor(),# normalize 标准化transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=False)

在 PyTorch 中创建自定义数据集

下面将创建一个由数字和文本组成的简单自定义数据集。需要封装Dataset 类中的__getitem__()__len__()方法。

  • __getitem__()方法通过索引返回数据集中的选定样本。
  • __len__()方法返回数据集的总大小。

下面是曾经封装FruitImagesDataset数据集的代码,基本是比较好的 PyTorch 中创建自定义数据集的模板。

import os
import numpy as np
import cv2
import torch
import matplotlib.patches as patches
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from xml.etree import ElementTree as et
from torchvision import transforms as torchtransclass FruitImagesDataset(torch.utils.data.Dataset):def __init__(self, files_dir, width, height, transforms=None):self.transforms = transformsself.files_dir = files_dirself.height = heightself.width = widthself.imgs = [image for image in sorted(os.listdir(files_dir))if image[-4:] == '.jpg']self.classes = ['_','apple', 'banana', 'orange']def __getitem__(self, idx):img_name = self.imgs[idx]image_path = os.path.join(self.files_dir, img_name)# reading the images and converting them to correct size and colorimg = cv2.imread(image_path)img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)# diving by 255img_res /= 255.0# annotation fileannot_filename = img_name[:-4] + '.xml'annot_file_path = os.path.join(self.files_dir, annot_filename)boxes = []labels = []tree = et.parse(annot_file_path)root = tree.getroot()# cv2 image gives size as height x widthwt = img.shape[1]ht = img.shape[0]# box coordinates for xml files are extracted and corrected for image size givenfor member in root.findall('object'):labels.append(self.classes.index(member.find('name').text))# bounding boxxmin = int(member.find('bndbox').find('xmin').text)xmax = int(member.find('bndbox').find('xmax').text)ymin = int(member.find('bndbox').find('ymin').text)ymax = int(member.find('bndbox').find('ymax').text)xmin_corr = (xmin / wt) * self.widthxmax_corr = (xmax / wt) * self.widthymin_corr = (ymin / ht) * self.heightymax_corr = (ymax / ht) * self.heightboxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])# convert boxes into a torch.Tensorboxes = torch.as_tensor(boxes, dtype=torch.float32)# getting the areas of the boxesarea = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])# suppose all instances are not crowdiscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)labels = torch.as_tensor(labels, dtype=torch.int64)target = {}target["boxes"] = boxestarget["labels"] = labelstarget["area"] = areatarget["iscrowd"] = iscrowd# image_idimage_id = torch.tensor([idx])target["image_id"] = image_idif self.transforms:sample = self.transforms(image=img_res,bboxes=target['boxes'],labels=labels)img_res = sample['image']target['boxes'] = torch.Tensor(sample['bboxes'])return img_res, targetdef __len__(self):return len(self.imgs)def get_transform(train):if train:return A.Compose([A.HorizontalFlip(0.5),ToTensorV2(p=1.0)], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})else:return A.Compose([ToTensorV2(p=1.0)], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})files_dir = '../input/fruit-images-for-object-detection/train_zip/train'
test_dir = '../input/fruit-images-for-object-detection/test_zip/test'dataset = FruitImagesDataset(train_dir, 480, 480)

【小白学习PyTorch教程】十七、 PyTorch 中 数据集torchvision和torchtext相关推荐

  1. 【小白学习tensorflow教程】二、TensorBoard可视化模型训练

    @Author:Runsen 本想在Torch和Keras更新TensorBoard,还是决定扔在了tensorflow. TensorBoard是用于可视化图形和其他工具以理解.调试和优化模型的界面 ...

  2. 【小白学习C++ 教程】十七、C++ 中的字符数组和字符串常见的函数

    @Author:Runsen 字符数组 char mychar[6] = {'H', 'e', 'l', 'l', 'o'}; 下面定义的字符串数组在 C/C++ 中的内存表示 #include &l ...

  3. 【小白学习tensorflow教程】四、使用 tfhub中的模型EfficientDet-Lite2 进行对象检测

    @Author:Runsen tfhub是tensorflow官方提供训练好的模型的一个仓库.今天,我使用 tfhub中的模型EfficientDet-Lite2 进行对象检测 选择的模型是Effic ...

  4. 【小白学习C++ 教程】十五、C++ 中的template模板和泛型

    @Author:Runsen template模板在 C++ 中一个简单但非常强大的工具.简单的想法是将数据类型作为参数传递,这样我们就不需要为不同的数据类型编写相同的代码. C++ 添加了两个新关键 ...

  5. 【小白学习Keras教程】四、Keras基于数字数据集建立基础的CNN模型

    @Author:Runsen 文章目录 基本卷积神经网络(CNN) 加载数据集 1.创建模型 2.卷积层 3. 激活层 4. 池化层 5. Dense(全连接层) 6. Model compile & ...

  6. 【小白学习keras教程】二、基于CIFAR-10数据集训练简单的MLP分类模型

    @Author:Runsen 分类任务的MLP 当目标(y)是离散的(分类的) 对于损失函数,使用交叉熵:对于评估指标,通常使用accuracy 数据集描述 CIFAR-10数据集包含10个类中的60 ...

  7. 【小白学习C++ 教程】二十二、C++ 中的STL容器stack、queue和map

    @Author:Runsen STL 中的栈容器是一种容器适配器.在栈容器中,元素在一端插入并在同一端删除. stack 为了实现堆栈容器,我们需要在我们的程序中包含头文件<stack>. ...

  8. 【小白学习C++ 教程】二十一、C++ 中的STL容器Arrays和vector

    @Author:Runsen C++的标准模板库(STL)是提供数组.向量.队列等数据结构的模板类的集合.STL是由容器.算法.迭代器组成的库. 容器 容器存储对象和数据.它们基本上是基于模板的泛型类 ...

  9. 【小白学习C++ 教程】二十、C++ 中的auto关键字

    @Author: Runsen 在 C++ 11 之前,每种数据类型都需要在编译时显式声明,在运行时限制表达式的值,但在 C++ 新版本之后,包含了许多关键字,允许程序员将类型推导留给编译器本身. 有 ...

最新文章

  1. 迪杰斯特拉算法(C语言实现)
  2. Python学习—函数
  3. 笔记-【6】-JS中JSON的基础理解!
  4. 编译原理实验(算符优先文法)
  5. 生活智慧:奇特的人生法则
  6. black.lst 丢失或被破坏,怎么解决
  7. STM32F0单片机快速入门八 聊聊 Coolie DMA
  8. 将已有项目转为se项目_威海将再添国家级非遗项目
  9. Python中出现“TabError: inconsistent use of tabs and spaces in indentation”
  10. lisp 标注螺纹孔_cad螺纹孔怎么标注
  11. 第一章 计算机组成原理 ---- 概述
  12. Project(10)——收货地址 - 设置默认
  13. macbook配置java环境变量_MAC安装JDK及环境变量配置
  14. 【C++PTA】7-1 运算符重载 分数类 约分
  15. 游戏原画和3D游戏建模,哪个更胜一筹?
  16. 蔡康永的201堂情商课
  17. 【学习OpenCV4】图像金字塔总结
  18. 这款免费开源的数据库工具,支持所有主流数据库!
  19. WebJars简介 —— 前端资源的jar包形式
  20. 贝壳的平台模式成长逻辑:如何赋能品牌提升效率

热门文章

  1. ubuntu 下通过 sh 命令运行脚本产生如下错误:[: y: unexpected operator
  2. linux中256错误,YUM安装遭遇: [Errno 256] No more mirrors to try
  3. 超低延迟直播架构解析
  4. c语言字符数组给字符指针,C语言常见有关问题之字符串数组和字符指针数组有关问题...
  5. c++ map 析构函数_说说C++的虚析构函数
  6. 论文阅读笔记03-fast-rcnn
  7. iOSUI视图面试及原理总结
  8. js的作用域链,原型链,以及闭包函数理解
  9. Jzoj4790 选数问题
  10. python 后台服务