原文链接:动手学深度学习pytorch版:7.7 AdaDelta算法
github:https://github.com/ShusenTang/Dive-into-DL-PyTorch

原论文:
[1] Zeiler, M. D. (2012). ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.

AdaDelta算法

除了 RMSProp 算法以外,另一个常用优化算法 AdaDelta 算法也针对 AdaGrad 算法在迭代后期可能较难找到有用解的问题做了改进 [1]。有意思的是,AdaDelta 算法没有学习率这一超参数。

算法

AdaDelta 算法也像 RMSProp 算法一样,使用了小批量随机梯度 gtg_tgt​ 按元素平方的指数加权移动平均变量 sts_tst​ 。在时间步0,它的所有元素被初始化为0。给定超参数 0≤ρ<100≤\rho<100≤ρ<10(对应RMSProp算法中的 γ\gammaγ),在时间步 t>0t>0t>0,同 RMSProp 算法一样计算
st←ρst−1+(1−ρ)gt⊙gt{{\text{s}}_{t}}\leftarrow \rho {{s}_{t-1}}+(1-\rho ){{g}_{t}}\odot {{g}_{t}} st​←ρst−1​+(1−ρ)gt​⊙gt​

与RMSProp算法不同的是,AdaDelta算法还维护一个额外的状态变量 Δxt\Delta {{x}_{t}}Δxt​,其元素同样在时间步0时被初始化为0。我们使用 Δxt−1\Delta {{x}_{t-1}}Δxt−1​ 来计算自变量的变化量:
g′t←Δxt−1+εst+ε⊙gtg{{'}_{t}}\leftarrow \sqrt{\frac{\Delta {{x}_{t-1}}+\varepsilon }{{{s}_{t}}+\varepsilon }}\odot {{g}_{t}} g′t​←st​+εΔxt−1​+ε​​⊙gt​

其中 ϵ\epsilonϵ 是为了维持数值稳定性而添加的常数,如 10−510^{-5}10−5。接着更新自变量:
xt←xt−1−g′t{{x}_{t}}\leftarrow {{x}_{t-1}}-g{{'}_{t}} xt​←xt−1​−g′t​

最后,我们使用 Δxt\Delta {{x}_{t}}Δxt​ 来记录自变量变化量 gt′g'_tgt′​ 按元素平方的指数加权移动平均:
Δxt←ρΔxt−1+(1−ρ)g′t⊙g′t\Delta {{x}_{t}}\leftarrow \rho \Delta {{x}_{t-1}}+(1-\rho )g{{'}_{t}}\odot g{{'}_{t}} Δxt​←ρΔxt−1​+(1−ρ)g′t​⊙g′t​

可以看到,如不考虑 ϵ\epsilonϵ 影响,AdaDelta 算法跟 RMSProp 算法的不同之处在于使用 Δxt−1\sqrt{Δx_{t-1}}Δxt−1​​ 来代替学习率 ηηη。

从零开始实现

%matplotlib inline
import torch
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"features, labels = d2l.get_data_ch7()def init_adadelta_states():s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)delta_w, delta_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)return ((s_w, delta_w), (s_b, delta_b))def adadelta(params, states, hyperparams):rho, eps = hyperparams['rho'], 1e-5for p, (s, delta) in zip(params, states):s[:] = rho * s + (1 - rho) * (p.grad.data**2)g =  p.grad.data * torch.sqrt((delta + eps) / (s + eps))p.data -= gdelta[:] = rho * delta + (1 - rho) * g * gd2l.train_ch7(adadelta, init_adadelta_states(), {'rho': 0.9}, features, labels)

输出:

loss: 0.243535, 0.057486 sec per epoch

简洁实现

d2l.train_pytorch_ch7(torch.optim.Adadelta, {'rho': 0.9}, features, labels)

输出:

loss: 0.304975, 0.058008 sec per epoch

小结

  • AdaDelta算法没有学习率超参数,它通过使用有关自变量更新量平方的指数加权移动平均的项来替代RMSProp算法中的学习率。

深度学习优化算法:AdaDelta算法相关推荐

  1. 重磅 | 2017年深度学习优化算法研究亮点最新综述火热出炉

    翻译 | AI科技大本营(微信ID:rgznai100) 梯度下降算法是机器学习中使用非常广泛的优化算法,也是众多机器学习算法中最常用的优化方法.几乎当前每一个先进的(state-of-the-art ...

  2. Adam 那么棒,为什么还对 SGD 念念不忘?一个框架看懂深度学习优化算法

    作者|Juliuszh 链接 | https://zhuanlan.zhihu.com/juliuszh 本文仅作学术分享,若侵权,请联系后台删文处理 机器学习界有一群炼丹师,他们每天的日常是: 拿来 ...

  3. 2017年深度学习优化算法最新进展:如何改进SGD和Adam方法?

    2017年深度学习优化算法最新进展:如何改进SGD和Adam方法? 深度学习的基本目标,就是寻找一个泛化能力强的最小值,模型的快速性和可靠性也是一个加分点. 随机梯度下降(SGD)方法是1951年由R ...

  4. 深度学习优化算法的总结与梳理(从 SGD 到 AdamW 原理和代码解读)

    作者丨科技猛兽 转自丨极市平台 本文思想来自下面这篇大佬的文章: Juliuszh:一个框架看懂优化算法之异同 SGD/AdaGrad/Adam https://zhuanlan.zhihu.com/ ...

  5. adam算法效果差原因_深度学习优化器-Adam两宗罪

    在上篇文章中,我们用一个框架来回顾了主流的深度学习优化算法.可以看到,一代又一代的研究者们为了我们能炼(xun)好(hao)金(mo)丹(xing)可谓是煞费苦心.从理论上看,一代更比一代完善,Ada ...

  6. 大梳理!深度学习优化算法:从 SGD 到 AdamW 原理和代码解读

    ‍ 作者丨知乎 科技猛兽  极市平台 编辑 https://zhuanlan.zhihu.com/p/391947979 本文思想来自下面这篇大佬的文章: Juliuszh:一个框架看懂优化算法之异同 ...

  7. 深度学习优化算法,Adam优缺点分析

    优化算法 首先我们来回顾一下各类优化算法. 深度学习优化算法经历了 SGD -> SGDM -> NAG ->AdaGrad -> AdaDelta -> Adam -& ...

  8. Adam那么棒,为什么还对SGD念念不忘?一个框架看懂深度学习优化算法

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者|Juliuszh,https://zhuanlan.zhih ...

  9. 2017年深度学习优化算法最新进展:改进SGD和Adam方法

    2017年深度学习优化算法最新进展:如何改进SGD和Adam方法 转载的文章,把个人觉得比较好的摘录了一下 AMSGrad 这个前期比sgd快,不能收敛到最优. sgdr 余弦退火的方案比较好 最近的 ...

  10. Pytorch框架的深度学习优化算法集(优化中的挑战)

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net Py ...

最新文章

  1. @Data注解使用后get set报错解决方法
  2. java 制作报表案例_javaweb项目报表案例
  3. eclipse web项目 解决“Dynamic Web Module 3.0 requires J
  4. Task2.特征提取
  5. 【源码】java中图片和Base64互相转换源码
  6. TextView 显示图像+文字的方法
  7. 面试字节跳动Android工程师该怎么准备?深度解析,值得收藏
  8. 对VSCode在安装了Dev-cpp的电脑上的配置
  9. 【编译原理笔记20】代码生成:代码生成器的主要任务,一个简单的目标机模型,指令选择,寄存器的选择,寄存器选择函数getReg的设计,窥孔优化
  10. 74HC595中文资料
  11. layui框架——弹出层layer
  12. 1143 Lowest Common Ancestor (30分) 附测试点分析
  13. 讲台计算机的英语怎么读,讲台的英语单词怎么写,英语怎么拼写!
  14. 【从蛋壳到满天飞】JS 数据结构解析和算法实现-链表
  15. mkt sensor1.0 alps
  16. 【数学建模】方差分析与回归分析的SPSS实现
  17. (最新)唯品会WEB端加密参数逆向分析
  18. 索骥馆-DIY操作系统之《30天自制操作系统》扫描版[PDF]
  19. 路由器计算机无法上网,电脑可以上网路由器不能上网怎么回事?
  20. 蓝牙4.0知识百科1 什么是BLE4.0

热门文章

  1. 记一次使用Openssl生成p12证书搭建https证书
  2. 齐全且实用的MySQL函数使用大全
  3. Junit 4 的 @Before 和 @BeforeClass 对比 Junit 5 @BeforeEach 和 @BeforeAll
  4. Gerber文件的输出
  5. [境内法规]中国人民银行关于印发《反洗钱现场检查管理办法(试行)》的通知—银发〔2007〕175号
  6. [java]房屋出租系统
  7. Verp中外部控制的六种方式
  8. PR自学之软件的安装
  9. C#学习之IntPtr类型
  10. cmd查看所有数据库 db2_DB2常用命令