PyTorch 分布式操作之 Barrier

原始文档:https://www.yuque.com/lart/ugkv9f/gy7sva

关于 barrier 的概念

关于 barrier 这个概念可以参考 Wiki 中的介绍:同步屏障(Barrier)是并行计算中的一种同步方法。对于一群进程或线程,程序中的一个同步屏障意味着任何线程/进程执行到此后必须等待,直到所有线程/进程都到达此点才可继续执行下文。

这里要注意,barrier 这一方法并不是 pytorch 独有的,这是并行计算中的一个基本概念,其他的并行计算的场景下也可能会涉及这一概念和操作。本文主要讨论 pytorch 中的情况。

torch.distributed.barrier(group=None, async_op=False, device_ids=None)Synchronizes all processes.This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait().Parameters
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
async_op (bool, optional) – Whether this op should be an async op
device_ids ([int], optional) – List of device/GPU ids. Valid only for NCCL backend.Returns
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

在多卡训练的时候,由于不同的 GPU 往往被设定在不同的进程中,有时候为了在单独的进程中执行一些任务,但是又同时希望限制其他进程的执行进度,就有了使用barrier的需求。
一个实际的场景是准备数据集:我们只需要在 0 号进程处理,其他进程没必要也执行这一任务,但是其他进程的后续工作却依赖准备好的数据。于是就需要在 0 号进程执行过程中阻塞其他的进程,使其进入等待状态。等到处理好之后,再一起放行。

这种需求下,一个典型的基于上下文管理器形式的构造如下:

# https://github.com/ultralytics/yolov5/blob/7d56d451241e94cd9dbe4fcb9bfba0e92c6e0e23/utils/torch_utils.py#L29-L38@contextmanager
def torch_distributed_zero_first(local_rank: int):"""Decorator to make all processes in distributed trainingwait for each local_master to do something."""if local_rank not in [-1, 0]:dist.barrier(device_ids=[local_rank])yieldif local_rank == 0:dist.barrier(device_ids=[0])

关于 barrier 的细节

# -*- coding: utf-8 -*-import os
import timeimport torch.distributed as dist
import torch.multiprocessing as mpdef ddp_test_v0(local_rank, word_size):# Initializes the distributed backend which will take care of sychronizing nodes/GPUsdist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)print("first before barrier{}\n".format(local_rank))if local_rank != 0:dist.barrier()print("first after barrier{}\n".format(local_rank))print("inter {}".format(local_rank))print("second before barrier{}\n".format(local_rank))if local_rank == 0:dist.barrier()print("second after barrier{}\n".format(local_rank))print("{} exit".format(local_rank))def ddp_test_v1(local_rank, word_size):# Initializes the distributed backend which will take care of synchronizing nodes/GPUsdist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)if local_rank != 0:print("1 before barrier{}\n".format(local_rank))start = time.time()time.sleep(5)dist.barrier()print(time.time() - start)print("1 after barrier{}\n".format(local_rank))dist.barrier()print("1 after barrier{}\n".format(local_rank))else:print("0 before barrier{}\n".format(local_rank))start = time.time()dist.barrier()print(time.time() - start)print("0 after barrier{}\n".format(local_rank))print("0 after barrier{}\n".format(local_rank))dist.barrier()print("0 after barrier{}\n".format(local_rank))print("{} exit".format(local_rank))def main():world_size = 2os.environ["MASTER_ADDR"] = "127.0.0.1"os.environ["MASTER_PORT"] = "29500"mp.spawn(ddp_test_v0, args=(world_size,), nprocs=world_size, join=True)if __name__ == "__main__":main()

这里展示了两个例子,实际上在官方展示的 dist.barrier 之外显示了该方法的一个重要特性,就是其操作实际上是每一个进程内部都需要对应的执行同样的次数,才会对应的由阻塞变为正常运行。
先看第一个例子:

def ddp_test(local_rank, word_size):# Initializes the distributed backend which will take care of sychronizing nodes/GPUsdist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)print("first before barrier{}\n".format(local_rank))if local_rank != 0:dist.barrier()print("first after barrier{}\n".format(local_rank))print("inter {}".format(local_rank))print("second before barrier{}\n".format(local_rank))if local_rank == 0:dist.barrier()print("second after barrier{}\n".format(local_rank))print("{} exit".format(local_rank))

其输出是:

first before barrier1
first before barrier0first after barrier0inter 0
second before barrier0second after barrier00 exit
first after barrier1inter 1
second before barrier1second after barrier11 exitProcess finished with exit code 0

可以看到,有几个细节:

  • barrier 之前,所有的操作都是各 GPU 进程自己输出自己的。

    • 由于 local_rank=0 执行到自己可见的 barrier 中间会输出多个,而 local_rank=1 则只有一条 first before barrier1
  • second before barrier0 之后,0 号执行到了属于自己的 barrier ,这回让使得其他进程不再阻塞,开始正常运行。由于中间操作的时间,所以先是 0 号输出自己的 second after barrier0 并随之退出,之后 1 号也接着开始输出自己的结果。

这里有一点值得注意,不同进程的 barrier 实际上是互相对应的,必须所有进程都执行一次barrier,才会重新放行正常前进。
对于第二段代码:

def ddp_test_v1(local_rank, word_size):# Initializes the distributed backend which will take care of sychronizing nodes/GPUsdist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)if local_rank != 0:print("1 before barrier{}\n".format(local_rank))start = time.time()time.sleep(5)dist.barrier()print(time.time() - start)print("1 after barrier{}\n".format(local_rank))dist.barrier()print("1 after barrier{}\n".format(local_rank))else:print("0 before barrier{}\n".format(local_rank))start = time.time()dist.barrier()print(time.time() - start)print("0 after barrier{}\n".format(local_rank))print("0 after barrier{}\n".format(local_rank))dist.barrier()print("0 after barrier{}\n".format(local_rank))print("{} exit".format(local_rank))

则是有输出:

1 before barrier1
0 before barrier05.002117395401001
5.0021262168884281 after barrier10 after barrier00 after barrier00 after barrier00 exit
1 after barrier11 exitProcess finished with exit code 0

可以看到一个重要的点,就是这两处 print(time.time() - start) 的输出是基本一样的,不管前面延时多少, barrier 后面的时间都是按照最长到达并执行 barrier 的间隔时间来的。这个更体现了不同进程 barrier 之间的互相限制关系。而 0 到达自己的第二个 barrier 之后,会使得 1 号再次运行。但是此时 0 是先结束的。
另外,可以验证,如果某个编号对应的代码中的两个 barrier 之中的一个,那么另一个就会陷入无限等待之中。
例如:


def ddp_test_v1(local_rank, word_size):# Initializes the distributed backend which will take care of sychronizing nodes/GPUsdist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)if local_rank != 0:print("1 before barrier{}\n".format(local_rank))start = time.time()time.sleep(5)dist.barrier()print(time.time() - start)print("1 after barrier{}\n".format(local_rank))# dist.barrier()print("1 after barrier{}\n".format(local_rank))else:print("0 before barrier{}\n".format(local_rank))start = time.time()time.sleep(3)dist.barrier()print(time.time() - start)print("0 after barrier{}\n".format(local_rank))print("0 after barrier{}\n".format(local_rank))dist.barrier()print("0 after barrier{}\n".format(local_rank))print("{} exit".format(local_rank))

输出:

0 before barrier0
1 before barrier15.002458572387695
1 after barrier11 after barrier11 exit
5.002473831176758
0 after barrier00 after barrier0Traceback (most recent call last):File "/home/lart/Coding/SODBetterProj/tools/dist_experiment_test.py", line 67, in <module>main()File "/home/lart/Coding/SODBetterProj/tools/dist_experiment_test.py", line 63, in mainmp.spawn(ddp_test_v1, args=(world_size,), nprocs=world_size, join=True)File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 199, in spawnreturn start_processes(fn, args, nprocs, join, daemon, start_method='spawn')File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processeswhile not context.join():File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 75, in joinready = multiprocessing.connection.wait(File "/home/lart/miniconda3/envs/pt17/lib/python3.8/multiprocessing/connection.py", line 931, in waitready = selector.select(timeout)File "/home/lart/miniconda3/envs/pt17/lib/python3.8/selectors.py", line 415, in selectfd_event_list = self._selector.poll(timeout)
KeyboardInterruptProcess finished with exit code 137 (interrupted by signal 9: SIGKILL)

会在第二个 barrier 处无限等待下去。
这一特点在这个回答中也被提到了:

when a process encounters a barrier it will block the position of the barrier is not important (not all processes have to enter the same if-statement, for instance) a process is blocked by a barrier until all processes have encountered a barrier, upon which the barrier is lifted for all processes

https://stackoverflow.com/a/59766443

重要的参考资料

  • [原创][深度][PyTorch] DDP 系列

    • 第一篇:https://zhuanlan.zhihu.com/p/178402798
    • 第二篇:https://zhuanlan.zhihu.com/p/187610959
    • 第三篇:https://zhuanlan.zhihu.com/p/250471767
  • PyTorch 单机多 GPU 训练方法与原理整理
    • https://github.com/jia-zhuang/pytorch-multi-gpu-training
  • Pytorch 分布式训练(图示非常友好)
    • https://zhuanlan.zhihu.com/p/76638962
  • Distribution is all you need(丰富全面)
    • https://github.com/tczhangzhi/pytorch-distributed

另外的话

欢迎关注我的公众号,文章更新提醒更及时哦:

PyTorch之分布式操作中的Barrier相关推荐

  1. 简单介绍pytorch中分布式训练DDP使用 (结合实例,快速入门)

    文章目录 DDP原理 pytorch中DDP使用 相关的概念 使用流程 如何启动 torch.distributed.launch spawn调用方式 针对实例voxceleb_trainer多卡介绍 ...

  2. pytorch GPU分布式训练 单机单卡、单机多卡

    可以用"watch -n 0.1 nvidia-smi"来查看gpu状态,我用的是3块12G的GPU进行实验 本实验将使用一个简单的瞎写的网络进行,网络训练一个分类任务,当然这个不 ...

  3. PyTorch 源码解读之分布式训练了解一下?

    来源丨商汤学术   编辑丨极市平台 本文由浅入深讲解 torch.distributed 这一并行计算包的概念,实现细节和应用方式,并带大家快速入门 PyTorch 分布式训练. 0 前言 由于大规模 ...

  4. [深度学习] 分布式Pytorch介绍(三)

    [深度学习] 分布式模式介绍(一) [深度学习] 分布式Tensorflow介绍(二) [深度学习] 分布式Pytorch介绍(三) [深度学习] 分布式Horovod介绍(四)  一  Pytorc ...

  5. 分布式训练PyTorch 源码解读

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:商汤 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 0 前 ...

  6. Pytorch - 分布式通信原语(附源码)

    前言 由于工作需要,最近在补充分布式训练方面的知识.经过一番理论学习后仍觉得意犹未尽,很多知识点无法准确get到(例如:分布式原语scatter.all reduce等代码层面应该是什么样的,ring ...

  7. pytorch默认初始化_“最全PyTorch分布式教程”来了!

    前言 本文对使用pytorch进行分布式训练(单机多卡)的过程进行了详细的介绍,附加实际代码,希望可以给正在看的你提供帮助.本文分三个部分展开,分别是: 先验知识 使用过程框架 代码解析 若想学习分布 ...

  8. PyTorch单机多卡分布式训练教程及代码示例

    导师不是很懂PyTorch的分布式训练流程,我就做了个PyTorch单机多卡的分布式训练介绍,但是他觉得我做的没这篇好PyTorch分布式训练简明教程 - 知乎.这篇讲的确实很好,不过我感觉我做的也还 ...

  9. Pytorch - 分布式训练极简体验

    由于工作需要,最近在补充分布式训练方面的知识.经过一番理论学习后仍觉得意犹未尽,很多知识点无法准确get到(例如:分布式原语scatter.all reduce等代码层面应该是什么样的,ring al ...

最新文章

  1. Stormpath发布了简化移动和前端身份验证的客户端API
  2. vue设置点击电话跳转到手机拨打电话的界面
  3. EXPLAIN 命令详解
  4. 安卓开发 fastjson 解析json使用详解
  5. wifi测试相关(iwconfig,WPA Supplicant用法)
  6. 转:java网络编程-HTTP编程
  7. 【渝粤教育】国家开放大学2018年秋季 0149-21T现代汉语 参考试题
  8. Linux设备驱动模型1——简介和底层架构
  9. centos下mysql更改数据存放目录_CentOS下mysql更改数据存放目录 --转载
  10. 定位html中的背景图,关于背景图的定位和透明度问题(HTML+CSS笔记)
  11. sql2005 reporting service,我总算找到一个完全程序化绑定报表(ado.net dataset 绑定reprot)的方案,谁能再给我些其他建议呢?...
  12. SQL SERVER数据库备份时出现“操作系统错误5(拒绝访问)。BACKUP DATABASE 正在异常终止。”错误的解决办法...
  13. 转 @PathVariable是什么?详情及用法解析
  14. c语言课程设计实训主要目的,《C语言课程设计实验大纲.doc
  15. Day715. 适配不同的类型的switch匹配 -Java8后最重要新特性
  16. 太赞了!分享一个数据科学利器 PyCaret,几行代码搞定从数据处理到模型部署
  17. WEB渗透测试(一)被动信息收集3(RECON-NG)
  18. 其他笔记 - Scum和KANBAN
  19. Verilog信号上升沿检测
  20. 时间序列中的平稳性检验之单位根检验

热门文章

  1. 小学生台灯哪个品牌更护眼?学习专用的护眼台灯品牌
  2. 苹果ppt_惊艳!苹果发布会最爱用的PPT动画,居然这么简单
  3. 传奇版本中云客户端状态在哪里去掉?
  4. android高仿京东秒杀,Android仿京东首页秒杀倒计时
  5. 腾讯游戏平台下载|腾讯游戏平台使用体验
  6. 《灵飞经5·龙生九子》第二十四章 九王朝阙 上
  7. 利用wrk工具压测腾讯CLB
  8. 机房收费系统合作版(四):一路走来感谢有你相伴
  9. 诺奖得主公司CAR-T细胞疗法临床试验现患者死亡,系今年第6例-1
  10. 南亚Patchwork APT组织新活动特点分析