pytorch 单机多卡的正确打开方式

pytorch 使用单机多卡,大体上有两种方式

  • 简单方便的 torch.nn.DataParallel(很 low,但是真的很简单很友好)
  • 使用 torch.distributed 加速并行训练(推荐,但是不友好)

首先讲一下这两种方式分别的优缺点

  • nn.DataParallel
    优点:就是简单
    缺点就是:所有的数据要先load到主GPU上,然后再分发给每个GPU去train,注意这时候主GPU的显存占用很大,你想提升batch_size,那你的主GPU就会限制你的batch_size,所以其实多卡提升速度的效果很有限
    注意: 模型是会被copy到每一张卡上的,而且对于每一个BATCH的数据,你设置的batch_size会被分成几个部分,分发给每一张卡,意味着,batch_size最好是卡的数量n的倍数,比如batch_size=6,而你有n=4张卡,那你实际上代码跑起来只能用3张卡,因为6整除3
  • torch.distributed
    优点: 避免了nn.DataParallel的主要缺点,数据不会再分发到主卡上,所以所有卡的显存占用很均匀
    缺点: 不友好,调代码需要点精力,有很多需要注意的问题,我后面会列出

接下来展示如何使用两种方法以及相关注意事项

一、torch.nn.DataParallel

主要的修改就是用nn.DataParallel处理一下你的model
model = nn.DataParallel(model.cuda(), device_ids=gpus, output_device=gpus[0])

这个很简单,就直接上个例子,根据这个例子去改你的代码就好,主要就是注意对model的修改
注意model要放在主GPU上:model.to(device)

# main.py
import torch
import torch.distributed as distgpus = [0, 1, 2, 3]
torch.cuda.set_device('cuda:{}'.format(gpus[0]))train_dataset = ...
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=...)model = ...
model = nn.DataParallel(model.to(device), device_ids=gpus, output_device=gpus[0]) #注意model要放在主GPU上optimizer = optim.SGD(model.parameters())for epoch in range(100):for batch_idx, (data, target) in enumerate(train_loader):images = images.cuda(non_blocking=True)target = target.cuda(non_blocking=True)...output = model(images)loss = criterion(output, target)...optimizer.zero_grad()loss.backward()optimizer.step()

二、torch.distributed加速

与 DataParallel 的单进程控制多 GPU 不同,在 distributed 的帮助下,只需要编写一份代码,torch 就会自动将其分配给多个进程,分别在多个 GPU 上运行。

要想把大象装冰箱,总共分四步!!

(1)要使用torch.distributed,你需要在你的main.py(也就是你的主py脚本)中的主函数中加入一个参数接口:--local_rank

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=-1, type=int,help='node rank for distributed training')
args = parser.parse_args()
print(args.local_rank)

(2)使用 init_process_group 设置GPU 之间通信使用的后端和端口:

dist.init_process_group(backend='nccl')

(3)使用 DistributedSampler 对数据集进行划分:

train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)

(4)使用 DistributedDataParallel 包装模型

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
  • 举个栗子,参照这个例子去设置你的代码结构
# main.py
import torch
import argparse
import torch.distributed as dist
#(1)要使用`torch.distributed`,你需要在你的`main.py(也就是你的主py脚本)`中的主函数中加入一个**参数接口:`--local_rank`**
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=-1, type=int,help='node rank for distributed training')
args = parser.parse_args()
#(2)使用 init_process_group 设置GPU 之间通信使用的后端和端口:
dist.init_process_group(backend='nccl')
torch.cuda.set_device(args.local_rank)
#(3)使用 DistributedSampler 对数据集进行划分:
train_dataset = ...
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)
#(4)使用 DistributedDataParallel 包装模型
model = ...
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
optimizer = optim.SGD(model.parameters())for epoch in range(100):for batch_idx, (data, target) in enumerate(train_loader):images = images.cuda(non_blocking=True)target = target.cuda(non_blocking=True)...output = model(images)loss = criterion(output, target)...optimizer.zero_grad()loss.backward()optimizer.step()

然后,使用以下指令,执行你的主脚本,其中--nproc_per_node=4表示你的单个节点的GPU数量

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py

问题来了!!

你可能会在完成代码之后遇到各种问题,我这里列举一些要注意的点,去避坑
如果你遇到的莫名奇妙报错的问题,尝试这样去修改你的代码

  • device 的设置
    你需要设置一个device参数,用来给你的数据加载到GPU上,由于你的数据会在不同线程中被加载到不同的GPU上,你需要传给他们一个参数device,用于a.to(device)的操作(a是一个tensor)
    device如下设置
device = torch.device("cuda", args.local_rank)

你也可以通过设置当前cuda,使用a.cuda()把张量放到GPU上,但是不推荐,可能会有一些问题

torch.cuda.set_device(args.local_rank)
  • find_unused_parameters=True
    这个是为了解决你的模型中定义了一些在forward函数中没有用到的网络层,会被视为“unused_layer”,这会引发错误,所以你在使用 DistributedDataParallel 包装模型的时候,传一个find_unused_parameters=True的参数来避免这个问题,如下:
encoder=nn.parallel.DistributedDataParallel(encoder, device_ids=[args.local_rank],find_unused_parameters=True)
  • num_workers
    很好理解,尽量不要给你的DataLoader设置numworkers参数,可能会有一些问题(不要太强迫症)
  • shuffle=False
    你的DataLoader不要设置shuffle=True
valid_loader = torch.utils.data.DataLoader(part_valid_set, batch_size=BATCH, shuffle=False, num_workers=num_workers,sampler=valid_sampler)

pytorch单机多卡的正确打开方式 以及可能会遇到的问题和相应的解决方法相关推荐

  1. [分布式训练] 单机多卡的正确打开方式:PyTorch

    [分布式训练] 单机多卡的正确打开方式:PyTorch 转自:https://fyubang.com/2019/07/23/distributed-training3/ PyTorch的数据并行相对于 ...

  2. [分布式训练] 单机多卡的正确打开方式:Horovod

    [分布式训练] 单机多卡的正确打开方式:Horovod 转自:https://fyubang.com/2019/07/26/distributed-training4/ 讲完了单机多卡的分布式训练的理 ...

  3. [分布式训练] 单机多卡的正确打开方式:理论基础

    [分布式训练] 单机多卡的正确打开方式:理论基础 转自:https://fyubang.com/2019/07/08/distributed-training/ 瓦砾由于最近bert-large用的比 ...

  4. 收藏 | PyTorch 单机多卡操作总结

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨科技猛兽@知乎 来源丨https://zhuanlan ...

  5. PyTorch 单机多卡操作总结:分布式DataParallel,混合精度,Horovod)

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨科技猛兽@知乎 来源丨https://zhuanlan.zhihu.com/p/15837505 ...

  6. pytorch 单机多卡训练distributedDataParallel

    pytorch单机多卡:从DataParallel到DistributedDataParallel 最近想做的实验比较多,于是稍微学习了一下和pytorch相关的加速方式.本人之前一直在使用DataP ...

  7. 拜托!这才是分布式系统CAP的正确打开方式!

    "纸面"上的CAP 相信很多同学都听过CAP这个理论,为了避免我们认知不同,我们先来统一下知识起点. CAP理论在1999年一经提出就成为了分布式系统领域的顶级教义.并表明分布式服 ...

  8. 企业搭建私域流量的正确打开方式

    做私域流量并不是@下新好友就解决了问题,如果这样也别期待私域流量发挥太大的价值.有人说,要给私域流量提供价值.提供价值没有错,错在你给所有用户提供了大家不想要的价值,而且私域流量中有一些用户再提供价值 ...

  9. 为什么说vivo S7才是5G轻薄旗舰的正确打开方式

    8月3日,vivo发布了最新的5G旗舰机型S7.S7 170g的整机重量和7.39mm的机身厚度,瞬间让其成为年轻用户追捧的热点. 一.厚重的5G手机 众所周知,5G手机由于在信号处理的技术要求上比4 ...

最新文章

  1. AI助力清华博士进入周杰伦战队,预告AI应用迎来黄金时代?
  2. 灰度直方图均衡化及其实现
  3. 深度学习面试25问题
  4. Flutter - 弹出底部菜单Show Modal Bottom Sheet
  5. 一名“企业定制化人才”的自诉:“我不愿意,但却无可奈何”
  6. mysql libstdc .so.6_编译安装mysql报错 ./mysqld: /usr/lib64/libstdc++.so.6:
  7. 企业五年后卓越或者死亡,数据战略是关键!
  8. 【MapGIS】常见问题处理
  9. dBm和dB(纯计数单位)
  10. 知乎8.5k赞的回答:自学编程需要注意什么?
  11. SpringBoot海景房出租管理系统+代码讲解
  12. python 招聘 海盐_聚焦普高新课标 提升信息核心素养——海盐县初中信息技术Python课堂教学研讨活动在武原中学举行...
  13. Linux-----Ubuntu通过shell脚本将SSH多次登录失败的IP自动加入黑名单
  14. PAT 1063 计算谱半径
  15. python中的super是什么?
  16. 计算机故障维修四种思路,维修“望闻问切” 电脑故障的排除方法
  17. 计算机单元格数值不保留小数,EXCEL单元格数值实现真正保留2位小数的方法
  18. 计算机组成原理setb,计算机组成原理与汇编语言4
  19. 【活动推荐】2020中国DevOps社区峰会(成都站)
  20. 算法刷题记录(Day 33)

热门文章

  1. android默认exported_android:exported 属性详解-阿里云开发者社区
  2. 易语言多级指针读取_C语言指针难吗?纸老虎而已,纯干货讲解(附代码)
  3. C++和Objective-C混编(官方文档翻译)
  4. XMLDictionary iOS的XML处理包
  5. ajax delete 传递参数,springMVC使用PUT、DELETE方法传递参数解决方案
  6. invoke 按钮点击_h5+ app内点击按钮实现复制功能 实现方法
  7. 修复mysql的view_MYSQL数据损坏修复方法
  8. python中的get函数_python之函数用法get()
  9. docker 挂载目录_Docker容器数据管理
  10. python 退出_如果读完这篇文章不能让你入门Python,那我将永久退出编程界