参考 torch Dataloader中的num_workers - 云+社区 - 腾讯云

考虑这么一个场景,有海量txt文件,一个个batch读进来,测试一下torch DataLoader的效率如何。

基本信息:

  • 本机配置:8核32G内存,工作站内置一块2T的机械硬盘,数据均放在该硬盘上
  • 操作系统:ubuntu 16.04 LTS
  • pytorch:1.0
  • python:3.6

1、首先生成很多随机文本txt

def gen_test_txt():population = list(string.ascii_letters) + ['\n']for i in range(1000):with open(f'./test_txt/{i}.txt', 'w') as f:f.write(''.join(random.choices(population, k=1000000)))

2、然后顺序读取作为benchmark

def test_torch_reader():class Dst(Dataset):def __init__(self, paths):self.paths = pathsdef __len__(self):return len(self.paths)def __getitem__(self, i):open(self.paths[i], 'r').read()return 1dst = Dst([f'./test_txt/{i}.txt' for i in range(1000)])loader = DataLoader(dst, 128, num_workers=0)ts = time()time_cost = []for i, ele in enumerate(loader, 1):dur = time() - tstime_cost.append(dur)print(i, dur)ts = time()print(f"{sum(time_cost):.3f}, "f"{np.mean(time_cost):.3f}, "f"{np.std(time_cost):.3f}, "f"{max(time_cost):.3f}, "f"{min(time_cost):.3f}")plt.plot(time_cost)plt.grid()plt.show()

每个batch耗时的基本统计信息如下,

基本维持在0.9 sec / batch

total, mean, std, max, min

7.148, 0.893, 0.074, 1.009, 0.726

可见,一共是1000个文件,batch size 128,也就是8个batch,总共耗时7.1s,接下来清除cache,

3、设置num_workers为4

每隔4个batch,要准备4个batch,且是串行的,因此时间增大4倍,接下来3个batch几乎不占用时间

total, mean, std, max, min

7.667, 0.958, 1.652, 3.983, 0.000

接下来实验在SSD上进行,同样num_workers先0后4,如下

total, mean, std, max, min

3.251, 0.406, 0.026, 0.423, 0.338

SSD上,对比机械硬盘更加稳定

然后是num_workers = 4,

total, mean, std, max, min

1.934, 0.242, 0.421, 1.088, 0.000

观察到同样的现象,但尖峰应该是0.4*4=1.6,这里反而epoch 4 (0-index)降为一半为0.8

基本结论:可以看到,不管是在SSD,还是机械硬盘上,总的耗时基本不变(SSD小一些,但原因也可能是实验不充分),并没有因为numworkers增大而减小,令我很费解!我一贯的理解是:比如num_workers为4,那么每个worker计算一个batch,因为本机多核且大于4,讲道理4个worker并行处理,因此时间为num_workers=0的1/4才合理,那原因是为何呢?(这个实验本来是为了load audio数据,其实在audio上作类似实验也是一致的现象)

补充了一个实验,尝试用ray读取,代码如下,

def test_ray():ray.init()@ray.remotedef read(paths):for path in paths:open(path, 'r').read()return 1def ray_read(paths, n_cpu=4):chunk_size = len(paths) // n_cpuobject_ids = []for i in range(n_cpu):x = read.remote(paths[i * chunk_size: (i + 1) * chunk_size])object_ids.append(x)return ray.get(object_ids)def batch(l, bs):out = []i = 0while i < len(l):out.append(l[i: i + bs])i += bsreturn outpaths = [os.path.expanduser(f'~/test_txt/{i}.txt') for i in range(1000)]paths = batch(paths, 128)time_cost = []ts = time()for i, ele in enumerate(paths, 1):# read(paths[i - 1])ray_read(paths[i - 1], 8)dur = time() - tstime_cost.append(dur)print(i, dur)ts = time()print(f"{sum(time_cost):.3f}, "f"{np.mean(time_cost):.3f}, "f"{np.std(time_cost):.3f}, "f"{max(time_cost):.3f}, "f"{min(time_cost):.3f}")plt.plot(time_cost)plt.grid()plt.show()

流程是这样的:将输入paths分成n_cpu个chunk,chunk之间通过ray异步执行,结果是:同样是在SSD上,理论上每个batch耗时是之前的1/4,也就是0.1s左右,但实测是0.2s,也就是说,n_cpu最大有效值就是2。

torch Dataloader中的num_workers相关推荐

  1. pytorch中Dataloader()中的num_workers设置问题

    pytorch中Dataloader()中的num_workers设置问题: 如果num_workers的值大于0,要在运行的部分放进__main__()函数里,才不会有错: import numpy ...

  2. Pytorch dataloader中的num_workers (选择最合适的num_workers值)

    num_workers是Dataloader的概念,默认值是0. 是告诉DataLoader实例要使用多少个子进程进行数据加载(和CPU有关,和GPU无关) 如果num_worker设为0,意味着每一 ...

  3. dataloader 源码_pytorch :: Dataloader中的迭代器和生成器应用

    在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader. 为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭 ...

  4. torch dataloader 数据并行_PyTorch Parallel Training(单机多卡并行、混合精度、同步BN训练指南文档)

    0 写在前面 这篇文章是我做实验室组会汇报的时候顺带整理的文档,在1-3部分参考了很多知乎文章,感谢这些大佬们的工作,所以先贴出Reference,本篇文章结合了这些内容,加上了我的一些理解,不足之处 ...

  5. Torch dataloader 参数详解

    参考文章GPU 利用率低常见原因分析及优化 - 知乎

  6. Pytorch中Dataloader踩坑:RuntimeError: DataLoader worker (pid(s) 6700, 10620) exited unexpectedly

    Pytorch中Dataloader踩坑 环境: 问题背景: 观察报错信息进行分析 根据分析进行修改尝试 总结 环境: 系统:windows10 Pytorch版本:1.5.1+cu101 问题背景: ...

  7. 深度学习和目标检测系列教程 10-300:通过torch训练第一个Faster-RCNN模型

    @Author:Runsen 上次介绍了Faster-RCNN模型,那么今天就开始训练第一个Faster-RCNN模型. 本文将展示如何在水果图像数据集上使用Faster-RCNN模型. 代码的灵感来 ...

  8. python吃显卡还是内存不足_Pythorch中的GPU内存问题,GPUMemoryProblemsinPyTorch,显卡,爆炸,与,利用率,不足...

    如今研究人工智能,跑深度学习算法,显卡/GPU绝对是第一大门槛,所以不管您是1080Ti还是V100,如果不能发挥出GPU的最大能力,那它可能就是不是显卡而是块普通的砖头了吧. 显卡爆炸 显卡爆炸和内 ...

  9. 用上Pytorch Lightning的这六招,深度学习pipeline提速10倍!

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 金磊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 面对数以 ...

最新文章

  1. Python 实现九九乘法表
  2. PHP PSR-1 基本代码规范(中文版)
  3. 新装myeclispse8.6GA、@Override出错
  4. pycharm 更改字体和界面样式
  5. Apache访问日志切割
  6. python查看函数参数,在python函数中获取参数名称列表
  7. mysql5.7.14多实例安装
  8. 使用迭代器时如何避免ConcurrentModificationException
  9. 工作流实战_15_flowable 我发起的流程实例查询
  10. 手机站点击商务通无轨迹解决方法
  11. 安全性、监控、调优 的一些思考
  12. cefsharp异步抓取html5,winform插件cefsharp65最新版完美demo,完美flash、html5、和调用摄像头支持,部署就能用...
  13. 修ecshop品牌筛选以LOGO图片形式显示
  14. C#使用oledb操作excel文件的方法
  15. android+号码归属地数据库,Android手机号码归属地的查询
  16. 三阶魔方CFOP cross总结
  17. 如何基于TwinCAT3实现伺服电机控制(一)
  18. shell脚本中source和expert的简单理解
  19. 地理信息系统(GIS)的发展历程
  20. <视觉SLAM十四讲> 李群与李代数

热门文章

  1. c4d如何把文字贴在物体表面_如何使用C4D制作动态滚动文字条
  2. VPS服务-Docker搭建个人博客网站
  3. 智能时代的内容安全,易盾是如何落地的?
  4. ROS 罗技手柄控制机器人(仿真和实体机器人)
  5. SAP MM 采购申请列表选择条件说明
  6. 同学,主业和副业如何选?
  7. html怎么设计关键字,干货分享——关键词如何做标记
  8. mysql常用增删改查命令总结
  9. 360 html页面乱码,360浏览器出现乱码的解决方法
  10. golang的hijack篡取劫持