文章目录

  • 利用PyTorch框架来开发深度学习算法时几个基础的模块
    • Dataset & DataLoader
      • 基础概念
      • 自定义数据集 1
      • 读取自定义数据集 1
      • 自定义数据集 2
      • 自定义数据集3
      • 官方文档写自定义数据集
      • DataLoader加载PyTorch提供的数据集
    • datasets
      • **datasets** 中有的数据集有
      • ImageFolder 和 ImageNet 的配合使用
    • models
      • 常用于Classification的模型
      • 常用于Semantic Segmentation的模型
      • 常用于 Object Detection, Instance Segmentation and Person Keypoint Detection 的模型
    • transforms
      • transforms中对于图像的处理
    • References

利用PyTorch框架来开发深度学习算法时几个基础的模块

  • Dataset & DataLoader
  • datasets
  • models
  • transforms

在利用PyTorch开始进入深度学习“大坑”的时候必须将以上的几个模块熟练掌握,这样才可以运用自如的写自己的算法或者魔改别人算法的code,下面将对以上几个模块逐一介绍其重点和一些注意事项。

Dataset & DataLoader

基础概念

Dataset & DataLoader 属于 torch中 torch.utils.data 中的模块,要使用 Dataset & DataLoader时, 必须预先导入,具体代码如下所示:

# DataLoader & Dataset 同时使用时一起导入
from torch.utils.data import Dataset Dataloader# DataLoader & Dataset 不是同时使用时,需要哪一个就导入哪一个
# 利用 Dataset 来构建自己的数据集时,必须导入Dataset
from torch.utils.data import Dataset
# 利用DataLoader 来加载自己的数据集或者官方提供的数据集时,必须导入Dataloader
from torch.utils.data import Dataloader

Dataset 是一个抽象类,为了能够方便的读取,需要将要使用的数据包装为Dataset类。
自定义的Dataset需要继承它并且实现两个成员方法:

  1. getitem() 该方法定义用索引(0 到 len(self))获取一条数据或一个样本
  2. len() 该方法返回数据集的总长度
    DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小)、 shuffle(是否进行shuffle操作)、 num_workers(加载数据的时候使用几个子进程)。

自定义数据集 1

# 使用kaggle上的一个竞赛bluebook for bulldozers自定义一个数据集
import torch
from torch.utils.data import Dataset  # 导入抽象类Dataset
import pandas as pd  # 本质是使用pandas进行处理,只是相当于进行了封装。# 定义一个数据集
class BulldozerDataset(Dataset):""" 数据集演示 """def __init__(self, csv_file):# 实现初始化方法,在初始化的时候将数据读载入# 数据保存在self.df中self.df=pd.read_csv(csv_file)def __len__(self):  # 本质替换定义了len()函数的作用# 返回df的长度return len(self.df)def __getitem__(self, idx):  # 本质定义了替换iloc[]的作用# 根据 idx 返回一行数据return self.df.iloc[idx].SalePrice
# 实例化一个对象访问它
ds_demo= BulldozerDataset('median_benchmark.csv')  #传入一个.csv文件
#实现了 __len__ 方法所以可以直接使用len获取数据总数
len(ds_demo)
# output: 11573
#用索引可以直接访问对应的数据,对应 __getitem__ 方法
ds_demo[0]
# output: 24000.0

读取自定义数据集 1

# 对刚刚上面建立的Dataset,利用DataLoader 进行读取。
dl = torch.utils.data.DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)
# DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据
# DataLoader本质是一个类,用来实现复杂的函数功能和其他功能
# .csv(原始数据)--->ds_demo(Dataset类对象)--->dl(DataLoader类对象)
idata=iter(dl)  # iter() 迭代函数
print(next(idata))
# 更常见的用法是使用for循环对其进行遍历
for i, data in enumerate(dl):print(i,data)# 为了节约空间,这里只循环一遍break
# output:0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.], dtype=torch.float64)
# 第一个维度是batch_size==10,每一个元素其实是一个实际的数据

自定义数据集 2

import os
from PIL import Image
from torch.utils.data import Datasetclass PatchDataset(Dataset):def __init__(self, data_dir, transform=None):""":param data_dir: 数据集所在路径:param transform: 数据预处理"""self.data_info = self.get_img_info(data_dir)self.transform = transformdef __getitem__(self, item):path_img, label = self.data_info[item]image = Image.open(path_img).convert('RGB')if self.transform is not None:image = self.transform(image)return image, labeldef __len__(self):return len(self.data_info)@staticmethoddef get_img_info(data_dir):path_dir = os.path.join(data_dir, 'train_dataset.txt')data_info = []with open(path_dir) as file:lines = file.readlines()for line in lines:data_info.append(line.strip('\n').split(' '))return data_info

自定义数据集3

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utilsnormalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([#transforms.Scale(256),#transforms.CenterCrop(224),transforms.ToTensor(),normalize
])def default_loader(path):img_pil =  Image.open(path)img_pil = img_pil.resize((224,224))img_tensor = preprocess(img_pil)return img_tensor#当然出来的时候已经全都变成了tensor
class trainset(Dataset):def __init__(self, loader=default_loader):#定义好 image 的路径self.images = file_trainself.target = number_trainself.loader = loaderdef __getitem__(self, index):fn = self.images[index]img = self.loader(fn)target = self.target[index]return img,targetdef __len__(self):return len(self.images)

官方文档写自定义数据集

A custom Dataset class must implement three functions: init, len, and getitem.

  • The init function is run once when instantiating the Dataset object. 通常是第一次加载数据。
  • The len function returns the number of samples in our dataset. 获取该数据集的长度。
  • The getitem function loads and returns a sample from the dataset at the given index idx. 根据索引获取数据集中的数据。
# A example
# The FashionMNIST images are stored in a directory img_dir, and their labels are stored separately in a CSV file annotations_file.
import os
import pandas as pd
from torchvision.io import read_image # 该函数读取图像的结果直接输出是一个tensorclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

:在构建自己的Dataset时候,如果需要用额外的包读图像,最好还是用 PIL,因为pytorch源码里就是用到了 PIL 读取图像。

DataLoader加载PyTorch提供的数据集

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

datasets

datasets 属于 torchvision 中的模块,需要利用datasets中的数据时必须提前将其导入,具体代码如下所示

from torchvision import datasets

datasets 中有的数据集有

  • CelebA
  • CIFAR
  • Cityscapes
  • COCO
  • DatasetFolder
  • EMNIST
  • FakeData
  • Fashion-MNIST
  • Flickr
  • HMDB51
  • ImageFolder
  • ImageNet
  • Kinetics-400
  • KMNIST
  • LSUN
  • MNIST
  • Omniglot
  • PhotoTour
  • Places365
  • QMNIST
  • SBD
  • SBU
  • STL10
  • SVHN
  • UCF101
  • USPS
  • VOC
    这些都继承了torch.utils.data.Dataset这个类,所以这些数据集都可以用torch.utils.data.DataLoader的多线程来进行快速的加载。

ImageFolder 和 ImageNet 的配合使用

# 具体例子 1如下所示:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transformsdata_transform = transforms.Compose([transforms.Resize(299),transforms.CenterCrop(299),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
train_dataset =torchvision.datasets.ImageFolder(root='ILSVRC2012/train',transform=data_transform)
train_dataset_loader =DataLoader(train_dataset,batch_size=4, shuffle=True,num_workers=4)train_dataset = torchvision.datasets.ImageFolder(root='ILSVRC2012/val',transform=data_transform)
train_dataset_loader = DataLoader(train_dataset,batch_size=4, shuffle=True,num_workers=4)
# 具体例子 2 如下所示:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasetsdata_transform = transforms.Compose([transforms.RandomSizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train', # 下载到本地ImageNet的路径transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4, shuffle=True,num_workers=4)

models

models 属于 torchvision 中的模块,需要调用时,需预先导入,具体代码如下所示:

# 直接导入models
import torchvision.models# 导入并用别名
import torchvision.models as models

常用于Classification的模型

alexnet vgg resnet squeezenet densenet inception googlenet shufflenet
You can construct a model with random weights by calling its constructor:


import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()

ImageNet 1-crop error rates (224x224)


# ImageNet 的均值和方差normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])

常用于Semantic Segmentation的模型

  • FCN ResNet50, ResNet101
  • DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large
  • LR-ASPP MobileNetV3-Large

The accuracies of the pre-trained models evaluated on COCO val2017 are as follows

常用于 Object Detection, Instance Segmentation and Person Keypoint Detection 的模型

  • Faster R-CNN
  • Mask R-CNN
  • SSD
  • SSDlite
  • RetinaNet

Here are the summary of the accuracies for the models trained on the instances set of COCO train2017 and evaluated on COCO val2017.

transforms

transforms 属于 torchvision 中的模块,用于对数据进行预处理,主要是图像的处理( 有22个function),可以使用Compose将其链接在一起。

transforms中对于图像的处理

  • 裁剪(Crop)—— 中心裁剪:transforms.CenterCrop 随机裁剪:transforms.RandomCrop 随机长宽比裁剪:transforms.RandomResizedCrop 上下左右中心裁剪:transforms.FiveCrop 上下左右中心裁剪后翻转,transforms.TenCrop
  • 翻转和旋转(Flip and Rotation) ——依概率p水平翻转:transforms.RandomHorizontalFlip(p=0.5) 依概率p垂直翻转:transforms.RandomVerticalFlip(p=0.5) 随机旋转:transforms.RandomRotation
  • 图像变换(resize) ——transforms.Resize 标准化:transforms.Normalize 转为tensor,并归一化至[0-1]:transforms.ToTensor 填充:transforms.Pad 修改亮度、对比度和饱和度:transforms.ColorJitter 转灰度图:transforms.Grayscale 线性变换:
    transforms.LinearTransformation() 仿射变换:transforms.RandomAffine 依概率p转为灰度图:transforms.RandomGrayscale 将数据转换为PILImage:transforms.ToPILImage transforms.Lambda:Apply a user-defined lambda as a transform.
  • 对transforms操作,使数据增强更灵活 transforms.RandomChoice(transforms), 从给定的一系列transforms中选一个进行操作 transforms.RandomApply(transforms, p=0.5),给一个transform加上概率,依概率进行操作 transforms.RandomOrder,将transforms中的操作随机打乱

最常用的处理:CenterCrop、Grayscale、RandomCrop、RandomHorizontalFlip、RandomVerticalFlip、RandomRotation、Normalize、ToTensor

# torchvision.transforms.Compose(transforms)
# Compose的例子 1
transforms.Compose([transforms.CenterCrop(10),transforms.ToTensor(), ])
# Compose的例子 2
data_transform = transforms.Compose([transforms.RandomSizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])

References

  1. Pytorch构建数据集——torch.utils.data.Dataset()和torch.utils.data.DataLoader()
  2. PyTorch 基础 :数据的加载和预处理
  3. 看pytorch official tutorials的新收获
  4. torch.utils.data
  5. DATASETS & DATALOADERS
  6. LEARN THE BASICS
  7. pytorch Dataset, DataLoader产生自定义的训练数据
  8. PyTorch手把手自定义Dataloader读取数据
  9. TORCHVISION.TRANSFORMS
  10. PyTorch 学习笔记:transforms的二十二个方法(transforms用法非常详细)
  11. PyTorch 模型训练实用教程

Pytorch之DataLoader Dataset、datasets、models、transforms的认识和学习相关推荐

  1. dataset__getitem___一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...

  2. 查看dataloader的大小_一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...

  3. 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    自上而下理解三者关系 首先我们看一下DataLoader.__next__的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据). c ...

  4. [Pytorch] Sampler, DataLoader和数据batch的形成

    目录 1. 简介 2. 整体流程 3. Sampler和BatchSampler 3.1 Sampler 3.2 BatchSampler 4. DataLoader 4.1 DataLoader 4 ...

  5. PyTorch 笔记(20)— torchvision 的 datasets、transforms 数据预览和加载、模型搭建(torch.nn.Conv2d/MaxPool2d/Dropout)

    计算机视觉是深度学习中最重要的一类应用,为了方便研究者使用,PyTorch 团队专门开发了一个视觉工具包torchvision,这个包独立于 PyTorch,需通过 pip instal torchv ...

  6. PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快

    PyTorch训练中Dataset多线程加载数据,而不是在DataLoader 背景与需求 现在做深度学习的越来越多人都有用PyTorch,他容易上手,而且API相对TF友好的不要太多.今天就给大家带 ...

  7. 【PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快】

    文章目录 一.引言 二.背景与需求 三.方法的实现 四.代码与数据测试 五.测试结果 5.1.Max elapse 5.2.Multi Load Max elapse 5.3.Min elapse 5 ...

  8. Pytorch数据读取(Dataset, DataLoader, DataLoaderIter)

    Pytorch的数据读取主要包含三个类: Dataset DataLoader DataLoaderIter 这三者是一个依次封装的关系: 1.被装进2., 2.被装进3. Dataset类 Pyto ...

  9. PyTorch 之 DataLoader

    DataLoader DataLoader 是 PyTorch 中读取数据的一个重要接口,该接口定义在 dataloader.py 文件中,该接口的目的: 将自定义的 Dataset 根据 batch ...

最新文章

  1. 【翻译】在Sencha应用程序中使用插件和混入
  2. opencore0.6.3_大杨随笔2020.11.3
  3. 零编程基础学python-零编程基础怎么自学python?
  4. Kubernetes使用集群联邦实现多集群管理
  5. struts2 获取 session
  6. python 解析xml格式_Python解析XML文件
  7. 【问链财经-区块链基础知识系列】 第三十三课 区块链溯源方案设计-中检集团区块链溯源平台
  8. webpack打包后引用cdn的js_呕心沥血编写的webpack多入口零基础配置 【建议收藏】...
  9. boost::geometry::svg用法的测试程序
  10. linux 安装 powershell
  11. wordpress主题的样式修改
  12. 从敲下一行JS代码到这行代码被执行,中间发生了什么?
  13. 有关i386和i686
  14. SCI论文写作--科研其实远没有那么难
  15. Task01:熟悉新闻推荐系统的基本流程(数据库设计)
  16. hutool 读取扩展名文件_好多公司都要用的一些知识点Office办公软件、文件加密、文件扩展名!...
  17. SAP R3 IDES 4.71电驴资源
  18. 如何破解winrar(可用)
  19. 拼多多数据分析笔试题(附代码答案)
  20. Laya 2.0 微信排行榜数据

热门文章

  1. 进制转换之十进制转换为D进制——整数部分除基取余法
  2. docker部署命令
  3. Istio服务网格进阶②:在Istio服务网格中部署Bookinfo在线书店微服务项目
  4. MACOM推出宽带多级硅基氮化镓 (GaN-on-Si) 功率放大器 (PA) 模块 具备灵活安装性能,实现领先
  5. Velodyne VLP16 接入ros系统
  6. CR2032 电池放电曲线
  7. 顶尖项目管理高手,都在用“敏捷预算”模式!
  8. 【mysql错误】MySQL server has gone away 问题的解决方法
  9. 自定义快捷键打开音量合成器
  10. PMP模拟试题每日5题(5月6日)