深度学习优化算法:AdaDelta算法
原文链接:动手学深度学习pytorch版:7.7 AdaDelta算法
github:https://github.com/ShusenTang/Dive-into-DL-PyTorch
原论文:
[1] Zeiler, M. D. (2012). ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
AdaDelta算法
除了 RMSProp 算法以外,另一个常用优化算法 AdaDelta 算法也针对 AdaGrad 算法在迭代后期可能较难找到有用解的问题做了改进 [1]。有意思的是,AdaDelta 算法没有学习率这一超参数。
算法
AdaDelta 算法也像 RMSProp 算法一样,使用了小批量随机梯度 gtg_tgt 按元素平方的指数加权移动平均变量 sts_tst 。在时间步0,它的所有元素被初始化为0。给定超参数 0≤ρ<100≤\rho<100≤ρ<10(对应RMSProp算法中的 γ\gammaγ),在时间步 t>0t>0t>0,同 RMSProp 算法一样计算
st←ρst−1+(1−ρ)gt⊙gt{{\text{s}}_{t}}\leftarrow \rho {{s}_{t-1}}+(1-\rho ){{g}_{t}}\odot {{g}_{t}} st←ρst−1+(1−ρ)gt⊙gt
与RMSProp算法不同的是,AdaDelta算法还维护一个额外的状态变量 Δxt\Delta {{x}_{t}}Δxt,其元素同样在时间步0时被初始化为0。我们使用 Δxt−1\Delta {{x}_{t-1}}Δxt−1 来计算自变量的变化量:
g′t←Δxt−1+εst+ε⊙gtg{{'}_{t}}\leftarrow \sqrt{\frac{\Delta {{x}_{t-1}}+\varepsilon }{{{s}_{t}}+\varepsilon }}\odot {{g}_{t}} g′t←st+εΔxt−1+ε⊙gt
其中 ϵ\epsilonϵ 是为了维持数值稳定性而添加的常数,如 10−510^{-5}10−5。接着更新自变量:
xt←xt−1−g′t{{x}_{t}}\leftarrow {{x}_{t-1}}-g{{'}_{t}} xt←xt−1−g′t
最后,我们使用 Δxt\Delta {{x}_{t}}Δxt 来记录自变量变化量 gt′g'_tgt′ 按元素平方的指数加权移动平均:
Δxt←ρΔxt−1+(1−ρ)g′t⊙g′t\Delta {{x}_{t}}\leftarrow \rho \Delta {{x}_{t-1}}+(1-\rho )g{{'}_{t}}\odot g{{'}_{t}} Δxt←ρΔxt−1+(1−ρ)g′t⊙g′t
可以看到,如不考虑 ϵ\epsilonϵ 影响,AdaDelta 算法跟 RMSProp 算法的不同之处在于使用 Δxt−1\sqrt{Δx_{t-1}}Δxt−1 来代替学习率 ηηη。
从零开始实现
%matplotlib inline
import torch
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"features, labels = d2l.get_data_ch7()def init_adadelta_states():s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)delta_w, delta_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)return ((s_w, delta_w), (s_b, delta_b))def adadelta(params, states, hyperparams):rho, eps = hyperparams['rho'], 1e-5for p, (s, delta) in zip(params, states):s[:] = rho * s + (1 - rho) * (p.grad.data**2)g = p.grad.data * torch.sqrt((delta + eps) / (s + eps))p.data -= gdelta[:] = rho * delta + (1 - rho) * g * gd2l.train_ch7(adadelta, init_adadelta_states(), {'rho': 0.9}, features, labels)
输出:
loss: 0.243535, 0.057486 sec per epoch
简洁实现
d2l.train_pytorch_ch7(torch.optim.Adadelta, {'rho': 0.9}, features, labels)
输出:
loss: 0.304975, 0.058008 sec per epoch
小结
- AdaDelta算法没有学习率超参数,它通过使用有关自变量更新量平方的指数加权移动平均的项来替代RMSProp算法中的学习率。
深度学习优化算法:AdaDelta算法相关推荐
- 重磅 | 2017年深度学习优化算法研究亮点最新综述火热出炉
翻译 | AI科技大本营(微信ID:rgznai100) 梯度下降算法是机器学习中使用非常广泛的优化算法,也是众多机器学习算法中最常用的优化方法.几乎当前每一个先进的(state-of-the-art ...
- Adam 那么棒,为什么还对 SGD 念念不忘?一个框架看懂深度学习优化算法
作者|Juliuszh 链接 | https://zhuanlan.zhihu.com/juliuszh 本文仅作学术分享,若侵权,请联系后台删文处理 机器学习界有一群炼丹师,他们每天的日常是: 拿来 ...
- 2017年深度学习优化算法最新进展:如何改进SGD和Adam方法?
2017年深度学习优化算法最新进展:如何改进SGD和Adam方法? 深度学习的基本目标,就是寻找一个泛化能力强的最小值,模型的快速性和可靠性也是一个加分点. 随机梯度下降(SGD)方法是1951年由R ...
- 深度学习优化算法的总结与梳理(从 SGD 到 AdamW 原理和代码解读)
作者丨科技猛兽 转自丨极市平台 本文思想来自下面这篇大佬的文章: Juliuszh:一个框架看懂优化算法之异同 SGD/AdaGrad/Adam https://zhuanlan.zhihu.com/ ...
- adam算法效果差原因_深度学习优化器-Adam两宗罪
在上篇文章中,我们用一个框架来回顾了主流的深度学习优化算法.可以看到,一代又一代的研究者们为了我们能炼(xun)好(hao)金(mo)丹(xing)可谓是煞费苦心.从理论上看,一代更比一代完善,Ada ...
- 大梳理!深度学习优化算法:从 SGD 到 AdamW 原理和代码解读
作者丨知乎 科技猛兽 极市平台 编辑 https://zhuanlan.zhihu.com/p/391947979 本文思想来自下面这篇大佬的文章: Juliuszh:一个框架看懂优化算法之异同 ...
- 深度学习优化算法,Adam优缺点分析
优化算法 首先我们来回顾一下各类优化算法. 深度学习优化算法经历了 SGD -> SGDM -> NAG ->AdaGrad -> AdaDelta -> Adam -& ...
- Adam那么棒,为什么还对SGD念念不忘?一个框架看懂深度学习优化算法
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者|Juliuszh,https://zhuanlan.zhih ...
- 2017年深度学习优化算法最新进展:改进SGD和Adam方法
2017年深度学习优化算法最新进展:如何改进SGD和Adam方法 转载的文章,把个人觉得比较好的摘录了一下 AMSGrad 这个前期比sgd快,不能收敛到最优. sgdr 余弦退火的方案比较好 最近的 ...
- Pytorch框架的深度学习优化算法集(优化中的挑战)
个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net Py ...
最新文章
- @Data注解使用后get set报错解决方法
- java 制作报表案例_javaweb项目报表案例
- eclipse web项目 解决“Dynamic Web Module 3.0 requires J
- Task2.特征提取
- 【源码】java中图片和Base64互相转换源码
- TextView 显示图像+文字的方法
- 面试字节跳动Android工程师该怎么准备?深度解析,值得收藏
- 对VSCode在安装了Dev-cpp的电脑上的配置
- 【编译原理笔记20】代码生成:代码生成器的主要任务,一个简单的目标机模型,指令选择,寄存器的选择,寄存器选择函数getReg的设计,窥孔优化
- 74HC595中文资料
- layui框架——弹出层layer
- 1143 Lowest Common Ancestor (30分) 附测试点分析
- 讲台计算机的英语怎么读,讲台的英语单词怎么写,英语怎么拼写!
- 【从蛋壳到满天飞】JS 数据结构解析和算法实现-链表
- mkt sensor1.0 alps
- 【数学建模】方差分析与回归分析的SPSS实现
- (最新)唯品会WEB端加密参数逆向分析
- 索骥馆-DIY操作系统之《30天自制操作系统》扫描版[PDF]
- 路由器计算机无法上网,电脑可以上网路由器不能上网怎么回事?
- 蓝牙4.0知识百科1 什么是BLE4.0
热门文章
- 记一次使用Openssl生成p12证书搭建https证书
- 齐全且实用的MySQL函数使用大全
- Junit 4 的 @Before 和 @BeforeClass 对比 Junit 5 @BeforeEach 和 @BeforeAll
- Gerber文件的输出
- [境内法规]中国人民银行关于印发《反洗钱现场检查管理办法(试行)》的通知—银发〔2007〕175号
- [java]房屋出租系统
- Verp中外部控制的六种方式
- PR自学之软件的安装
- C#学习之IntPtr类型
- cmd查看所有数据库 db2_DB2常用命令