来源:AINLPer微信公众号
编辑: ShuYini
校稿: ShuYini
时间: 2019-8-16

引言

    很多人在使用pytorch的时候都会遇到优化器选择的问题,今天就给大家介绍对比一下pytorch中常用的四种优化器。SGD、Momentum、RMSProp、Adam。

随机梯度下降法(SGD)

算法介绍

    对比批量梯度下降法,假设从一批训练样本

中随机选取一个样本
。模型参数为
,代价函数为
,梯度为
,学习率为
,则使用随机梯度下降法更新参数表达式为:

   其中,

,
表示随机选择的一个梯度方向,
表示t时刻的模型参数。
,这里虽然引入了随机性和噪声,但期望仍然等于正确的梯度下降。

    基本策略可以理解为随机梯度下降像是一个盲人下山,不用每走一步计算一次梯度,但是他总能下到山底,只不过过程会显得扭扭曲曲。

算法评价

优点:

    虽然SGD需要走很多步的样子,但是对梯度的要求很低(计算梯度快)。而对于引入噪声,大量的理论和实践工作证明,只要噪声不是特别大,SGD都能很好地收敛。应用大型数据集时,训练速度很快。比如每次从百万数据样本中,取几百个数据点,算一个SGD梯度,更新一下模型参数。相比于标准梯度下降法的遍历全部样本,每输入一个样本更新一次参数,要快得多。

缺点:

    SGD在随机选择梯度的同时会引入噪声,使得权值更新的方向不一定正确。此外,SGD也没能单独克服局部最优解的问题。

标准动量优化算法(Momentum)

算法介绍

    使用动量(Momentum)的随机梯度下降法(SGD),主要思想是引入一个积攒历史梯度信息动量来加速SGD。从训练集中取一个大小为n的小批量

样本,对应的真实值分别为
,则Momentum优化表达式为:

其中,

表示t时刻积攒的加速度。α表示动力的大小,一般取值为0.9(表示最大速度10倍于SGD)。
含义见SGD算法。
表示t时刻模型参数。

算法的理解

    动量主要解决SGD的两个问题:一是随机梯度的方法(引入的噪声);二是Hessian矩阵病态问题(可以理解为SGD在收敛过程中和正确梯度相比来回摆动比较大的问题)。

    简单理解:由于当前权值的改变会受到上一次权值改变的影响,类似于小球向下滚动的时候带上了惯性。这样可以加快小球向下滚动的速度。

RMSProp算法

算法介绍

    与动量梯度下降一样,都是消除梯度下降过程中的摆动来加速梯度下降的方法。 梯度更新公式:

    更新权重的时候,使用除根号的方法,可以使较大的梯度大幅度变小,而较小的梯度小幅度变小,这样就可以使较大梯度方向上的波动小下来,那么整个梯度下降的过程中摆动就会比较小,就能设置较大的learning-rate,使得学习步子变大,达到加快学习的目的。

    在实际的应用中,权重W或者b往往是很多维度权重集合,就是多维的,在进行除根号操作中,会将其中大的维度的梯度大幅降低,不是说权重W变化趋势一样。

    RMSProp算法在经验上已经被证明是一种有效且实用的深度神经网络优化算法。目前它是深度学习从业者经常采用的优化方法之一。

Adam算法

算法介绍

    Adam中动量直接并入了梯度一阶矩(指数加权)的估计。其次,相比于缺少修正因子导致二阶矩估计可能在训练初期具有很高偏置的RMSProp,Adam包括偏置修正,修正从原点初始化的一阶矩(动量项)和(非中心的)二阶矩估计。Adam算法策略可以表示为:

    其中,

分别为一阶动量项和二阶动量项。
为动力值大小通常分别取0.9和0.999;
,
分别为各自的修正值。
表示t时刻即第t迭代模型的参数,
表示t次迭代代价函数关于W的梯度大小;ϵ是一个取值很小的数(一般为1e-8)为了避免分母为0。

算法分析

    该方法和RMSProp很像,除了使用的是平滑版的梯度m,而不是原始梯度dx。推荐参数值eps=1e-8, beta1=0.9, beta2=0.999。在实际操作中,推荐Adam作为默认算法,一般比RMSProp要好一点。

算法比较

    为了验证四种算法的性能,在pytorch中的对同一个网络进行优化,比较四种算法损失函数随着时间的变化情况。代码如下:

opt_SGD=torch.optim.SGD(net_SGD.parameters(),lr=LR)
opt_Momentum=torch.optim.SGD(net_Momentum.parameters(),lr=LR,momentum=0.8)
opt_RMSprop=torch.optim.RMSprop(net_RMSprop.parameters(),lr=LR,alpha=0.9)
opt_Adam=torch.optim.Adam(net_Adam.parameters(),lr=LR,betas=(0.9,0.99))

    SGD 是最普通的优化器, 也可以说没有加速效果, 而 Momentum 是 SGD 的改良版, 它加入了动量原则. 后面的 RMSprop 又是 Momentum 的升级版. 而 Adam 又是 RMSprop 的升级版. 不过从这个结果中我们看到, Adam 的效果似乎比 RMSprop 要差一点. 所以说并不是越先进的优化器, 结果越佳。

参考:

https://blog.csdn.net/weixin_40170902/article/details/80092628

Attention

更多自然语言处理相关知识,还请关注AINLPer公众号,极品干货即刻送达。

pytorch梯度下降函数_Pytorch中常用的四种优化器SGD、Momentum、RMSProp、Adam相关推荐

  1. 【温故知新】——原生js中常用的四种循环方式

    一.引言 本文主要是利用一个例子,讲一下原生js中常用的四种循环方式的使用与区别: 实现效果: 在网页中弹出框输入0   网页输出"欢迎下次光临" 在网页中弹出框输入1   网页输 ...

  2. 妈耶,讲得好详细,十分钟彻底看懂深度学习常用优化器SGD、RMSProp、Adam详解分析

    深度学习常用优化器学习总结 常用优化器 SGD RMS Prop Adam 常用优化器 SGD 基本思想:通过当前梯度和历史梯度共同调节梯度的方向和大小 我们首先根据pytorch官方文档上的这个流程 ...

  3. Java中常用的四种线程池

    在Java中使用线程池,可以用ThreadPoolExecutor的构造函数直接创建出线程池实例,在Executors类中,为我们提供了常用线程池的创建方法. ​ 接下来我们就来了解常用的四种: ne ...

  4. iOS中常用的四种数据持久化方法

    iOS中的数据持久化方式,基本上有以下四种:属性列表.对象归档.SQLite3和Core Data 1.属性列表 涉及到的主要类:NSUserDefaults,一般 [NSUserDefaults s ...

  5. pytorch梯度下降函数_Pytorch学习笔记6:激活函数/单层感知机/梯度下降求最小值实例...

    #添加到学习笔记2末尾,直接运行.代码意义可以看注释. #需要import以下库 import torch import numpy as np from matplotlib import pypl ...

  6. js学习总结----js中常用的四种输出方式

    1.alert('内容') 在浏览器中弹出框显示我们的内容    不输入内容弹出undefined  (注意alert弹出的都是字符串) 2.document.write('内容')  在页面中输出显 ...

  7. CSS中常用的4种长度单位

    在现实生活中,我们知道很多的长度单位,例如米,厘米,寸,尺等等,在css的世界中,也有很多的长度单位 以下是css中常用的四种常用的长度单位 1,像素 px - 像素是我们在网页中使用的最多的一个单位 ...

  8. PyTorch: torch.optim 的6种优化器及优化算法介绍

    import torch import torch.nn.functional as F import torch.utils.data as Data import matplotlib.pyplo ...

  9. 51单片机下载完程序后不亮_程序如何下载到单片机中?单片机常用的四种烧写程序方式介绍...

    单片机是一种可编程控制器,搭好硬件电路后,可以利用程序实现很多非常复杂的逻辑功能,与纯硬件电路相比,简化了硬件外围的设计.方便了逻辑的设计.丰富了逻辑的输出.不同厂家的单片机需要不同编程IDE来实现编 ...

最新文章

  1. melogin宽带连接服务器无响应,输入melogin.cn进不了路由器设置界面怎么办
  2. 【每日一算法】填充同一层的兄弟节点
  3. :src 三目运算
  4. ST17H26只pwm波形特征
  5. 关于SAP Cloud Platform ABAP环境费用的问题
  6. 程序异常异常代码: 0xc0000005_Java基础:看完这篇你还怕碰到异常吗?
  7. 我为什么对TypeScript由黑转粉?
  8. 具有瞬态属性的视图对象的钝化和激活
  9. 面向对象软件开发代码结构(2)
  10. JS 中 Map 与 JSON 转换
  11. 朴素贝叶斯分类器的python实现
  12. bzoj 1030: [JSOI2007]文本生成器(AC自动机+DP)
  13. Python 3.x对.CSV数据按任意行、列读取
  14. 朗途职业规划之一 职业发展报告 (北森测评)
  15. linux维护盘ISO,MYISO XPPE+Win10PE+porteus系统维护盘ISO量产全能版
  16. 公开课丨重中之重!Web安全漏洞与防御
  17. 厦门92坐标参数讨论
  18. 私有云的优缺点_私有云的优缺点是什么?与公有云的区别
  19. Office EXCEL 创建图片超链接打不开怎么办 Excel打开图片提示发生了意外错误怎么办...
  20. 第七章第八章思维导图

热门文章

  1. ios 隐藏app的插件_等了5年终于复活,iPhone上最干净好用的微博App
  2. vim配置之spacevim
  3. 计算机寄存器端口,CPU和外设之间的数据传送方式有哪几种
  4. java 可逆的加密算法_java实现AES可逆加密算法
  5. Python 装饰器详解(下)
  6. 【数据结构1.3笔记】研究内容
  7. python中用函数设计栈的括号匹配问题_数据结构和算法(Python版):利用栈(Stack)实现括号的匹配问题...
  8. python 执行shell_python执行shell命令的方法
  9. python renamer模块_Python - 批量文件重命名
  10. 日常问题———安装新版zookeeper 出现Starting zookeeper ... FAILED TO START