前言 

本文介绍了classdataset的几个要点,由哪些部分组成,每个部分需要完成哪些事情,如何进行数据增强,如何实现自己设计的数据增强。然后,介绍了分布式训练的数据加载方式,数据读取的整个流程,当面对超大数据集时,内存不足的改进思路。

本文延续了以往的写作态度和风格,即便是自己知道的内容,也仍然在写之前看了很多的文章来保证内容的正确性和全面性,因此写得极累,耗费时间较长。若有读者看完后觉得有所帮助,文末可以赞赏一点。

文末扫描二维码关注公众号CV技术指南 ,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读,招聘信息发布。

(零) 概述


浮躁是人性的一个典型的弱点,很多人总擅长看别人分享的现成代码解读的文章,看起来学会了好多东西,实际上仍然不具备自己从零搭建一个pipeline的能力。

在公众号(CV技术指南)的交流群里(群内交流氛围不错,有需要的请关注公众号加群),常有不少人问到一些问题,根据这些问题明显能看出是对pipeline不了解,却已经在搞项目或论文了,很难想象如果基本的pipeline都不懂,如何分析代码问题所在?如何分析结果不正常的可能原因?遇到问题如何改?

Pytorch在这几年逐渐成为了学术上的主流框架,其具有简单易懂的特点。网上有很多pytorch的教程,如果是一个已经懂的人去看这些教程,确实pipeline的要素都写到了,感觉这教程挺不错的。但实际上更多地像是写给自己看的一个笔记,记录了pipeline要写哪些东西,却没有介绍要怎么写,为什么这么写,刚入门的小白看的时候容易云里雾里。

鉴于此,本教程尝试对于pytorch搭建一个完整pipeline写一个比较明确且易懂的说明。

本教程将介绍以下内容:

  1. 准备数据,自定义classdataset,分布式训练的数据加载方式,加载超大数据集的改进思路。

  2. 搭建模型与模型初始化。

  3. 编写训练过程,包括加载预训练模型、设置优化器、设置损失函数等。

  4. 可视化并保存训练过程。

  5. 编写推理函数。

(一)数据读取


classdataset的定义

先来看一个完整的classdataset

import torch.utils.data as data
import torchvision.transforms as transformsclass MyDataset(data.Dataset):def __init__(self,data_folder):self.data_folder = data_folderself.filenames = []self.labels = []per_classes = os.listdir(data_folder)for per_class in per_classes:per_class_paths = os.path.join(data_folder, per_class)label = torch.tensor(int(per_class))per_datas = os.listdir(per_class_paths)for per_data in per_datas:self.filenames.append(os.path.join(per_class_paths, per_data))self.labels.append(label)def __getitem__(self, index):image = Image.open(self.filenames[index])label = self.labels[index]data = self.proprecess(image)return data, labeldef __len__(self):return len(self.filenames)def proprecess(self,data):transform_train_list = [transforms.Resize((self.opt.h, self.opt.w), interpolation=3),transforms.Pad(self.opt.pad, padding_mode='edge'),transforms.RandomCrop((self.opt.h, self.opt.w)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]return transforms.Compose(transform_train_list)

classdataset的几个要点:

  1. classdataset类继承torch.utils.data.dataset。

  2. classdataset的作用是将任意格式的数据,通过读取、预处理或数据增强后以tensor的形式输出。其中任意格式的数据可能是以文件夹名作为类别的形式、或以txt文件存储图片地址的形式、或视频、或十几帧图像作为一份样本的形式。而输出则指的是经过处理后的一个batch的tensor格式数据和对应标签。

  3. classdataset主要有三个函数要完成:__init__函数、__getitem__ 函数和__len__函数。

__init__函数

init函数主要是完成两个静态变量的赋值。一个是用于存储所有数据路径的变量,变量的每个元素即为一份训练样本,(注:如果一份样本是十几帧图像,则变量每个元素存储的是这十几帧图像的路径),可以命名为self.filenames。一个是用于存储与数据路径变量一一对应的标签变量,可以命名为self.labels。

假如数据集的格式如下:

#这里的0,1指的是类别0,1
/data_path/0/image0.jpg
/data_path/0/image1.jpg
/data_path/0/image2.jpg
/data_path/0/image3.jpg
......
/data_path/1/image0.jpg
/data_path/1/image1.jpg
/data_path/1/image2.jpg
/data_path/1/image3.jpg

可通过per_classes = os.listdir(data_path) 获得所有类别的文件夹,在此处per_classes的每个元素即为对应的数据标签,通过for遍历per_classes即可获得每个类的标签,将其转换成int的tensor形式即可。在for下获得每个类下每张图片的路径,通过self.join获得每份样本的路径,通过append添加到self.filenames中。

__getitem__ 函数

getitem 函数主要是根据索引返回对应的数据。这个索引是在训练前通过dataloader切片获得的,这里先不管。它的参数默认是index,即每次传回在init函数中获得的所有样本中索引对应的数据和标签。因此,可通过下面两行代码找到对应的数据和标签。

image = Image.open(self.filenames[index]))
label = self.labels[index]

获得数据后,进行数据预处理。数据预处理主要通过 torchvision.transforms 来完成,这里面已经包含了常用的预处理、数据增强方式。其完整使用方式在官网有详细介绍:https://pytorch.org/vision/stable/transforms.html

上面这里介绍了最常用的几种,主要就是resize,随机裁剪,翻转,归一化等。

最后通过transforms.Compose(transform_train_list)来执行。

除了这些已经有的数据增强方式外,在《数据增强方法总结》中还介绍了十几种特殊的数据增强方式,像这种自己设计了一种新的数据增强方式,该如何添加进去呢

下面以随机擦除作为例子。

class RandomErasing(object):""" Randomly selects a rectangle region in an image and erases its pixels.'Random Erasing Data Augmentation' by Zhong et al.See https://arxiv.org/pdf/1708.04896.pdfArgs:probability: The probability that the Random Erasing operation will be performed.sl: Minimum proportion of erased area against input image.sh: Maximum proportion of erased area against input image.r1: Minimum aspect ratio of erased area.mean: Erasing value."""def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):self.probability = probabilityself.mean = meanself.sl = slself.sh = shself.r1 = r1def __call__(self, img):if random.uniform(0, 1) > self.probability:return imgfor attempt in range(100):area = img.size()[1] * img.size()[2]target_area = random.uniform(self.sl, self.sh) * areaaspect_ratio = random.uniform(self.r1, 1 / self.r1)h = int(round(math.sqrt(target_area * aspect_ratio)))w = int(round(math.sqrt(target_area / aspect_ratio)))if w < img.size()[2] and h < img.size()[1]:x1 = random.randint(0, img.size()[1] - h)y1 = random.randint(0, img.size()[2] - w)if img.size()[0] == 3:img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]else:img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]return imgreturn img

如上所示,自己写一个类RandomErasing,继承object,在call函数里完成你的操作。在transform_train_list里添加上RandomErasing的定义即可。

transform_train_list = [transforms.Resize((self.opt.h, self.opt.w), interpolation=3),transforms.Pad(self.opt.pad, padding_mode='edge'),transforms.RandomCrop((self.opt.h, self.opt.w)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])RandomErasing(probability=self.opt.erasing_p, mean=[0.0, 0.0, 0.0])#添加到这里]

__len__函数

len函数主要就是返回数据长度,即样本的总数量。前面介绍了self.filenames的每个元素即为每份样本的路径,因此,self.filename的长度就是样本的数量。通过return len(self.filenames)即可返回数据长度。

验证classdataset

train_dataset = My_Dataset(data_folder=data_folder)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
print('there are total %s batches for train' % (len(train_loader)))for i,(data,label) in enumerate(train_loader):print(data.size(),label.size())

分布式训练的数据加载方式


前面介绍的是单卡的数据加载,实际上分布式也是这样,但为了高速高效读取,每张卡上也会保存所有数据的信息,即self.filenames和self.labels的信息。只是在DistributedSampler 中会给每张卡分配互不交叉的索引,然后由torch.utils.data.DataLoader来加载。

dataset = My_Dataset(data_folder=data_folder)
sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)

数据读取的完整流程


结合上面这段代码,在这里,我们介绍以下读取数据的整个流程。

  1. 首先定义一个classdataset,在初始化函数里获得所有数据的信息。

  2. classdataset中实现getitem函数,通过索引来获取对应的数据,然后对数据进行预处理和数据增强。

  3. 在模型训练前,初始化classdataset,通过Dataloader来加载数据,其加载方式是通过Dataloader中分配的索引,调用getitem函数来获取。

    关于索引的分配,在单卡上,可通过设置shuffle=True来随机生成索引顺序;在多机多卡的分布式训练上,shuffle操作通过DistributedSampler来完成,因此shuffle与sampler只能有一个,另一个必须为None。

超大数据集的加载思路


问题所在

再回顾一下上面这个流程,前面提到所有数据信息在classdataset初始化部分都会保存在变量中,因此当面对超大数据集时,会出现内存不足的情况。

思路

将切片获取索引的步骤放到classdataset初始化的位置,此时每张卡都是保存不同的数据子集。通过这种方式,可以将内存用量减少到原来的world_size倍(world_size指卡的数量)。

参考代码

class RankDataset(Dataset):'''实际流程获取rank和world_size 信息 -> 获取dataset长度 -> 根据dataset长度产生随机indices ->给不同的rank 分配indices -> 根据这些indices产生metas'''def __init__(self, meta_file, world_size, rank, seed):super(RankDataset, self).__init__()random.seed(seed)np.random.seed(seed)self.world_size = world_sizeself.rank = rankself.metas = self.parse(meta_file)def parse(self, meta_file):dataset_size = self.get_dataset_size(meta_file)                                     # 获取metafile的行数local_rank_index = self.get_local_index(dataset_size, self.rank, self.world_size)   # 根据world size和rank,获取当前epoch,当前rank需要训练的index。self.metas = self.read_file(meta_file, local_rank_index)def __getitem__(self, idx):return self.metas[idx]def __len__(self):return len(self.metas)##train
for epoch_num in range(epoch_num):dataset = RankDataset("/path/to/meta", world_size, rank, seed=epoch_num)sampler = RandomSampler(datset)dataloader = DataLoader(dataset=dataset,batch_size=32,shuffle=False,num_workers=4,sampler=sampler)

但这种思路比较明显的问题时,为了让每张卡上在每个epoch都加载不同的训练子集,因此需要在每个epoch重新build dataloader。

这一节参考链接:https://zhuanlan.zhihu.com/p/357809861

总结


本篇文章介绍了数据读取的完整流程,如何自定义classdataset,如何进行数据增强,自己设计的数据增强如何写,分布式训练是如何加载数据的,超大数据集的数据加载改进思路。

相信读完本文的读者对数据读取有了比较清晰的认识,下一篇将介绍搭建模型与模型初始化。

关注公众号可加计算机视觉交流群

欢迎关注公众号 CV技术指南 ,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读。

在公众号中回复关键字 “入门指南“可获取计算机视觉入门所有必备资料。

其它文章

自编码器综述论文:概念、图解和应用

解决图像分割落地场景真实问题,港中文等提出:开放世界实体分割

资源分享 |Nebullvm:一行代码测试多个DL编译器,模型推理提高5-20倍

目标检测、实例分割、多目标跟踪的Anchor-free应用方法总结

Soft Sampling:探索更有效的采样策略

如何解决工业缺陷检测小样本问题

机器学习、深度学习面试知识点汇总

深度学习图像识别的未来:机遇与挑战并存

招聘 | 22-65k!迁移科技:招聘深度学习、传统视觉、3D视觉算法工程师、项目经理、机械设计

关于快速学习一项新技术或新领域的一些个人思维习惯与思想总结

计算机视觉中的图像标注工具总结

计算机视觉中的神经网络可视化工具与项目

计算机视觉中的高效阅读论文的方法总结

计算机视觉中的transformer模型创新思路总结

一文概括机器视觉常用算法以及常用开发库

HOG和SIFT图像特征提取简述    |  特征金字塔技术总结

目标检测中回归损失函数总结    |    实例分割综述总结综合整理版

2021年小目标检测最新研究综述    |    小目标检测常用方法总结

从零搭建Pytorch模型教程(一)数据读取相关推荐

  1. 网站搭建:从零搭建个人网站教程(4)

     系列文章 网站搭建:从零搭建个人网站教程(1) 网站搭建:从零搭建个人网站教程(2) 网站搭建:从零搭建个人网站教程(3) 网站搭建:从零搭建个人网站教程(4) 网站搭建:从零搭建个人网站教程(5) ...

  2. 网站搭建:从零搭建个人网站教程(2)

     系列文章 网站搭建:从零搭建个人网站教程(1) 网站搭建:从零搭建个人网站教程(2) 网站搭建:从零搭建个人网站教程(3) 网站搭建:从零搭建个人网站教程(4) 网站搭建:从零搭建个人网站教程(5) ...

  3. sklearn与pytorch模型的保存与读取

    当我们花了很长时间训练了一个模型,需要用该模型做其他事情(比如迁移学习),或者我们想把自己的机器学习模型分享出去的时候,我们这时候需要将我们的ML模型持久化到硬盘中去. 1.sklearn中模型的保存 ...

  4. Pytorch Note52 灵活的数据读取介绍

    Pytorch Note52 灵活的数据读取介绍 文章目录 Pytorch Note52 灵活的数据读取介绍 灵活的数据读取 读入数据 传入数据预处理方式 Dataset DataLoader 例子 ...

  5. PyTorch实现自由的数据读取

    北京 上海巡回站 | NVIDIA DLI深度学习培训 2018年1月26/1月12日 NVIDIA 深度学习学院 带你快速进入火热的DL领域 阅读全文                        ...

  6. Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增

    文章目录 数据读取 图像读取 1.pillow 2.opencv 数据读取 数据扩增 数据读取 导入需要的包以及文件路径 import json, glob import numpy as np fr ...

  7. Transforemr模型从零搭建Pytorch逐行实现

    文章目录 摘要 一. 细致理解Transforemr模型Encoder原理讲解与其Pytorch逐行实现 1.1 关于word embedding 1.2 生成源句子与目标句子 1.3 构建posti ...

  8. Datawhale零基础入门NLP赛事 - Task2 数据读取与数据分析

    用pandas处理一下数据,训练集的shape为(200000, 2),建议刚开始可以读取几百条看看效果,全部读取的话内存大概要12G左右才能进行正常处理. import pandas as pd t ...

  9. 数据读取与数据扩增方法

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:樊亮.黄星源.Datawhale优秀学习者 本文对图像数据读取及图 ...

  10. 【深度学习】数据读取与数据扩增方法

    转载自:Datawhale,作者:樊亮.黄星源.Datawhale优秀学习者 本文对图像数据读取及图像数据扩增方法进行了总结,并以阿里天池零基础入门CV赛事为实践,利用Pytorch对数据进行了读取和 ...

最新文章

  1. python爬虫入门urllib库的使用
  2. 如何做好SOC的一点点体会
  3. python那么慢为什么还有人用-为什么Python如此慢
  4. js 得到select所有option里的值
  5. 动态规划 HDU1231-------最大连续子序列
  6. 口译务实——unit10 II
  7. PHP执行linux系统命令
  8. EL表达式中fn函数
  9. java在线订单系统源码_春哥酒店在线预订微信小程序源码系统正式发布!
  10. 利用MentoHUST在路由器上使用锐捷认证来共享校园网
  11. 什么是RS-232-C接口与什么是RS-485接口?
  12. u盘为什么要安全弹出?丢失的数据怎么恢复?
  13. 计算机软件实习每日学习打卡(6)20201227
  14. GEE:GEE平台怎么提高用户内存限制
  15. pscs6怎么做html模板,怎么在Adobe Photoshop CS6里制作表格模板(PS)怎么画表格
  16. 工作笔记:如何用Django连接Kerberized甲骨文(Oracle)数据库
  17. 【安全预警】WINRAR,7ZIP,WINZIP等存在严重漏洞
  18. Serpent.AI - 游戏代理框架(Python)
  19. 丢手绢 【约瑟夫环】
  20. 技术分享 | Slow Query Log 使用详解

热门文章

  1. 青岛自然人税收管理系统服务器地址,青岛市自然人税收管理系统扣缴客户端
  2. 如何使用AxureShare+Axure RP 8.0创建团队项目,实现团队协同
  3. Spring、SpringMVC、Shiro面试题
  4. 《编译原理及实践教程》第一章学习笔记
  5. 编译原理与编译构造 LR文法
  6. 如何获取iOS应用网络权限?
  7. python数据库模糊查询_python中的mysql数据库like模糊查询
  8. QGIS 3初级到高级
  9. 使用 Python 开发 QGIS 插件
  10. 小甲鱼 C语言 19课 字符串的处理函数