0.深入理解GPU训练加速原理

GPU是如何加速的呢?

我打算从两个方面来解答:

  • 单个GPU较于CPU加速:

在训练网络中,其实大量的运算资源都消耗在了数值计算上面,大部分网络训练的过程都是1.计算loss,2.根据loss求梯度,3.再根据梯度更新参数(梯度下降原理)。无论在GPU还是CPU中,都是不断重复123步。但是由于CPU是通用计算单元(并不擅长数值运行),而GPU特长是图像处理(数值计算)。所以GPU更加适合训练网络,从而起到加速效果。

  • 多GPU较于单GPU加速:

一般在GPU训练中,同一个GPU中,batch_size的大小,决定训练的速度,batch_size越小,训练一轮所需的步数(data_len/batch_size)就会越大,从而花费时间越多。

下面介绍下使用多GPU数据并行加速原理:

假设一台机器上有k块GPU。给定需要训练的模型,每块GPU及其相应的显存将分别独立维护一份完整的模型参数。在模型训练的任意一次迭代中,给定一个随机小批量,我们将该批量中的样本划分成k份并分给每块显卡的显存一份。然后,每块GPU将根据相应显存所分到的小批量子集和所维护的模型参数分别计算模型参数的本地梯度。接下来,我们把k块显卡的显存上的本地梯度相加,便得到当前的小批量随机梯度。之后,每块GPU都使用这个小批量随机梯度分别更新相应显存所维护的那一份完整的模型参数。下图描绘了使用2块GPU的数据并行下的小批量随机梯度的计算。

使用2块GPU的数据并行下的小批量随机梯度的计算

我们回忆下梯度下降的过程,1.计算loss,2.根据loss求梯度,3.再根据梯度更新参数。

使用上述的多GPU数据并行方法,可以理解为把batch_size扩大了k倍,从而总的时间缩短为了k分之1,实现了多GPU计算训练。

其实每一个GPU上网络的参数都是相同的,因为都是从相同的loss做的更新。


1.如何在 GPU 上运行 Keras?

如果你以 TensorFlow 或 CNTK 后端运行,只要检测到任何可用的 GPU,那么代码将自动在 GPU 上运行。

如果你以 Theano 后端运行,则可以使用以下方法之一:

方法 1: 使用 Theano flags。

THEANO_FLAGS=device=gpu,floatX=float32 python my_keras_script.py

"gpu" 可能需要根据你的设备标识符(例如gpu0,gpu1等)进行更改。

方法 2: 创建 .theanorc: 指导教程

方法 3: 在代码的开头手动设置 theano.config.device, theano.config.floatX:

import theanotheano.config.device = 'gpu'theano.config.floatX = 'float32'

2.如何在多 GPU 上运行 Keras 模型?

我们建议使用 TensorFlow 后端来执行这项任务。有两种方法可在多个 GPU 上运行单个模型:数据并行设备并行

在大多数情况下,你最需要的是数据并行。

数据并行

数据并行包括在每个设备上复制一次目标模型,并使用每个模型副本处理不同部分的输入数据。Keras 有一个内置的实用函数 keras.utils.multi_gpu_model,它可以生成任何模型的数据并行版本,在多达 8 个 GPU 上实现准线性加速。

有关更多信息,请参阅 multi_gpu_model 的文档。这里是一个快速的例子:

from keras.utils import multi_gpu_model# 将 `model` 复制到 8 个 GPU 上。# 假定你的机器有 8 个可用的 GPU。parallel_model = multi_gpu_model(model, gpus=8)parallel_model.compile(loss='categorical_crossentropy', optimizer='rmsprop')# 这个 `fit` 调用将分布在 8 个 GPU 上。# 由于 batch size 为 256,每个 GPU 将处理 32 个样本。parallel_model.fit(x, y, epochs=20, batch_size=256)

设备并行

设备并行性包括在不同设备上运行同一模型的不同部分。对于具有并行体系结构的模型,例如有两个分支的模型,这种方式很合适。

这种并行可以通过使用 TensorFlow device scopes 来实现。这里是一个简单的例子:

# 模型中共享的 LSTM 用于并行编码两个不同的序列input_a = keras.Input(shape=(140, 256))input_b = keras.Input(shape=(140, 256))shared_lstm = keras.layers.LSTM(64)# 在一个 GPU 上处理第一个序列with tf.device_scope('/gpu:0'): encoded_a = shared_lstm(tweet_a)# 在另一个 GPU上 处理下一个序列with tf.device_scope('/gpu:1'): encoded_b = shared_lstm(tweet_b)# 在 CPU 上连接结果with tf.device_scope('/cpu:0'): merged_vector = keras.layers.concatenate([encoded_a, encoded_b], axis=-1)

3.参考

1.http://zh.d2l.ai/chapter_computational-performance/multiple-gpus.html

2.https://keras.io/zh/getting-started/faq/#how-can-i-run-a-keras-model-on-multiple-gpus

原文:https://blog.csdn.net/xiaosongshine/article/details/88567519

keras用cpu加速_GPU训练加速原理(附KerasGPU训练技巧)相关推荐

  1. 【神经网络与深度学习】CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——[附完整训练代码]

    [神经网络与深度学习]CIFAR-10数据集介绍,并使用卷积神经网络训练模型--[附完整代码] 一.CIFAR-10数据集介绍 1.1 CIFAR-10数据集的内容 1.2 CIFAR-10数据集的结 ...

  2. Yolov5如何在训练意外中断后接续训练

    Yolov5如何在训练意外中断后接续训练 1.配置环境 2.问题描述 3.解决方法 3.1设置需要接续训练的结果 3.2设置训练代码 4.原理 5.结束语 1.配置环境 操作系统:Ubuntu20.0 ...

  3. PyTorch训练加速17种技巧

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 文自 机器之心 作者:LORENZ KUHN 编辑:陈萍 掌握这 ...

  4. 海量小文件场景下训练加速优化之路

    作者:星辰算力平台 1. 背景 随着大数据.人工智能技术的蓬勃发展,人类对于算力资源的需求也迎来大幅度的增长.在腾讯内部,星辰算力平台以降本增效为目标,整合了公司的GPU训练卡资源,为算法工程师们提供 ...

  5. 收藏 | PyTorch深度学习模型训练加速指南2021

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:LORENZ KUHN 编译:ronghuaiyang ...

  6. GLM国产大模型训练加速:性能最高提升3倍,显存节省1/3,低成本上手

    作者|BBuf.谢子鹏.冯文 2017 年,Google 提出了 Transformer 架构,随后 BERT .GPT.T5等预训练模型不断涌现,并在各项任务中都不断刷新 SOTA 纪录.去年,清华 ...

  7. 实践教程|PyTorch训练加速技巧

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨用什么名字没那么重要@知乎(已授权) 来源丨https://z ...

  8. 上海交大:基于近似随机Dropout的LSTM训练加速

    机器之心发布 作者:宋卓然.王儒.茹栋宇.彭正皓.蒋力 上海交通大学 在这篇文章中,作者利用 Dropout 方法在神经网络训练过程中产生大量的稀疏性进行神经网络的训练加速.该论文已经被 Design ...

  9. PyTorch训练加速技巧

    PyTorch训练加速技巧 由于最近的程序对速度要求比较高,想要快速出结果,因此特地学习了一下混合精度运算和并行化操作,由于已经有很多的文章介绍相关的原理,因此本篇只讲述如何应用PyTorch实现混合 ...

  10. 实践经验|PyTorch训练加速技巧

    最近程序对速度要求比较高,想要快速出结果,因此特地研究了一下混合精度运算和并行化操作,由于已经有很多的文章介绍相关的原理,因此本篇只讲述如何应用torch实现混合精度运算.数据并行和分布式运算,不具体 ...

最新文章

  1. opencv中Range类的使用
  2. Makefile —— 如何在文件内使用变量?
  3. 蓝桥杯-未名湖边的烦恼(java)
  4. 数据科学家十年后彻底消失?25年行业元老:无稽之谈!
  5. CUDA学习日志:常量内存和纹理内存
  6. 10个常用的Python图像处理工具,建议收藏!
  7. 一篇博客读懂设计模式之---单例模式
  8. java实现线性回归(简单明了,适合理解)
  9. 二叉树的几种遍历方法
  10. 机器视觉:平行光源在双远心系统中的应用
  11. 2014.10.18笔记
  12. Win11如何开启聚焦功能?Win11开启聚焦功能的方法
  13. Android wm指令的用法笔记
  14. 用glew,glfw实现opengl学习笔记5课纹理(2)
  15. vim批量删除与插入
  16. Web表单提交之disabled问题
  17. 计算机网络——应用层
  18. CSS 实现 系统登录界面 (二)
  19. 理解两个函数乘积的导数的一种视角
  20. H.264视频的RTP有效负载格式 (RFC-6184)

热门文章

  1. 4. Browser 对象 - Navigator 对象(2)
  2. 开源Web安全测试工具调研
  3. BIOS、BootLoader、uboot对比
  4. (转)linux sort 命令详解
  5. Ubuntu安装Atom编辑器
  6. 计算机网络第五次笔记
  7. asp.net中使用水晶报表 ---pull
  8. IBM服务器诊断面板
  9. Android初级教程:对文件和字符串进行MD5加密工具类
  10. 微信小程序——诉讼费计算