公式是实现的原理,而源码才是让想法落地的媒介。希望能透过源代码,对原理有更具体的理解,回顾公式,也会有更深入的感受。
前期基于Pytorch的源码,对SGD进行了学习:基于Pytorch源码对SGD、momentum、Nesterov学习
本文会基于Pytorch源码,对Adagrad进行学习。


Adagrad
在SGD的年代,我们只能通过学习率(learning rate)来宏观控制网络的参数的学习速度,这从直观上是不太细致的。
随着人脑突触的实验进展,发现人脑神经元是有一定稀疏性的。以及ReLU(Rectified linear unit)激活函数的成功,也使得神经网络隐式地往稀疏性参数学习。以上证明了参数的稀疏性对于神经网络是有先验的现实支撑,以及实验支撑的。
在此大背景下,仅通过单一学习率来调节所有参数的学习,对于稀疏的网络参数而言是不太可行的。毕竟每个参数都有不同的大小,更可能有几个数量级的差距。

Adagrad的更新公式:
vt=vt−1+▽J(θt−1)∗▽J(θt−1)θt=θt−1−αvt+ϵ∗▽J(θt−1)v_t = v_{t-1} + ▽J(θ_{t-1}) * ▽J(θ_{t-1}) \\ θ_t = \theta_{t-1} - \frac{\alpha}{\sqrt{v_t + \epsilon}} * ▽J(θ_{t-1}) vt​=vt−1​+▽J(θt−1​)∗▽J(θt−1​)θt​=θt−1​−vt​+ϵ​α​∗▽J(θt−1​)
其中▽J(θt−1)∗▽J(θt−1)▽J(θ_{t-1}) * ▽J(θ_{t-1})▽J(θt−1​)∗▽J(θt−1​)为按元素相乘,即每个梯度的平方。ϵ\epsilonϵ用于确保分数计算的数值稳定。
Adagrad能让每个参数都有不同的学习率。梯度大的参数学习率较小,反之,梯度小的参数拥有较大的学习率。此番设计能让稀疏的参数得到充足的学习。
缺点是随着模型的训练,学习率的分母会不断累加导致学习率变得很小,模型可能无法最终收敛。

Pytorch源码
Pytorch中对于Adagrad的实现并不复杂,基本上与公式一致:

for group in self.param_groups:for p in group['params']:if p.grad is None:continuegrad = p.gradstate = self.state[p]state['sum'].addcmul_(grad, grad, value=1)std = state['sum'].sqrt().add_(group['eps'])p.addcdiv_(grad, std, value=-group['lr'])

遍历模型中的每个参数,计算它的学习率后更新参数。
其中state['sum']就是vtv_tvt​,保存每个参数累计的梯度平方和。
addcmul_表示state[‘sum’] + value * grad * grad。addcdiv_同理。


Adadelta
鉴于Adagrad学习率分母的不断累加,导致学习率减少,模型提前收敛,Adadelta把分母的累加替换为均值,达到学习率相对稳定。
类似于Batch Normalization对于均值和方差使用滑动平均(running average)来近似于训练集的均值和方差,作用于测试和验证。Adagrad也是采用滑动平均来近似于参数梯度的均值。
并且从Pytorch源码中发现,其中使用了两次滑动平均。

Adadelta的更新公式:
vt=γvt−1+(1−γ)▽J(θt−1)2θt=θt−1−αvt+ϵ∗▽J(θt−1)v_t = \gamma v_{t-1} + (1-\gamma)▽J(\theta_{t-1})^2 \\ \theta_t = \theta_{t-1} - \frac{\alpha}{\sqrt{v_t + \epsilon}} * ▽J(\theta_{t-1}) vt​=γvt−1​+(1−γ)▽J(θt−1​)2θt​=θt−1​−vt​+ϵ​α​∗▽J(θt−1​)
vtv_tvt​与Adagrad简单的累加相比,采用了平方的滑动平均,γ为平衡因子,通常取0.9。
vt+ϵ{\sqrt{v_t + \epsilon}}vt​+ϵ​相当于梯度的均方根(RMS:Root Mean Squared),梯度平方后求均值(滑动平均),再开方。

故可简写为:
θt=θt−1−αRMS(vt)∗▽J(θt−1)\theta_t = \theta_{t-1} - \frac{\alpha}{RMS(v_t)} * ▽J(\theta_{t-1}) θt​=θt−1​−RMS(vt​)α​∗▽J(θt−1​)

Pytorch源码

for group in self.param_groups:for p in group['params']:if p.grad is None:continuegrad = p.gradstate = self.state[p]# State initializationif len(state) == 0:state['step'] = 0state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)state['acc_delta'] = torch.zeros_like(p, memory_format=torch.preserve_format)square_avg, acc_delta = state['square_avg'], state['acc_delta']rho, eps = group['rho'], group['eps']# square_avg = rho * square_avg + (1 - rho) * grad * gradsquare_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)# std = \sqrt{square_avg + eps}std = square_avg.add(eps).sqrt_()# delta = \sqrt{acc_delta + eps} / std * grad = \sqrt{acc_delta + eps} / \sqrt{square_avg + eps} * graddelta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)# p = p - lr * deltap.add(delta, alpha=-group['lr'])# acc_delta = rho * acc_delta + (1 - rho) * delta * deltaacc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)

源码中对 grad 和 delta 进行了滑动平均,分别保存在了 square_avg 和 acc_delta 中,acc应该是指 accumulation 累计。
源码实际上对应的公式为:
square_avg=rho∗square_avg+(1−rho)∗grad∗graddelta=acc_delta+epsstd∗grad=acc_delta+epssquare_avg+eps∗gradp=p−lr∗deltaacc_delta=rho∗acc_delta+(1−rho)∗delta∗deltasquare\_avg = rho * square\_avg + (1 - rho) * grad * grad \\ delta = \frac{\sqrt{acc\_delta + eps}}{std} * grad = \frac{\sqrt{acc\_delta + eps}}{\sqrt{square\_avg + eps}} * grad \\ p = p - lr * delta \\ acc\_delta = rho * acc\_delta + (1 - rho) * delta * delta square_avg=rho∗square_avg+(1−rho)∗grad∗graddelta=stdacc_delta+eps​​∗grad=square_avg+eps​acc_delta+eps​​∗gradp=p−lr∗deltaacc_delta=rho∗acc_delta+(1−rho)∗delta∗delta
抛开超参数lr,p 的学习率取决于 acc_delta 和 square_avg。
分母的 square_avg 相当于 grad 的均值,用于 Adadelta 对不同参数实现不同的学习率大小。分子的 acc_delta 相当于 delta 的均值,我理解是为了对 delta 保持数值波动的稳定。


RMSProp
RMSprop 和 Adadelta 都是为了解决 Adagrad 学习率急剧下降问题的。
同时,对于大的梯度提供较小的学习率,小的梯度提供较大的学习率,降低了模型训练的摆动。

Pytorch源码

for group in self.param_groups:for p in group['params']:if p.grad is None:continuegrad = p.gradstate = self.state[p]# State initializationif len(state) == 0:state['step'] = 0state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)square_avg = state['square_avg']alpha = group['alpha']# square_avg = alpha * square_avg + (1 - alpha) * grad * gradsquare_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)# avg = \sqrt{square_avg} + epsavg = square_avg.sqrt().add_(group['eps'])# p = p - lr * grad / avgp.addcdiv_(grad, avg, value=-group['lr'])

代码对应的公式为:
square_avg=alpha∗square_avg+(1−alpha)∗grad∗gradavg=square_avg+epsp=p−lravg∗gradsquare\_avg = alpha * square\_avg + (1 - alpha) * grad * grad \\ avg = \sqrt{square\_avg} + eps \\ p = p - \frac{lr}{avg} * grad square_avg=alpha∗square_avg+(1−alpha)∗grad∗gradavg=square_avg​+epsp=p−avglr​∗grad
其实相当于把Adadelta分子的滑动平均置为1,即学习率仅取决于分母的 avg。


Adam: Adaptive Moment Estimation
Adam相当于加了momentum的RMSProp,把RMSProp中直接相乘的梯度换成带动量的梯度。
RMSProp的更新公式为:
mt=β1mt−1+(1−β1)▽θJ(θt−1)vt=β2vt−1+(1−β2)▽θJ(θt−1)∗▽θJ(θt−1)mt^=mt1−β1tvt^=vt1−β2tθt=θt−1−αvt^+ϵmt^m_t = \beta_1m_{t-1} + (1 - \beta_1)▽_\theta J(\theta_{t-1}) \\ v_t = \beta_2v_{t-1} + (1 - \beta_2) ▽_\theta J(\theta_{t-1}) * ▽_\theta J(\theta_{t-1}) \\ \ \\ \hat{m_t} = \frac{m_t}{1 - \beta_1^t} \\ \hat{v_t} = \frac{v_t}{1 - \beta_2^t} \\ \ \\ \theta_t = \theta_{t-1} - \frac{\alpha}{\sqrt{\hat{v_t}+\epsilon}} \hat{m_t} mt​=β1​mt−1​+(1−β1​)▽θ​J(θt−1​)vt​=β2​vt−1​+(1−β2​)▽θ​J(θt−1​)∗▽θ​J(θt−1​) mt​^​=1−β1t​mt​​vt​^​=1−β2t​vt​​ θt​=θt−1​−vt​^​+ϵ​α​mt​^​
超参数建议:β1=0.9\beta_1=0.9β1​=0.9,β2=0.99/0.999\beta_2=0.99/0.999β2​=0.99/0.999,α=0.001\alpha=0.001α=0.001,ϵ=10−8\epsilon=10^{-8}ϵ=10−8。
这里的vtv_tvt​和RMSProp的square_avg一样计算指数衰减平均值,mtm_tmt​则代替了原本的梯度,采用momentum的方式叠加了动量。
如果 mtm_tmt​ 和 vtv_tvt​ 被初始化为 0 ,那它们就会向 0 偏置,所以mt^\hat{m_t}mt​^​和vt^\hat{v_t}vt​^​做了偏差校正,通过计算偏差校正后的 mt 和 vt 来抵消这些偏差。越往后计算,βt\beta^tβt越接近0,校正效果也随之趋向于0。

Pytorch源码

for group in self.param_groups:for p in group['params']:if p.grad is None:continuegrad = p.gradamsgrad = group['amsgrad']state = self.state[p]# State initializationif len(state) == 0:state['step'] = 0# Exponential moving average of gradient valuesstate['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)# Exponential moving average of squared gradient valuesstate['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']beta1, beta2 = group['betas']state['step'] += 1bias_correction1 = 1 - beta1 ** state['step']bias_correction2 = 1 - beta2 ** state['step']# Decay the first and second moment running average coefficient# exp_avg = beta1 * exp_avg + (1 - beta1) * gradexp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)# exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * gradexp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)# denom = \sqrt{exp_avg_sq} / \sqrt(bias2) + epsdenom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])# step_size = lr / bias1step_size = group['lr'] / bias_correction1# p = p - step_size * exp_avg / denomp.addcdiv_(exp_avg, denom, value=-step_size)

exp_avg为mtm_tmt​,exp_avg_sq为vtv_tvt​,源码对于的公式为:
p=p−step_size∗exp_avgdenom=p−lrbias1∗mt∗bias2+epsvt=p−lr∗mt^∗这里变不回1vt^+ϵ尴尬但无伤大雅\begin{aligned} p &= p - \frac{step\_size * exp\_avg}{denom} \\ &=p - \frac{lr}{bias1 } * m_t * \frac{\sqrt{bias2}+eps}{\sqrt{v_t}} \\ &=p - lr * \hat{m_t} * 这里变不回\frac{1}{\sqrt{\hat{v_t}+\epsilon}}尴尬但无伤大雅 \end{aligned} p​=p−denomstep_size∗exp_avg​=p−bias1lr​∗mt​∗vt​​bias2​+eps​=p−lr∗mt​^​∗这里变不回vt​^​+ϵ​1​尴尬但无伤大雅​


参考文档
各优化器详解,含公式、附图、优劣:深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)
Pytorch的Adagrad类:torch.optim.adagrad.py
Pytorch的Adadelta类:torch.optim.adadelta.py
Pytorch的RMSprop类:torch.optim.rmsprop.py
Pytorch的Adam类:torch.optim.adam.py
Pytorch的Optimizer类:torch.optim.optimizer.py

基于Pytorch源码对Adagrad、Adadelta、RMSProp、Adam等自适应学习率进行学习相关推荐

  1. 基于Pytorch源码对SGD、momentum、Nesterov学习

    目前神经网络的监督学习过程通常为: 数据加载(load)进神经网络 经过网络参数对数据的计算,得出预测值(predict) 根据预测值与标注值(label)之间的差距,产生损失(loss) 通过反向传 ...

  2. 优化器(AdaGrad,AdaDelta,RmsProp,Adam,Nadam,Nesterovs,Sgd,momentum)

    以下来自: https://my.oschina.net/u/2935389/blog/2967242 https://mp.weixin.qq.com/s/NmSVXezxsQOZzK8pne3pC ...

  3. 优化算法SGD/ASGD/AdaGrad/Adadelta/RMSprop/Adam/Adamax/SparseAdam/L-BFGS/Rprop

    机器学习界有一群炼丹师,他们每天的日常是: 拿来药材(数据),架起八卦炉(模型),点着六味真火(优化算法),就摇着蒲扇等着丹药出炉了. 不过,当过厨子的都知道,同样的食材,同样的菜谱,但火候不一样了, ...

  4. ELMo解读(论文 + PyTorch源码)

    ELMo的概念也是很早就出了,应该是18年初的事情了.但我仍然是后知后觉,居然还是等BERT出来很久之后,才知道有这么个东西.这两天才仔细看了下论文和源码,在这里做一些记录,如果有不详实的地方,欢迎指 ...

  5. Transformer-XL解读(论文 + PyTorch源码)

    前言 目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer.RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用s ...

  6. PyTorch源码学习系列 - 1.初识

    本系列文章会优先发布于微信公众号和知乎,欢迎大家关注 微信公众号:小飞怪兽屋 知乎: PyTorch源码学习系列 - 1.初识 - 知乎 (zhihu.com) 目录 本系列的目的 PyTorch是什 ...

  7. PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

    目录 0 前言 1 Dataset 1.1 Map-style dataset 1.2 Iterable-style dataset 1.3 其他 dataset 2 Sampler 3 DataLo ...

  8. PyTorch源码浅析(1):THTensor

    PyTorch源码浅析(1):THTensor PyTorch中Tensor的存储和表示分开,多个THTensor可能共享一个THStorage,每个THTensor可能拥有不同的view(e.g. ...

  9. pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)

    写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...

  10. LambdaMART简介——基于Ranklib源码(一 lambda计算)

     LambdaMART简介--基于Ranklib源码(一 lambda计算) 时间:2014-08-09 21:01:49      阅读:168      评论:0      收藏:0      ...

最新文章

  1. php批量处理图片大小,word图片怎么批量调整大小
  2. JavaScript 开发进阶:理解 JavaScript 作用域和作用域链(上)
  3. 汇编解析(6)-二进制文件(嵌入式,纯二进制格式的文件)进行反汇编和汇编
  4. 遥感计算机分类实验的难点,8-遥感实验.doc
  5. java实现遍历文件夹下的文件及文件夹
  6. JPA 系列教程21-JPA2.0-@MapKeyColumn
  7. 聊聊FilterSecurityInterceptor
  8. 利用栈和队列将队列中的元素逆置☆
  9. (淘宝无限适配)手机端rem布局详解(转载非原创)
  10. python读取第一行设为字典_将csv读入字典,第一行成为名称
  11. tomcat老启动不起来问题
  12. 现在完成进行时和现在完成时的区别
  13. 什么是小规模纳税人、小型微利企业、小微企业
  14. 科普丨“垃圾”DNA?转座子在植物抗旱中的逆袭之路
  15. win10清理_别人都说win10不需要装电脑管家,那电脑产生的垃圾该怎么清理呢
  16. 刷新页面Vue Whitelabel Error Page
  17. STM32寄存器ODR,BSRR和BRR
  18. Hive之——Hive和Oozie整合
  19. 【亚马逊(上海)-AI Lab-DGL】实习生投递+面试(凉经)
  20. 阿里云服务器报错: Error response from daemon: Get “https://registry-1.docker.io/v2/“: net/http: request...

热门文章

  1. Unity Editor 查找资源依赖、反向查找资源依赖Dependencies
  2. 创建 VSTO 外接程序的windows安装包
  3. Mysql常用函数大全(分类汇总讲解)
  4. 如何由 XRD 图谱确定所做的样品是准晶结构
  5. 怎么批量修改文件后缀名?
  6. javashop B2C开源电商系统源代码
  7. 如何安装Junit4
  8. linux同时连接内外网的设置
  9. 【Word】如何实现特殊数字 带圈数字
  10. Struts2的OGNL表达式