PyTorch单机多卡分布式训练教程及代码示例
导师不是很懂PyTorch的分布式训练流程,我就做了个PyTorch单机多卡的分布式训练介绍,但是他觉得我做的没这篇好PyTorch分布式训练简明教程 - 知乎。这篇讲的确实很好,不过我感觉我做的也还可以,希望大家看完之后能给我一些建议。
目录
1.预备知识
1.1 主机(Host),节点(Node),进程(Process)和工作结点(Worker)。
1.2 World,Rank,Local Rank
1.2.1 World
1.2.2 Rank
1.2.3 Local Rank
2. PyTorch单机多卡数据并行
2.1 多进程启动
2.1.1 多进程启动示例
2.2 启动进程间通信
2.2.1 初始化成功示例
2.2.2 初始化失败示例
2.2.3 进程间通信示例
2.3. 单机多卡数据并行示例
后记:如何拓展到多机多卡?
1.预备知识
多卡训练涉及到多进程和进程间通信,因此有必要先解释一些进程间通信的概念。
1.1 主机(Host),节点(Node),进程(Process)和工作结点(Worker)。
众所周知,每个主机都可以同时运行多个进程,但是在通常情况下每个进程都是做各自的事情,各个进程间是没有关系的。
而在MPI中,我们可以拥有一组能够互相发消息的进程,但是这些进程可以分布在多个主机中,这时我们可以将主机称为节点(Node),进程称为工作结点(Worker)。
由于PyTorch中的主要说法还是进程,所以后面也会统一采用主机和进程的说法。
1.2 World,Rank,Local Rank
对于一组能够互相发消息的进程,我们需要区分每一个进程,因此每个进程会被分配一个序号,称作rank。进程间可以通过指定rank来进行通信。
1.2.1 World
World可以认为是一个集合,由一组能够互相发消息的进程组成。
如下图中假如Host 1的所有进程和Host 2的所有进程都可以进行通信,那么它们就组成了一个World。
因此,world size就表示这组能够互相通信的进程的总数,上图中world size为6。
1.2.2 Rank
Rank可以认为是这组能够互相通信的进程在World中的序号。
1.2.3 Local Rank
Local Rank可以认为是这组能够互相通信的进程在它们相应主机(Host)中的序号。
即在每个Host中,Local rank都是从0开始。
2. PyTorch单机多卡数据并行
数据并行本质上就是增大模型的batch size,但batch size也不是越大越好,所以一般对于大模型才会使用数据并行。
Pytorch进行数据并行主要依赖于它的两个模块multiprocessing和distributed。
所以首先介绍multiprocessing和distributed模块的基本用法。
2.1 多进程启动
由于Python多线程存在GIL(全局解释器锁),为了提高效率,Pytorch实现了一个multiprocessing多进程模块。用于在一个Python进程中启动额外的进程。
2.1.1 多进程启动示例
该程序启动了4个进程,每个进程会输出当前rank,表明与其他进程不同。
#run_multiprocess.py
#运行命令:python run_multiprocess.py
import torch.multiprocessing as mpdef run(rank, size):print("world size:{}. I'm rank {}.".format(size,rank))if __name__ == "__main__":world_size = 4mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为target函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#当前进程会阻塞在join函数,直到相应进程结束。p0.join()p1.join()p2.join()p3.join()
输出结果:
world size:4. I'm rank 1.
world size:4. I'm rank 0.
world size:4. I'm rank 2.
world size:4. I'm rank 3.
2.2 启动进程间通信
虽然启动了多进程,但是此时进程间并不能进行通信,所以PyTorch设计了另一个distributed模块用于进程间通信。
init_process_group函数是distributed模块用于初始化通信模块的函数。
当该函数初始化成功则表明进程间可以进行通信。
2.2.1 初始化成功示例
只有当world size和实际启动的进程数匹配,init_process_group才可以初始化成功。
#multiprocess_comm.py
#运行命令:python multiprocess_comm.pyimport os
import torch.distributed as dist
import torch.multiprocessing as mpdef run(rank, size):#MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。#由于是在单机上,所以用localhost的ip就可以了。os.environ['MASTER_ADDR'] = '127.0.0.1'#端口可以是任意空闲端口os.environ['MASTER_PORT'] = '29500'#通信模块初始化#进程会阻塞在该函数,直到确定所有进程都可以通信。dist.init_process_group('gloo', rank=rank, world_size=size)print("world size:{}. I'm rank {}.".format(size,rank))if __name__ == "__main__":world_size = 4mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#等待进程结束p0.join()p1.join()p2.join()p3.join()
输出结果:
world size:4. I'm rank 1.
world size:4. I'm rank 0.
world size:4. I'm rank 2.
world size:4. I'm rank 3.
2.2.2 初始化失败示例
当将world size设置为2,但是实际却启动了4个进程,此时init_process_group就会报错。
#multiprocess_comm.py
#运行命令:python multiprocess_comm.pyimport os
import torch.distributed as dist
import torch.multiprocessing as mpdef run(rank, size):#MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。#由于是在单机上,所以用localhost的ip就可以了。os.environ['MASTER_ADDR'] = '127.0.0.1'#端口可以是任意空闲端口os.environ['MASTER_PORT'] = '29500'#通信模块初始化#进程会阻塞在该函数,直到确定所有进程都可以通信。dist.init_process_group('gloo', rank=rank, world_size=size)print("world size:{}. I'm rank {}.".format(size,rank))if __name__ == "__main__":world_size = 2mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为target函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#当前进程会阻塞在join函数,直到相应进程结束。p0.join()p1.join()p2.join()p3.join()
输出结果:
RuntimeError: [enforce fail at /opt/conda/conda-bld/pytorch_1623448224956/work/third_party/gloo/gloo/context.cc:27] rank < size. 3 vs 2
2.2.3 进程间通信示例
当init_process_group初始化成功,进程间就可以进行通信了,这里我以集体通信Allreduce为例。
#multiprocess_allreduce.py
#运行命令:python multiprocess_allreduce.pyimport os
import torch
import torch.distributed as dist
import torch.multiprocessing as mpdef run(rank, size):os.environ['MASTER_ADDR'] = '127.0.0.1'os.environ['MASTER_PORT'] = '29500'#通信模块初始化#进程会阻塞在该函数,直到确定所有进程都可以通信。dist.init_process_group('gloo', rank=rank, world_size=size)#每个进程都创建一个Tensor,Tensor值为该进程相应rank。param = torch.tensor([rank])print("rank {}: tensor before allreduce: {}".format(rank,param))#对该Tensor进行Allreduce。dist.all_reduce(param.data, op=dist.ReduceOp.SUM)print("rank {}: tensor after allreduce: {}".format(rank,param))if __name__ == "__main__":world_size = 4mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为target函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#当前进程会阻塞在join函数,直到相应进程结束。p0.join()p1.join()p2.join()p3.join()
输出结果:
rank 0: tensor before allreduce: tensor([0])
rank 2: tensor before allreduce: tensor([2])
rank 3: tensor before allreduce: tensor([3])
rank 1: tensor before allreduce: tensor([1])rank 0: tensor after allreduce: tensor([6])
rank 3: tensor after allreduce: tensor([6])
rank 2: tensor after allreduce: tensor([6])
rank 1: tensor after allreduce: tensor([6])
2.3. 单机多卡数据并行示例
当可以启动多进程,并进行进程间通信后,实际上就已经可以进行单机多卡的分布式训练了。
但是Pytorch为了便于用户使用,所以在这之上又增加了很多更高层的封装,如DistributedDataParallel,DistributedSampler等。
所以为了便于理解这中间的一些流程,这里演示一下不使用这些封装时的单机多卡数据并行。
该示例代码和单机训练主要有两个区别:
(1)需要基于每个进程的rank将模型参数放置到不同的GPU。
(2) 在参数更新前需要对梯度进行Allreduce。
#multiprocess_training.py
#运行命令:python multiprocess_training.py
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
#用于平均梯度的函数
def average_gradients(model):size = float(dist.get_world_size())for param in model.parameters():dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)param.grad.data /= size
#模型
class ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc = nn.Linear(7*7*32, num_classes)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.fc(out)return outdef accuracy(outputs,labels):_, preds = torch.max(outputs, 1) # taking the highest value of prediction.correct_number = torch.sum(preds == labels.data)return (correct_number/len(preds)).item()def run(rank, size):#MASTER_ADDR和MASTER_PORT是通信模块初始化需要的两个环境变量。#由于是在单机上,所以用localhost的ip就可以了。os.environ['MASTER_ADDR'] = '127.0.0.1'#端口可以是任意空闲端口os.environ['MASTER_PORT'] = '29500'dist.init_process_group('gloo', rank=rank, world_size=size)#1.数据集预处理train_dataset = torchvision.datasets.MNIST(root='../data',train=True,transform=transforms.ToTensor(),download=True)training_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)#2.搭建模型#device = torch.device("cuda:{}".format(rank))device = torch.device("cpu")print(device)torch.manual_seed(0)model = ConvNet().to(device)torch.manual_seed(rank)criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr = 0.001,momentum=0.9) # fine tuned the lr#3.开始训练epochs = 15batch_num = len(training_loader)running_loss_history = []for e in range(epochs):for i,(inputs, labels) in enumerate(training_loader):inputs = inputs.to(device) labels = labels.to(device)#前向传播outputs = model(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() #反传loss.backward() #记录lossrunning_loss_history.append(loss.item())#参数更新前需要Allreduce梯度。average_gradients(model)#参数更新optimizer.step() if (i + 1) % 50 == 0 and rank == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f},acc:{:.2f}'.format(e + 1, epochs, i + 1, batch_num,loss.item(),accuracy(outputs,labels)))if __name__ == "__main__":world_size = 4mp.set_start_method("spawn")#创建进程对象#target为该进程要运行的函数,args为target函数的输入参数p0 = mp.Process(target=run, args=(0, world_size))p1 = mp.Process(target=run, args=(1, world_size))p2 = mp.Process(target=run, args=(2, world_size))p3 = mp.Process(target=run, args=(3, world_size))#启动进程p0.start()p1.start()p2.start()p3.start()#当前进程会阻塞在join函数,直到相应进程结束。p0.join()p1.join()p2.join()p3.join()
后记:如何拓展到多机多卡?
在多机多卡环境中初始化init_process_group还需要做一些额外的处理,主要考虑两个问题
(1)需要让其余进程知道rank=0进程的 IP:Port 地址,此时rank=0进程会在相应端口进行监听,其余进程则会给这个IP:Port发消息。这样rank=0进程就可以进行统计,确认初始化是否成功。这一步在PyTorch中是通过设置os.environ['MASTER_ADDR']和os.environ['MASTER_PORT']这两个环境变量来做的。
(2)需要为每个进程确定相应rank,通常采用的做法是给主机编号,因此多机多卡启动时给不同主机传入的参数肯定是不同的。此时参数可以直接手动在每个主机的代码上修改,也可以通过argparse模块在运行时传递不同参数来做。
PyTorch单机多卡分布式训练教程及代码示例相关推荐
- pytorch 单机多卡训练distributedDataParallel
pytorch单机多卡:从DataParallel到DistributedDataParallel 最近想做的实验比较多,于是稍微学习了一下和pytorch相关的加速方式.本人之前一直在使用DataP ...
- pytorch单机多卡的正确打开方式 以及可能会遇到的问题和相应的解决方法
pytorch 单机多卡的正确打开方式 pytorch 使用单机多卡,大体上有两种方式 简单方便的 torch.nn.DataParallel(很 low,但是真的很简单很友好) 使用 torch.d ...
- ShardingSphere RAW JDBC 分布式事务XA 代码示例
ShardingSphere RAW JDBC 分布式事务XA 代码示例 项目工程在:transaction-2pc-xa-raw-jdbc-example 代码简介 基于ShardingSp ...
- 用c语言做RFID读卡程序,2.STM32读卡号读写数据代码示例3.0(C语言)
文件名大小更新时间 2.STM32读卡号读写数据代码示例3.0(C语言)\HFRFID.uvgui.WEIZAI736912016-07-15 2.STM32读卡号读写数据代码示例3.0(C语言)\H ...
- PyTorch 单机多卡操作总结:分布式DataParallel,混合精度,Horovod)
点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨科技猛兽@知乎 来源丨https://zhuanlan.zhihu.com/p/15837505 ...
- 收藏 | PyTorch 单机多卡操作总结
点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨科技猛兽@知乎 来源丨https://zhuanlan ...
- PyTorch单机多卡训练(DDP-DistributedDataParallel的使用)备忘记录
不做具体的原理分析和介绍(因为我也不咋懂),针对我实际修改可用的一个用法介绍,主要是模型训练入口主函数(main_multi_gpu.py)的四处修改. 以上的介绍来源https://zhuanlan ...
- Pytorch 单机多卡训练DDP
多卡训练方式 1.DP--torch.nn.DataParallel 2.DDP--torch.nn.parallel.DistributedDataParallel 通俗一点讲就是用了4张卡训练,就 ...
- Pytorch单机多卡加速
忙了两个月从收到原始数据到最后在工程项目中加载成功完成测试,好像从元旦上班后就再没休息过,昨天项目通过三期评审终于可以喘口气补点作业了.(年前写的文章,今天转过来了) 多卡并行 一定要使用 torch ...
最新文章
- Python xlrd 读取excel表格 常用用法整理
- hdu 3033(分组背包)
- 聚类结果不好怎么办_使用bert-serving生成词向量并聚类可视化
- win10用一会就蓝屏重启_电脑出现蓝屏?教你如何解决
- 【渝粤题库】国家开放大学2021春2509学前教育学题目
- 支持markdown的服务器,基于tornado实现的一个markdown解析服务器
- oracle索引的监控
- Python 图形 GUI 库 pyqtgraph
- 小米3g刷高格固件_今天小米路由器3G到手就刷 老毛子 固件。
- nes模拟器java版_JAVA版手机FC/Nes模拟器vN
- InSAR数据处理软件简介
- 【有利可图网】PS实战系列:PS美化婚纱照片
- STM32读写ADXL345 中断功能
- 数据分析精选案例:3行代码上榜Kaggle学生评估赛
- 数据结构与算法真的那么重要么?
- ARM体系结构2:汇编指令集
- mysql修改密码总是报错_mysql修改密码报错 | 吴老二
- python ddt安装
- 单片机应用系统的基本组成
- 用计算机做一克拉等于多少克的单位换算,克和克拉怎么换算(克拉和克的换算单位是多少)...