PyTorch 数据处理工具箱

文章目录

  • PyTorch 数据处理工具箱
    • 1、数据处理工具箱概述
    • 2、utils.data 简介
      • 2.1、自定义一个数据集
    • 3、torchvision 简介
      • 3.1、transforms
      • 3.2、ImageFolder
    • 4、可视化工具

1、数据处理工具箱概述

Pytorch 涉及数据处理(数据装载、数据预处理、数据增强等)主要工具包及相互关系如图:

它主要包含 4 个类:

  • Dataset:是一个抽象类,其它数据集需要继承这个类,并且覆写其中的两个方法(_getitem_、_len_);
  • DataLoader:定义一个新的迭代器实现批量(batch)读取,打乱数据(shuffle)并提供并行加速等功能
  • random_split:把数据集随机拆分为给定长度的非重叠新数据集
  • *sampler:多种采样函数

中间是 Pytorch 可视化处理工具(torchvision)Pytorch 的一个视觉处理工具包,独立于 Pytorch,需要另外安装。它包括 4 个类,各类的主要功能如下:

  • datasets:提供常用的数据集加载,设计上都是继承 torch.utils.data.Dataset,主要包括 MMIST、CIFAR10/100、ImageNet、COCO 等;
  • models:提供深度学习中各种经典的网络结构以及训练好的模型(如果选择 pretrained=True),包括 AlexNet, VGG系列、ResNet 系列、Inception 系列等;
  • transforms:常用的数据预处理操作,主要包括对 Tensor 及 PIL Image 对象的操作
  • utils:含两个函数,一个是 make_grid,它能将多张图片拼接在一个网格中;另一个是 save_img,它能将 Tensor 保存成图片

2、utils.data 简介

  • utils.data 包括 Dataset 和 DataLoader:

    1. torch.utils.data.Dataset 为抽象类。自定义数据集需要继承这个类,并实现两个函数。一个是__len__,另一个是__getitem__,前者提供数据的大小(size),后者通过给定索引获取数据和标签

    2. _getitem_ 一次只能获取一个数据,所以通过 torch.utils.data.DataLoader 来定义一个新的迭代器,实现 batch 读取

      DataLoader 的格式为:

      data.DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,
      )
      
      • dataset:加载的数据集;
      • batch_size:批大小;
      • shuffle:是否将数据打乱;
      • sampler:样本抽样;
      • num_workers:使用多进程加载的进程数,0 代表不使用多进程;
      • collate_fn:如何将多个样本数据拼接成一个 batch,一般使用默认的拼接方式即可;
      • pin_memory:是否将数据保存在 pin memory 区,pin memory 中的数据转到 GPU 会快一些;
      • drop_last:dataset 中的数据个数可能不是 batch_size 的整数倍,drop_last 为 True 会将多出来不足一个batch 的数据丢弃。

    2.1、自定义一个数据集

    1. 导入需要的模块

      import torch
      from torch.utils import data
      import numpy as np
      
    2. 定义获取数据集的类

      类继承基类 Dataset,自定义一个数据集及对应标签

      class TestDataset(data.Dataset):#继承Datasetdef __init__(self):self.Data=np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])#一些由2维向量表示的数据集self.Label=np.asarray([0,1,0,1,2])#这是数据集对应的标签def __getitem__(self, index):#把numpy转换为Tensortxt=torch.from_numpy(self.Data[index])label=torch.tensor(self.Label[index])return txt,label def __len__(self):return len(self.Data)
      
    3. 获取数据集中数据

      Test=TestDataset()
      print(Test[2])  #相当于调用__getitem__(2)
      print(Test.__len__())#輸出:
      #(tensor([2, 1]), tensor(0))
      #5
      

      以上数据以 tuple 返回,每次只返回一个样本。实际上,Dateset 只负责数据的抽取,一次调用__getitem__只返回一个样本。如果希望批量处理(batch),同时还要进行 shuffle 和并行加速等操作,可选择 DataLoader。

      test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=2)
      for i,traindata in enumerate(test_loader):print('i:',i)Data,Label=traindataprint('data:',Data)print('Label:',Label)
      

      从这个结果可以看出,这是批量读取。我们可以像使用迭代器一样使用它,如对它进行循环操作。不过它不是迭代器,我们可以通过 iter 命令转换为迭代器。

      dataiter=iter(test_loader)
      imgs,labels=next(dataiter)
      #imgs.size()
      

    一般用 data.Dataset 处理同一个目录下的数据。如果数据在不同目录下,不同目录代表不同类别(这种情况比较普遍),使用 data.Dataset 来处理就不很方便。不过,可以使用 Pytorch 另一种可视化数据处理工具(即 torchvision)就非常方便,不但可以自动获取标签,还提供很多数据预处理、数据增强等转换函数。


3、torchvision 简介

  • torchvision 有 4 个功能模块,model、datasets、transforms 和 utils

    1. 利用 datasets 下载一些经典数据集;
    2. 提供深度学习中各种经典的网络结构以及训练好的模型(如果选择 pretrained=True);
    3. datasets 的 ImageFolder处理自定义数据集;
    4. transforms 对源数据进行预处理、增强。

3.1、transforms

transforms 提供了对 PIL Image 对象和 Tensor 对象的常用操作:

  • 对 PIL Image 的常见操作如下:

    1. Scale/Resize:调整尺寸,长宽比保持不变
    2. CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片,CenterCrop 和 RandomCrop 在 crop 时是固定size,RandomResizedCrop 则是 random size 的 crop;
    3. Pad:填充
    4. ToTensor:把一个取值范围是 [0,255] 的 PIL.Image 转换成 Tensor。形状为 (H,W,C) 的 numpy.ndarray,转换成形状为 [C,H,W],取值范围是 [0,1.0] 的 torch.FloatTensor;
    5. RandomHorizontalFlip:图像随机水平翻转,翻转概率为 0.5;
    6. RandomVerticalFlip:图像随机垂直翻转;
    7. ColorJitter:修改亮度、对比度和饱和度
  • 对 Tensor 的常见操作如下:

    1. Normalize:标准化,即减均值,除以标准差;
    2. ToPILImage:将 Tensor 转为 PIL Image
  • 如果要对数据集进行多个操作,可通过 Compose 将这些操作像管道一样拼接起来,类似于 nn.Sequential。以下为示例代码:

    transforms.Compose([#将给定的 PIL.Image 进行中心切割,得到给定的 size,#size 可以是 tuple,(target_height, target_width)。#size 也可以是一个 Integer,在这种情况下,切出来的图片形状是正方形。            transforms.CenterCrop(10),#切割中心点的位置随机选取transforms.RandomCrop(20, padding=0),#把一个取值范围是 [0, 255] 的 PIL.Image 或者 shape 为 (H, W, C) 的 numpy.ndarray,#转换为形状为 (C, H, W),取值范围是 [0, 1] 的 torch.FloatTensortransforms.ToTensor(),#规范化到[-1,1]transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
    ])
    

3.2、ImageFolder

  • 当文件依据标签处于不同文件下时,可以利用 torchvision.datasets.ImageFolder 来直接构造出 dataset。

    ImageFolder 会将目录中的文件夹名自动转化成序列,那么 DataLoader 载入时,标签自动就是整数序列了

示例代码:

from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt my_trans=transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/torchvision_data', transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True,)for i_batch, img in enumerate(train_loader):if i_batch == 0:print(img[1])fig = plt.figure()grid = utils.make_grid(img[0])plt.imshow(grid.numpy().transpose((1, 2, 0)))plt.show()utils.save_image(grid,'test01.png')break

4、可视化工具

TensorboardX 是 Google TensorFlow 的可视化工具,它可以记录训练数据、评估数据、网络结构、图像等,并且可以在 web 上展示,对于观察神经网路训练的过程非常有帮助。

  • 使用 tensorboardX 的一般步骤为

    1. 导入 tensorboardX,实例化 SummaryWriter 类,指明记录日志路径等信息:

      from tensorboardX import SummaryWriter
      #实例化SummaryWriter,并指明日志存放路径。在当前目录没有logs目录将自动创建。
      writer = SummaryWriter(log_dir='logs')
      #调用实例
      writer.add_xxx()
      #关闭writer
      writer.close()
      
      SummaryWriter(log_dir=None, comment='', **kwargs)
      #其中comment在文件命名加上comment后缀
      
    2. 调用相应的 API 接口,接口一般格式为:

      add_xxx(tag-name, object, iteration-number)
      #即add_xxx(标签,记录的对象,迭代次数)
      
    3. 启动 tensorboard 服务

      cd 到 logs 目录所在的同级目录,在命令行输入如下命令,logdir 等式右边可以是相对路径或绝对路径。

      tensorboard --logdir=logs --port 6006
      #如果是windows环境,要注意路径解析,如
      #tensorboard --logdir=r'D:\myboard\test\logs' --port 6006
      
    4. web 展示

      在浏览器输入:

      http://服务器IP或名称:6006  #如果是本机,服务器名称可以使用localhost
      

PyTorch数据处理工具箱相关推荐

  1. 第3章 Pytorch神经网络工具箱

    链接 前面我们介绍了Pytorch的数据结构及自动求导机制,充分运行这些技术可以大大提高我们的开发效率.这章将介绍Pytorch的另一利器:神经网络工具箱.利用这个工具箱,设计一个神经网络就像搭积木一 ...

  2. 第3章 PyTorch神经网络工具箱(1/2)

    前面已经介绍了PyTorch的数据结构及自动求导机制,充分运行这些技术可以大大提高我们的开发效率.这章将介绍PyTorch的另一利器:神经网络工具箱.利用这个工具箱,设计一个神经网络就像搭积木一样,可 ...

  3. PyTorch数据处理工具

    PyTorch数据处理工具 概述 PyTorch主要数据处理工具: Dataset:是一个抽象类,其他数据集需要继承这个类,并且覆写其中的两个方法(getitem_.len). DataLoader: ...

  4. Pytorch:神经网络工具箱nn

    第四章 神经网络工具箱nn 上一章中提到,使用autograd可实现深度学习模型,但其抽象程度较低,如果用其来实现深度学习模型,则需要编写的代码量极大.在这种情况下,torch.nn应运而生,其是专门 ...

  5. 【MODIS数据处理#15】分享一个自制的MODIS数据处理工具箱

    文章目录 一.下载地址 二.工具箱内容 三.配置教程 四.使用教程 后记 整理了本人自制的MODIS数据批处理脚本工具,以ArcGIS共享工具箱(.tbx)的方式免费分享给大家.所有工具都有详细的说明 ...

  6. 计算机视觉PyTorch - 数据处理(库数据和训练自己的数据)

    pytorch实现图像分类数据处理 1. pytorch库自带数据 数据预处理 数据生成 数据加载 2. 训练自己的数据 生成数据集 数据预处理 数据加载 1. pytorch库自带数据 为了更好的理 ...

  7. 京东开源FaceX-Zoo:PyTorch面部识别工具箱

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达小白导读论文是学术研究的精华和未来发展的明灯.小白决心每天为大家带来 ...

  8. pytorch 数据处理(复数处理)记录

    因为无信通信中经常用到复数的乘法,pytorch中又没有现成的处理方式,自己懒得写,就在网上搜了下好心人分享的资料.确实是正确的.但是作者采用循环的方式,在处理大批量数据的时候非常慢,我对几十万复数的 ...

  9. pytorch数据处理的操作

    张量(256,256)转换为三维通道图片张量(256,256,3)可以采用下面方法: 1.先将张量转换为数组. 2.在numpy数组增加维数(相当于在张量里增加了通道的维度),即(256,256,1) ...

最新文章

  1. Codeforces Round #599A~D题解
  2. nginx的权限问题(Permission denied)解决办法
  3. 刷爆技术圈的《知识图谱》终于补货了,最后 968 份,低至 2 折,抢完不补!...
  4. RecyclerView解析--onViewDetachedFromWindow()/onViewAttachedToWindow()
  5. 希尔排序(ShellSort) c源码
  6. 计算机二级vb考试查分,全国计算机二级考试VB程序设计复习试题
  7. OA,ERP等源码一部分演示
  8. Linux 7 关闭、禁用防火墙服务
  9. 初学者python笔记(json模块、pickle模块、xml模块、shelve模块)
  10. 【斯坦福大学新研究】声波、光波等都是RNN
  11. IDEA、MySQL、SQLyog安装教程
  12. FPGA实现sobel边缘检测并Modelsim仿真,与MATLAB实现效果对比
  13. openjudge 买书
  14. Mac 下JDK 1.8 下载地址
  15. CREATE PROCEDURE
  16. 2022年登高架设操作证考试题及在线模拟考试
  17. SEO优化 - robots协议
  18. 一些基础电路和物理量在线换算公式
  19. 程序猿从不缺对象,想要随时可以new出来一个
  20. 安卓分屏神器_8款App打造一个学术型iPad,这才是它秒杀安卓平板的杀手锏

热门文章

  1. win 8.1 64位彻底删除王码98
  2. git之branch分支增删改查、切换、更新远程代码到本地仓库
  3. dell15-5559_Dell Mini 9-实用开发人员评论
  4. QQ2009 暂时无法登陆,请稍候重试 问题解决方法
  5. 问道手游服务器维护,《问道》手游服务器例行维护公告(2016.03.07)
  6. 查找存储卡的路径在WM6 windows mobile
  7. Succeeding At Your Yahoo! Business
  8. Dockerfile最佳实践【原创、很多实践经验】
  9. VM虚拟机 系统出现鼠标定位不准确、双鼠标问题
  10. 怎么将打开的网页在浏览器中隐藏而不关闭