数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力。 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能。为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。

数据集存放大致有以下两种方式:

(1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg

root/cat_dog/cat.02.jpg

........................

root/cat_dog/dog.01.jpg

root/cat_dog/dog.02.jpg

......................

(2)不同类别的数据集放在不同目录下,目录名就是标签,数据集存放格式如下:

root/ants/xxx.png

root/ants/xxy.jpeg

root/ants/xxz.png

................

root/bees/123.jpg

root/bees/nsdf3.png

root/bees/asd932_.png

..................

1.1 对第1种数据集的处理步骤

(1)生成包含各文件名的列表(List)

(2)定义Dataset的一个子类,该子类需要继承Dataset类,查看Dataset类的源码

(3)重写父类Dataset中的两个魔法方法: 一个是: __lent__(self),其功能是len(Dataset),返回Dataset的样本数。 另一个是__getitem__(self,index),其功能假设索引为i,使Dataset[i]返回第i个样本。

(4)使用torch.utils.data.DataLoader加载数据集Dataset.

1.2 实例详解

以下以cat-dog数据集为例,说明如何实现自定义数据集的加载。

1.2.1 数据集结构

所有数据集在cat-dog目录下:

.\cat_dog\cat.01.jpg

.\cat_dog\cat.02.jpg

.\cat_dog\cat.03.jpg

....................

.\cat_dog\dog.01.jpg

.\cat_dog\dog.02.jpg

....................

1.2.2 导入需要用到的模块

from torch.utils.data import DataLoader,Dataset

from skimage import io,transform

import matplotlib.pyplot as plt

import oimport torch

from torchvision import transforms, utils

from PIL import Image

import pandas as pd

import numpy as np

#过滤警告信息

import warnings

warnings.filterwarnings("ignore")

1.2.3定义加载自定义数据的类

class MyDataset(Dataset): #继承Dataset

def __init__(self, path_dir, transform=None): #初始化一些属性

self.path_dir = path_dir #文件路径,如'.\data\cat-dog'

self.transform = transform #对图形进行处理,如标准化、截取、转换等

self.images = os.listdir(self.path_dir)#把路径下的所有文件放在一个列表中

def __len__(self):#返回整个数据集的大小

return len(self.images)

def __getitem__(self,index):#根据索引index返回图像及标签

image_index = self.images[index]#根据索引获取图像文件名称

img_path = os.path.join(self.path_dir, image_index)#获取图像的路径或目录

img = Image.open(img_path).convert('RGB')# 读取图像

# 根据目录名称获取图像标签(cat或dog)

label = img_path.split('\\')[-1].split('.')[0]

#把字符转换为数字cat-0,dog-1

label = 1 if 'dog' in label else 0

if self.transform is not None:

img = self.transform(img)

return img,label

1.2.4 实例化类

dataset = MyDataset('.\data\cat-dog',transform=None)

img, label = dataset[0] #将启动魔法方法__getitem__(0)

print(type(img))

1.2.5 查看图像形状

i=1

for img, label in dataset:

if i

img的形状(500, 374),label的值0

img的形状(300, 280),label的值0

img的形状(489, 499),label的值0

img的形状(431, 410),label的值0

img的形状(300, 224),label的值0

从上面返回样本的形状来看:

(1)每张图片的大小不一样,如果需要取batch训练的神经网络来说很不友好。

(2)返回样本的数值较大,未归一化至[-1, 1]

为此需要对img进行转换,如何转换?只要使用torchvision中的transforms即可

1.2.6 对图像数据进行处理

这里使用torchvision中的transforms模块

from torchvision import transforms as T

transform = T.Compose([

T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素

T.CenterCrop(224), # 从图片中间切出224*224的图片

T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]

T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差

])

1.2.7查看处理后的数据

dataset = MyDataset('.\data\cat-dog',transform=transform)

for img, label in dataset:

print("图像img的形状{},标签label的值{}".format(img.shape, label))

print("图像数据预处理后:\n",img)

break

图像img的形状torch.Size([3, 224, 224]),标签label的值0

图像数据预处理后:

tensor([[[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9529, 0.9529, 0.9529],

...,

[-0.4824, -0.5294, -0.5373, ..., -0.9216, -0.9294, -0.9451],

[-0.4980, -0.5529, -0.5608, ..., -0.9294, -0.9373, -0.9529],

[-0.4980, -0.5529, -0.5686, ..., -0.9529, -0.9608, -0.9608]],

[[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.8039, 0.7961, 0.7961],

...,

[-0.6078, -0.6471, -0.6549, ..., -0.9137, -0.9216, -0.9373],

[-0.6157, -0.6706, -0.6784, ..., -0.9216, -0.9294, -0.9451],

[-0.6157, -0.6706, -0.6863, ..., -0.9451, -0.9529, -0.9529]],

[[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2157, 0.2235, 0.2235],

...,

[-0.9529, -0.9843, -0.9922, ..., -0.9529, -0.9608, -0.9765],

[-0.9686, -0.9922, -1.0000, ..., -0.9608, -0.9686, -0.9843],

[-0.9686, -0.9922, -1.0000, ..., -0.9843, -0.9922, -0.9922]]])

由此可知,数据已标准化、规范化。

1.2.8对数据集进行批量加载

使用DataLoader模块,对数据集dataset进行批量加载

#使用DataLoader加载数据

dataloader = DataLoader(dataset,batch_size=4,shuffle=True)

for batch_datas, batch_labels in dataloader:

print(batch_datas.size(),batch_labels.size())

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([4, 3, 224, 224]) torch.Size([4])

torch.Size([2, 3, 224, 224]) torch.Size([2])

1.2.9随机查看一个批次的图像

import torchvision

import matplotlib.pyplot as plt

import numpy as np

%matplotlib inline

# 显示图像

def imshow(img):

img = img / 2 + 0.5 # unnormalize

npimg = img.numpy()

plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.show()

# 随机获取部分训练数据

dataiter = iter(dataloader)

images, labels = dataiter.next()

# 显示图像

imshow(torchvision.utils.make_grid(images))

# 打印标签

print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小猫" for j in range(4)]))

2 对第2种数据集的处理

处理这种情况比较简单,可分为2步:

(1)使用datasets.ImageFolder读取、处理图像。

(2)使用.data.DataLoader批量加载数据集,示例如下:

import torch

from torchvision import transforms, datasets

data_transform = transforms.Compose([

transforms.RandomSizedCrop(224),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize(mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225])

])

hymenoptera_dataset = datasets.ImageFolder(root='.\catdog\train',

transform=data_transform)

dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,

总结

到此这篇关于PyTorch加载自己的数据集实例详解的文章就介绍到这了,更多相关PyTorch加载 数据集内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

本文标题: PyTorch加载自己的数据集实例详解

本文地址: http://www.cppcns.com/jiaoben/python/303230.html

python从date目录导入数据集_PyTorch加载自己的数据集实例详解相关推荐

  1. linux如何确定共享库路径,摘录Linux下动态共享库加载时的搜索路径详解

    对动态库的实际应用还不太熟悉的读者可能曾经遇到过类似"error while loading shared libraries"这样的错误,这是典型的因为需要的动态库不在动态链接器 ...

  2. Node.js学习笔记——模块加载机制及npm指令详解

    文章目录 二.模块化 1.模块化的基本概念 2.Node.js 中的模块化 Node.js 中模块的分类 加载模块 Node.js 中的模块作用域 向外共享模块作用域中的成员 Node.js 中的模块 ...

  3. python 录制web视频_Python django框架 web端视频加密的实例详解

    视频加密流程图: 后端获取保利威的视频播放授权token,提供接口api给前端 参考文档:http://dev.polyv.net/2019/videoproduct/v-api/v-api-play ...

  4. python 经典脚本文件_Python3.5文件读与写操作经典实例详解

    本文实例讲述了Python3.5文件读与写操作.分享给大家供大家参考,具体如下: 1.文件操作的基本流程: (1)打开文件,得到文件句柄并赋值给一个变量 (2)通过句柄对文件进行操作 (3)关闭文件 ...

  5. python如何调用文件进行换位加密_python 换位密码算法的实例详解

    python 换位密码算法的实例详解 一前言: 换位密码基本原理:先把明文按照固定长度进行分组,然后对每一组的字符进行换位操作,从而实现加密.例如,字符串"Error should neve ...

  6. python标准库对象导入语句_Python标准库之Sys模块使用详解

    sys 模块提供了许多函数和变量来处理 Python 运行时环境的不同部分. 处理命令行参数 在解释器启动后, argv 列表包含了传递给脚本的所有参数, 列表的第一个元素为脚本自身的名称. 使用sy ...

  7. Android插件化开发之动态加载三个关键问题详解

    本文摘选自任玉刚著<Android开发艺术探索>,介绍了Android插件化技术的原理和三个关键问题,并给出了作者自己发起的开源插件化框架. 动态加载技术(也叫插件化技术)在技术驱动型的公 ...

  8. java loadjs_Javarscript中模块(module)、加载(load)与捆绑(bundle)详解

    JS模块简介 js模块化,简单说就是将系统或者功能分隔成单独的.互不影响的代码片段,经过严格定义接口,使各模块间互不影响,且可以为其他所用. 常见的模块化有,C中的include (.h)文件.jav ...

  9. as3加载外部图片的方法详解

    开始之前先做一些准备工作.新建一个空的flash文件,注意选择支持ActionScript 3.0的flash文件,保存该flash文件.再找一张图片并将其和新建的flash文件放在同一目录下(AS3 ...

最新文章

  1. 网游放缓页游疾进 客户端游戏会被取代吗?
  2. Octavia health-manager 与 amphora 故障修复的实现与分析
  3. 查看linux安装redis的位置,linux查看是否安装redis
  4. Python的enumerater
  5. 字符串经典题目(Leetcode题解-Python语言)
  6. 如何在Windows下发布QT应用程序
  7. HTML - 脚本JavaScript
  8. 计算机组成原理(第3版)唐朔飞著 知识点总结
  9. 微信红包后台系统设计
  10. 【喵迹 Pro】GPS轨迹记录安卓APP使用说明
  11. 【初学大数据】CentOS7安装hadoop3.3.2完全分布式详细流程
  12. 使用BeautifulSoup解析网页内容
  13. java日期计算_java中date日期计算使用方法
  14. Linux环境:可变剪切分析软件rMATS安装、使用与解读
  15. Radix Tree总结
  16. ear的英语怎么念_ear英语怎么读谐音
  17. VisionPro相机操作类
  18. 一步一步玩转树莓派~
  19. 20 多个国外优秀Android开源 App
  20. 王者荣耀背后的实时大数据平台用了什么黑科技?

热门文章

  1. blender 子弹时间 动画
  2. weka中文乱码解决办法
  3. *45.程序的装入方式
  4. C++语言map和unordered_map的下标操作
  5. ARM嵌入式开发之JTAG与SWD接口
  6. 10本计算机视觉必读经典图书,入门篇 + 提升篇
  7. STL 之 list 容器详解
  8. C++中list的使用方法及常用list操作总结
  9. 监控USB设备插入/拔出写法2
  10. linux基础知识和命令试题,Linux基础试题及答案