Pytorch数据加载的效率一直让人头痛,此前我介绍过两个方法,实际使用后数据加载的速度还是不够快,我陆续做了一些尝试,这里做个简单的总结和分析。

1、定位问题

在优化数据加载前,应该先确定是否需要优化数据加载。数据读取并不需要更快,够快就好。一般的,显存占用率很高,利用率却很低的时候,通常会怀疑是数据加载太慢导致,但不是唯一原因,比如模型内大量的循环也会导致GPU利用率低。可以尝试固定数据看看是否可以提高GPU利用率。

确定数据加载需优化后,需要判断是数据加载的哪一部分慢。整个数据处理的流程如下:

  • 读取图片数据(IO,可能存在IO瓶颈)
  • 解码数据(一般是numpy格式,可能存在计算性能瓶颈)
  • 数据增强(可能存在计算性能瓶颈)
  • 类型变换(CPU->GPU,,可能存在数据拷贝瓶颈)

为节省阅读时间,先给结论,数据加载慢主要是由于计算性能的瓶颈,而不是IO瓶颈和数据拷贝瓶颈(测试数据为1920x1080的大图,小图片可能结论不同)。为优化加载速度应该从两个方向下手:

  • 更快的图片解码
  • 更快的数据增强
  • 更强性能的设备,如使用GPU进行数据解码和增强(DALI库)

下面是具体的实验分析,测试环境和数据如下:

  • CPU: Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz (正常负载,利用率小于30%)
  • GPU: V100 (正常负载,利用率小于30%)
  • 数据集数据840张,尺寸为1920x1080
  • batch size = 64
  • 无特别说明下num_workers=2
  • 迭代10个epoch(13x10个迭代次数)

2. Baseline

不进行任何额外优化下的速度如下:

其中:

  • 无任何额外操作的输出图片为原始大小(1920x1080)
  • 归一化的具体操作为: x = x.permute(0, 3, 1, 2).float().div(255)
  • 转GPU的具体操作为: x = x.cuda()
  • resize为opencv的resize,插值方式为 cv2.INTER_LINEAR
  • 数据增强的操作使用Numpy和Opencv,包括:
  • random resize
  • random crop
  • random filp
  • random HSV

可以明显的看出耗时主要发生在数据读取和数据增强部分,而CPU到GPU的数据转换等耗时较少。

需要注意的一个地方是【crop(8960x540)、转GPU、归一化】和【转GPU、归一化】的耗时差不多,crop的耗时很小,且crop后图片较小,使得转GPU的操作也变快了,最终二者的耗时差不多。

分析将分为以下几个部分: DataLoader 图片读取 * 数据增强

此外由于【CPU转GPU、数据的归一化转秩】和【DataLoader】比较相关,会一起分析。

3. DataLoader

(1) num_workers

显然的是num_workers并不是越多越好,瓶颈在CPU,太多的数据worker不光不能提升,还会占用过多的CPU资源,影响其他程序的速度。

(2) pin_memory

定义DataLoader时单纯的pin_memory无用。

(3) non_blocking

谈到pin_memory就不得不说一下non_blocking=true操作,先来看一下过程吧:

1. x = x.cuda(non_blocking=True)
2. 进行一些和x无关的操作
3. 执行和x有关的操作

non_blocking=true下,1不会阻塞212并行。但我们已经知道x.cuda()实际上耗时并不多(181s -> 194s,并不算是耗时的主要原因,就算去掉也不能加速太多。

一个比较常见的用法如下:

class DataPrefetcher():def __init__(self, dataset, batch_size=64, shuffle=True, num_workers=2):self.data_loader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=shuffle,pin_memory=True,num_workers=num_workers,drop_last=True)self.loader = iter(self.data_loader)self.stream = torch.cuda.Stream()self.preload()def preload(self):try:self.next_x, self.next_y = next(self.loader)except StopIteration:self.loader = iter(self.data_loader)self.next_x, self.next_y = next(self.loader)with torch.cuda.stream(self.stream):self.next_x = self.next_x.cuda(non_blocking=True)self.next_y = self.next_y.cuda(non_blocking=True)self.next_x = self.next_x.permute(0, 3, 1, 2).float().div(255)self.next_y = self.next_y.float()def next(self):torch.cuda.current_stream().wait_stream(self.stream)data = (self.next_x, self.next_y)self.preload()return data

DataPrefetcher对DataLoader又包了一层,需要注意pin_memory=Truenon_blocking=true才才生效,next()函数直接返回data而无需等待cuda()。实验结果如下,和预期的差不多,并无明显的改善。

说到cuda(),有个小细节需要注意:

1. x = x.cuda().permute(0, 3, 1, 2).float().div(255)
2. x = x.permute(0, 3, 1, 2).float().div(255).cuda()

这两种方式的耗时是不一样的

4. 读取图片 读取图片实际上包括了两部分,一个是图片数据读进内存,一个是图片的解码。耗时主要在图片的解码上,常见的优化方法主要是转换图片的格式(如lmdb)、使用解码更快的库。

(1) lmdb 先分析一下lmdb,一个jpg转lmdb的例子:

import numpy as np
import cv2
import os
import lmdblist_path = "img_list/test.txt"
img_root = './data/test/edu/test/JPEGImages'
img_type = ".jpg"
label_root = './data/test/edu/test/labels'
label_type = ".txt"env_db = lmdb.open("./test", map_size=1024*1024*1024*8)
txn = env_db.begin(write=True)for line in open(list_path):img_name = line.strip() + img_typeimg_path = os.path.join(img_root, img_name)img = cv2.imread(img_path)label_name = line.strip() + label_typelabel_path = os.path.join(label_root, label_name)label = np.loadtxt(label_path).reshape(-1, 5)print(img_name)# txn.put(img_name.encode(), img) ### 不编码txn.put(img_name.encode(), cv2.imencode('.jpg', img)[1])txn.put(label_name.encode(), label)txn.commit()
env_db.close()

使用cv2.imencode('.jpg', img)对图片数据进行了编码,不然会非常大,840张图片jpg大小为350MB,lmdb不编码为4.86GB。

这里插一句,lmdb文件大小为实际使用大小且小于map_size,但window下生成lmdb大小不正常为map size大小,读取一下后会变成实际大小。

读取的操作如下:

img = np.frombuffer(self.txn.get(img_path.encode()), dtype=np.uint8)
img = cv2.imdecode(img, cv2.IMREAD_COLOR)label = np.frombuffer(self.txn.get(label_path.encode())).reshape(-1, 5) #lmdb存储后numpy会丢失形状

实验结果如下:

使用lmdb并没加速效果,实际上IO读取无任何优势,只是无解码时省掉了解码时间,但空间占用太多。

(2) hdf5

和lmdb类似的是hdf5,不压缩同样很大(4.86GB),但使用起来感觉较方便但存在一些bug不建议使用:

f = h5py.File('test.h5', 'w')
img_dataset = f.create_dataset('img', shape=(840, 1080, 1920, 3), maxshape=(None, 1080, 1920, 3), chunks=(1, 1080, 1920, 3), dtype=np.uint8)
for i, line in enumerate(open(list_path)):img_name = line.strip() + img_typeimg_path = os.path.join(img_root, img_name)img = cv2.imread(img_path)img_dataset[i, :] = imgprint(img_name, img.shape)
f.close()### read,h5的多线程读写存在问题,使用压缩无法正常使用
f = h5py.File('test.h5', "r", swmr=True)
# f = h5py.File('test.h5', "r")
for i in range(840):img = f["img"][i, ...]

没有编解码,速度还是有优势的,同时也印证了瓶颈在解码上:

(3) libjpeg-turbo

解码更快的库,主要是使用libjpeg-turbo,分享两个python的封装,jpeg4py和PyTurboJPEG。比较推荐的是PyTurboJPEG,可以指定so位置和缩放图片。

加速非常明显,推荐使用,但这个库比较吃CPU资源,worker开多之后会很卡。

(4) mxnet

mxnet的读取也测试了一下,测试时机器CPU占用发生变化,【数据增强、转GPU、归一化、libjpeg-turbo 】重新测试了一下,速度上无优势。

(5) opencv缩小比例读取

此外,opencv在读取图片时也可以指定缩小比例(1/2,1/4,1/8),提速明显,对部分需要固定缩放的比较友好:

img = cv2.imread(img_path, cv2.IMREAD_REDUCED_COLOR_2)
img = cv2.imread(img_path, cv2.IMREAD_REDUCED_COLOR_4)
img = cv2.imread(img_path, cv2.IMREAD_REDUCED_COLOR_8)

5. 数据增强

(1) 建议 这里先给大家一些建议:

  • opencv一般要比 PIL 快
  • 由于数据增强操作基本是opencv的内置函数,因此cuda的@jit加速是无效的,Cython应该也无加速效果。几个操作的耗时排列应该是这样的,opencv <Cython <jit <pyhton。
  • cv2.UMat()加速几乎无效,测试与不使用无差别,本地测试单图反而比不使用耗时长。
  • 由于图片数据为0~255,部分数值变换可以使用cv2.LUT()查找表直接映射,避免多次计算数值。
  • 尽量提前做好数据处理,减少预处理步骤。

(2) DALI 如果加速还不满意的话,最后还有一个大招,DALLI库,大力出奇迹,极度推荐。 我们已经知道瓶颈在CPU的性能上,把这些计算放到GPU上是很合理的。NVIDIA DALI是一个GPU加速的数据增强和图像加载库,支持单个和批处理图像的解码、缩放、Crop、颜色空间转换等,具体支持的操作。

使用DALI完成所有操作的时间如下:

只要我数据加载的够快,GPU就追不上我。加载的部分和pytorch差不多,出来就是gpu的tensor,具体的代码较多,就不放在本文里了。

pipeline = DataPipeline(**dataset_dict)
pipeline.build()
loader = DALIGenericIterator(pipeline, ["imgs", "labels"], img_len, fill_last_batch=False)for j in range(epoch):for i, data in enumerate(loader):x = data[0]["imgs"]y = data[0]["labels"].cuda()# trainloader.reset()

cbitmap 从内存中加载jpg_Pytorch数据加载的分析相关推荐

  1. VB如何直接显示内存中的二进制图像数据

    有时在进行网络程序设计时,我们希望客户端接收到服务器传来的图像文件的二进制数组的,能够直接显示,而不是通过保存到临时文件后显示,其实通过COM的IPicture接口,在VB里非常容易做到,代码如下: ...

  2. echarts中饼图的异步数据加载绘制

    ECharts 中实现异步数据的更新非常简单,在图表初始化后不管任何时候只要通过 jQuery 等工具异步获取数据后通过 setOption 填入数据和配置项就行,但是从后台异步获取数据后,需要在前台 ...

  3. easyui中的datagrid的数据加载的问题

    我们在第一次使用easyui的datagrid的url加载所需的数据时,如果第二次加载数据我们使用的不是ulr而是数据返回结果进行加载的时候$("#div").datagird(' ...

  4. 内存 增量数据持久_内存中数据模型和大数据持久性

    内存 增量数据持久 ORM框架在需要与关系数据库进行交互时可以帮助开发人员. 对于关系数据库,有许多出色的ORM框架,例如Hibernate和Apache OpenJPA,其中一些确实很棒. 如今,大 ...

  5. 内存中数据模型和大数据持久性

    ORM框架在希望与关系数据库进行交互时可以帮助开发人员. 对于关系数据库,有许多出色的ORM框架,例如Hibernate和Apache OpenJPA,其中一些确实很棒. 如今,大数据正在涌现,越来越 ...

  6. Java 中把声明变量的语句如果写在循环体内,每次执行时栈内存中的变量和数据是如何变化的?

    问题一:如下面的代码示例 1,JVM 是不是会反复回收旧的变量 a 再重新创建新的变量 a 呢?还是旧的变量 a 一直保留在栈内,只是反复赋值 0 而已呢? 代码示例 1: while (true) ...

  7. 聚合中返回source_大数据搜索与可视化分析(9)elasticsearch聚合分析Metric Aggregation...

    在上一篇文章中,我们介绍了<大数据搜索与可视化分析(8)kibana入门教程-2-Discover>,本文学习elasticsearch聚合分析,是对<大数据搜索与可视化分析(3)e ...

  8. cbitmap 从内存中加载jpg_[转载]windows照片查看器无法显示图片内存不足

    问题描述 最近在使用Windows照片查看器打开一个jpg文件的时候异常 Windows照片查看器无法显示此图片,因为计算机上的可用内存可能不足.请关闭一些目前没有使用的程序或者释放部分硬盘空间(如果 ...

  9. Loading 加载直到数据加载完再消失(vue-elementui)

    自定义指令v-loading: 只需要绑定Boolean即可.(默认状况下,Loading 遮罩会插入到绑定元素的子节点,通过添加body修饰符,可以使遮罩插入至 DOM 中的 body 上.) &l ...

最新文章

  1. 学院后勤报修系统php_2020年西航后勤管理服务技能培训
  2. python 二叉树遍历
  3. 干货 | 仅需10分钟,开启你的机器学习之路!
  4. 使用 SAP Business Application Studio 搭建 CAP Java 开发环境
  5. 《C++ Primer 5th》笔记(3 / 19):字符串、向量、迭代器和数组
  6. RTSP播放器开发过程中需要考虑哪些关键因素
  7. 【Java】MapReduce 程序五步走的思想详细描述
  8. 解决webpack4版本在打包时候出现Cannot read property ‘bindings‘ of null 或 Cannot find module ‘@babel/core‘问题
  9. 一点关于MD5计算的封装
  10. 苹果修复被 XCSSET 恶意软件滥用的3个 0day
  11. 利用百度Echarts.js生成雷达图
  12. LINUX安装TensorRT及特别注意事项
  13. Unity3D数字孪生笔记——Unity脚本篇
  14. 小米浏览器禁止java,如何禁止小米手机浏览器中弹出窗口广告
  15. 4.29 笔记+day7作业
  16. Vue el-menu-item路由跳转
  17. Q4营收同比增长34.7%,Saleforces股价为何总停滞不前?
  18. 空间超分辨率(SISR)领域非常不错的blog/论文(长期更新)
  19. ACW829模拟队列
  20. 又拍云upyun 文件上传(Java)

热门文章

  1. SpringBoot 开启关闭自动任务配置(EnableScheduling )
  2. 算法之【折半插入法】
  3. Java.Lang.NoSuchMethod 错误
  4. 自定义导航栏的背景、标题、返回按钮文字颜色[转]
  5. Inno Setup 插件 CallbackCtrl V1.1 (回调函数插件)
  6. 聪明的ITPRO之二IT人做事要“圆”
  7. 10g的客户端从9i的服务器中导出数据时遇到上面的问题
  8. Oracle中的AWR,全称为Automatic Workload Repository
  9. HBase集群环境部署
  10. hadoop中的helloword