pytorch 指定卡1_在pytorch中指定显卡
1. 利用CUDA_VISIBLE_DEVICES设置可用显卡
在CUDA中设定可用显卡,一般有2种方式:
(1) 在代码中直接指定
import os
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids
(2) 在命令行中执行代码时指定
CUDA_VISIBLE_DEVICES=gpu_ids python3 train.py
如果使用sh脚本文件运行代码,则有3种方式可以设置
(3) 在命令行中执行脚本文件时指定:
CUDA_VISIBLE_DEVICES=gpu_ids sh run.sh
(4) 在sh脚本中指定:
source bashrc
export CUDA_VISIBLE_DEVICES=gpu_ids && python3 train.py
(5) 在sh脚本中指定
source bashrc
CUDA_VISIBLE_DEVICES=gpu_ids python3 train.py
如果同时使用多个设定可用显卡的指令,比如
source bashrc
export CUDA_VISIBLE_DEVICES=gpu_id1 && CUDA_VISIBLE_DEVICES=gpu_id2 python3 train.py
那么高优先级的指令会覆盖第优先级的指令使其失效。优先级顺序为:不使用sh脚本 (1)>(2); 使用sh脚本(1)>(5)>(4)>(3)
个人感觉在炼丹时建议大家从(2)(3)(4)(5)中选择一个指定可用显卡,不要重复指定以防造成代码的混乱。方法(1)虽然优先级最高,但是需要修改源代码,所以不建议使用。
2 .cuda()方法和torch.cuda.set_device()
我们还可以使用.cuda()[包括model.cuda()/loss.cuda()/tensor.cuda()]方法和torch.cuda.set_device()来把模型和数据加载到对应的gpu上。
(1) .cuda()
以model.cuda()为例,加载方法为:
model.cuda(gpu_id) # gpu_id为int类型变量,只能指定一张显卡
model.cuda('cuda:'+str(gpu_ids)) #输入参数为str类型,可指定多张显卡
model.cuda('cuda:1,2') #指定多张显卡的一个示例
(2) torch.cuda.set_device()
使用torch.cuda.set_device()可以更方便地将模型和数据加载到对应GPU上, 直接定义模型之前加入一行代码即可
torch.cuda.set_device(gpu_id) #单卡
torch.cuda.set_device('cuda:'+str(gpu_ids)) #可指定多卡
但是这种写法的优先级低,如果model.cuda()中指定了参数,那么torch.cuda.set_device()会失效,而且pytorch的官方文档中明确说明,不建议用户使用该方法。
第1节和第2节所说的方法同时使用是并不会冲突,而是会叠加。比如在运行代码时使用
CUDA_VISIBLE_DEVICES=2,3,4,5 python3 train.py
而在代码内部又指定
model.cuda(1)
loss.cuda(1)
tensor.cuda(1)
那么代码会在GPU3上运行。原理是CUDA_VISIBLE_DEVICES使得只有GPU2,3,4,5可见,那么这4张显卡,程序就会把它们看成GPU0,1,2,3,.cuda(1)把模型/loss/数据都加载到了程序所以为的GPU1上,则实际使用的显卡是GPU3。
如果利用.cuda()或torch.cuda.set_device()把模型加载到多个显卡上,而实际上只使用一张显卡运行程序的话,那么程序会把模型加载到第一个显卡上,比如如果在代码中指定了
model.cuda('cuda:2,1')
在运行代码时使用
CUDA_VISIBLE_DEVICES=2,3,4,5 python3 train.py
这一指令,那么程序最终会在GPU4上运行。
3.多卡数据并行torch.nn.DataParallel
多卡数据并行一般使用
torch.nn.DataParallel(model,device_ids)
其中model是需要运行的模型,device_ids指定部署模型的显卡,数据类型是list
device_ids中的第一个GPU(即device_ids[0])和model.cuda()或torch.cuda.set_device()中的第一个GPU序号应保持一致,否则会报错。此外如果两者的第一个GPU序号都不是0,比如设置为:
model=torch.nn.DataParallel(model,device_ids=[2,3])
model.cuda(2)
那么程序可以在GPU2和GPU3上正常运行,但是还会占用GPU0的一部分显存(大约500M左右),这是由于pytorch本身的bug导致的(截止1.4.0,没有修复这个bug)。
device_ids的默认值是使用可见的GPU,不设置model.cuda()或torch.cuda.set_device()等效于设置了model.cuda(0)
4. 多卡多线程并行torch.nn.parallel.DistributedDataParallel
(这个我是真的没有搞懂,,,,)
参考了这篇文章和这个代码,关于GPU的指定,多卡多线程中有2个地方需要设置
torch.cuda.set_device(args.local_rank)
torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
模型/loss/tensor设置为.cuda()或.cuda(args.local_rank)均可,不影响正常运行。
5. 推荐设置方式:
(1) 单卡
使用CUDA_VISIBLE_DEVICES指定GPU,不要使用torch.cuda.set_device(),不要给.cuda()赋值。
(2) 多卡数据并行
直接指定CUDA_VISIBLE_DEVICES,通过调整可见显卡的顺序指定加载模型对应的GPU,不要使用torch.cuda.set_device(),不要给.cuda()赋值,不要给torch.nn.DataParallel中的device_ids赋值。比如想在GPU1,2,3中运行,其中GPU2是存放模型的显卡,那么直接设置
CUDA_VISIBLE_DEVICES=2,1,3
(3) 多卡多线程
由于这块还有好多地方没有搞懂(尤其是torch.nn.parallel.DistributedDataParallel),所以文章中难免会有很多错误和疏漏,欢迎各位大佬指正。
pytorch 指定卡1_在pytorch中指定显卡相关推荐
- pytorch 指定卡1_[原创][深度][PyTorch] DDP系列第一篇:入门教程
引言 DistributedDataParallel(DDP)是一个支持多机多卡.分布式训练的深度学习工程方法.PyTorch现已原生支持DDP,可以直接通过torch.distributed使用,超 ...
- python读取word指定内容_python读取word 中指定位置的表格及表格数据
1.Word文档如下: 2.代码 # -*- coding: UTF-8 -*- from docx import Document def readSpecTable(filename, specT ...
- pytorch 指定卡1_如何为TensorFlow和PyTorch自动选择空闲GPU,解决抢卡争端
原标题:如何为TensorFlow和PyTorch自动选择空闲GPU,解决抢卡争端 雷锋网按:本文作者天清,原文载于其知乎专栏世界那么大我想写代码,雷锋网获其授权发布. 项目地址:QuantumLiu ...
- pytorch 指定卡1_收藏 | 13则PyTorch使用的小窍门
点击上方"智能与算法之路",选择"星标"公众号 第一时间获取价值内容 仅作学术分享,不代表本公众号立场,侵权联系删除转载于:极市平台,知乎作者丨z.defyin ...
- java读取zip中指定文件_java读取zip中指定文件
public static void main(String args[]) { String file = "c://ssi.zip"; String saveRootDirec ...
- php 删除字符串里指定字符,php删除字符串中指定字符_php删除字符串
在做项目时需要对一个字符串进行处理,也就是删除指定的字符,吾爱编程通过这篇文章主要介绍了PHP实现删除字符串中任何字符的函数,涉及php针对字符串的遍历与截取操作技巧,需要的朋友可以参考一下: PHP ...
- php获取页面中的指定内容,php 获取页面中指定内容的实现类
[email protected] image: Grep.class.php /** grep class * Date: 2013-06-15 * Author: fdipzone * Ver: ...
- python删除txt指定内容_python删除文件中指定内容
更多追问追答 追问 我按你的方法试了下,文件内容还在,没有删掉...... 追答 把你的 file.txt 贴出来,确保 20150723 在要删除行的最开始,前面不能有空格等其他任何字符. 另外, ...
- python删除指定字符_python删除字符串中指定字符的方法
最近开始学机器学习,学习分析垃圾邮件,其中有一部分是要求去除一段字符中的标点符号,查了一下,网上的大多很复杂例如这样 import re temp = "司法局让我和户 1 5. 8 0. ...
最新文章
- ViewStub must have a valid layoutResource
- Git客户端TortoiseGit(Windows系统)的使用方法
- 使用Windows10 software center升级版本1909
- 前端学习(3000):vue+element今日头条管理--远程仓库的issue
- 易到司机无法提现:客服电话变空号,要钱无路
- java虚拟机的内存_Java虚拟机的内存结构
- 【Stimulsoft Reports.WPF教程】在代码中使用报表变量
- 计算机毕业设计——基于Spring Boot框架的网络游戏虚拟交易平台的设计与实现
- android svg 编辑器,Android svg 格式使用小结
- 让cajviewer记住正在浏览的文献,下次启动时自动打开上次浏览的文献
- IDEA清除Local History
- Nuxt.js 如何做SEO
- sqlite3数据库文件损坏修复
- C# 获取适配器网络连接IP地址,子网掩码,DNS,数据包等信息
- xUnit.net入门
- latex 公式换行、对齐
- Android studio模拟器尺寸和真机不一样的原因
- 超星系统登录,信息爬取
- 华为、微软、瑞幸、维达、奈飞、爱彼迎等公司高管变动
- 安卓讲课笔记3.4 网格布局
热门文章
- 一元多项式的建立及加减
- 各大媒体优劣对比_信息流投放广告丨各大平台的信息流都有什么特点与弊端
- typescript箭头函数参数_Typescript 入门基础篇(一)
- php取不到post数据库,安卓post 数据到php 在写入数据库老是不成功, 数据post不到php...
- Chrome划词翻译插件
- python 文件和目录 当前目录以及当前目录的所有子目录下查找文件名包含指定字符串的文件,并打印出相对路径。
- pytorch数据预处理
- 吴恩达作业11:残差网络实现手势数字的识别(基于 keras)+tensorbord显示loss值和acc值
- 安卓逆向_7 --- 六种快速定位关键 Smali 代码的方法 ( 去掉 RE 广告 )
- 小甲鱼 OllyDbg 教程系列 (十一) : inline patch ( 内嵌补丁 )