导师不是很懂PyTorch的分布式训练流程,我就做了个PyTorch单机多卡的分布式训练介绍,但是他觉得我做的没这篇好PyTorch分布式训练简明教程 - 知乎。这篇讲的确实很好,不过我感觉我做的也还可以,希望大家看完之后能给我一些建议。

目录

1.预备知识

1.1 主机(Host),节点(Node),进程(Process)和工作结点(Worker)。

1.2 World,Rank,Local Rank

1.2.1 World

1.2.2 Rank

1.2.3 Local Rank

2. PyTorch单机多卡数据并行

2.1 多进程启动

2.1.1 多进程启动示例

2.2 启动进程间通信

2.2.1 初始化成功示例

2.2.2 初始化失败示例

2.2.3 进程间通信示例

2.3. 单机多卡数据并行示例

后记:如何拓展到多机多卡?


1.预备知识

多卡训练涉及到多进程和进程间通信,因此有必要先解释一些进程间通信的概念。

1.1 主机(Host),节点(Node),进程(Process)和工作结点(Worker)。

众所周知,每个主机都可以同时运行多个进程,但是在通常情况下每个进程都是做各自的事情,各个进程间是没有关系的。

而在MPI中,我们可以拥有一组能够互相发消息的进程,但是这些进程可以分布在多个主机中,这时我们可以将主机称为节点(Node),进程称为工作结点(Worker)。

由于PyTorch中的主要说法还是进程,所以后面也会统一采用主机和进程的说法。

1.2 World,Rank,Local Rank

对于一组能够互相发消息的进程,我们需要区分每一个进程,因此每个进程会被分配一个序号,称作rank。进程间可以通过指定rank来进行通信。

1.2.1 World

World可以认为是一个集合,由一组能够互相发消息的进程组成。

如下图中假如Host 1的所有进程和Host 2的所有进程都可以进行通信,那么它们就组成了一个World。

因此,world size就表示这组能够互相通信的进程的总数,上图中world size为6。

1.2.2 Rank

Rank可以认为是这组能够互相通信的进程在World中的序号。

1.2.3 Local Rank

Local Rank可以认为是这组能够互相通信的进程在它们相应主机(Host)中的序号。

即在每个Host中,Local rank都是从0开始。

2. PyTorch单机多卡数据并行

数据并行本质上就是增大模型的batch size,但batch size也不是越大越好,所以一般对于大模型才会使用数据并行。

Pytorch进行数据并行主要依赖于它的两个模块multiprocessing和distributed。

所以首先介绍multiprocessing和distributed模块的基本用法。

2.1 多进程启动

由于Python多线程存在GIL(全局解释器锁),为了提高效率,Pytorch实现了一个multiprocessing多进程模块。用于在一个Python进程中启动额外的进程。

2.1.1 多进程启动示例

该程序启动了4个进程,每个进程会输出当前rank,表明与其他进程不同。

#run_multiprocess.py
#运行命令:python run_multiprocess.py
import torch.multiprocessing as mpdef run(rank, size):print("world size:{}. I'm rank {}.".format(size,rank))if __name__ == "__main__":world_size = 4mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为target函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#当前进程会阻塞在join函数,直到相应进程结束。p0.join()p1.join()p2.join()p3.join()

输出结果:

world size:4. I'm rank 1.
world size:4. I'm rank 0.
world size:4. I'm rank 2.
world size:4. I'm rank 3.

2.2 启动进程间通信

虽然启动了多进程,但是此时进程间并不能进行通信,所以PyTorch设计了另一个distributed模块用于进程间通信。

init_process_group函数是distributed模块用于初始化通信模块的函数。

当该函数初始化成功则表明进程间可以进行通信。

2.2.1 初始化成功示例

只有当world size和实际启动的进程数匹配,init_process_group才可以初始化成功。

#multiprocess_comm.py
#运行命令:python multiprocess_comm.pyimport os
import torch.distributed as dist
import torch.multiprocessing as mpdef run(rank, size):#MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。#由于是在单机上,所以用localhost的ip就可以了。os.environ['MASTER_ADDR'] = '127.0.0.1'#端口可以是任意空闲端口os.environ['MASTER_PORT'] = '29500'#通信模块初始化#进程会阻塞在该函数,直到确定所有进程都可以通信。dist.init_process_group('gloo', rank=rank, world_size=size)print("world size:{}. I'm rank {}.".format(size,rank))if __name__ == "__main__":world_size = 4mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#等待进程结束p0.join()p1.join()p2.join()p3.join()

输出结果:

world size:4. I'm rank 1.
world size:4. I'm rank 0.
world size:4. I'm rank 2.
world size:4. I'm rank 3.

2.2.2 初始化失败示例

当将world size设置为2,但是实际却启动了4个进程,此时init_process_group就会报错。

#multiprocess_comm.py
#运行命令:python multiprocess_comm.pyimport os
import torch.distributed as dist
import torch.multiprocessing as mpdef run(rank, size):#MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。#由于是在单机上,所以用localhost的ip就可以了。os.environ['MASTER_ADDR'] = '127.0.0.1'#端口可以是任意空闲端口os.environ['MASTER_PORT'] = '29500'#通信模块初始化#进程会阻塞在该函数,直到确定所有进程都可以通信。dist.init_process_group('gloo', rank=rank, world_size=size)print("world size:{}. I'm rank {}.".format(size,rank))if __name__ == "__main__":world_size = 2mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为target函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#当前进程会阻塞在join函数,直到相应进程结束。p0.join()p1.join()p2.join()p3.join()

输出结果:

RuntimeError: [enforce fail at /opt/conda/conda-bld/pytorch_1623448224956/work/third_party/gloo/gloo/context.cc:27] rank < size. 3 vs 2

2.2.3 进程间通信示例

当init_process_group初始化成功,进程间就可以进行通信了,这里我以集体通信Allreduce为例。

#multiprocess_allreduce.py
#运行命令:python multiprocess_allreduce.pyimport os
import torch
import torch.distributed as dist
import torch.multiprocessing as mpdef run(rank, size):os.environ['MASTER_ADDR'] = '127.0.0.1'os.environ['MASTER_PORT'] = '29500'#通信模块初始化#进程会阻塞在该函数,直到确定所有进程都可以通信。dist.init_process_group('gloo', rank=rank, world_size=size)#每个进程都创建一个Tensor,Tensor值为该进程相应rank。param = torch.tensor([rank])print("rank {}: tensor before allreduce: {}".format(rank,param))#对该Tensor进行Allreduce。dist.all_reduce(param.data, op=dist.ReduceOp.SUM)print("rank {}: tensor after allreduce: {}".format(rank,param))if __name__ == "__main__":world_size = 4mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为target函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#当前进程会阻塞在join函数,直到相应进程结束。p0.join()p1.join()p2.join()p3.join()

输出结果:

rank 0: tensor before allreduce: tensor([0])
rank 2: tensor before allreduce: tensor([2])
rank 3: tensor before allreduce: tensor([3])
rank 1: tensor before allreduce: tensor([1])rank 0: tensor after allreduce: tensor([6])
rank 3: tensor after allreduce: tensor([6])
rank 2: tensor after allreduce: tensor([6])
rank 1: tensor after allreduce: tensor([6])

2.3. 单机多卡数据并行示例

当可以启动多进程,并进行进程间通信后,实际上就已经可以进行单机多卡的分布式训练了。

但是Pytorch为了便于用户使用,所以在这之上又增加了很多更高层的封装,如DistributedDataParallel,DistributedSampler等。

所以为了便于理解这中间的一些流程,这里演示一下不使用这些封装时的单机多卡数据并行。

该示例代码和单机训练主要有两个区别:

(1)需要基于每个进程的rank将模型参数放置到不同的GPU。

(2) 在参数更新前需要对梯度进行Allreduce。

#multiprocess_training.py
#运行命令:python multiprocess_training.py
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
#用于平均梯度的函数
def average_gradients(model):size = float(dist.get_world_size())for param in model.parameters():dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)param.grad.data /= size
#模型
class ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc = nn.Linear(7*7*32, num_classes)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.fc(out)return outdef accuracy(outputs,labels):_, preds = torch.max(outputs, 1) # taking the highest value of prediction.correct_number = torch.sum(preds == labels.data)return (correct_number/len(preds)).item()def run(rank, size):#MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。#由于是在单机上,所以用localhost的ip就可以了。os.environ['MASTER_ADDR'] = '127.0.0.1'#端口可以是任意空闲端口os.environ['MASTER_PORT'] = '29500'dist.init_process_group('gloo', rank=rank, world_size=size)#1.数据集预处理train_dataset = torchvision.datasets.MNIST(root='../data',train=True,transform=transforms.ToTensor(),download=True)training_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)#2.搭建模型#device = torch.device("cuda:{}".format(rank))device = torch.device("cpu")print(device)torch.manual_seed(0)model = ConvNet().to(device)torch.manual_seed(rank)criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr = 0.001,momentum=0.9) # fine tuned the lr#3.开始训练epochs = 15batch_num = len(training_loader)running_loss_history = []for e in range(epochs):for i,(inputs, labels) in enumerate(training_loader):inputs = inputs.to(device) labels = labels.to(device)#前向传播outputs = model(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() #反传loss.backward() #记录lossrunning_loss_history.append(loss.item())#参数更新前需要Allreduce梯度。average_gradients(model)#参数更新optimizer.step() if (i + 1) % 50 == 0 and rank == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f},acc:{:.2f}'.format(e + 1, epochs, i + 1, batch_num,loss.item(),accuracy(outputs,labels)))if __name__ == "__main__":world_size = 4mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为target函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#当前进程会阻塞在join函数,直到相应进程结束。p0.join()p1.join()p2.join()p3.join()

后记:如何拓展到多机多卡?

在多机多卡环境中初始化init_process_group还需要做一些额外的处理,主要考虑两个问题

(1)需要让其余进程知道rank=0进程的 IP:Port 地址,此时rank=0进程会在相应端口进行监听,其余进程则会给这个IP:Port发消息。这样rank=0进程就可以进行统计,确认初始化是否成功。这一步在PyTorch中是通过设置os.environ['MASTER_ADDR']和os.environ['MASTER_PORT']这两个环境变量来做的。

(2)需要为每个进程确定相应rank,通常采用的做法是给主机编号,因此多机多卡启动时给不同主机传入的参数肯定是不同的。此时参数可以直接手动在每个主机的代码上修改,也可以通过argparse模块在运行时传递不同参数来做。

PyTorch单机多卡分布式训练教程及代码示例相关推荐

  1. pytorch 单机多卡训练distributedDataParallel

    pytorch单机多卡:从DataParallel到DistributedDataParallel 最近想做的实验比较多,于是稍微学习了一下和pytorch相关的加速方式.本人之前一直在使用DataP ...

  2. pytorch单机多卡的正确打开方式 以及可能会遇到的问题和相应的解决方法

    pytorch 单机多卡的正确打开方式 pytorch 使用单机多卡,大体上有两种方式 简单方便的 torch.nn.DataParallel(很 low,但是真的很简单很友好) 使用 torch.d ...

  3. ShardingSphere RAW JDBC 分布式事务XA 代码示例

    ShardingSphere RAW JDBC 分布式事务XA 代码示例 项目工程在:transaction-2pc-xa-raw-jdbc-example 代码简介     基于ShardingSp ...

  4. 用c语言做RFID读卡程序,2.STM32读卡号读写数据代码示例3.0(C语言)

    文件名大小更新时间 2.STM32读卡号读写数据代码示例3.0(C语言)\HFRFID.uvgui.WEIZAI736912016-07-15 2.STM32读卡号读写数据代码示例3.0(C语言)\H ...

  5. PyTorch 单机多卡操作总结:分布式DataParallel,混合精度,Horovod)

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨科技猛兽@知乎 来源丨https://zhuanlan.zhihu.com/p/15837505 ...

  6. 收藏 | PyTorch 单机多卡操作总结

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨科技猛兽@知乎 来源丨https://zhuanlan ...

  7. PyTorch单机多卡训练(DDP-DistributedDataParallel的使用)备忘记录

    不做具体的原理分析和介绍(因为我也不咋懂),针对我实际修改可用的一个用法介绍,主要是模型训练入口主函数(main_multi_gpu.py)的四处修改. 以上的介绍来源https://zhuanlan ...

  8. Pytorch 单机多卡训练DDP

    多卡训练方式 1.DP--torch.nn.DataParallel 2.DDP--torch.nn.parallel.DistributedDataParallel 通俗一点讲就是用了4张卡训练,就 ...

  9. Pytorch单机多卡加速

    忙了两个月从收到原始数据到最后在工程项目中加载成功完成测试,好像从元旦上班后就再没休息过,昨天项目通过三期评审终于可以喘口气补点作业了.(年前写的文章,今天转过来了) 多卡并行 一定要使用 torch ...

最新文章

  1. Python xlrd 读取excel表格 常用用法整理
  2. hdu 3033(分组背包)
  3. 聚类结果不好怎么办_使用bert-serving生成词向量并聚类可视化
  4. win10用一会就蓝屏重启_电脑出现蓝屏?教你如何解决
  5. 【渝粤题库】国家开放大学2021春2509学前教育学题目
  6. 支持markdown的服务器,基于tornado实现的一个markdown解析服务器
  7. oracle索引的监控
  8. Python 图形 GUI 库 pyqtgraph
  9. 小米3g刷高格固件_今天小米路由器3G到手就刷 老毛子 固件。
  10. nes模拟器java版_JAVA版手机FC/Nes模拟器vN
  11. InSAR数据处理软件简介
  12. 【有利可图网】PS实战系列:PS美化婚纱照片
  13. STM32读写ADXL345 中断功能
  14. 数据分析精选案例:3行代码上榜Kaggle学生评估赛
  15. 数据结构与算法真的那么重要么?
  16. ARM体系结构2:汇编指令集
  17. mysql修改密码总是报错_mysql修改密码报错 | 吴老二
  18. python ddt安装
  19. 单片机应用系统的基本组成
  20. 用计算机做一克拉等于多少克的单位换算,克和克拉怎么换算(克拉和克的换算单位是多少)...

热门文章

  1. CAD二次开发之LISP读取excel数据
  2. openlayers设置黑色底图,自定义修改天地图颜色
  3. prometheus:原理和部署
  4. 基于Matlab的QPSK系统设计(多径瑞利信道,采用jakes模型以及指数模型)
  5. 《西方经济学》笔记1-需求曲线
  6. c#连接西门子plc
  7. 瑞云渲染 | 全面支持Anima®4渲染插件,实现高精度的群集角色!
  8. 关于hive统计周wau、保留率需求的几种思路
  9. 第九届蓝桥杯C++B组题解
  10. 向外国大师学习敏捷式开发?嫦娥掩面而笑