我们在前面的文章中熟悉了梯度下降的各种形态,深度学习优化算法之(小批量)随机梯度下降(MXNet),也了解了梯度下降的原理,由每次的迭代,梯度下降都根据自变量的当前位置来更新自变量,做自我迭代。但是如果说自变量的迭代方向只是取决于自变量的当前位置的话,这可能会带来一些问题。比如我们来看下函数  的走势,现在我们来看下这个函数其系数为0.1的情况 在学习率变化时,将会发生什么变化。

从eta=0.4一个比较合适的学习率开始:

import d2lzh as d2l
from mxnet import ndeta=0.4def f_2d(x1,x2):return 0.1*x1**2 + 2*x2**2def gd_2d(x1,x2,s1,s2):return (x1-eta*0.2*x1,x2-eta*4*x2,0,0)d2l.show_trace_2d(f_2d,d2l.train_2d(gd_2d))#epoch 20, x1 -0.943467, x2 -0.000073

图中可以看出,同一个位置上,目标函数在竖直方向(x2轴方向)比在水平方向(x1轴方向)的斜率的绝对值更大,换句话说就是自变量的更新会使自变量在竖直方向比在水平方向移动幅度更大。
我们将学习率调大一点:eta=0.6

我们发现自变量在竖直方向不断越过最优解并逐渐发散了。

动量法

那上面这个问题,我们通过动量法来处理,在前面的文章也有介绍,这里算是一种新的学习与巩固,更重要的是了解为什么动量法能够处理这种上下方向的偏幅。
那很明显上面存在的问题就是自变量在竖直方向的更新不一致,时正时负,找到问题所在之后,那我们就只需要解决这个方向一致的问题就好办了。

对于动量法的推导,我们从指数加权移动平均(Exponentially Weighted  Moving Average)来理解它,还是画图来直观看下其推导过程:

然后我们通过代码来看下实际情况:

eta,gamma=0.4,0.5def f_2d(x1,x2):return 0.1*x1**2 + 2*x2**2#当gamma=0时,就是小批量随机梯度下降
def momentum_gd_2d(x1,x2,v1,v2):v1=gamma*v1 + eta*0.2*x1v2=gamma*v2 + eta*4*x2return x1-v1,x2-v2,v1,v2d2l.show_trace_2d(f_2d,d2l.train_2d(momentum_gd_2d))#epoch 20, x1 -0.062843, x2 0.001202

图中可以看出使用动量法之后在竖直方向上的移动更加平滑了,而且在水平方向也更快逼近最优解。
然后将学习率调大到0.6,也没有出现发散的情况。

飞机机翼噪音测试

import d2lzh as d2l
from mxnet import nd#使用飞机噪音数据集来测试
#https://download.csdn.net/download/weixin_41896770/86513479
features,labels=d2l.get_data_ch7()#1503x5,1503#速度变量用更广义的状态变量states表示
def init_momentum_states():v_w=nd.zeros((features.shape[1],1))v_b=nd.zeros(1)return (v_w,v_b)def sgd_momentum(params,states,hyperparams):for p,v in zip(params,states):v[:]=hyperparams['momentum']*v +hyperparams['lr']*p.gradp[:]-=vd2l.train_ch7(sgd_momentum,init_momentum_states(),{'lr':0.02,'momentum':0.5},features,labels)
#loss: 0.249161, 0.171031 sec per epoch

#看做特殊的小批量随机梯度下降
#最近2个时间步的2倍小批量梯度的加权平均
#d2l.train_ch7(sgd_momentum,init_momentum_states(),{'lr':0.02,'momentum':0.5},features,labels)
#最近10个时间步的10倍小批量梯度的加权平均,1/(1-0.9)
d2l.train_ch7(sgd_momentum,init_momentum_states(),{'lr':0.02,'momentum':0.9},features,labels)
loss: 0.259894, 0.177999 sec per epoch

图中可以看出后期的迭代不够平滑,因为10倍小批量梯度比2倍小批量梯度大了5倍,我们将学习率调小5倍试下:

#学习率调下5倍,从0.02到0.004
d2l.train_ch7(sgd_momentum,init_momentum_states(),{'lr':0.004,'momentum':0.9},features,labels)
#loss: 0.243785, 0.181000 sec per epoch

#简洁实现
d2l.train_gluon_ch7('sgd',{'learning_rate':0.004,'momentum':0.9},features,labels)

动量法的出现主要是解决相邻时间步的自变量的在更新方向上的问题,使得它们更加趋向一致,因为它将过去时间步的梯度做了加权平均,而不仅仅是关注当前变量梯度的位置。

深度学习优化算法之动量法[公式推导](MXNet)相关推荐

  1. Adam 那么棒,为什么还对 SGD 念念不忘?一个框架看懂深度学习优化算法

    作者|Juliuszh 链接 | https://zhuanlan.zhihu.com/juliuszh 本文仅作学术分享,若侵权,请联系后台删文处理 机器学习界有一群炼丹师,他们每天的日常是: 拿来 ...

  2. 2017年深度学习优化算法最新进展:改进SGD和Adam方法

    2017年深度学习优化算法最新进展:如何改进SGD和Adam方法 转载的文章,把个人觉得比较好的摘录了一下 AMSGrad 这个前期比sgd快,不能收敛到最优. sgdr 余弦退火的方案比较好 最近的 ...

  3. 2017年深度学习优化算法最新进展:如何改进SGD和Adam方法?

    2017年深度学习优化算法最新进展:如何改进SGD和Adam方法? 深度学习的基本目标,就是寻找一个泛化能力强的最小值,模型的快速性和可靠性也是一个加分点. 随机梯度下降(SGD)方法是1951年由R ...

  4. 深度学习优化算法的总结与梳理(从 SGD 到 AdamW 原理和代码解读)

    作者丨科技猛兽 转自丨极市平台 本文思想来自下面这篇大佬的文章: Juliuszh:一个框架看懂优化算法之异同 SGD/AdaGrad/Adam https://zhuanlan.zhihu.com/ ...

  5. 大梳理!深度学习优化算法:从 SGD 到 AdamW 原理和代码解读

    ‍ 作者丨知乎 科技猛兽  极市平台 编辑 https://zhuanlan.zhihu.com/p/391947979 本文思想来自下面这篇大佬的文章: Juliuszh:一个框架看懂优化算法之异同 ...

  6. 深度学习优化算法,Adam优缺点分析

    优化算法 首先我们来回顾一下各类优化算法. 深度学习优化算法经历了 SGD -> SGDM -> NAG ->AdaGrad -> AdaDelta -> Adam -& ...

  7. Adam那么棒,为什么还对SGD念念不忘?一个框架看懂深度学习优化算法

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者|Juliuszh,https://zhuanlan.zhih ...

  8. Pytorch框架的深度学习优化算法集(优化中的挑战)

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net Py ...

  9. 深度学习优化算法实现(Momentum, Adam)

    目录 Momentum 初始化 更新参数 Adam 初始化 更新参数 除了常见的梯度下降法外,还有几种比较通用的优化算法:表现都优于梯度下降法.本文只记录完成吴恩达深度学习作业时遇到的Momentum ...

最新文章

  1. textisselectable长按再点击_微信朋友圈如何发布长视频?原来打开这个功能就可以,涨知识了...
  2. Linux 查看端口占用情况 并 结束进程
  3. 转行学Java,如何才能成为年薪50万的Java程序员呢?
  4. Linux权限的简单剖析
  5. 小脑袋智能推广软件360专版
  6. uFrame近况(2016年4月8日更新)
  7. html——windows.onload()与$(document).ready()区别
  8. django连接mysql自动同步生成数据表
  9. 联想小新触摸板驱动_联想lenovo笔记本触摸板驱动-联想触摸驱动 win7版下载16.2.5.0 官方版-西西软件下载...
  10. python对json的操作及实例解析
  11. 《原力计划-打卡挑战》总榜名单揭晓!!
  12. MNIST手写体数字识别数据集
  13. LVGL (8) 绘制流程
  14. 凭什么让你“转贴”?
  15. 下载xampp之后还用下载PHP吗,PHP 下载并安装XAMPP
  16. 百度云主机连接FTP
  17. Git最详细的基础教程
  18. K8S系列:Deployment更新、锁定、解锁、回滚版本
  19. Jetpack Compose 从入门到入门(三)
  20. 1500页技术人的黑皮书 免费下载!

热门文章

  1. 机器学习理论与实战(十六)概率图模型04
  2. 视频播放器—纹理-渲染-窗口
  3. 【FPGA教程案例100】深度学习1——基于CNN卷积神经网络的手写数字识别纯Verilog实现,使用mnist手写数字数据库
  4. 当程序员真的好累——IT界那些笑话
  5. EPEL(Extra Packages for Enterprise Linux)的介绍与安装
  6. 如何实现网站多语言版本
  7. 【 javascript】<input> 实现输入框只能输入数字(个人认为最好的)
  8. Hadoop与NoSQL正迅速融企业生产环境
  9. QML绘制圆角多边形(Canvas)
  10. c语言植入手机系统,一种手机课堂C语言编程系统的制作方法