Pytorch数据加载顺序


手把手视频讲解+代码讲解

1.如何实现输入(完全免费解析直达,致力干货分享)

2.如何实现模型(完全免费解析直达,致力干货分享)

3.如何实现输出(完全免费解析直达,致力干货分享)


神经网络模型训练过程需要进行梯度更新,梯度更新可分三种方式。1.批梯度下降(batch gradient descent):一次所有数据批计算,过于复杂,计算缓慢;2.随机梯度下降(stochastic gradient descent):每次读一个数据,数据差异大,导致训练波动太大,收敛性不好;3.最小批量梯度下降(mini-batch gradient descent / SGD gradient descent):随机取一定量数据进行训练,既降低计算量,又能提高训练速度。

使用pytorch对数据进行批次量读取构建,首先了解其加载数据顺序分为以下三个点。

pytorch中加载数据的顺序分为以下三个点:
1."创建一个 dataset 对象"; 并加入 transforms 数据增强方案;
2."创建一个 dataloader 对象";
3."获取数据集的 mini_batch"; 循环 dataloader 对象, 获取训练样本送入模型进行训练;其中,
"1.创建一个 dataset 对象", 继承 pytorch 的 torch.utils.data.Dataset; 一般需要含3个主要函数:1.__init__:    传入数据, 或者直接加载固化的数据包;2.__len__:     返回这个数据集一共有多少个item;3.__getitem__:  返回一条训练数据, 并将其转换成tensor;"2.创建一个 dataloader 对象", 采用 pytorch 的 torch.utils.data.DataLoader 整合成 mini_batch;"3.获取数据集的 mini_batch"

Pytorch官方示例与实践改造

Pytorch官方示例与实践改造

数据加载万能模板

针对自己数据集进行分装,数据列表单元+数据增强单元是我们需要关注的点,所以只要在这两个函数进行改造,其他部分和官方的1.dataset对象,2.dataloader对象,3.mini_batch获取一致。

模板代码

######## py内置函数:help-文件架构, dir-代码架构 ########
import torch # 包含基本,加减乘除,张量操作,优化器'torch.optim', 数据索引 'torch.utils.data.DataLoader'
import torch.nn as nn # "类":   包含卷积,池化,激活,损失等 "nn.CrossEntropyLoss()"
import torch.nn.functional as F  # "函数": 包含卷积,池化,激活,损失等 "F.cross_entropy()"
import torchvision # 包含图像算法的基本操作等 torchvision.models; torchvision.datasets;
import torchvision.transforms as T # "类":   包含图像增强方向等 "T.RandomCrop()"
import torchvision.transforms.functional as TF # "函数": 包含图像增强方向等 "TF.center_crop()"
import os
import glob
import math
import numpy as np
import random
from PIL import Image
import PIL
import matplotlib.pyplot as plt#################### 构建 lines 可略 ####################
class MyLinesGetter(object):def __init__(self, FilePath, dtype="seg"):self.FilePath = FilePathself.dtype = dtype # None="cls", "seg"def getter(self):self.datalines = []with open(self.FilePath, "r") as f:lines = f.read().splitlines()if self.dtype is 'seg':for line in lines:img_dir, seg_dir = line.split(" ")[:2]img_dir = os.path.join("data_flowers", "JPEGImages", img_dir)seg_dir = os.path.join("data_flowers", "SegmentationClassRAW", seg_dir)self.datalines.append([img_dir, seg_dir])else:raise "wrong dtype! check dtype on ['seg']!"return self.datalines#################### 创建 dataset class ####################
class SegmentDataset(torch.utils.data.Dataset): # 继承def __init__(self, dataset, transforms=None):self.dataset = datasetself.transforms = transformsdef __len__(self):return len(self.dataset)def __getitem__(self, idx):img_dir, seg_dir = self.dataset[idx]img = Image.open(img_dir)seg = Image.open(seg_dir)if self.transforms is not None:data_dict = self.transforms(img, seg)img = data_dict['image']seg = data_dict["mask"]else:img = TF.to_tensor(img)seg = torch.as_tensor(np.array(seg), dtype=torch.int64)return img, segpass#################### 创建 transforms+Compose 增强方案 ####################
class Resize(object):def __init__(self, size):self.size = sizedef __call__(self, image, target=None, label=None):image = TF.resize(image, self.size)if target is not None:target = TF.resize(target, self.size, interpolation=PIL.Image.BILINEAR) # PIL.Image.BILINEARif label is not None:label = labelreturn image, target, labelpassclass ToTensor(object):def __call__(self, image, target=None, label=None):image = TF.to_tensor(image)if target is not None:target = torch.as_tensor(np.array(target), dtype=torch.int64)return image, target, labelpass# 可用 torchvision 里面的 compose, 为方便看过程,因此自己实现一遍
class Compose(object):def __init__(self, transforms):self.transforms = transformsdef __call__(self, image, mask=None, label=None):for t in self.transforms:image, mask, label = t(image, mask, label)return {'image':image, 'mask':mask, 'label':label}passif __name__=="__main__":# "1.创建一个 dataset 对象"train_dataset = SegmentDataset(MyLinesGetter(FilePath="data_flowers/train.txt", dtype="seg").getter(), transforms=Compose([Resize((256,256)), ToTensor(),]))# "2.创建一个 dataloader 对象"train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)# "3.获取数据集的 mini_batch"for (images, masks) in train_data_loader:plt.figure(figsize=(20,20))plt.imshow(np.hstack(images.permute(0,2,3,1)))plt.show()plt.figure(figsize=(20,20))plt.imshow(np.hstack(masks))plt.show()break

参考链接

植物素材库
代码高亮
Pytorch dataset&dataloader

图像语义分割实践(二)数据增强与读取相关推荐

  1. 图像语义分割实践(一)标签制作与转换

    语义分割实践过程中,网上常用公开数据集和模型都是打包固化的,只要配置环境一致就可以进行所谓"复现",但是这种拿来主义很被动,很多人陷入邯郸学步,贪多嚼不烂的窘迫感,最后浪费大把时间 ...

  2. 图像语义分割实践(三)模型搭建与实现

    众所周知,神经网络搭建常用基础模块有卷积,池化,归一,激活,全连接等等.如果使用Pytorch进行网络的搭建时,除了需要掌握这些基础模块外,还需要熟悉模型容器. Pytorch.nn 的容器conta ...

  3. 图像语义分割实践(四)损失函数与实现

    在确定检测任务和模型构建完成后,随之需要对训练的准则Criterion进行实现,可称之为损失函数或代价函数,简明而言,训练过程中真实值和计算值的误差. ​​ 视频讲解,转移至西瓜视频主页:@智能之心 ...

  4. 图像语义分割实践(五)优化器与学习率

    概述 在数据制作环节中,提到minibatch思想用于数据批次量获取,是一种优化器思想,而该文则是对各种优化器进行介绍. 优化器:最小化损失函数算法,把深度学习当炼丹的话,优化器就是炉子,决定火候大小 ...

  5. 制作用于图像语义分割训练的标签数据【图像分割】【labelme】

    制作用于图像语义分割训练的标签数据 *写在前面 一.使用labelme制作json数据 1.安装labelme 2.利用labelme制作json数据 二.将json数据转化为图像数据 1.单个jso ...

  6. 遥感图像语义分割——从原始图像开始制作自己的数据集(以高分二号为例)

    遥感图像语义分割--从原始图像开始制作自己的数据集(以高分二号为例) 文章目录 遥感图像语义分割--从原始图像开始制作自己的数据集(以高分二号为例) 1.遥感影像获取 2.遥感数据预处理(影像融合) ...

  7. 图像语义分割python_图像语义分割ICNET_飞桨-源于产业实践的开源深度学习平台...

    图像语义分割-ICNET 类别 智能视觉(PaddleCV) 应用 自动驾驶 室内导航 医学图像诊断 穿戴设备 虚拟现实与增强现实 无人机 模型概述 ICNet 主要用于图像实时语义分割,能够兼顾速度 ...

  8. PaddleServing图像语义分割部署实践

    目录 一.任务概述 二.官方示例部署 2.1 安装PaddleServing 2.2 导出静态图模型 2.3 转换为serving模型 2.4 启动服务 2.5 客户端请求 三.基于PipeLine的 ...

  9. 深度学习(二十一)基于FCN的图像语义分割-CVPR 2015-未完待续

    CNN应用之基于FCN的图像语义分割 原文地址:http://blog.csdn.net/hjimce/article/details/50268555 作者:hjimce 一.相关理论     本篇 ...

最新文章

  1. 在网络通讯中,如何自己分配可用的端口号和获取自己的ip地址
  2. Exchange2010 控制台提示您的权限不足,无法查看此数据
  3. CosmoMC第一次测试
  4. Netty基础系列(1) --linux网路I/O模型
  5. R语言入门3---R语言六大基本数据结构
  6. oracle open for using的用法,oracle OPEN FOR [USING] 语句
  7. Python基础python变量
  8. linux中配置tomcat
  9. hdfs的副本数为啥增加了_HDFS架构小结
  10. USB 3.0 是什么
  11. eclipse下的tomcat内存设置大小(转)
  12. FPGA中亚稳态相关问题及跨时钟域处理
  13. 《系统集成项目管理》第四章 项目管理一般知识
  14. Abaqus2020帮助文件无法搜索问题
  15. 剪映怎么把无字幕的英文视频翻译成制作成中文字幕?(附教程+剪映字幕翻译工具免费下载)...
  16. Android蓝牙音量调节,安卓 蓝牙音量控制 Bluetooth Volume Control v2.40 付费高级特别版...
  17. 高通 MSM 8916与MSM8926芯片的区别
  18. php 专业英语,给大家推荐几个专业英语翻译功能强大的网站
  19. 社交电商-京东云小店简介
  20. 【USACO3-4-2】电网 皮克定理

热门文章

  1. httpd关于web-dav的配置
  2. html5怎么设置视频快进,如何在剪映APP中给视频设置快进效果
  3. TQ2440 GPB口控制寄存器GPBCON和GPBUP作用
  4. 简单的网络3D解密游戏
  5. 绘制变形图形--Canvas的基本操作
  6. rfcn 共享_RFCN理解,不完整待补全
  7. 名帖219 赵孟頫 行书《止斋记》
  8. 基于深度学习的特征提取、匹配全解析
  9. MysqlWorkBench将已有数据库转换为mwb模型文件
  10. 基于SSM的社会救助信息管理 毕业设计-附源码211633