参考文献:
1.https://www.fast.ai/2018/07/02/adam-weight-decay/
2.https://arxiv.org/pdf/1904.00962.pdf
3.https://blog.csdn.net/weixin_43269174/article/details/106255084

前言

说到优化器,我们脑海中首先浮现的可能就是 Stochastic Gradient Descent (SGD)、Adaptive Gradient (AdaGrad)、Root Mean Square prop (RMSprop)、Adaptive Moment estimation (Adam) 等常用的老牌优化器。但是神经网络发展到了现在,大部分 NLP 预训练模型已不再使用这些方法,而是使用 Adam Weight Decay Regularization (AdamW) 和19年首度亮相的 Layer-wise Adaptive Moments optimizer for Batching training (LAMB)。这些新兴优化器的优点是什么呢?为什么如此受欢迎?这些网上已经有很多分析和解释了,这里不再说明,本文的重点就是Adam,AdamW,LAMB的计算公式和代码实现。

1 Adam

为解决 GD 中固定学习率带来的不同参数间收敛速度不一致的弊端,AdaGrad 和 RMSprop 诞生出来,为每个参数赋予独立的学习率。计算梯度后,梯度较大的参数获得的学习率较低,反之亦然。此外,为避免每次梯度更新时都独立计算梯度,导致梯度方向持续变化,Momentum 将上一轮梯度值加入到当前梯度的计算中,通过某种权重对两者加权求和,获得当前批次参数更新的更新值。 Adam 结合了这两项考虑,既为每一个浮点参数自适应性地设置学习率,又将过去的梯度历史纳入考量,其实现原理如下:
mt=β1∗mt−1+(1−β1)∗gtvt=β2∗vt−1+(1−β2)∗gt2mt^=mt/(1−β1t)vt^=vt/(1−β2t)θt=θt−1−α∗mt^vt^+ϵm_t=\beta_1*m_{t-1}+(1-\beta_1)*g_t\\ v_t=\beta_2*v_{t-1}+(1-\beta_2)*g_t^2\\ \hat{m_t}=m_t/(1-\beta_1^t)\\ \hat{v_t}=v_t/(1-\beta_2^t)\\ \theta_t=\theta_{t-1}-\alpha*\frac{\hat{m_t}}{\sqrt{\hat{v_t}}+\epsilon} mt​=β1​∗mt−1​+(1−β1​)∗gt​vt​=β2​∗vt−1​+(1−β2​)∗gt2​mt​^​=mt​/(1−β1t​)vt​^​=vt​/(1−β2t​)θt​=θt−1​−α∗vt​^​​+ϵmt​^​​
计算一阶、二阶动量矩,加入偏置修正,最后更新参数,gt表示t时刻梯度。从上述公式可以看出,训练前期的学习率和梯度更新是比较激进的,到后期逐渐平稳。虽然 Adam 优化器的使用会导致内存中多出两倍于原参数体量的占用,但与之换来的训练收益使得学术界并没有放弃这一高效的方法。

代码实现比较简单,照着公式敲就行了:

import autograd.numpy as np
from autograd import gradclass Adam:def __init__(self, loss, weights, lr=0.001, beta1=0.9, beta2=0.999, epislon=1e-8):self.loss = lossself.theta = weightsself.lr = lrself.beta1 = beta1self.beta2 = beta2self.epislon = epislonself.get_gradient = grad(loss)self.m = 0self.v = 0self.t = 0def minimize_raw(self):self.t += 1g = self.get_gradient(self.loss)self.m = self.beta1 * self.m + (1 - self.beta1) * gself.v = self.beta2 * self.v + (1 - self.beta2) * (g * g)self.m_hat = self.m / (1 - self.beta1 ** self.t)self.v_hat = self.v / (1 - self.beta2 ** self.t)self.theta = self.theta - self.lr * self.m_hat / (self.v_hat ** 0.5 + self.epislon)

2 AdamW

Adam 虽然收敛速度快,但没能解决参数过拟合的问题。学术界讨论了诸多方案,其中包括在损失函数中引入参数的 L2 正则项。这样的方法在其他的优化器中或许有效,但会因为 Adam 中自适应学习率的存在而对使用 Adam 优化器的模型失效,具体分析可见fastai的这篇文章:AdamW and Super-convergence is now the fastest way to train neural nets。AdamW 的出现便是为了解决这一问题,达到同样使参数接近于 0 的目的。具体的举措,是在最终的参数更新时引入参数自身:
mt=β1∗mt−1+(1−β1)∗gtvt=β2∗vt−1+(1−β2)∗gt2mt^=mt/(1−β1t)vt^=vt/(1−β2t)θt=θt−1−α∗(mt^vt^+ϵ+λ∗θt−1)m_t=\beta_1*m_{t-1}+(1-\beta_1)*g_t\\ v_t=\beta_2*v_{t-1}+(1-\beta_2)*g_t^2\\ \hat{m_t}=m_t/(1-\beta_1^t)\\ \hat{v_t}=v_t/(1-\beta_2^t)\\ \theta_t=\theta_{t-1}-\alpha*(\frac{\hat{m_t}}{\sqrt{\hat{v_t}}+\epsilon}+\lambda*\theta_{t-1}) mt​=β1​∗mt−1​+(1−β1​)∗gt​vt​=β2​∗vt−1​+(1−β2​)∗gt2​mt​^​=mt​/(1−β1t​)vt​^​=vt​/(1−β2t​)θt​=θt−1​−α∗(vt​^​​+ϵmt​^​​+λ∗θt−1​)
λ 即为权重衰减因子,常见的设置为 0.005/0.01。这一优化策略目前正广泛应用于各大预训练语言模型。

class AdamW:def __init__(self, loss, weights, lambda1, lr=0.001, beta1=0.9, beta2=0.999, epislon=1e-8):self.loss = lossself.theta = weightsself.lr = lrself.beta1 = beta1self.beta2 = beta2self.epislon = epislonself.lambda1 = lambda1self.get_gradient = grad(loss)self.m = 0self.v = 0self.t = 0def minimize_raw(self):self.t += 1g = self.get_gradient(self.loss)self.m = self.beta1 * self.m + (1 - self.beta1) * gself.v = self.beta2 * self.v + (1 - self.beta2) * (g * g)self.m_hat = self.m / (1 - self.beta1 ** self.t)self.v_hat = self.v / (1 - self.beta2 ** self.t)self.theta = self.theta - self.lr * (self.m_hat / (self.v_hat ** 0.5 + self.epislon) + self.lambda1 * self.theta)

3 LAMB

LAMB 优化器是 2019 年出现的一匹新秀,它将bert模型的预训练时间从3天压缩到了76分钟! LAMB 出现的目的是加速预训练进程,这个优化器也成为 NLP 社区为泛机器学习领域做出的一大贡献。在使用 Adam 和 AdamW 等优化器时,一大问题在于 batch size 存在一定的隐式上限,一旦突破这个上限,梯度更新极端的取值会导致自适应学习率调整后极为困难的收敛,从而无法享受增加的 batch size 带来的提速增益。LAMB 优化器的作用便在于使模型在进行大批量数据训练时,能够维持梯度更新的精度。具体来说,LAMB 优化器支持自适应元素级更新(adaptive element-wise updating)和准确的逐层修正(layer-wise correction)。LAMB 可将 BERT 预训练的批量大小扩展到 64K,且不会造成准确率损失。BERT 预训练包括两个阶段:1)前 9/10 的训练 epoch 使用 128 的序列长度,2)最后 1/10 的训练 epoch 使用 512 的序列长度。LAMB的算法如下:
mt=β1∗mt−1+(1−β1)∗gtvt=β2∗vt−1+(1−β2)∗gt2mt^=mt/(1−β1t)vt^=vt/(1−β2t)rt=mt^vt^+ϵθt=θt−1−α∗ϕ(∣∣θt−1∣∣)∣∣rt+λθt−1∣∣(rt+λθt−1)m_t=\beta_1*m_{t-1}+(1-\beta_1)*g_t\\ v_t=\beta_2*v_{t-1}+(1-\beta_2)*g_t^2\\ \hat{m_t}=m_t/(1-\beta_1^t)\\ \hat{v_t}=v_t/(1-\beta_2^t)\\ r_t=\frac{\hat{m_t}}{\sqrt{\hat{v_t}}+\epsilon}\\ \theta_t=\theta_{t-1}-\alpha*\frac{\phi(||\theta_{t-1}||)}{||r_t+\lambda \theta_{t-1}||}(r_t+\lambda \theta_{t-1}) mt​=β1​∗mt−1​+(1−β1​)∗gt​vt​=β2​∗vt−1​+(1−β2​)∗gt2​mt​^​=mt​/(1−β1t​)vt​^​=vt​/(1−β2t​)rt​=vt​^​​+ϵmt​^​​θt​=θt−1​−α∗∣∣rt​+λθt−1​∣∣ϕ(∣∣θt−1​∣∣)​(rt​+λθt−1​)
其中,$\phi 是一个可选择的映射函数,一种是是一个可选择的映射函数,一种是是一个可选择的映射函数,一种是\phi(z)=z,另一种则为起到归一化作用的,另一种则为起到归一化作用的,另一种则为起到归一化作用的min(max(z,\gamma_l),\gamma_u)。。。\gamma_l,\gamma_u$为预先设定的超参数,分别代表参数调整的下界和上界。这一简单的调整所带来的实际效果非常显著。使用 AdamW 时,batch size 超过 512 便会导致模型效果大幅下降,但在 LAMB 下,batch size 可以直接提到 32,000 而不会导致精度损失。

以下是 LAMB 优化器的 tensorflow1.x 代码,可作为参考以理解算法,具体的代码出处已无法找寻。

class LAMBOptimizer(tf.train.Optimizer):'''LAMBOptimizer optimizer.# Important Note- This is NOT an official implementation.- LAMB optimizer is changed from arXiv v1 ~ v3.- We implement v3 version (which is the latest version on June, 2019.).- Our implementation is based on `AdamWeightDecayOptimizer` in BERT (provided by Google).# References- LAMB optimier: https://github.com/ymcui/LAMB_Optimizer_TF- Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. https://arxiv.org/abs/1904.00962v3- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. https://arxiv.org/abs/1810.04805# Parameters- There is nothing special, just the same as `AdamWeightDecayOptimizer`.'''def __init__(self,learning_rate,weight_decay_rate=0.01,beta_1=0.9,beta_2=0.999,epsilon=1e-6,exclude_from_weight_decay=None,name="LAMBOptimizer"):"""Constructs a LAMBOptimizer."""super(LAMBOptimizer, self).__init__(False, name)self.learning_rate = learning_rateself.weight_decay_rate = weight_decay_rateself.beta_1 = beta_1self.beta_2 = beta_2self.epsilon = epsilonself.exclude_from_weight_decay = exclude_from_weight_decaydef apply_gradients(self, grads_and_vars, global_step=None, name=None):"""See base class."""assignments = []for (grad, param) in grads_and_vars:if grad is None or param is None:continueparam_name = self._get_variable_name(param.name)m = tf.get_variable(name=param_name + "/lamb_m",shape=param.shape.as_list(),dtype=tf.float32,trainable=False,initializer=tf.zeros_initializer())v = tf.get_variable(name=param_name + "/lamb_v",shape=param.shape.as_list(),dtype=tf.float32,trainable=False,initializer=tf.zeros_initializer())# Standard Adam update.next_m = (tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))next_v = (tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,tf.square(grad)))update = next_m / (tf.sqrt(next_v) + self.epsilon)# Just adding the square of the weights to the loss function is *not*# the correct way of using L2 regularization/weight decay with Adam,# since that will interact with the m and v parameters in strange ways.## Instead we want ot decay the weights in a manner that doesn't interact# with the m/v parameters. This is equivalent to adding the square# of the weights to the loss with plain (non-momentum) SGD.if self._do_use_weight_decay(param_name):update += self.weight_decay_rate * param############## BELOW ARE THE SPECIFIC PARTS FOR LAMB ############### Note: Here are two choices for scaling function \phi(z)# minmax:   \phi(z) = min(max(z, \gamma_l), \gamma_u)# identity: \phi(z) = z# The authors does not mention what is \gamma_l and \gamma_u# UPDATE: after asking authors, they provide me the code below.# ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(#      math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)r1 = tf.sqrt(tf.reduce_sum(tf.square(param)))r2 = tf.sqrt(tf.reduce_sum(tf.square(update)))r = tf.where(tf.greater(r1, 0.0),tf.where(tf.greater(r2, 0.0),r1 / r2,1.0),1.0)eta = self.learning_rate * rupdate_with_lr = eta * updatenext_param = param - update_with_lrassignments.extend([param.assign(next_param),m.assign(next_m),v.assign(next_v)])return tf.group(*assignments, name=name)def _do_use_weight_decay(self, param_name):"""Whether to use L2 weight decay for `param_name`."""if not self.weight_decay_rate:return Falseif self.exclude_from_weight_decay:for r in self.exclude_from_weight_decay:if re.search(r, param_name) is not None:return Falsereturn Truedef _get_variable_name(self, param_name):"""Get the variable name from the tensor name."""m = re.match("^(.*):\\d+$", param_name)if m is not None:param_name = m.group(1)return param_name

Adam,AdamW,LAMB优化器原理与代码相关推荐

  1. 一训练就显存爆炸?Facebook 推出 8 比特优化器,两行代码拯救你的显存

    自从人们发现越大的模型性能越好后,神经网络模型的参数量就在越来越大的道路上一去不复返了.从XX-large到GPT3,再到5300亿参数的Megatron Turing-NLG,深度学习越来越像是只有 ...

  2. 一训练就显存爆炸?Facebook 推出 8 比特优化器,两行代码拯救你的显存!

    文 | jxyxiangyu 编 | 小轶 "小夕,小夕!又出来了个 SOTA 模型!赶紧 follow !" 小夕看了看新模型的参数量, 然后看了看实验室服务器的几张小破卡. 小 ...

  3. APG(Accelerate Proximal Gradient)加速近端梯度算法 和 NAG(Nesterov accelerated gradient)优化器原理 (二)

    文章目录 前言 NAG优化器 APG 与 NAG的结合 Pytorch 代码实现 总结 附录 公式(11)推导 引用 前言 近期在阅读Data-Driven Sparse Structure Sele ...

  4. APG(Accelerate Proximal Gradient)加速近端梯度算法 和 NAG(Nesterov accelerated gradient)优化器原理 (一)

    文章目录 前言 APG(Accelerate Proximal Gradient)加速近端梯度算法[^1] PGD (Proximal Gradient Descent)近端梯度下降法推导[^2] E ...

  5. SQL优化器原理 - Auto Hash Join

    这是MaxCompute有关SQL优化器原理的系列文章之一.我们会陆续推出SQL优化器有关优化规则和框架的其他文章.添加钉钉群"关系代数优化技术"(群号11719083)可以获取最 ...

  6. QQ强聊器原理和代码

    QQ强聊器原理和代码 可以做下面一个试验: 在 IE 地址栏内输入如下字符: http://wpa.qq.com/msgrd?V=1&Uin=123456&Site=ioshenmue ...

  7. 萤火虫算法_40多种智能优化算法原理和代码分享

    40多种智能优化算法原理和代码分享 <智能优化算法讲解>PDF下载地址: <智能优化算法原理讲解>PDF​mianbaoduo.com 包括: 1.海鸥算法SOA 智能优化算法 ...

  8. 【深度学习】小白学深度学习:参数优化与优化器原理

    深度学习的「参数优化」 深度学习模型的优化过程是指调整模型的参数以尽量减小预测误差的过程.下面是深度学习模型优化的基本流程: 确定损失函数:衡量模型预测输出和实际输出之间误差的函数. 梯度反向传播:用 ...

  9. mysql not in优化_98%的人不知道的MySQL优化器原理

    ​| 作者 梁东阳,数据库研发中心数据库内核工程师,负责腾讯云MySQL的内核开发. 在日常运维中,相信不少人都收藏了很多关于查询优化的方法论和小技巧,但是仔细想想,你真的了解这些优化背后的原理吗? ...

最新文章

  1. 我们如何在Python中创建多行注释?
  2. 六个 Linux性能监控命令行工具
  3. ElasticSearch-6.3.2 linux 安装
  4. [快速入门]Spring Boot+springfox-swagger2 之RESTful API自动生成和测试
  5. 26. 复杂链表的复制
  6. 关于元素水平垂直居中的那些事?
  7. 这个 bug,硬是让我折腾了一周
  8. Python:try……excepted捕获方法
  9. 2018年,51LA新版的那些事
  10. 如何提高学习效率,三大法则,五大步骤
  11. 一点点读懂cpufreq(一)
  12. Windows和Xyplorer的完美结合
  13. 一次Ajax报错:“存储空间不足,无法完成此操作”的解决经验
  14. C/C++ opencv 计算 LBP特征 包括旋转不变 uniform 圆形邻域
  15. RTX51tiny 延时长度计算
  16. 蓝桥杯 STEMA 考试 C++ 编程题模拟题
  17. 推荐python入门进阶到大神的书籍
  18. 保本基金的投资组合保险策略运用及建议
  19. *.3ds的文件格式
  20. 成教计算机科学与技术怎么样,华中农业大学成考计算机科学与技术专业就业前景怎么样?...

热门文章

  1. 北京交通大学本科毕业论文答辩PPT模板
  2. windows输入英文-搜狗输入法不提示很恼火怎么办
  3. 建筑标准何其之多,python爬虫半天全梭
  4. ae制h5文字动画_对于8个华丽的HTML5文字动画特效图文赏析
  5. CVPR2021目标检测方向论文
  6. iOS开发之自定义的framework添加第三方framework,lipo和ar命令看.o文件
  7. 连载 | Android之Camera1实现相机开发
  8. Java技术栈学习路线
  9. vue项目中扫码枪收款
  10. 【git】You have not concluded your merge (MERGE_HEAD exists).