目录

混合精度训练

理论原理

三大深度学习框架的打开方式

Pytorch

Tensorflow

PaddlePaddle


混合精度训练

一切还要从2018年ICLR的一篇论文说起。。。
《MIXED PRECISION TRAINING》

这篇论文是百度&Nvidia研究院一起发表的,结合N卡底层计算优化,提出了一种灰常有效的神经网络训练加速方法,不仅是预训练,在全民finetune BERT的今天变得异常有用哇。而且小夕调研了一下,发现不仅百度的paddle框架支持混合精度训练,在Tensorflow和Pytorch中也有相应的实现。下面我们先来讲讲理论,后面再分析混合精度训练在三大深度学习框架中的打开方式

理论原理

训练过神经网络的小伙伴都知道,神经网络的参数和中间结果绝大部分都是单精度浮点数(即float32)存储和计算的,当网络变得超级大时,降低浮点数精度,比如使用半精度浮点数****,显然是提高计算速度,降低存储开销的一个很直接的办法。然而副作用也很显然,如果我们直接降低浮点数的精度直观上必然导致模型训练精度的损失。但是呢,天外有天,这篇文章用了三种机制有效地防止了模型的精度损失。待小夕一一说来o(* ̄▽ ̄*)ブ

权重备份(master weights)我们知道半精度浮点数(float16)在计算机中的表示分为1bit的符号位,5bits的指数位和10bits的尾数位,所以它能表示的最小的正数即2^-24(也就是精度到此为止了)。当神经网络中的梯度灰常小的时候,网络训练过程中每一步的迭代(灰常小的梯度 ✖ 也黑小的learning rate)会变得更小,小到float16精度无法表示的时候,相应的梯度就无法得到更新。

论文统计了一下在Mandarin数据集上训练DeepSpeech 2模型时产生过的梯度,发现在未乘以learning rate之前,就有接近5%的梯度直接悲剧的变成0(精度比2^-24还要高的梯度会直接变成0),造成重大的损失呀/(ㄒoㄒ)/~~
还有更难的,假设迭代量逃过一劫准备奉献自己的时候。。。由于网络中的权重往往远大于我们要更新的量,当迭代量小于Float16当前区间内能表示的最小间隔的时候,更新也会失败(哭瞎┭┮﹏┭┮我怎么这么难鸭)              所以怎么办呢?作者这里提出了一个非常simple but effective的方法,就是前向传播和梯度计算都用float16,但是存储网络参数的梯度时要用float32!这样就可以一定程度上的解决上面说的两个问题啦~~~

我们来看一下训练曲线,蓝色的线是正常的float32精度训练曲线,橙色的线是使用float32存储网络参数的learning curve,绿色滴是不使用float32存储参数的曲线,两者一比就相形见绌啦。
损失放缩(loss scaling)有了上面的master weights已经可以足够高精度的训练很多网络啦,但是有点强迫症的小夕来说怎么还是觉得有点不对呀o((⊙﹏⊙))o.
虽然使用float32来存储梯度,确实不会丢失精度了,但是计算过程中出现的指数位小于 -24 的梯度不还是会丢失的嘛!相当于用漏水的筛子从河边往村里运水,为了多存点水,村民们把储水的碗换成了大缸,燃鹅筛子依然是漏的哇,在路上的时候水就已经漏的木有了。。

于是loss scaling方法来了。首先作者统计了一下训练过程中激活函数梯度的分布情况,由于网络中的梯度往往都非常小,导致在使用FP16的时候右边有大量的范围是没有使用的。这种情况下, 我们可以通过放大loss来把整个梯度右移,减少因为精度随时变为0的梯度。
那么问题来了,怎么合理的放大loss呢?一个最简单的方法是常数缩放,把loss一股脑统一放大S倍。float16能表示的最大正数是2^15*(1+1-2^-10)=65504,我们可以统计网络中的梯度,计算出一个常数S,使得最大的梯度不超过float16能表示的最大整数即可。

当然啦,还有更加智能的动态调整(automatic scaling) o(* ̄▽ ̄*)ブ我们先初始化一个很大的S,如果梯度溢出,我们就把S缩小为原来的二分之一;如果在很多次迭代中梯度都没有溢出,我们也可以尝试把S放大两倍。以此类推,实现动态的loss scaling。              **运算精度(precison of ops)**精益求精再进一步,神经网络中的运算主要可以分为四大类,混合精度训练把一些有更高精度要求的运算,在计算过程中使用float32,存储的时候再转换为float16。

  • **matrix multiplication: **linear, matmul, bmm, conv

  • **pointwise: **relu, sigmoid, tanh, exp, log

  • **reductions: **batch norm, layer norm, sum, softmax

  • **loss functions: **cross entropy, l2 loss, weight decay

像矩阵乘法和绝大多数pointwise的计算可以直接使用float16来计算并存储,而reductions、loss function和一些pointwise(如exp,log,pow等函数值远大于变量的函数)需要更加精细的处理,所以在计算中使用用float32,再将结果转换为float16来存储。

总结陈词混合精度训练做到了在前向和后向计算过程中均使用半精度浮点数,并且没有像之前的一些工作一样还引入额外超参,而且重要的是,实现非常简单却能带来非常显著的收益,在显存half以及速度double的情况下保持模型的精度,简直不能再厉害啦。

三大深度学习框架的打开方式

看完了硬核技术细节之后,我们赶紧来看看代码实现吧!如此强大的混合精度训练的代码实现不要太简单了吧

模型训练慢和显存不够怎么办?GPU加速混合精度训练相关推荐

  1. ResNet实战:单机多卡DDP方式、混合精度训练

    文章目录 摘要 apex DP和DDP Parameter Server架构(PS模式) ring-all-reduce模式 DDP的基本用法 (代码编写流程) Mixup 项目结构 计算mean和s ...

  2. 基于OpenSeq2Seq的NLP与语音识别混合精度训练

    基于OpenSeq2Seq的NLP与语音识别混合精度训练 Mixed Precision Training for NLP and Speech Recognition with OpenSeq2Se ...

  3. 深度神经网络混合精度训练

    深度神经网络混合精度训练 Mixed-Precision Training of Deep Neural Networks 论文链接:https://arxiv.org/abs/1710.03740 ...

  4. 混合精度训练原理总结

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨ZOMI酱@知乎(已授权) 来源丨https://zhuanl ...

  5. 全网最全-神经网络混合精度训练原理

    通常我们训练神经网络模型的时候默认使用的数据类型为单精度FP32.近年来,为了加快训练时间.减少网络训练时候所占用的内存,并且保存训练出来的模型精度持平的条件下,业界提出越来越多的混合精度训练的方法. ...

  6. 浅谈深度学习:如何计算模型以及中间变量的显存占用大小

    原文链接:https://oldpan.me/archives/how-to-calculate-gpu-memory 前言 亲,显存炸了,你的显卡快冒烟了! torch.FatalError: cu ...

  7. 内存 显存,cpu,GPU,显卡

    内存 显存,cpu,GPU 1 硬件上的区别 1 内存条 2 cpu如下图: 3 显存:属于显卡的组成部分,主要负责存储GPU需要处理的各种数据: 4 GPU:在显卡上,属于显卡的芯片,又称图形处理单 ...

  8. 浅谈深度学习混合精度训练

    ↑ 点击蓝字 关注视学算法 作者丨Dreaming.O@知乎 来源丨https://zhuanlan.zhihu.com/p/103685761 编辑丨极市平台 本文主要记录下在学习和实际试用混合精度 ...

  9. 使用Apex进行混合精度训练

    使用Apex进行混合精度训练 转自:https://fyubang.com/2019/08/26/fp16/ 你想获得双倍训练速度的快感吗? 你想让你的显存空间瞬间翻倍吗? 如果我告诉你只需要三行代码 ...

最新文章

  1. 如何建立顺畅的项目流程
  2. server.xml中也能获取Tomcat相对路径
  3. c++ 获取linux系统信息_linux系统c程序移植
  4. android连接耳机时音量控制,android – 扬声器音量(闹钟)在插入耳机时会降低
  5. x11 gtk qt gnome kde 之间的区别和联系
  6. [vue] vue开发过程中你有使用什么辅助工具吗?
  7. qml 时间控件_Qt编写自定义控件54-时钟仪表盘
  8. 成为软件高手的几个忌讳(转贴)
  9. 使用PowerShell SQL Server DBATools的IDENTITY列阈值
  10. lambda表达式不使用委托(delegate) 用FUNC
  11. angularjs内置63个指令
  12. Netty 核心组件 Pipeline 源码分析(二)一个请求的 pipeline 之旅
  13. office 打开wps乱_Word 打开WPS文档成乱码的解决方法
  14. c语言自学书籍 新闻,如何学习C语言
  15. 老陕解读:陕西10大泡馍的品尝诀窍
  16. 虚拟 IO 服务器(VIOS)和 IBM i
  17. 自然语言处理(NLP)的一般处理流程!
  18. analyze怎么优化oracle,[转] Oracle analyze 命令分析
  19. ChatGPT有效提问技巧
  20. Django建站 - 模板篇

热门文章

  1. UML类图画法及类之间几种关系
  2. Java入门到精通——基础篇之static关键字
  3. 如何用jar命令对java工程进行打包
  4. CSS 中的定位:relative,absolute
  5. ADS中startup.s文件启动分析
  6. 准备 KVM 实验环境 - 每天5分钟玩转 OpenStack(3)
  7. C语言大神进来看看这个题目
  8. ubuntu12.04
  9. Python3——多线程之threading模块
  10. 流水灯c语言实验报告心得,嵌入式流水灯实验心得体会.docx