深度学习中的数据处理概述

深度学习三要素:数据、算力和算法
在工程实践中,数据的重要性越来越引起人们的关注。在数据科学界流传着一种说法,“数据决定了模型的上限,算法决定了模型的下限”,因此在这个“说法”中,明确的表明了,只有好的数据才能够有好的模型,数据才是决定了模型的关键因素。

数据很重要

简单来说,就是找到好的数据,拿给模型“吃”。
但是,什么样才算是“好”的数据,怎么找到“好”的数据,给模型吃了后模型性能有没有变化等等问题是一个非常庞大的课题,本文并不深入探讨,首先从特征工程的角度抛出一张图,另外总结一下深度学习中最常用的几个数据处理过程。

ML/DL的数据处理基本步骤

收集

在进行实验之前,需要收集数据,数据包括原始样本和标签。标签信息一般有收集公开数据集数据、人工标注、自动化/半自动化标注、仿真模拟平台生成等几种方法。

划分

有了原始数据之后,需要对数据集进行划分,把数据集划分为训练集、验证集和测试集

  • 训练集:训练模型
  • 验证集:验证集用于验证模型是否过拟合,通过比较算法在验证集的性能挑选模型的超参数(学习率、优化算法、网络结构等)
  • 测试集: 测试模型的性能,测试模型的泛化能力(往往测试话指标由第三方出,算法同学不接触测试数据和标签)

数据读取

pytorch中数据读取的核心是DataLoader
DataLoader还会细分为两个子模块,Sampler和DataSet;Sample的功能是生成索引,也就是样本的序号;Dataset是根据索引去读取图片以及对应的标签

数据预处理

比如说数据的中心化,标准化,旋转或者翻转等
pytorch中数据预处理是通过transforms进行处理的

PyTorch中的数据读取模块

torch.utils.data.DataLoader

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_works=0,clollate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)
  • 功能:构建可迭代的数据装载器;
  • dataset:Dataset类,决定数据从哪里读取及如何读取;
  • batchsize:批大小;
  • num_works:是否多进程读取数据;
  • shuffle:每个epoch是否乱序;
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;

torch.utils.data.Dataset

class Dataset(object):def __getitem__(self, index):raise NotImplementedErrordef __add__(self, other)return ConcatDataset([self,other])
  • Dataset是用来定义数据从哪里读取,以及如何读取的问题;
  • 功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__();
  • 函数__getitem__() 作用:接收一个索引,返回一个样本

一个分类任务的数据读取例子

详情见

这里是引用分类任务DataLoader例子
核心代码:

# 构建MyDataset实例,MyDataset必须是用户自己构建的
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)  # data_dir是数据的路径,transform是数据预处理
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)  # 一个用于训练,一个用于验证#构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)  # shuffle=True,每一个epoch中样本都是乱序的
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

其中,DataLoader会传入一个参数Dataset,也就是前面构建好的RMBDataset;第二个参数是batch_size,shuffle=True,它的作用是每一个epoch中样本都是乱序的

代码中跟踪RMBDataset构建了两个Dataset,一个用于训练,一个用于验证。
核心为重写了函数

def __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB')     # 0~255if self.transform is not None:img = self.transform(img)   # 在这里做transform,转为tensor等等return img, label

数据预处理模块Transforms

torchvision

是pytorch的计算机视觉工具包
主要有三个模块:

  • torchvision.transforms,常用的图像预处理方法,在transforms中提供了一系列的图像预处理方法,例如数据的标准化,中心化,旋转,翻转等等;
  • torchvision.datasets,定义了一系列常用的公开数据集的datasets,比如常用的MNIST,CIFAR-10,ImageNet等等;
  • torchvision.model,提供大量常用的预训练模型,例如AlexNet,VGG,ResNet,GoogLeNet等等;

torchvision.transforms

常用的图像预处理方法:

  • 数据中心化
  • 数据标准化
  • 缩放
  • 裁剪
  • 旋转
  • 翻转
  • 填充
  • 噪声添加
  • 灰度变换
  • 线性变换
  • 仿射变换
  • 亮度、饱和度及对比度变换

transforms.Compose的功能是将一系列的transforms方法进行有序的组合包装,在具体实现的时候,会依次按顺序对图像进行操作
使用:

#设置数据标准化的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),  #Resize,将图像缩放到32*32的大小transforms.RandomCrop(32, padding=4),  #RandomCrop,对数据进行随机的裁剪transforms.ToTensor(),  #ToTensor,将图片转成张量的形式同时会进行归一化操作,把像素值的区间从0-255归一化到0-1transforms.Normalize(norm_mean, norm_std),  #标准化操作,将数据的均值变为0,标准差变为1
])   # Resize的功能是缩放,RandomCrop的功能是裁剪,ToTensor的功能是把图片变为张量

(1) transforms.Normalize

  • 功能:逐channel的对图像进行标准化,即数据的均值变为0,标准差变为1。逐通道的意思是,以GRB图像为例,计算训练集的所有图片的R通道之和,并除以 (图像N x W x H)
  • 标准化的计算公式为 o u t p u t = ( i n p u t − m e a n ) / s t d output = (input - mean) /stdoutput=(input−mean)/std
  • mean:各通道的均值
  • std:各通道的标准差
  • inplace:是否原位操作

对数据进行标准化之后可以加快模型的收敛
为什么会呢,这是一个比较大的话题,后续有时间会开一个专栏尝试解答该问题。

PyTorch 入坑六 数据处理模块Dataloader、Dataset、Transforms相关推荐

  1. PyTorch 入坑七:模块与nn.Module学习

    PyTorch 入坑七 模型创建概述 PyTorch中的模块 torch模块 torch.Tensor模块 torch.sparse模块 torch.cuda模块 torch.nn模块 torch.n ...

  2. PyTorch入坑(一)~(三): Tensor的概念,基本操作和线性回归

    PyTorch 一文入门 PyTorch 入坑一:数据类型与Tensor的概念 PyTorch数据类型 Tensor的概念 Tensor与Variable Variable Tensor Tensor ...

  3. PyTorch 入坑四 梯度、链式法则、计算图与反向传播

    PyTorch 一文入门 PyTorch 入坑四:计算图与反向传播 导数.梯度与最优化 方向导数与梯度 最优化 复合函数求梯度 计算图 PyTorch实战 PyTorch中的动态图机制 PyTorch ...

  4. PyTorch 入坑十一: 损失函数、正则化----深刻剖析softmax+CrossEntropyLoss

    这里写目录标题 概念 Loss Function Cost Function Objective Function 正则化 损失函数 交叉熵损失函数nn.CrossEntropyLoss() 自信息 ...

  5. PyTorch 入坑九:权重初始化

    这里写目录标题 Pytorch提供的十种权值初始化方法 为什么要进行权重初始化设计:从梯度消失与爆炸说起 常用的几种权重初始化方法 不考虑激活函数 饱和型激活函数 xavier_uniform方法 非 ...

  6. PyTorch 入坑八:卷积与转置卷积

    卷积 操作 略 输入输出尺寸变化 略 PyTorch实现 nn.Conv2d(in_channels,out_channels,kernel_size,stride=1.padding=0,dilat ...

  7. PyTorch 入坑五 autograd与逻辑回归

    torch.autograd 深度学习模型的训练就是不断更新权值,权值的更新需要求解梯度,梯度在模型训练中是至关重要的.       然而求解梯度十分繁琐,pytorch提供自动求导系统.我们不需要手 ...

  8. Pytorch之DataLoader Dataset、datasets、models、transforms的认识和学习

    文章目录 利用PyTorch框架来开发深度学习算法时几个基础的模块 Dataset & DataLoader 基础概念 自定义数据集 1 读取自定义数据集 1 自定义数据集 2 自定义数据集3 ...

  9. 查看dataloader的大小_一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...

最新文章

  1. Linux系统中运行.sh文件的几种方法
  2. 查找Excel的Sheetname的方法
  3. kafka删除队列_没想到 Kafka 还会这样问,学会这些带你轻松搞定大厂面试!
  4. JavaScript中的的面向对象中的一些知识
  5. libzbar.a armv7
  6. Callable接口-创建线程的第三种方法
  7. 用C++,调用浏览器打开一个网页
  8. “中文直达”能推动移动上网入口洗牌吗?
  9. 用自己的数据集训练Mask-RCNN实现过程中的坑
  10. 【系统】lsof 命令简介
  11. Linux 文件权限、用户权限和用户组管理vim的一些基本使用技巧
  12. JAVA中如何全局监听鼠标事件
  13. postgresql 客户端登录配置
  14. (笔记总结)串行通信接口标准
  15. vmware安装win7 vmware tools
  16. CentOS查看端口
  17. ORA-01033错误解决方法
  18. 计算机图形学期末考试题及答案,计算机图形学期末考试卷答案
  19. 百度云主机(配置、调试环境)文件上传(快速建站,实现首页访问) - FTP篇
  20. 计算机应用 审稿2个月,我的《计算机应用》投稿经历

热门文章

  1. python中random库
  2. PHP基础学习课程强烈推荐
  3. java图片色阶调整、亮度调整
  4. jQuery Mobile中网格grid样式ui-grid-*
  5. matlab2c使用c++实现matlab函数系列教程-diag函数
  6. python打印9宫格,25宫格等奇数格,且横竖斜相加和相等
  7. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十二)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网。...
  8. asp.net GridView控件的列属性
  9. Java基础知识强化87:BigInteger类之BigInteger加减乘除法的使用
  10. 实现iOS长时间后台的两种方法:Audiosession和VOIP(转)