本文介绍最简单的pytorch分布式训练方法:使用torch.nn.DataParallel这个API来实现分布式训练。环境为单机多gpu,不妨假设有4个可用的gpu。

一、构建方法

使用这个API实现分布式训练的步骤非常简单,总共分为3步骤:
1、创建一个model,并将该model推到某个gpu上(这个gpu也将作为output_device,后面具体解释含义),不妨假设推到第0号gpu上,

device = torch.device("cuda:0")
model.to(device)

2、将数据推到output_device对应的gpu上,

data = data.to(device)

3、使用torch.nn.DataParallel这个API来在0,1,2,3四个gpu上构建分布式模型,

model = torch.nn.DataParallel(model, device_ids=[0,1,2,3], output_device=0)

然后这个model就可以像普通的单gpu上的模型一样开始训练了。

二、原理详解

2.1 原理图

  首先通过图来看一下这个最简单的分布式训练API的工作原理,然后结合代码详细阐述。

将模型和数据推入output_device(也就是0号)gpu上。

0号gpu将当前模型在其他几个gpu上进行复制,同步模型的parameter、buffer和modules等;将当前batch尽可能平均的分为len(device)=4份,分别推给每一个设备,并开启多线程分别在每个设备上进行前向传播,得到各自的结果,最后将各自的结果全部汇总在一起,拷贝回0号gpu。

在0号gpu进行反向传播和模型的参数更新,并将结果同步给其他几个gpu,即完成了一个batch的训练。

2.2 代码原理

  通过分析torch.nn.DataParallel的代码,可以看到具体的过程,这里重点看一下几个关键的地方。

# 继承自nn.Module,只要实现__init__和forward函数即可
class DataParallel(Module):# 构造函数里没有什么关键内容,主要是根据传进来的model、device_ids和output_device进行一些变量生成def __init__(self, module, device_ids=None, output_device=None, dim=0):super(DataParallel, self).__init__()device_type = _get_available_device_type()if device_type is None:self.module = moduleself.device_ids = []returnif device_ids is None:device_ids = _get_all_device_indices()if output_device is None:output_device = device_ids[0]self.dim = dimself.module = moduleself.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))self.output_device = _get_device_index(output_device, True)self.src_device_obj = torch.device(device_type, self.device_ids[0])_check_balance(self.device_ids)if len(self.device_ids) == 1:self.module.to(self.src_device_obj)def forward(self, *inputs, **kwargs):if not self.device_ids:return self.module(*inputs, **kwargs)for t in chain(self.module.parameters(), self.module.buffers()):if t.device != self.src_device_obj:raise RuntimeError("module must have its parameters and buffers ""on device {} (device_ids[0]) but found one of ""them on device: {}".format(self.src_device_obj, t.device))inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)if len(self.device_ids) == 1:return self.module(*inputs[0], **kwargs[0])# 在每个gpu上都复制一个modelreplicas = self.replicate(self.module, self.device_ids[:len(inputs)])# 开启多线程进行前向传播,得到结果outputs = self.parallel_apply(replicas, inputs, kwargs)# 将每个gpu上得到的结果都gather到0号gpu上return self.gather(outputs, self.output_device)def replicate(self, module, device_ids):return replicate(module, device_ids, not torch.is_grad_enabled())def scatter(self, inputs, kwargs, device_ids):return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)def parallel_apply(self, replicas, inputs, kwargs):return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])def gather(self, outputs, output_device):return gather(outputs, output_device, dim=self.dim)

再看一下parallel_apply这个关键的函数,

def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):assert len(modules) == len(inputs)if kwargs_tup is not None:assert len(modules) == len(kwargs_tup)else:kwargs_tup = ({},) * len(modules)if devices is not None:assert len(modules) == len(devices)else:devices = [None] * len(modules)devices = list(map(lambda x: _get_device_index(x, True), devices))# 创建一个互斥锁,防止前后两个batch的数据覆盖lock = threading.Lock()results = {}grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()# 线程的target函数,实现每个gpu上进行推理,其中i为gpu编号def _worker(i, module, input, kwargs, device=None):torch.set_grad_enabled(grad_enabled)if device is None:device = get_a_var(input).get_device()try:# 根据当前gpu编号确定推理硬件环境with torch.cuda.device(device), autocast(enabled=autocast_enabled):# this also avoids accidental slicing of `input` if it is a Tensorif not isinstance(input, (list, tuple)):input = (input,)output = module(*input, **kwargs)# 锁住赋值,防止后一个batch的数据将前一个batch的结果覆盖with lock:results[i] = outputexcept Exception:with lock:results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))if len(modules) > 1:# 创建多个线程,进行不同gpu的前向推理threads = [threading.Thread(target=_worker,args=(i, module, input, kwargs, device))for i, (module, input, kwargs, device) inenumerate(zip(modules, inputs, kwargs_tup, devices))]for thread in threads:thread.start()for thread in threads:thread.join()else:_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])# 将不同gpu上推理的结果打包起来,后面会gather到output_device上outputs = []for i in range(len(inputs)):output = results[i]if isinstance(output, ExceptionWrapper):output.reraise()outputs.append(output)return outputs

结论

  至此我们看到了torch.nn.DataParallel模块进行分布式训练的原理,数据和模型首先推入output_device对应的gpu,然后将分成多个子batch的数据和模型分别推给其他gpu,每个gpu单独处理各自的子batch,结果再打包回原output_device对应的gpu算梯度和更新参数,如此循环往复,其本质是一个单进程多线程的并发程序。
  由此我们也很容易得到torch.nn.DataParallel模块进行分布式的缺点,
1、每个batch的数据先分发到各gpu上,结果再打包回output_device上,在output_device一个gpu上进行梯度计算和参数更新,再把更新同步给其他gpu上的model。其中涉及数据的来回拷贝,网络通信耗时严重,GPU使用率低。
2、这种模式只支持单机多gpu的硬件拓扑结构,不支持Apex的混合精度训练等。
3、torch.nn.DataParallel也没有很完整的考虑到多个gpu做数据并行的一些问题,比如batchnorm,在训练时各个gpu上的batchnorm的mean和variance是子batch的计算结果,而不是原来整个batch的值,可能会导致训练不稳定影响收敛等问题。

pytorch分布式训练(一):torch.nn.DataParallel相关推荐

  1. torch.nn.DataParallel()--多个GPU加速训练

    公司配备多卡的GPU服务器,当我们在上面跑程序的时候,当迭代次数或者epoch足够大的时候,我们通常会使用nn.DataParallel函数来用多个GPU来加速训练.一般我们会在代码中加入以下这句: ...

  2. 新手手册:Pytorch分布式训练

    文 | 花花@机器学习算法与自然语言处理 单位 | SenseTime 算法研究员 目录 0X01 分布式并行训练概述 0X02 Pytorch分布式数据并行 0X03 手把手渐进式实战 A. 单机单 ...

  3. PyTorch 分布式训练DDP 单机多卡快速上手

    PyTorch 分布式训练DDP 单机多卡快速上手 本文旨在帮助新人快速上手最有效的 PyTorch 单机多卡训练,对于 PyTorch 分布式训练的理论介绍.多方案对比,本文不做详细介绍,有兴趣的读 ...

  4. 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练

    文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...

  5. Pytorch - 分布式训练极简体验

    由于工作需要,最近在补充分布式训练方面的知识.经过一番理论学习后仍觉得意犹未尽,很多知识点无法准确get到(例如:分布式原语scatter.all reduce等代码层面应该是什么样的,ring al ...

  6. pytorch分布式训练 DistributedSampler、DistributedDataParallel

    pytorch分布式训练 DistributedSampler.DistributedDataParallel   大家好,我是亓官劼(qí guān jié ),在[亓官劼]公众号.CSDN.Git ...

  7. 【分布式】Pytorch分布式训练原理和实战

    [分布式]基于Horovod的Pytorch分布式训练原理和实战 并行方法: 1. 模型并行 2. 数据并行 3. 两者之间的联系 更新方法: 1. 同步更新 2. 异步更新 分布式算法: 1. Pa ...

  8. PyTorch分布式训练

    PyTorch分布式训练 PyTorch 是一个 Python 优先的深度学习框架,能够在强大的 GPU 加速基础上实现张量和动态神经网络.PyTorch的一大优势就是它的动态图计算特性. Licen ...

  9. pytorch深度学习框架—torch.nn模块(一)

    pytorch深度学习框架-torch.nn模块 torch.nn模块中包括了pytorch中已经准备好的层,方便使用者调用构建的网络.包括了卷积层,池化层,激活函数层,循环层,全连接层. 卷积层 p ...

最新文章

  1. 浪潮发布重磅产品“元脑”,专注AI全栈能力输出
  2. .net分布式系统架构的思路
  3. java double储存原理_Java内存分配原理
  4. 娱乐社交,玩票大的!2021网易云信“融合通信开发者大赛”正式开赛!
  5. 当我们在谈 .NET Core 跨平台时,我们在谈些什么?--学习笔记
  6. 【Java】第一阶段练习题
  7. gdiplus判断一个点是否在圆弧线上_福建教师招聘考试小学数学面试教案:圆的认识...
  8. 聚划算的夜场新生意 “三叉戟”打通夜间消费命脉
  9. Linux服务器的攻防技术
  10. webservice项目部署部署到weblogic报错之解决方案
  11. HDU1642 UVA167 UVALive5227 The Sultan's Successors题解
  12. 工作流系统之三十三 撤回的实现
  13. Microsoft Windows Sharepoint Services V3.0 安装图示
  14. 正二十面体的各个面位置点
  15. h5混合开发框架初识
  16. 用连续自然数之和来表达整数
  17. 制作3D实时交互影像产品,需要用到的技术和软件!
  18. 开箱即用!使用Rancher 2.3 启用Istio初体验
  19. 智慧点餐系统多方面优化餐厅运作效率
  20. 开源了一套wms系统,支持lodop和网页打印入库单、出库单。

热门文章

  1. 虚拟专题:知识图谱 | 知识图谱多跳问答推理研究进展、挑战与展望
  2. 【项目管理】变更管理与过程改进
  3. 【VBS】归纳 Visual Basic Script 内置函数
  4. 贪婪的送礼者(洛谷P1201题题解,Java语言描述)
  5. kali linux2.0下MariaDB修改密码
  6. 《SolidWorks 2013中文版机械设计从入门到精通》一1.4 操作环境设置
  7. 前端开发者必备的20个文档和在线工具
  8. StarkSoft题库管理系统(二)--生成word格式试卷
  9. 模式窗体中调用父页面Javascript
  10. Android蓝牙设备名显示修改