目的

问题: 当我们使用pytorch训练小模型或者使用较大batch size的时候会发现GPU利用率很低,训练周期比较长。其原因之一是在dataloader加载数据之后在cpu上做一些数据增强的操作(eg.resize、crop等),比较耗时,导致很多时候都是GPU在等CPU的数据,造成了严重的浪费。
解决: 使用nvidia-dali将一些cpu上的数据预处理操作放到gpu上去处理,可以极大的提高训练的效率.
缺点: 好像只提供了固定的几种格式的数据,ImageNet数据格式(分类)、COCO数据集格式(检测)

实现

  1. 使用DALI封装的数据加载代码(暂时看不懂,可以先看官方文档的Install、Getting started)
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, DALIGenericIteratorclass HybridTrainPipe(Pipeline):def __init__(self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=False, local_rank=0, world_size=1):super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)dali_device = "gpu"self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True)self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)self.res = ops.RandomResizedCrop(device="gpu", size=crop, random_area=[0.08, 1.25])self.cmnp = ops.CropMirrorNormalize(device="gpu",output_dtype=types.FLOAT,output_layout=types.NCHW,image_type=types.RGB,mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],std=[0.229 * 255, 0.224 * 255, 0.225 * 255])self.coin = ops.CoinFlip(probability=0.5)print('DALI "{0}" variant'.format(dali_device))def define_graph(self):rng = self.coin()self.jpegs, self.labels = self.input(name="Reader")images = self.decode(self.jpegs)images = self.res(images)output = self.cmnp(images, mirror=rng)return [output, self.labels]class HybridValPipe(Pipeline):def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, local_rank=0, world_size=1):super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size,random_shuffle=False)self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_TRIANGULAR)self.cmnp = ops.CropMirrorNormalize(device="gpu",output_dtype=types.FLOAT,output_layout=types.NCHW,crop=(crop, crop),image_type=types.RGB,mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],std=[0.229 * 255, 0.224 * 255, 0.225 * 255])def define_graph(self):self.jpegs, self.labels = self.input(name="Reader")images = self.decode(self.jpegs)images = self.res(images)output = self.cmnp(images)return [output, self.labels]def get_imagenet_iter_dali(type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, val_size=256,world_size=1,local_rank=0):if type == 'train':pip_train = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank,data_dir=image_dir + '/train',crop=crop, world_size=world_size, local_rank=local_rank)pip_train.build()dali_iter_train = DALIClassificationIterator(pip_train, size=pip_train.epoch_size("Reader") // world_size)return dali_iter_trainelif type == 'val':pip_val = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank,data_dir=image_dir + '/val',crop=crop, size=val_size, world_size=world_size, local_rank=local_rank)pip_val.build()dali_iter_val = DALIClassificationIterator(pip_val, size=pip_val.epoch_size("Reader") // world_size)return dali_iter_val
if __name__ == '__main__':train_loader = get_imagenet_iter_dali(type='train', image_dir='/userhome/memory_data/imagenet', batch_size=256,num_threads=4, crop=224, device_id=0, num_gpus=1)print('start iterate')start = time.time()for i, data in enumerate(train_loader):images = data[0]["data"].cuda(non_blocking=True)labels = data[0]["label"].squeeze().long().cuda(non_blocking=True)end = time.time()print('end iterate')print('dali iterate time: %fs' % (end - start))
  1. 使用DALI自定义数据加载类代码(参考官方文档的Tutorials/Data Loading/ ExternalSource operator)
from __future__ import division
import torch
import types
import joblib
import collections
import numpy as np
import pandas as pd
from random import shuffle
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import nvidia.dali.plugin.pytorch as dalitorch
from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIteratordef grid2x2(img):#自定义函数h, w, c = img.shapeleft_top = img[0:h//2, 0:w//2, :]left_bottom = img[h//2:h, 0:w//2, :]right_top = img[0:h//2, w//2:w, :]right_bottom = img[h//2:h, w//2:w, :]return left_top, right_top, left_bottom, left_bottomclass ExternalInputIterator(object):#自定义迭代器def __init__(self, images_dir, txt_path, batch_size, device_id, num_gpus):self.images_dir = images_dirself.batch_size = batch_sizewith open(txt_path, 'r') as f:self.files = [line.rstrip() for line in f if line is not '']# whole data set sizeself.data_set_len = len(self.files)# based on the device_id and total number of GPUs - world size# get proper shardself.files = self.files[self.data_set_len * device_id // num_gpus:self.data_set_len * (device_id + 1) // num_gpus]self.n = len(self.files)def __iter__(self):self.i = 0shuffle(self.files)return selfdef __next__(self):batch = []labels = []if self.i >= self.n:raise StopIterationfor _ in range(self.batch_size):jpeg_filename, label = self.files[self.i].split(',')f = open(self.images_dir + jpeg_filename, 'rb')batch.append(np.frombuffer(f.read(), dtype = np.uint8))labels.append(np.array([int(label)], dtype = np.uint8))self.i = (self.i + 1) % self.nreturn (batch, labels)@propertydef size(self,):return self.data_set_lennext = __next__class ExternalSourcePipeline(Pipeline):#自定义数据增强操作(通过iter_setup函数将迭代器的数据送入define_graph函数处理)def __init__(self, resize, batch_size, num_threads, device_id, external_data):super(ExternalSourcePipeline, self).__init__(batch_size,num_threads,device_id,seed=12,exec_async=False,exec_pipelined=False,)self.input = ops.ExternalSource()self.input_label = ops.ExternalSource()self.decode = ops.ImageDecoder(device = "cpu", output_type = types.RGB)#自定义的函数只能在cpu上运行self.grid = ops.PythonFunction(function=grid2x2, num_outputs=4)   self.resize = ops.Resize(device="gpu", resize_x=resize, resize_y=resize,interp_type=types.INTERP_LINEAR)self.external_data = external_dataself.iterator = iter(self.external_data)def define_graph(self):self.jpegs = self.input()self.labels = self.input_label()images = self.decode(self.jpegs)images1, images2, images3, images4 = self.grid(images)images = self.resize(images.gpu())images1 = self.resize(images1.gpu())images2 = self.resize(images2.gpu())images3 = self.resize(images3.gpu())images4 = self.resize(images4.gpu())return (images, images1, images2, images3, images4, self.labels)def iter_setup(self):try:images, labels = self.iterator.next()self.feed_input(self.jpegs, images)self.feed_input(self.labels, labels)except StopIteration:self.iterator = iter(self.external_data)raise StopIterationdef create_dataloder(img_dir, txt_path, resize,batch_size,device_id=0,num_gpus=1,num_threads=6):eii = ExternalInputIterator(img_dir,txt_path, batch_size=batch_size, device_id=device_id,num_gpus=num_gpus)pipe = ExternalSourcePipeline(resize=resize,batch_size=batch_size, num_threads=num_threads, device_id = 0,external_data = eii)pii = PyTorchIterator(pipe, output_map=["data0", "data1", "data2", "data3", "data4", "label"], size=eii.size, last_batch_padded=True, fill_last_batch=False)return piiif __name__ == '__main__':batch_size = 32num_gpus = 1num_threads = 8epochs = 1pii = create_dataloder('img_path',resize=224,batch_size=batch_size,txt_path='file_path',)for e in range(epochs):for i, data in enumerate(pii):imgs = data[0]["data4"]labels = data[0]["label"]print("epoch: {}, iter {}".format(e, i), imgs.shape, labels.shape)pii.reset()

注,分类任务的数据准备最好按照ImageNet的数据格式,检测任务的数据准备参考检测数据格式


参考1:DALI官方文档(记录一下常用)
目录:

  • Installation:安装DALI命令

  • Getting started:入门的简单案例(分类任务)

  • Tutorials
     General:
      1. Data Loading:
       1.1 ExternalSource operator:自定义数据加载操作(ExternalInputIterator、ExternalSourcePipeline);
       1.2 COCO Reader:COCO数据格式读取(检测任务)
      2. DALI expressions and arithmetic operations:tensor上自定义+ - * /操作
      3. Multiple GPU support:GPU上进行数据增强操作(shard_id:显卡id, num_shards:将数据分成几份)
      4. Normalize operator :正则化
     Image Processing:一些图片处理上的常用操作(Decoder的CPU/Hybrid)
     Use Cases:一些demo(包括用于分类任务和检测任务)

  • Framework integration:DALI在常用框架(Pyotch、tf)的使用

  • Supported operations:DALI中封装的所有函数的使用

参考2:nvidia-dali GPU加速预处理
参考3:pytorch 一种加速dataloder的方法
参考4:NIVIDIA/DALI的github

【Pytorch】nvidia-dali——一种加速数据增强的方法相关推荐

  1. NLP中数据增强的方法

    为什么使用数据增强 当在一些任务中需要大量数据,但是实际上数据量不足时,可以考虑使用数据增强的方式增加数据量 数据增强的方法 数据增强主要有两种方法: 法一:简单数据增强(Easy Data Augm ...

  2. Python PIL库处理图片常用操作,图像识别数据增强的方法

    在博客AlexNet原理及tensorflow实现训练神经网络的时候,做了数据增强,对图片的处理采用的是PIL(Python Image Library), PIL是Python常用的图像处理库. 下 ...

  3. 10种网站数据的采集方法

    10种AI训练数据采集工具排行榜 10种网站数据的采集方法 1.目前常用的10种网站数据 2.如何写Python爬虫: 3.人生第一个 爬虫代码示例: 另外: 10种网站数据的采集方法 如何收集网站数 ...

  4. 10种招聘数据的采集方法

    10种AI训练数据采集工具排行榜 10种招聘数据的采集方法 1.目前常用的10种数据网站 2.如何写Python爬虫: 3.人生第一个 爬虫代码示例: 另外: 10种招聘数据的采集方法 如何收集招聘数 ...

  5. 6种上市公司数据的采集方法和工具

    10种AI训练数据采集工具排行榜 6种上市公司数据的采集方法和工具 1.目前常用的6种数据网站 2.如何写Python爬虫: 3.人生第一个 爬虫代码示例: 另外: 6种上市公司数据的采集方法和工具 ...

  6. 目标检测:python实现多种图像数据增强的方法(光照,对比度,遮挡,模糊)

    图像数据增强的内容(可根据需要自定义选择): 1.直方图均衡化 2.clahe自适应对比度直方图均衡化 3.白平衡 4.亮度增强 5.亮度,饱和度,对比度增强 6.去除图像上的高光部分 7.自适应亮度 ...

  7. 一种加速Github下载速度的方法

    一种加速Github下载速度的方法 Github好多人都用过,不知道您是否遇到过由于某种原因而造成的如龟

  8. 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度...

    「@Author:Runsen」 上次基于CIFAR-10 数据集,使用PyTorch构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. imp ...

  9. 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度

    @Author:Runsen 上次基于CIFAR-10 数据集,使用PyTorch ​​构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. im ...

最新文章

  1. 最新批量***dedecms|dedecms最新0day
  2. 很有必要看,这篇 解决 IndexError: list index out of range
  3. 20172332 2017-2018-2 《程序设计与数据结构》实验三报告
  4. java代码耗尽内存_有关Java内存溢出及内存消耗的小知识
  5. sqlmap源码阅读系列init中的_cleanupOptions
  6. 浙大 PAT b1017
  7. 『计算机视觉』Mask-RCNN_推断网络终篇:使用detect方法进行推断
  8. 第二届广东大学生网络安全攻防大赛 个人向Write Up
  9. 写贺卡给毕业师姐怎么写计算机系的,给师兄师姐的毕业贺卡寄语
  10. 计算机重启是什么原因,电脑自动重启是什么原因以及如何解决【图文教程】
  11. Matlab+cpp矩量法代码演示
  12. 如何在A4相纸上打印4张5寸相片
  13. RISC-V MCU低功耗场景的应用分析
  14. java url解码解不了_java – 为什么URL没有完全解码?
  15. Python 用Ursina 3D引擎做一个太阳系行星模拟器
  16. cookie、session与token之间的关系
  17. 神秘大佬写的的运营思维课
  18. 精益生产目视管理法 (zt)
  19. 月销涨3倍:阿里巴巴零售通联合饿了么宣布将赋能近万天猫小店
  20. 中兴c300业务板_中兴C300 EPON 基本业务配置

热门文章

  1. 【学员管理系统】0x01 班级信息管理功能
  2. Machine Learning--决策树(一)
  3. CSS background-position用法
  4. Java程序实现密钥库的维护
  5. 用批处理启动常用服务
  6. 【jQuery笔记Part1】06-jQuery对象与js对象转换
  7. 算法图解学习笔记02之选择排序
  8. 软件设计师18-系统开发和运行01
  9. windows、Linux下nginx搭建集群
  10. (第三章)查看数据库