现在 Keras 中你也可以用小的 batch size 实现大 batch size 的效果了——只要你愿意花 n 倍的时间,可以达到 n 倍 batch size 的效果,而不需要增加显存。

作者丨苏剑林

研究方向丨NLP,神经网络

个人主页丨kexue.fm

Github地址:

https://github.com/bojone/accum_optimizer_for_keras

扯淡

在一两年之前,做 NLP 任务都不用怎么担心 OOM 问题,因为相比 CV 领域的模型,其实大多数 NLP 模型都是很浅的,极少会显存不足。幸运或者不幸的是,Bert 出世了,然后火了。Bert 及其后来者们(GPT-2、XLNET 等)都是以足够庞大的 Transformer 模型为基础,通过足够多的语料预训练模型,然后通过 fine tune 的方式来完成特定的 NLP 任务。

即使你很不想用 Bert,但现在的实际情况是:你精心设计的复杂的模型,效果可能还不如简单地 fine tune 一下 Bert 好。所以不管怎样,为了跟上时代,总得需要学习一下 Bert 的 fine tune。

问题是“不学不知道,一学吓一跳”,只要任务稍微复杂一点,或者句子长度稍微长一点,显存就不够用了,batch size 急剧下降——32?16?8?一跌再跌都是有可能的。

这不难理解,Transformer 基于 Attention,而 Attention 理论上空间和时间复杂度都是,虽然在算力足够强的时候,Attention 由于其并行性还是可以表现得足够快,但是显存占用量是省不了了,意味着当你句子长度变成原来的 2 倍时,显存占用基本上就需要原来的 4 倍,这个增长比例肯定就容易 OOM 了。

而更不幸的消息是,大家都在 fine tune 预训练 Bert 的情况下,你 batch_size=8 可能比别人 batch_size=80 低好几个千分点甚至是几个百分点,显然这对于要刷榜的读者是很难受的。难道除了加显卡就没有别的办法了吗?

正事

有!通过梯度缓存和累积的方式,用时间来换取空间,最终训练效果等效于更大的 batch size。因此,只要你跑得起 batch_size=1,只要你愿意花 n 倍的时间,就可以跑出 n 倍的 batch size 了。

梯度累积的思路,在之前的文章“让Keras更酷一些!”:小众的自定义优化器已经介绍了,当时称之为“软 batch(soft batch)”,本文还是沿着主流的叫法称之为“梯度累积(accumulate gradients)”好了。

所谓梯度累积,其实很简单,我们梯度下降所用的梯度,实际上是多个样本算出来的梯度的平均值,以 batch_size=128 为例,你可以一次性算出 128 个样本的梯度然后平均,我也可以每次算 16 个样本的平均梯度,然后缓存累加起来,算够了 8 次之后,然后把总梯度除以 8,然后才执行参数更新。当然,必须累积到了 8 次之后,用 8 次的平均梯度才去更新参数,不能每算 16 个就去更新一次,不然就是 batch_size=16 了。

刚才说了,在之前的文章的那个写法是有误的,因为用到了:

K.switch(cond, K.update(p, new_p), p)

来控制更新,但事实上这个写法不能控制更新,因为 K.switch 只保证结果的选择性,不保证执行的选择性,事实上它等价于:

cond * K.update(p, new_p) + (1 - cond) * p

也就是说不管 cond 如何,两个分支都是被执行了。事实上 Keras 或 Tensorflow“几乎”不存在只执行一个分支的条件写法(说“几乎”是因为在一些比较苛刻的条件下可以做到),所以此路不通。

不能这样写的话,那只能在“更新量”上面下功夫,如前面所言,每次算 16 个样本的梯度,每次都更新参数,只不过 8 次中有 7 次的更新量是 0,而只有 1 次是真正的梯度下降更新。

很幸运的是,这种写法还可以无缝地接入到现有的 Keras 优化器中,使得我们不需要重写优化器!详细写法请看:

https://github.com/bojone/accum_optimizer_for_keras

具体的写法无外乎就是一些移花接木的编程技巧,真正有技术含量的部分不多。关于写法本身不再细讲,如果有疑问欢迎讨论区讨论。

注:这个优化器的修改,使得小 batch size 能起到大 batch size 的效果,前提是模型不包含 Batch Normalization,因为 Batch Normalization 在梯度下降的时候必须用整个 batch 的均值方差。所以如果你的网络用到了 Batch Normalization,想要准确达到大 batch size 的效果,目前唯一的方法就是加显存/加显卡。

实验

至于用法则很简单:

opt = AccumOptimizer(Adam(), 10) # 10是累积步数model.compile(loss='mse', optimizer=opt)model.fit(x_train, y_train, epochs=10, batch_size=10)# 10是累积步数model.compile(loss='mse', optimizer=opt)model.fit(x_train, y_train, epochs=10, batch_size=10)

这样一来就等价于 batch_size=100 的 Adam 优化器了,代价就是你跑了 10 个 epoch,实际上只相当于 batch_size=100 跑了 1 个 epoch,好处是你只需要用到 batch_size=10 的显存量。

可能读者想问的一个问题是:你怎么证明你的写法生效了?也就是说你怎么证明你的结果确实是 batch_size=100 而不是 batch_size=10?

为此,我做了个比较极端的实验,代码在这里:

https://github.com/bojone/accum_optimizer_for_keras/blob/master/mnist_mlp_example.py

代码很简单,就是用多层 MLP 做 MNIST 分类,用 Adam 优化器, fit 的时候 batch_size=1。优化器有两个选择,第一个是直接 Adam() ,第二个是 AccumOptimizer(Adam(), 100) :

如果是直接 Adam() ,那 loss 一直在 0.4 上下徘徊,后面 loss 越来越大了(训练集都这样),val 的准确率也没超过 97%;

如果是 AccumOptimizer(Adam(), 100) ,那么训练集的 loss 越来越低,最终降到 0.02 左右,val 的最高准确率有 98%+;

最后我比较了直接 Adam() 但是 batch_size=100 的结果,发现跟 AccumOptimizer(Adam(), 100) 但是 batch_size=1 时表现差不多。

这个结果足以表明写法生效了,达到了预期的目的。如果这还不够说服力,我再提供一个训练结果作为参考:

在某个 Bert 的 fine tune 实验中,直接用 Adam() 加 batch_size=12,我跑到了 70.33% 的准确率;我用 AccumOptimizer(Adam(), 10) 加 batch_size=12(预期等效 batch size 是 120),我跑到了 71% 的准确率,提高了 0.7%,如果你在刷榜,那么这 0.7% 可能是决定性的。

结论

终于把梯度累积(软 batch)正式地实现了,以后用 Bert 的时候,也可以考虑用大 batch_size 了。

点击以下标题查看作者其他文章:

  • 变分自编码器VAE:原来是这么一回事 | 附源码

  • 再谈变分自编码器VAE:从贝叶斯观点出发

  • 变分自编码器VAE:这样做为什么能成?

  • 简单修改,让GAN的判别器秒变编码器

  • 深度学习中的互信息:无监督提取特征

  • 全新视角:用变分推断统一理解生成模型

  • 细水长flow之NICE:流模型的基本概念与实现

  • 细水长flow之f-VAEs:Glow与VAEs的联姻

  • 深度学习中的Lipschitz约束:泛化与生成模型

#好 书 推 荐#

 深度学习理论与实战:基础篇 

李理 / 编著

本书不仅包含人工智能、机器学习及深度学习的基础知识,如卷积神经网络、循环神经网络、生成对抗网络等,而且也囊括了学会使用 TensorFlow、PyTorch 和 Keras 这三个主流的深度学习框架的*小知识量;不仅有针对相关理论的深入解释,而且也有实用的技巧,包括常见的优化技巧、使用多 GPU 训练、调试程序及将模型上线到生产系统中。

本书希望同时兼顾理论和实战,使读者既能深入理解理论知识,又能把理论知识用于实战,因此本书每介绍完一个模型都会介绍其实现,读者阅读完一个模型的介绍之后就可以运行、阅读和修改相关代码,从而可以更加深刻地理解理论知识。

 长按识别二维码查看详情 

?

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

▽ 点击 | 阅读原文 | 查看作者博客

Keras梯度累积优化器:用时间换取效果相关推荐

  1. tf.keras.optimizers.Adam 优化器 示例

    tf.keras.optimizers.Adam 优化器 示例 tf.keras.optimizers.Adam(learning_rate=0.001, # 学习率 默认 0.001beta_1=0 ...

  2. 神经网络 梯度下降_梯度下降优化器对神经网络训练的影响

    神经网络 梯度下降 co-authored with Apurva Pathak 与Apurva Pathak合着 尝试梯度下降优化器 (Experimenting with Gradient Des ...

  3. 对比学习可以使用梯度累积吗?

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 在之前的文章用时间换取效果:Keras 梯度累积优化器中,我们介绍过"梯度累积&qu ...

  4. Keras 优化器总结

    1. 优化器(Optimizer)用法 优化器是Keras模型Compile()方法所需的参数之一,其决定采用何种方法来训练模型. 优化器两种用法: 实例化优化器对象,然后传入model.compil ...

  5. Keras之小众需求:自定义优化器

    作者丨苏剑林 单位丨广州火焰信息科技有限公司 研究方向丨NLP,神经网络 个人主页丨kexue.fm 今天我们来看一个小众需求:自定义优化器. 细想之下,不管用什么框架,自定义优化器这个需求可谓真的是 ...

  6. 优化器介绍—SGD、Adam、Adagrad

    文章目录 深度学习中的优化器 介绍 随机梯度下降优化器 Adam 优化器 Adagrad 优化器 如何选择优化器 结论 深度学习中的优化器 介绍 在深度学习中,优化器是一个非常重要的组成部分,它通过调 ...

  7. 面试准备——机器学习中的优化器算法

    一.优化问题 总体来看,机器学习的核心目标是给出一个模型(一般是映射函数),然后定义对这个模型好坏的评价函数(目标函数),求解目标函数的极大值或者极小值,以确定模型的参数,从而得到我们想要的模型.在这 ...

  8. 针对这一行业痛点,创新工场投资的潞晨科技发布了大规模并行AI训练系统——“夸父”(Colossal-AI) ,通过多维并行、大规模优化器、自适应任务调度、消除冗余内存、降低能量损耗等方式,旨在打造一个

    前沿AI模型越来越大,无论是对于企业还是个人,AI模型的训练成本都越来越高.即使花巨资使用超级计算机集群,聘请专家,也难以实现AI模型的高效训练. 针对这一行业痛点,创新工场投资的潞晨科技发布了大规模 ...

  9. pytorch adagrad_【学习笔记】Pytorch深度学习—优化器(二)

    点击文末 阅读原文,体验感更好哦! 前面学习过了Pytorch中优化器optimizer的基本属性和方法,优化器optimizer的主要功能是 "管理模型中的可学习参数,并利用参数的梯度gr ...

最新文章

  1. HTML基础标签入门
  2. boost::type_erasure::binding相关的测试程序
  3. python selenium鼠标点击_Python+Selenium学习--鼠标事件
  4. 银行营业网点管理系统——entity类(Branches)
  5. C51单片机————串行接口
  6. Java异常处理:如何写出“正确”但被编译器认为有语法错误的程序
  7. php跳转方式带rere_PHP利用REFERER根居访问来地址进行页面跳转
  8. POJ 1458 Common Subsequence
  9. android程序员简历模板
  10. ES6模板字符串中使用变量
  11. 解决QQ空间说说自动被发广告信息办法:取消第三方授权
  12. php的aes加密解密算法,PHP实现的简单AES加密解密算法实例
  13. echarts甘特图
  14. gc java_java内存管理以及GC
  15. Android手机Root授权原理细节全解析
  16. 2016OSC源创会年终盛典-综合技术专场-惠新宸
  17. 博士申请 | 美国密歇根州立大学ACTION Lab招收CV/ML方向全奖博士生
  18. Android蓝牙打印机功能开发完整Demo
  19. 基于Cesium的卫星及空间碎片的轨迹展示
  20. 为了安全起见,要求使用强SA密码。请使用SAPWD开关提供同一密码

热门文章

  1. R语言实战-统计分析基础-描述性统计4-psych-describe
  2. 【thinkPHP框架】Failed opening required 'header.php' include_path='.;c:\php5\pear 终级解决方案...
  3. 从技术面试官的角度来谈谈简历和面试
  4. 机器学习:从入门到第一个模型
  5. DisplayPageBoundaries 打开word后自动将页面间空白隐藏 (auto+定时器)
  6. .NET开发者如何愉快的进行微信公众号开发
  7. 通过编程为ASP.NET页面设置缓存
  8. 二叉树路径和最大python_python3实现在二叉树中找出和为某一值的所有路径(推荐)...
  9. 电子测量与仪器第四版pdf_固定资产管理系统_资产分类名称(电子和通信测量分析仪器篇)...
  10. android动画延迟执行,Android 过渡动画框架