最近在做图像分类实验时,在4个gpu上使用pytorch的DataParallel 函数并行跑程序,批次为16时会报如下所示的错误:
  RuntimeError: CUDA out of memory. Tried to allocate 858.00 MiB (GPU 3; 10.92 GiB total capacity; 10.10 GiB already allocated; 150.69 MiB free; 10.13 GiB reserved in total by PyTorch)

  实验发现,每块gpu最多可以跑2条数据,但是我又想设置batch_size=16,参考https://zhuanlan.zhihu.com/p/86441879了解到transformer-XL官方写的BalancedDataParallel 函数,用来解决DataParallel 显存使用不平衡的问题(参考代码见最后)。
  为了理解BalancedDataParallel 函数用法,我们先来弄清楚几个问题。
1,DataParallel 函数是如何工作的?
  首先将模型加载到主 GPU 上,然后再将模型复制到各个指定的从 GPU 中,然后将输入数据按 batch 维度进行划分,具体来说就是每个 GPU 分配到的数据 batch 数量是总输入数据的 batch 除以指定 GPU 个数。每个 GPU 将针对各自的输入数据独立进行 forward计算,之后会把计算结果传到主GPU 上完成梯度计算和参数更新,最后将更新后的参数复制到从 GPU 中,这样就完成了一次迭代计算。参考https://blog.csdn.net/zhjm07054115/article/details/104799661当gpu=2,batch_size=30时,我们可以从下图清楚的看到首先会在两个gpu上分别分配15条数据,进行forward计算,之后汇总结果再进行梯度计算和参数更新。
  我们可以看到反向传播计算和参数更新完全放在主gpu上进行的,这样会造成显存使用不平衡的问题。

2,梯度累加
  参考https://blog.csdn.net/wuzhongqiang/article/details/102572324做的反向传播梯度累加实验,发现pytorch在反向传播的时候,默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

  了解了DataParallel 函数和梯度累加后,我们就可以来解决显存使用不平衡问题以及如何在显存固定的情况下加大训练批次。
  首先,简单介绍BalancedDataParallel 用法【下图截取自https://github.com/Link-Li/Balanced-DataParallel】

  简单解释一下:当我们需要在3个gpu并行跑程序,每个gpu最多一次可以处理3条数据,分配是[3,3,3],那么3个gpu最多可以同时处理9条数据,也就是batch_size最大可设为9,因为主gpu上还要进行反向传播,所以这里我们设置主gpu处理2条数据,分布就是[2,3,3],batch_size=8。
  此时如果我们想加大批次,使得batch_size=16,那么分布应该是[4,6,6],但是我们知道每个gpu最多可以处理3条数据,这里就用到梯度累加的方法了,即上图中的acc_grad,acc_grad参数表示将batch_size分成多少份送入网络,当acc_grad=2,表示我们会先将16个数据分成2份,每份有8条数据,每次输入8条数据分给3个gpu做并行训练,forward计算结果放入主gpu上进行反向传播,由于梯度可以累加,循环两次后,再更新参数。这样做不仅可以缓解显存不平衡问题也可以解决显存不足的问题。
  下面是我根据https://blog.csdn.net/zhjm07054115/article/details/104799661做了修改,加上BalancedDataParallel 完整代码:

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from data_parallel_balance import BalancedDataParallel# Dataset
class RandomDataset(Dataset):def __init__(self, size, length):self.len = lengthself.data = torch.randn(length, size)self.target=np.random.randint(3,size=length)def __getitem__(self, index):label=torch.tensor(self.target[index])return self.data[index],labeldef __len__(self):return self.len# model
class Model(nn.Module):def __init__(self, input_size, output_size):super(Model, self).__init__()self.fc = nn.Linear(input_size, output_size)def forward(self, input):output = self.fc(input)print("\tIn Model: input size", input.size(),"output size", output.size())return output# trian
def train(rand_loader,model,optimizer,criterion):train_loss=0# trainmodel.train()optimizer.zero_grad()for image,target in rand_loader:print('image:',image.shape)if batch_chunk > 0:image_chunks = torch.chunk(image, batch_chunk, 0)target_chunks = torch.chunk(target, batch_chunk, 0)for i in range(len(image_chunks)):print('image_chunks:',i)img=image_chunks[i].to(device)lab=target_chunks[i].to(device)out=model(img)print("Chunks_Outputs: input size", img.size(),"output_size", out.size())loss=criterion(out,lab)# print('{} chunk,loss:{}.'.format(i,loss))train_loss+=loss.item()loss = loss.float().mean().type_as(loss) / len(image_chunks)loss.backward()else:image = image.to(device)target=target.to(device)output = model(image)loss=criterion(output,target)train_loss=loss.item()print("Outside: input size", image.size(),"output_size", output.size())optimizer.step()  return train_lossif __name__=="__main__":input_size = 5output_size = 3batch_size = 32data_size = 70batch_chunk=2gpu0_bsz=8epochs=2device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# datarand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),batch_size=batch_size, shuffle=True)# model                 model = Model(input_size, output_size)if torch.cuda.device_count() > 1:print("Let's use", torch.cuda.device_count(), "GPUs!")if gpu0_bsz >= 0:   model = BalancedDataParallel(gpu0_bsz // batch_chunk, model, dim=0)else:model = nn.DataParallel(model)model.to(device)# optimizeroptimizer= torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)# losscriterion=nn.CrossEntropyLoss()for epoch in range(epochs):print('Epoch:',epoch)train(rand_loader,model,optimizer,criterion)

参考:
pytorch多gpu并行训练
transformer-XL的官方代码
BalancedDataParallel 参考代码
PyTorch-4 nn.DataParallel 数据并行详解
Pytorch反向传播中的细节-计算梯度时的默认累加

欢迎大家留言批评指正!

pytorch多gpu DataParallel 及梯度累加解决显存不平衡和显存不足问题相关推荐

  1. pytorch 多GPU训练总结(DataParallel的使用)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_40087578/arti ...

  2. Pytorch分布式训练/多卡训练(二) —— Data Parallel并行(DDP)(2.2)(代码示例)(BN同步主卡保存梯度累加多卡测试inference随机种子seed)

    DDP的使用非常简单,因为它不需要修改你网络的配置.其精髓只有一句话 model = DistributedDataPrallel(model, device_ids=[local_rank], ou ...

  3. Gradient Accumulation 梯度累加 (Pytorch)

    我们在训练神经网络的时候,batch_size的大小会对最终的模型效果产生很大的影响.一定条件下,batch_size设置的越大,模型就会越稳定.batch_size的值通常设置在 8-32 之间,但 ...

  4. Pytorch的nn.DataParallel详细解析

    前言 pytorch中的GPU操作默认是异步的,当调用一个使用GPU的函数时,这些操作会在特定设备上排队但不一定在稍后执行.这就使得pytorch可以进行并行计算.但是pytorch异步计算的效果对调 ...

  5. pytorch多gpu并行训练操作指南

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 来源:知乎 作者:link-web 链接:https://zhuanlan.zhi ...

  6. pytorch多gpu并行训练

    pytorch多gpu并行训练 link-web 转自:pytorch多gpu并行训练 - 知乎 目录(目录不可点击) 说明 1.和DataParallel的区别 2.如何启动程序的时候 2.1 单机 ...

  7. [源码解析] PyTorch 分布式(2) ----- DataParallel(上)

    [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 文章目录 [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 0x00 摘要 ...

  8. 【计数网络】梯度累加增加LCFCN的BatchSize

    LCFCN是一个以分割网络为基础的专用于计数的网络. LCFCN模型由于loss的特殊性 batch size 目前只能为1 LCFCN代码 https://github.com/ElementAI/ ...

  9. Pytorch多GPU笔记

    Pytorch分布式笔记 Pytorch多GPU计算笔记 DP和DDP的区别 DP DDP Apex amp的使用 apex.parallel.DistributedDataParallel的使用 D ...

最新文章

  1. 领歌leangoo敏捷工具个人工作台功能
  2. 我被编程语言PUA了!用互联网黑话写代码,每天都在“赋能”变量
  3. Postgis常用函数
  4. 个人对持续集成的理解和实践
  5. api过滤器_了解播放过滤器API
  6. 出发a标签_以用户标签为例,复盘B端产品的需求挖掘方法论
  7. Java文件类boolean setLastModified(long set_new_time)方法,包含示例
  8. mysql安装忘了root_MySQL - 安装:MySQL忘记root密码的解决办法
  9. iOS学习笔记-自定义过渡动画
  10. Responsive自适应网页设计与ResponsiveColumn自适应列实例
  11. iOS开源项目周报1229
  12. 擦拭法 java 泛型_廖雪峰Java4反射与泛型-3范型-4擦拭法
  13. 如何将计算机网络作为热点,教你如何三步让笔记本电脑做wifi热点??
  14. Python数据处理之导入导出excel数据
  15. MATLAB Coder工具箱介绍【如何利用MATLAB Coder将.m文件生成C/C++代码?】
  16. [JZOJ5551] 【NOI2019模拟6.24】旅途【最短路】
  17. 独立网店运营简要分析
  18. Photoshop CS6 在 4k屏上非常小的解决办法
  19. C# 如何给Word文档设置背景颜色和背景图片
  20. 华视100UC 身份证阅读器 Java

热门文章

  1. DELL SCv3020存储日常运维
  2. ABP VNext学习日记24
  3. JavaScript实现简易ATM
  4. 禾赛科技“梦碎”科创板:营收递增、由盈转亏,在专利官司中败退
  5. 种春草肥禾,织数字天下
  6. Ubuntu入门,Ubuntu基本软件,Ubuntu起始配置
  7. 红米K50电竞版上手体验
  8. 对于c++面向对象的深刻认识和理解--哲学角度看问题(源生论)
  9. 【实验3 循环结构】7-14 循环结构 —— 中国古代著名算题。趣味题目:物不知其数。
  10. Delphi XE10.x实现Android下Https双向认证