一只小狐狸带你解锁 炼丹术&NLP 秘籍

作者:苏剑林(来自追一科技,人称“苏神”)

前言

在前不久的文章《BERT重计算:用22.5%的训练时间节省5倍的显存开销(附代码)》中介绍了一个叫做“重计算”的技巧(附pytorch和paddlepaddle实现)。简单来说重计算就是用来省显存的方法,让平均训练速度慢一点,但batch_size可以增大好几倍,该技巧首先发布于论文《Training Deep Nets with Sublinear Memory Cost》。

最近笔者发现,重计算的技巧在tensorflow也有实现。事实上从tensorflow1.8开始,tensorflow就已经自带了该功能了,当时被列入了tf.contrib这个子库中,而从tensorflow1.15开始,它就被内置为tensorflow的主函数之一,那就是tf.recompute_grad。找到 tf.recompute_grad 之后,笔者就琢磨了一下它的用法,经过一番折腾,最终居然真的成功地用起来了,居然成功地让 batch_size 从48增加到了144!然而,在继续整理测试的过程中,发现这玩意居然在tensorflow 2.x是失效的...于是再折腾了两天,查找了各种资料并反复调试,最终算是成功地补充了这一缺陷。

最后是笔者自己的开源实现:

Github地址:

https://github.com/bojone/keras_recompute

该实现已经内置在bert4keras中,使用bert4keras的读者可以升级到最新版本(0.7.5+)来测试该功能。

使用

笔者的实现也命名为recompute_grad,它是一个装饰器,用于自定义Keras层的 call函数,比如

from recompute import recompute_gradclass MyLayer(Layer):
@recompute_grad
def call(self, inputs):
return inputs * 2

对于已经存在的层,可以通过继承的方式来装饰:

from recompute import recompute_grad
class MyDense(Dense):@recompute_graddef call(self, inputs):return super(MyDense, self).call(inputs)

自定义好层之后,在代码中嵌入自定义层,然后在执行代码之前,加入环境变量RECOMPUTE=1来启用重计算。

注意:不是在总模型里插入了@recomputr_grad,就能达到省内存的目的,而是要在每个层都插入@recomputr_grad才能更好地省显存。简单来说,就是插入的@recomputr_grad越多,就省显存。具体原因请仔细理解重计算的原理。

效果

bert4keras0.7.5已经内置了重计算,直接传入环境变量RECOMPUTE=1就会启用重计算,读者可以自行尝试,大概的效果是:

1、在BERT Base版本下,batch_size可以增大为原来的3倍左右;

2、在BERT Large版本下,batch_size可以增大为原来的4倍左右;

3、平均每个样本的训练时间大约增加25%;

4、理论上,层数越多,batch_size可以增大的倍数越大。

环境

在下面的环境下测试通过:

tensorflow 1.14 + keras 2.3.1

tensorflow 1.15 + keras 2.3.1

tensorflow 2.0 + keras 2.3.1

tensorflow 2.1 + keras 2.3.1

tensorflow 2.0 + 自带tf.keras

tensorflow 2.1 + 自带tf.keras

确认不支持的环境:

tensorflow 1.x + 自带tf.keras

欢迎报告更多的测试结果。

顺便说一下,强烈建议用keras2.3.1配合tensorflow1.x/2.x来跑,强烈不建议使用tensorflow 2.x自带的tf.keras来跑

  • 算法工程师的效率神器——vim篇

  • 硬核推导Google AdaFactor:一个省显存的宝藏优化器

  • 数据缺失、混乱、重复怎么办?最全数据清洗指南让你所向披靡

  • LayerNorm是Transformer的最优解吗?

  • ACL2020|FastBERT:放飞BERT的推理速度

夕小瑶的卖萌屋

_

关注&星标小夕,带你解锁AI秘籍

订阅号主页下方「撩一下」有惊喜哦

巨省显存的重计算技巧在TF、Keras中的正确打开方式相关推荐

  1. 深圳大学计算机双学位绩点规定,以深圳大学小伙伴为例 为你展开绩点计算的正确打开方式...

    原标题:以深圳大学小伙伴为例 为你展开绩点计算的正确打开方式 深圳大学的平均GPA究竟应该怎么算 一般谈到计算绩点,同学们都是简单粗暴地把每个学期的GPA相加,然后除以总的学期数. 而实际上这种坊间常 ...

  2. 模型显存占用及其计算量

    1. 显存的占用 当在GPU上跑一个模型时,显存的占用主要有两部分: 模型的输出(特征图.特征图的梯度).模型的参数(权重矩阵.偏置值.梯度) 1. 模型参数的显存占用:(例如:卷积核的参数.BN层. ...

  3. 神经网络占用内存(显存)的计算

    所占用内存 KB = 参数x4 / 1024 所占用内存 MB = 参数x4 / 1024 / 1024 比如:某网络权重参数量106073,那么他占用的内存是106073x4/1024=414.34 ...

  4. 这才是Matlab的正确打开方式!——Matlab矩阵、绘图、函数计算与数据读取

    Matlab基础学习笔记 基础及预设置 矩阵 各种函数 二维制图 三维制图 运算 输入/输出 各种语句 数据读出/写入 这里用的是Matlab2016a版本 基础及预设置 1.设置路径 选择路径,或是 ...

  5. 怎么避免options请求_和上级沟通的正确打开方式:3种技巧,轻松让领导答应你的请求...

    点击右上角[关注]霸王课头条号,收获更多加薪秘籍.本文共2130字,阅读全文约3分钟 今天刚和广告商谈完合作回来,我就看见几个人围在同事姗姗那,不用猜也知道发生了什么事,正巧这时收到同事小王的微信-- ...

  6. xshell6使用技巧_Xshell6的正确打开方式

    WX众号:基因学苑 Q群:32798724 更多精彩内容等你发掘! 远程连接服务器的工具有很多,一般都是支持ssh协议,例如putty,mobaxterm,SSH Secure Shell Clien ...

  7. 【Latex】高级插入图片技巧: 双栏中如何正确插入图片 + 如何多图

    一.双栏中正确使用图片 [问题描述] 貌似multicols环境中不能放图片,即 \begin{multicols}{2} \begin{figure} \centering \includegrap ...

  8. CNN模型的计算量、参数、显存占用

    经典CNN模型的计算量.参数.显存占用 文章目录 经典CNN模型的计算量.参数.显存占用 1. 深度学习复杂度 2. FLOPS概念 3.参数量计算 4. 输出特征图尺寸 5. 常用模型的FlOPs和 ...

  9. 延迟渲染G-buffer所占显存带宽计算(解决移动端和抗锯齿的若干疑问)

    延迟渲染需要在前面阶段,将计算的内容保留在N张G-buffer中,但是网上的文章只是提及了G-buffer应该压缩,并且尽量少用,没有说明G-buffer所占带宽应该是多少,我将在下面介绍G-buff ...

最新文章

  1. RHEL5.1单域主/从NIS服务器配置及测试
  2. apache安装 windows
  3. vue 定义全局函数
  4. C# winform开发:Graphics、pictureBox同时画多个矩形
  5. mysql 临时列_mysql – 在SQL中添加一个临时列,其中值取决于另一列
  6. java中gc是怎么工作的_java中的GC(gabage collection)如何工作
  7. xml:使用xmlspy创建xml文件,且通过xml文件生成对应的dtd文件
  8. Java NIO、BIO介绍
  9. 电脑开机自动推送微信通知
  10. android数据格式化,手机格式化了?教你找回安卓手机误删数据
  11. mand-mobile TabPicker 多级联动选择
  12. 怎样使用github?(转)
  13. 2018-07-03 根据Excel后缀名获取WorkBook
  14. 北航计算机控制系统实验报告,北航计算机控制系统实验报告教程.doc
  15. 利用css特性布局页面制作京东特价框
  16. matlab getprmdflt,DFLT40A-7中文资料
  17. phpRedis函数使用总结
  18. Aspen中物性方法选择
  19. HTML表单制作,上传到服务器
  20. 复旦大学-陈果老师笔记

热门文章

  1. JAVA NIO 简介(转)
  2. 天气预报Dom解析(转)
  3. DIV Scroll属性
  4. 单片机的引脚,你都清楚吗?
  5. 串口,com口,ttl,max232你应该知道的事
  6. 杭电java期末试卷2015_2014年杭州电子科技大学Java期末试卷.doc
  7. string最大容量_string初步使用
  8. 相邻位数字差值的绝对值不能超过_热点争议中技术问题,伺服控制有几个零点?对应真绝对值多圈编码器意义...
  9. 从拉格朗日乘子法到SVM
  10. HDFS依然是存储的王者