一. 基本原理

1.1 引入

Momentum算法在原有的梯度下降法中引入了动量,从物理学上看,引入动量比起普通梯度下降法主要能够增加两个优点。首先,引入动量能够使得物体在下落过程中,当遇到一个局部最优的时候有可能在原有动量的基础上冲出这个局部最优点;并且,普通的梯度下降法方法完全由梯度决定,这就可能导致在寻找最优解的过程中出现严重震荡而速度变慢,但是在有动量的条件下,物体运动方向由动量和梯度共同决定,可以使得物体的震荡减弱,更快地运动到最优解。

1.2 指数加权移动平均

指数加权移动平均是一种常用的序列数据处理方式,用于描述数值的变化趋势,本质上是一种近似求平均的方法。计算公式如下:
vt=βvt−1+(1−β)θtv_t=\beta v_{t-1}+(1-\beta)\theta_tvt​=βvt−1​+(1−β)θt​
vtv_tvt​表示第t个数的估计值,β\betaβ为一个可调参数,能表示vt−1v_{t-1}vt−1​的权重,θt\theta_tθt​表示第t个数的实际值。
比如,假设有以下余弦函数,并随机添加了一些噪声,

于是,使用指数加权移动平均对这些离散点进行计算,可以对这些离散点进行去噪(β=0.9\beta=0.9β=0.9),于是得到图像:

使用指数加权移动平均得到的估计数据能对数据进行去噪,并且得到的数据接近原始功能。
上面进行指数加权移动平均使用的超参数β=0.9\beta=0.9β=0.9,但是是如何确定的该超参数呢?由原始的公式,设v0=0v_0=0v0​=0,得到:
{v1=βv0+(1−β)θ1v2=βv1+(1−β)θ2v3=βv2+(1−β)θ3\left\{ \begin{array}{rcl} v_1=\beta v_0+(1-\beta)\theta_1 \\ v_2=\beta v_1+(1-\beta)\theta_2 \\ v_3=\beta v_2+(1-\beta)\theta_3 \end{array} \right. ⎩⎨⎧​v1​=βv0​+(1−β)θ1​v2​=βv1​+(1−β)θ2​v3​=βv2​+(1−β)θ3​​
于是
vn=βvn−1+(1−β)θn=(1−β)θn+β(1−β)θn−1+β2(1−β)θn−2+...+βn−1(1−β)θ1v_n=\beta v_{n-1}+(1-\beta)\theta_n=(1-\beta)\theta_n+\beta(1-\beta)\theta_{n-1}+\beta^2(1-\beta)\theta_{n-2}+...+\beta^{n-1}(1-\beta)\theta_1vn​=βvn−1​+(1−β)θn​=(1−β)θn​+β(1−β)θn−1​+β2(1−β)θn−2​+...+βn−1(1−β)θ1​
vnv_nvn​为对第n个数的估计值,也是加权平均,由于加权系数随着与第n个数的距离以指数的形式递减,在x方向上越靠近第n个数,权重越大,越远离这个数则权重越小。当β=0.9\beta=0.9β=0.9时,0.910≈0.35≈1/e0.9^{10}\approx0.35\approx1/e0.910≈0.35≈1/e,如果以1/e为分界点,权值衰减到这个值以下时就忽略不计,则对某个点的指数加权平均求的就是这个数以前最近N=1/(1−β)=10N=1/(1-\beta)=10N=1/(1−β)=10个数的加权平均值。
当使用不同的β\betaβ值时,对上图的数据点进行估计,结果如下:

可以见到,当β\betaβ值太大时并不能很好反应总体数据的趋势情况,太小时出现过拟合情况,相邻点间的波动太大,当β=0.9\beta=0.9β=0.9时,既能反应出原数据的总体趋势,波动又不至于太大。所有在一般的Momentum算法中,β\betaβ值一般取0.9。

1.3 Momentum

Momentum就是在普通的梯度下降法中引入指数加权移动平均,即定义一个动量,它是梯度的指数加权移动平均值,然后使用该值代替原来的梯度方向来更新。定义的动量为:
vt=βvt−1+(1−β)∇L(w)v_t=\beta v_{t-1}+(1-\beta)\nabla L(w)vt​=βvt−1​+(1−β)∇L(w)
该式中,vtv_tvt​表示当前动量,β\betaβ就是前文提到的超参数,∇L(w)\nabla L(w)∇L(w)为目标函数的当前梯度,使用该动量带入梯度下降公式:
w=w−αvtw=w-\alpha v_tw=w−αvt​
该式和普通梯度下降法迭代公式基本一致,只是方向vtv_tvt​是定义的动量,α\alphaα为步长,一般也是一个定义的超参数。
在机器学习中,普通的随机梯度下降法中,由于无法计算损失函数的确切导数,嘈杂的数据会使下降过程并不朝着最佳方向前进,使用加权平均能对嘈杂数据进行一定的屏蔽,使前进方向更接近实际梯度。
另外,随机梯度下降法在局部极小值极有可能被困住,但Momentum由于下降方向由最近的一些数共同决定,能在一定程度反应总体的最佳下降方向,所以在这方面被困在局部最优解的可能会减小。随机梯度下降法和Momentum对比:

二. 程序实现

我以一个简单函数y=x1+2x2y=x_1+2x_2y=x1​+2x2​为例,给定一些训练数据,使用Momentum算法进行训练,确定x1x_1x1​和x2x_2x2​的参数,首先给定8组训练数据并确定相关学习参数:

    x = np.array([(1, 1), (1, 2), (2, 2), (3, 1), (1, 3), (2, 4), (2, 3), (3, 3)])y = np.array([3, 5, 6, 5, 7, 10, 8, 9])# 初始化m, dim = x.shapetheta = np.zeros(dim)  # 参数alpha = 0.1  # 学习率beta = 0.9  # betathreshold = 0.0001  # 停止迭代的错误阈值iterations = 1500  # 迭代次数error = 0  # 初始误差为0gradient = 0  # 初始梯度为0

随后就是进行学习的过程,学习结束条件为达到最大迭代次数iterations或者总体误差小于阈值threshold。学习过程就是更新参数theta的过程。更新方法就是使用Momentum算法,首先求得动量,求动量的梯度是使用的近似计算,然后进行使用迭代公式更新参数即可,两个公式和前文相同:
vt=βvt−1+(1−β)∇L(w)v_t=\beta v_{t-1}+(1-\beta)\nabla L(w)vt​=βvt−1​+(1−β)∇L(w) w=w−αvtw=w-\alpha v_tw=w−αvt​
代码实现:

    gradient = beta * gradient + (1 - beta) * (x[j] * (np.dot(x[j], theta) - y[j]))theta -= alpha * gradient

全部代码:

# 带冲量的随机梯度下降SGD  Momentum
# 以 y=x1+2*x2为例import numpy as np# 多元数据
def momentum():# 训练集,每个样本有三个分量x = np.array([(1, 1), (1, 2), (2, 2), (3, 1), (1, 3), (2, 4), (2, 3), (3, 3)])y = np.array([3, 5, 6, 5, 7, 10, 8, 9])# 初始化m, dim = x.shapetheta = np.zeros(dim)  # 参数alpha = 0.1  # 学习率beta = 0.9  # betathreshold = 0.0001  # 停止迭代的错误阈值iterations = 1500  # 迭代次数error = 0  # 初始误差为0gradient = 0  # 初始梯度为0# 迭代开始for i in range(iterations):j = i % merror = 1 / (2 * m) * np.dot((np.dot(x, theta) - y), (np.dot(x, theta) - y))# 迭代停止if abs(error) <= threshold:breakgradient = beta * gradient + (1 - beta) * (x[j] * (np.dot(x[j], theta) - y[j]))theta -= alpha * gradientprint('迭代次数:%d' % (i + 1), 'theta:', theta, 'error:%f' % error)if __name__ == '__main__':momentum()

运行结果:

迭代次数:98 theta: [1.00250137 1.99699225] error:0.000009

可以见到,仅迭代了98次就达到了近似正确结果。

三. 总结

Momentum算法主要的公式只有如下两个,已在前文多次提到:
vt=βvt−1+(1−β)∇L(w)v_t=\beta v_{t-1}+(1-\beta)\nabla L(w)vt​=βvt−1​+(1−β)∇L(w) w=w−αvtw=w-\alpha v_tw=w−αvt​
Momentum算法在随机梯度下降法的基础上引入动量,能够加速向最优解的迭代,同时能够在一定程度上避免随机梯度下降法的被困在局部最优解。动量的计算是使用指数加权移动平均,使用该动量代替了随机梯度下降法的梯度方向。总体的迭代公式和随机梯度下降法基本相同。

四. 参考文献

[1] Vitaly Bushaev. Stochastic Gradient Descent with momentum. https://towardsdatascience.com/stochastic-gradient-descent-with-momentum-a84097641a5d. 2020.5.29

[2] 不会停的蜗牛. 为什么在优化算法中使用指数加权平均. https://www.jianshu.com/p/41218cb5e099?utm_source=oschina-app. 2020.5. 28

梯度下降优化算法Momentum相关推荐

  1. 梯度下降优化算法总结

    写在前面 梯度下降(Gradient descent)算法可以说是迄今最流行的机器学习领域的优化算法.并且,基本上每一个深度学习库都包括了梯度下降算法的实现,比如Lasagne.cafe.keras等 ...

  2. 深度学习-各类梯度下降优化算法回顾

    本文是根据 链接 进行的翻译,回顾了深度学习的各种梯度下降优化算法.*已获得原作者的翻译许可. 文章目录 一.概述 二.引言 三.Gradient Descent Variants(梯度下降法变体) ...

  3. 深度学习中的梯度下降优化算法综述

    1 简介 梯度下降算法是最常用的神经网络优化算法.常见的深度学习库也都包含了多种算法进行梯度下降的优化.但是,一般情况下,大家都是把梯度下降系列算法当作是一个用于进行优化的黑盒子,不了解它们的优势和劣 ...

  4. 基于机器学习梯度下降优化算法来寻找最佳的线性回归模型

    https://www.toutiao.com/a6638782437587419652/ 幻风的AI之路 2018-12-25 18:12:27 线性回归模型 线性回归模型是一个非常简单的算法模型, ...

  5. 梯度下降优化算法综述(转载)

    原文地址:http://www.cnblogs.com/ranjiewen/p/5938944.html 对梯度下降进行详细解释,以及总结不同的梯度下降优化算法的优劣,可以作为参考. 上两张图,简直不 ...

  6. 【深度学习】——梯度下降优化算法(批量梯度下降、随机梯度下降、小批量梯度下降、Momentum、Adam)

    目录 梯度 梯度下降 常用的梯度下降算法(BGD,SGD,MBGD) 梯度下降的详细算法 算法过程 批量梯度下降法(Batch Gradient Descent) 随机梯度下降法(Stochastic ...

  7. 梯度下降优化算法概述

    本文原文是 An overview of gradient descent optimization algorithms,同时作者也在 arXiv 上发了一篇同样内容的 论文. 本文结合了两者来翻译 ...

  8. 梯度下降优化算法综述与PyTorch实现源码剖析

    现代的机器学习系统均利用大量的数据,利用梯度下降算法或者相关的变体进行训练.传统上,最早出现的优化算法是SGD,之后又陆续出现了AdaGrad.RMSprop.ADAM等变体,那么这些算法之间又有哪些 ...

  9. 各种 Optimizer 梯度下降优化算法回顾和总结

    1. 写在前面 当前使用的许多优化算法,是对梯度下降法的衍生和优化.在微积分中,对多元函数的参数求  偏导数,把求得的各个参数的导数以向量的形式写出来就是梯度.梯度就是函数变化最快的地方.梯度下降是迭 ...

  10. 各种 Optimizer 梯度下降优化算法总结

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:DengBoCong,编辑:极市平台 来源:https://zhu ...

最新文章

  1. centos7 开启 关闭 NetworkManager
  2. python怎么做软件程序_如何打包和发布Python程序
  3. AspNetCore 启动地址配置详解
  4. 最全的BAT大厂面试题整理,系列篇
  5. 英语笔记:词组句子:0806
  6. es6 async函数与其他异步处理方法的比较
  7. Android O(29 )---MTK 平台代码同步
  8. 获取响应里面的cookie的方法
  9. 统计模型混响信号预报matlab,基于MATLAB的混响效果设计课程设计
  10. 服务器版dll修复工具,dll修复工具
  11. 安卓初级开发教程 ppt+视频+案例源码
  12. 原理图学习(点读笔调试)
  13. OpenCV MPR.DLL WNetRestoreConnectionA相关问题
  14. MacOS 利用keka.app压缩工具制作dmg文件
  15. 饮料如何畅销市场?看农夫山泉如何玩转营销
  16. 希望各位博友解答一下
  17. JAVA-Word转PDF各种版本实现方式
  18. 显卡超频很简单 RivaTuner使用教程
  19. oracle 19c创建sample schema-HR,OE,SH等等
  20. android mvvm官方demo,Android MVVM实战Demo完全解析

热门文章

  1. [Transformer] PVT系列:PVT CPVT Twins
  2. 2021-07-01 <1000+常用Python库>
  3. 在MacOS系统下DMG文件显示压缩包无法双击安装解决办法
  4. 怎样合并磁盘分区?看这里~
  5. 养成备份的习惯的重要性
  6. C语言的奇技淫巧之三
  7. 思维拓展:用java实现巧妙过桥问题
  8. handsome for Typecho主题重建备忘
  9. 图Android 片缓存文件名,手机图片去了哪?教你理清照片存放路径
  10. 吃饭 睡觉 打豆豆游戏