简述

Pytorch自定义数据集方法,应该是用pytorch做算法的最基本的东西。
往往网络上给的demo都是基于torch自带的MNIST的相关类。所以,为了解决使用其他的数据集,在查阅了torch关于MNIST数据集的源码之后,很容易就可以推广到了我们自己需要的代码上。

具体操作如下:

准备工作

需要导入一些包。

from torch.utils.data import Dataset, DataLoader

再自定义一个用于当训练集合的类。

class TrainSet(Dataset):def __init__(self, X, Y):# 定义好 image 的路径self.X, self.Y = X, Ydef __getitem__(self, index):return self.X[index], self.Y[index]def __len__(self):return len(self.X)

数据预处理

之后,假设你的训练集合为[X,Y],其中X是训练数据,Y是对应的数据的标签。

首先,需要知道的是,torch能处理的数据只能是torch.Tensor,所以有必要将其他数据转换为torch.Tensor

常见的有几种数据:

  • np.ndarray
  • PIL.Image

如果是图片数据,其实也有多种情况,根据数据维度不同,有些是二维图,有些是三维图(通俗来讲,就是黑白图和彩图)。

所以,我先按照数据类型的模式将一遍,再补充关于图片的处理。

np.ndarray

np.ndarray是非常常见的格式,转成Tensor也非常简单。

torch.Tensor(array)

这样代码的返回格式就是一个Tensor

PIL.Image

import torchvision.transforms as transforms
transforms.ToTensor()(image)

这样代码的返回格式就是一个Tensor

关于图片

  • 彩色的三维图: 上面方法就已经完成了对应的数据处理的步骤
  • 灰白或者是二值的二维图:就需要将数据增加一个维度了(因为往往关于图片,所用到的算法都是包括了卷积的步骤,所以要求增加一个维度)

具体操作如下: 明显,torch.Tensor(X)这样的步骤,其实是重复了上面的将np.ndarray转成torch.Tensor的步骤。同理可以换成上面的关于PIL.Image的方法

X_tensor = torch.unsqueeze(torch.Tensor(X), 1)
Y_tensor = torch.unsqueeze(torch.Tensor(Y), 1)

导入数据

建立自己的数据集。

mydataset = TrainSet(X_tensor, Y_tensor)

再把自己的数据集导入到数据加载器上:

  • batch_size表示用将原数据拆分之后,每batch_size个数据作为一组数据被调用。shuffle表示数据是否被洗牌(即刷新顺序,避免训练的时候多次调用结果都遇到同一batch,从而避免误差)
train_loader = DataLoader(mydataset, batch_size=10, shuffle=True)

使用的方式也非常简单:

for step, (x, y) in enumerate(train_loader):

这里的x,y就是每个batch所处理的数据。

另外,附上一个我常用的读取自定义图片的dataset类

main函数部分是对数据集做测试。

import torch.utils.data as data
import glob
import os
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torchimport piexif
import imghdrclass MyDataset(data.Dataset):def __init__(self, path, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False):if resize != -1:transform = transforms.Compose([transforms.Resize(resize),transforms.CenterCrop(resize),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])else:transform = transforms.Compose([transforms.ToTensor(),])img_format = '*.%s' % img_typeif remove_exif:for name in glob.glob(os.path.join(path, img_format)):try:piexif.remove(name)  # 去除exifexcept Exception:continue# imghdr.what(img_path) 判断是否为损坏图片if Len == -1:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format)) if imghdr.what(name)]else:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)]self.dataset = np.array(self.dataset)self.dataset = torch.Tensor(self.dataset)self.Train = Traindef __len__(self):return len(self.dataset)def __getitem__(self, idx):return self.dataset[idx]if __name__ == '__main__':path = r'D:\Software\DataSet\faces'dataset = MyDataset(path=path, resize=96, Len=10, img_type='jpg')print(len(dataset))plt.imshow(dataset[0].numpy().transpose(1, 2, 0) * 0.5 + 0.5)plt.show()print(dataset[0].max(), dataset[0].min())

Pytorch自定义数据集相关推荐

  1. 数据集制作_轻松学Pytorch自定义数据集制作与使用

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 大家好,这是轻松学Pytorch系列的第六篇分享,本篇你将学会如何从头开始制作自己的数据集,并通过DataLo ...

  2. pytorch: 自定义数据集加载

    很多网络在数据加载方式 pytorch 的输入流水线的操作顺序是这样的: 创建一个 Dataset 对象     创建一个 DataLoader 对象     不停的 循环 这个 DataLoader ...

  3. 利用PyTorch自定义数据集实现猫狗分类

    看了许多关于PyTorch的入门文章,大抵是从torchvision.datasets中自带的数据集进行训练,导致很难把PyTorch运用于自己的数据集上,真正地灵活运用PyTorch. 这里我采用从 ...

  4. pytorch自定义数据集DataLoder

    pytorch官方例程: DATA LOADING AND PROCESSING TUTORIAL torch.utils.data.Dataset 是dataset的抽象类,我们可以同过继承Data ...

  5. 【问题记录】pytorch自定义数据集 No such file or directory, invalid index of a 0-dim

    保存模型: : 保存整个神经网络的结构和模型参数 torch.save(mymodel, 'mymodel.pkl') 只保存神经网络的模型参数 torch.save(mymodel.state_di ...

  6. pytorch自定义数据集语义分割报错备忘RuntimeError: 1only batches of spatial targets supported (3D tensors)

    报错原文:RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: ...

  7. PyTorch版YOLOv4更新了,不仅适用于自定义数据集,还集成了注意力和MobileNet

    机器之心报道 作者:陈萍 距离 YOLO v4 的推出,已经过去 5 个多月.YOLO 框架采用 C 语言作为底层代码,这对于惯用 Python 的研究者来说,实在是有点不友好.因此网上出现了很多基于 ...

  8. 我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!

    大家好,我是红色石头! 在上三篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)! 我用 PyTorch ...

  9. 【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!

    在上三篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)! 我用 PyTorch 复现了 LeNet-5 ...

最新文章

  1. Python在计算内存时值得注意的几个问题
  2. cdialog创建后马上隐藏_隐藏你的小秘密,这款神器就是玩的这么6!
  3. 三位数除以两位数竖式计算没有余数_二年级数学第三十课:有余数的除法 例4 试商...
  4. 靠谱的N95消毒方法终于来了:放你家电饭锅里,干烧丨伊利诺伊香槟分校出品...
  5. Silverlight 控件开发记录之 extern alias” 关键字
  6. python爬取aspx数据
  7. nssl1209-奇怪的队列【贪心,权值线段树】
  8. bootstraptable中responsehandle获取数据缺失_Python中的向量化字符串操作
  9. mysql使用已有的数据库_使用SQL操作MySQL数据库
  10. python脚本性能分析
  11. 求连续子数组的最大和C语言,求助:最长连续子数组问题
  12. Windows Ubuntu,软件推荐,小技巧总结,免费内网穿透方案
  13. MLAPP————第十四章 核方法
  14. 平面设计和3D建模哪个好找工作?
  15. unity2d 投影_Unity Projector 投影器原理以及优化
  16. localStorage数据丢失
  17. Vue + TypeScript + Element 搭建简洁时尚的博客网站及踩坑记
  18. Android手机获取屏幕分辨率高度因虚拟导航栏带来的问题
  19. 2019年 团体程序设计天梯赛——题解集
  20. 一个简单的TTS文语转换实例

热门文章

  1. Android4.3 屏蔽HOME按键返回桌面详解(源码环境下)
  2. [Android L]SEAndroid开放设备文件结点权限(读或写)方法(涵盖常用操作:sys/xxx、proc/xxx、SystemProperties)热门干货
  3. SylixOS普通定时器精度分析
  4. 操作系统杂谈 mac 和linux windows若干概念
  5. weblogic服务器保存图片失败解决办法
  6. 超强 css 实现 table 隔行 ,隔列 换色
  7. 使用SQL DTS功能实现从DB/2向SQL Server传输数据
  8. 借收购搭桥,风河Workbench软件环境涵盖至测试领域
  9. ps怎么制作流体_PS教程:制作渐变流体效果海报
  10. pandas分批读取csv文件