1. nn.Dataparallel

多GPU加速训练

原理:
模型分别复制到每个卡中,然后把输入切片,分别放入每个卡中计算,然后再用第一块卡进行汇总求loss,反向传播更新参数。

第一块卡占用的内存多一点,因为output loss每次都会在第一块GPU相加计算,这就造成了第一块GPU的负载远远大于剩余其他的显卡。

要求:
batch_size > GPU 数量

第一种方法:

os.environment['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
device_ids = [0,1,2,3]
net  = torch.nn. Dataparallel(net, device_ids =device_ids)
net = net.cuda()

第二种方法

os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2"
if torch.cuda.is_available():self.device = "cuda"if torch.cuda.device_count() > 1:self.G = nn.DataParallel(self.G)self.D_A = nn.DataParallel(self.D_A)self.D_B = nn.DataParallel(self.D_B)self.vgg = nn.DataParallel(self.vgg)self.criterionHis = nn.DataParallel(self.criterionHis)self.criterionGAN = nn.DataParallel(self.criterionGAN)self.criterionL1 = nn.DataParallel(self.criterionL1)self.criterionL2 = nn.DataParallel(self.criterionL2)self.criterionGAN = nn.DataParallel(self.criterionGAN)self.G.cuda()self.vgg.cuda()self.criterionHis.cuda()self.criterionGAN.cuda()self.criterionL1.cuda()self.criterionL2.cuda()self.D_A.cuda()self.D_B.cuda()

2.模型分别单独放入每个指定的GPU中

把模型分别放到指定的GPU中,然后在运算的过程中,需要把利用**.to(cuda:x)** 去转移数据。这样暂用的内存比平行计算小。但是配置复杂一点。

 vgg_encoder = VGGEncoder().to('cuda:0')attn=CoAttention(channel=512).to('cuda:1')decoder = Decoder().to('cuda:2')optimizer_decoder = Adam(decoder.parameters(), lr=args.learning_rate)optimizer_attn = Adam(attn.parameters(), lr=args.learning_rate)content = content.cuda()  # 默认的是cuda:0style = style.cuda()content_features = vgg_encoder(content, output_last_feature=True)style_features = vgg_encoder(style, output_last_feature=True)content_features, style_features=attn(content_features.to('cuda:1'),style_features.to('cuda:1')) # 因为attn在cuda:1中

nn.Dataparallel pytorch 平行计算的两种方法相关推荐

  1. DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练、预测

    DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练.预测 导读 利用python的numpy计算库,进行自定义搭建2层神经网络TwoLayerN ...

  2. QT时间差计算的两种方法代码

    QT时间差计算的两种方法 提供两种方法,直接贴出代码供参考,主要用到函数secsTo,toTIme_t(): #include <qdatetime.h>#include <wind ...

  3. 基尼系数计算的两种方法:python实现 简单高效

    使用两种方法,通过python计算基尼系数. 在sql中如何计算基尼系数,可以查看我的另一篇文章.两篇文章取数相同,可以结合去看. 文章中方法1的代码来自于:(加入了一些注释,方便理解).为精确计算. ...

  4. html闰年计算方法,闰年计算的两种方法

    说起闰年,估计一些朋友会很糊涂.好像隔个一两年就有闰年,结果闰来闰去,闰得头都快大了.到底什么是闰年?闰年该怎么计算呢? 实际上,闰年是公历的一个计算方式,也就是常说的阳历,或者叫西历也可以.在我国的 ...

  5. pytorch保存模型的两种方法

    文章目录 前言 一.保存整个模型 二.只保存参数 模型不同后缀名的区别 总结 前言 模型的本质是一堆用某种结构存储起来的参数 用数据对模型进行训练后得到了比较理想的模型,就需要将其存储起来,然后在需要 ...

  6. 请描述定时器初值的计算方式_单片机C语言编程中定时器初值计算的两种方法...

    单片机C语言编程中,定时器的初值对于初学者真的是比较不好计算,因此我总结了以下几种方法. 第1种方法: #define FOSC 11059200L //晶振的频率 #define TIMS (655 ...

  7. 多项式计算的两种方法(包含秦九韶公式)

    写程序计算给定多项式在定点处的值 普通写法 double f(int n, double a[], double x) {int i;double p = a[0];for(i=1; i<=n; ...

  8. 年龄php,PHP根据生日计算年龄两种方法(周岁)

    温馨提示:本文共1429个字,读完预计4分钟. 1.计算年龄 functionhowOld($birth) { list($birthYear, $birthMonth, $birthDay) = e ...

  9. 行列式计算的两种方法

    #include<iostream> #include<cstring> #include<cstdio> #include<algorithm> #d ...

最新文章

  1. os.environ[CUDA_DEVICE_ORDER] = PCI_BUS_ID os.environ[CUDA_VISIBLE_DEVICES] = 0
  2. TensorFlow 2.2.0-rc0,这次更新让人惊奇!
  3. 使用Excel 通过 ODBC 连接到 MySQL 数据库
  4. 各安全浏览器如何设2345为主页
  5. 切换ip下的sql server用户权限丢失_Zabbix_server高可用之文件同步
  6. Weblogic10 集群配置
  7. LeetCode MySQL 512. 游戏玩法分析 II
  8. 二维码图像去噪文献调研(1)--Real Image Denoising with Feature Attention
  9. Java中需要全部小写的是,java – 如何处理JSR 310中的大写或小写?
  10. 2021-10-12
  11. linux添加磁盘分区,linux添加磁盘分区
  12. green: JRE + Tomcat + Mysql - JaveEE JTM0.9
  13. 自己开发JAVA Swing版★山寨 马里奥★
  14. 云呐|国有企业资产管理系统建设该如何开展_固定资产管理信息系统
  15. 微软逆转互联网战局,错过了智能手机却君临游戏帝国
  16. Android系统模拟位置的使用方法
  17. 算法谜题1,狼羊菜过河
  18. 路径MTU(PMTU)发现控制与DF位
  19. 科研工具篇|看完之后能提高你80%的科研工作效率
  20. 工程师为女朋友自制的硬核礼物

热门文章

  1. tcp实时传输kafka数据_tcp怎么传输大数据
  2. pythonpygame中主函数_从0开始学Python-14.2 pygame的核心对象
  3. mysql 前端proxy_mysql-proxy中间件使用
  4. qt自定义按钮类,每个按钮自带一个右键弹出框,如何使同一时刻只显示一个弹出框
  5. BugkuCTF-WEB题计算器
  6. python调用百度语音实时转为文字_百度语音转文字 (Python)
  7. flink checkpoint 恢复_干货:Flink+Kafka 0.11端到端精确一次处理语义实现
  8. centos7安装php5.2yum源操作_CentOS7使用阿里yum源进行升级和安装php70W
  9. html的课设作业6,第七节课html标签元素属性作业-2019-9-6 作业
  10. isfull mysql_Mysql8.0及以上版本,关于only_full_group_by的问题