一,数据加载

数据路径:

#coding:utf-8
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as npclass DogCat(data.Dataset):def __init__(self, path):imgs = os.listdir(path)# 所有图片的绝对路径# 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片self.imgs_list_path = [os.path.join(path, i) for i in imgs]def __getitem__(self, index):img_path = self.imgs_list_path[index]# dog->1, cat->0label = 1 if 'dog' in img_path.split('/')[-1] else 0pil_img = Image.open(img_path)array = np.asarray(pil_img)img = t.from_numpy(array)return img_path,img, labeldef __len__(self):return len(self.imgs_list_path)
if __name__ == '__main__':dataset = DogCat('./data/dogcat/')# img, label = dataset[0]  # 相当于调用dataset.__getitem__(0)print('len(dataset)=',len(dataset))for img_path,img, label in dataset:print(img_path,img.size(), img.float().mean(), label)

打印结果:

二,数据归一化

PyTorch提供了torchvision1。它是一个视觉工具包,提供了很多视觉图像处理的工具,其中transforms模块提供了对PIL Image对象和Tensor对象的常用操作。

对PIL Image的操作包括:

  • Scale:调整图片尺寸,长宽比保持不变
  • CenterCropRandomCropRandomResizedCrop: 裁剪图片
  • Pad:填充
  • ToTensor:将PIL Image对象转成Tensor,会自动将[0, 255]归一化至[0, 1]
  • transforms.ColorJitter(0.3, 0.3, 0.2) 颜色抖动
  • transforms.RandomRotation(10)随机旋转

对Tensor的操作包括:

  • Normalize:标准化,即减均值,除以标准差
  • ToPILImage:将Tensor转为PIL Image对象
#coding:utf-8
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
from torchvision import transformstransform = transforms.Compose([transforms.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素transforms.CenterCrop(224), # 从图片中间切出224*224的图片transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差#input[channel] = (input[channel] - mean[channel]) / std[channel]
])class DogCat(data.Dataset):def __init__(self, path,transforms=None):imgs = os.listdir(path)# 所有图片的绝对路径# 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片self.imgs_list_path = [os.path.join(path, i) for i in imgs]self.transforms=transformsdef __getitem__(self, index):img_path = self.imgs_list_path[index]# dog->1, cat->0label = 1 if 'dog' in img_path.split('/')[-1] else 0pil_img = Image.open(img_path)if self.transforms:pil_img=self.transforms(pil_img)array = np.asarray(pil_img)img = t.from_numpy(array)return img_path,img, labeldef __len__(self):return len(self.imgs_list_path)
if __name__ == '__main__':dataset = DogCat('./data/dogcat/',transforms=transform)# img, label = dataset[0]  # 相当于调用dataset.__getitem__(0)print('len(dataset)=',len(dataset))for img_path,img, label in dataset:print(img_path,img.size(), img.float().mean(), label)

三,利用fer2013数据集进行预处理

数据集地址:https://download.csdn.net/download/fanzonghao/11183885

''' Fer2013 Dataset class'''
from __future__ import print_function
from PIL import Image
import numpy as np
import h5py
import torch.utils.data as data
import cv2
import torchvision.transforms as transforms# 定义对数据的预处理
transform = transforms.Compose([transforms.ToTensor(), # 转为Tensor 归一化至0~1transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化])
class FER2013(data.Dataset):"""`FER2013 Dataset.Args:train (bool, optional): If True, creates dataset from training set, otherwisecreates from test set.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``"""def __init__(self, path,split='Training', transform=None):self.transform = transformself.split = split  # training set or test setself.data = h5py.File(path, 'r', driver='core')# now load the picked numpy arraysif self.split == 'Training':self.train_data = self.data['Training_pixel']self.train_labels = self.data['Training_label']self.train_data = np.asarray(self.train_data)self.train_data = self.train_data.reshape((28709, 48, 48))elif self.split == 'PublicTest':self.PublicTest_data = self.data['PublicTest_pixel']self.PublicTest_labels = self.data['PublicTest_label']self.PublicTest_data = np.asarray(self.PublicTest_data)self.PublicTest_data = self.PublicTest_data.reshape((3589, 48, 48))else:self.PrivateTest_data = self.data['PrivateTest_pixel']self.PrivateTest_labels = self.data['PrivateTest_label']self.PrivateTest_data = np.asarray(self.PrivateTest_data)self.PrivateTest_data = self.PrivateTest_data.reshape((3589, 48, 48))def __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""if self.split == 'Training':img, target = self.train_data[index], self.train_labels[index]elif self.split == 'PublicTest':img, target = self.PublicTest_data[index], self.PublicTest_labels[index]else:img, target = self.PrivateTest_data[index], self.PrivateTest_labels[index]# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = img[:, :, np.newaxis]img = np.concatenate((img, img, img), axis=2)img = Image.fromarray(img)if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):if self.split == 'Training':return len(self.train_data)elif self.split == 'PublicTest':return len(self.PublicTest_data)else:return len(self.PrivateTest_data)if __name__ == '__main__':train_data=FER2013(path='./data/data.h5',split='Training',transform=transform)train_loader = data.DataLoader(dataset=train_data,batch_size=8,shuffle=True,num_workers=2)print(len(train_data))# for i,(img,label) in enumerate(train_data):#     if i<1:#         img=np.transpose(np.array(img),(1,2,0))#         print(img.shape)#         img=(img*0.5+0.5)*255#         cv2.imwrite('1.jpg',img)#         print(label.shape)for i,(img, label) in enumerate(train_loader):if i<1:print('train')img=np.transpose(np.array(img)[0],(1,2,0))img = (img * 0.5 + 0.5) * 255cv2.imwrite('2.jpg',img)

结果:

pytorch数据预处理相关推荐

  1. (第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

    前言:在深度学习中,数据的预处理是第一步,pytorch提供了非常规范的处理接口,本文将针对处理过程中的一些问题来进行说明,本文所针对的主要数据是图像数据集. 本文的案例来源于车道线语义分割,采用的数 ...

  2. 深度之眼Pytorch打卡(九):Pytorch数据预处理——预处理过程与数据标准化(transforms过程、Normalize原理、常用数据集均值标准差与数据集均值标准差计算)

    前言   前段时间因为一些事情没有时间或者心情学习,现在两个多月过去了,事情结束了,心态也调整好了,所以又来接着学习Pytorch.这篇笔记主要是关于数据预处理过程.数据集标准化与数据集均值标准差计算 ...

  3. Pytorch 数据预处理

    数据预处理 0. 环境介绍 环境使用 Kaggle 里免费建立的 Notebook 教程使用李沐老师的 动手学深度学习 网站和 视频讲解 小技巧:当遇到函数看不懂的时候可以按 Shift+Tab 查看 ...

  4. 4.3 pytorch数据预处理:transforms图像增强方法

    一.数据增强概述 二.数据增强方法:裁剪 三.数据增强方法:翻转和旋转 四.数据增强方法:变换 五.transforms方法的选择操作 一.数据增强概述 我们来看图片中的数据增强是怎么样的. 左边的图 ...

  5. PyTorch系列 (二): pytorch数据读取自制数据集并

    PyTorch系列 (二): pytorch数据读取 PyTorch 1: How to use data in pytorch Posted by WangW on February 1, 2019 ...

  6. 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)

    首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple- ...

  7. 英伟达DALI加速技巧:使数据预处理比原生PyTorch运算速度快4倍

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 你的数据处理影响整个训练速度,如果加上英伟达 DALI 库,处理速度比原生 PyT ...

  8. 【Pytorch神经网络基础理论篇】 03 数据操作 + 数据预处理

    1.数组样例 2.数据操作(代码实现与结果截图) #首先导入torch,虽然被称为pytorch,但是我们应该导入torch而不是pytorch import torch #张量表示一个数值组成的数组 ...

  9. 英伟达DALI加速技巧:让数据预处理速度比原生PyTorch快4倍

    点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自机器之心. 选自towardsdatascience 作者:Pieterluitjens 机器之心编译 参与:一鸣.嘉明.思 你的数据处理影响 ...

最新文章

  1. 实战SSM_O2O商铺_11【商铺注册】Controller层的实现
  2. BZOJ2816:[ZJOI2012]网络(LCT)
  3. 个人图书管理系统c语言代码,c语言源代码---------------个人图书管理系统
  4. python怎么设置图的大小_python – 如何调整seaborn中的子图大小?
  5. 操作系统与存储:解析Linux内核全新异步IO引擎io_uring设计与实现
  6. hadoop安装部署(伪分布及集群)
  7. 注释嵌套注释_注释梦Night
  8. 缓存服务的更新策略有哪些?
  9. C# winfrom gridview全部选择和全部取消
  10. 申请CSDN博客专家的成功历程
  11. Android APK实现WIFI协议包抓取(上)-实现思路
  12. 它不是哆啦A梦 也能满足你的挑剔需求
  13. php源码旅行网站模板,背包客旅行扁平网站模板
  14. 【Spinning up】零、DRLib:一个简洁的强化学习库,集成了HER和PER
  15. 数组 reduce 简介及使用场景
  16. [Neo4j] CQL命令
  17. 离婚了我们先同居 (转贴)
  18. 知识总结--性能优化总结(摘录+转载)
  19. latex如何使文字不空格_latex 段开头不空格
  20. HeapCreate()

热门文章

  1. 拒绝无脑吹!从ACL20看预训练缺陷
  2. 论文小综 | 文档级关系抽取方法(下)
  3. 论文浅尝 | 混合注意力原型网络的含噪音少样本的关系分类
  4. 基于MATLAB的Sobel边缘检测算法实现
  5. hdoj-1004-Let the Balloon Rise(map排序)
  6. import-module的注意事项与NDK_MODULE_PATH的配置
  7. MemoryInjector 无痕注入
  8. WiFi共享精灵 - 不需路由器一键轻松把网线共享给手机、笔记本等同时无线上网...
  9. 程序员在群询问破解软件
  10. 2021-07-27 详解TCP连接建立和释放的过程(三报文握手和四次挥手)