目录

说明

Adam原理

梯度滑动平均

偏差纠正

Adam计算过程

pytorch Adam参数

params

lr

betas

eps

weight_decay

amsgrad


说明

模型每次反向传导都会给各个可学习参数p计算出一个偏导数,用于更新对应的参数p。通常偏导数不会直接作用到对应的可学习参数p上,而是通过优化器做一下处理,得到一个新的,处理过程用函数F表示(不同的优化器对应的F的内容不同),即,然后和学习率lr一起用于更新可学习参数p,即

Adam是在RMSProp和AdaGrad的基础上改进的。先掌握RMSProp的原理,就很容易明白Adam了。本文是在RMSProp这篇博客的基础上写的。

Adam原理

在RMSProp的基础上,做两个改进:梯度滑动平均偏差纠正

梯度滑动平均

在RMSProp中,梯度的平方是通过平滑常数平滑得到的,即(根据论文,梯度平方的滑动均值用v表示;根据pytorch源码,Adam中平滑常数用的是β,RMSProp中用的是α),但是并没有对梯度本身做平滑处理。

在Adam中,对梯度也做了平滑,平滑后的滑动均值用m表示,即,在Adam中有两个β。

偏差纠正

上述m的滑动均值的计算,当时,,由于的初始是0,且β接近1,因此t较小时,m的值是偏向于0的,v也是一样。这里通过除以来进行偏差纠正,即

Adam计算过程

为方便理解,以下伪代码和论文略有差异,其中蓝色部分是比RMSProp多出来的。

  1. 初始:学习率 lr
  2. 初始:平滑常数(或者叫做衰减速率) ,分别用于平滑m和v
  3. 初始:可学习参数 
  4. 初始:
  5. while 没有停止训练 do
  6.         训练次数更新:
  7.         计算梯度:(所有的可学习参数都有自己的梯度,因此 表示的是全部梯度的集合)
  8.         累计梯度:(每个导数对应一个m,因此m也是个集合)
  9.         累计梯度的平方:(每个导数对应一个v,因此v也是个集合)
  10.         偏差纠正m:
  11.         偏差纠正v:
  12.         更新参数:
  13. end while

pytorch Adam参数

torch.optim.Adam(params,lr=0.001,betas=(0.9, 0.999),eps=1e-08,weight_decay=0,amsgrad=False)

params

模型里需要被更新的可学习参数

lr

学习率

betas

平滑常数

eps

,加在分母上防止除0

weight_decay

weight_decay的作用是用当前可学习参数p的值修改偏导数,即:,这里待更新的可学习参数p的偏导数就是

weight_decay的作用是L2正则化,和Adam并无直接关系。

amsgrad

如果amsgrad为True,则在上述伪代码中的基础上,保留历史最大的,记为,每次计算都是用最大的,否则是用当前

amsgrad和Adam并无直接关系。

pytorch优化器详解:Adam相关推荐

  1. pytorch优化器详解:SGD

    目录 说明 SGD参数 params lr momentum dampening weight_decay nesterov 举例(nesterov为False) 第1轮迭代 第2轮迭代 说明 模型每 ...

  2. pytorch优化器详解:RMSProp

    说明 模型每次反向传导都会给各个可学习参数p计算出一个偏导数,用于更新对应的参数p.通常偏导数不会直接作用到对应的可学习参数p上,而是通过优化器做一下处理,得到一个新的值,处理过程用函数F表示(不同的 ...

  3. 深度学习各类优化器详解(动量、NAG、adam、Adagrad、adadelta、RMSprop、adaMax、Nadam、AMSGrad)

    深度学习梯度更新各类优化器详细介绍 文章目录 <center>深度学习梯度更新各类优化器详细介绍 一.前言: 二.梯度下降变形形式 1.批量归一化(BGD) 2.随机梯度下降(SGD) 3 ...

  4. 【深度学习】优化器详解

    优化器 深度学习模型通过引入损失函数,用来计算目标预测的错误程度.根据损失函数计算得到的误差结果,需要对模型参数(即权重和偏差)进行很小的更改,以期减少预测错误.但问题是如何知道何时应更改参数,如果要 ...

  5. Keras深度学习实战(3)——神经网络性能优化技术详解

    Keras深度学习实战(3)--神经网络性能优化技术详解 0. 前言 1. 缩放输入数据集 1.1 数据集缩放的合理性解释 1.2 使用缩放后的数据集训练模型 2. 输入值分布对模型性能的影响 3. ...

  6. Pytorch优化器全总结(三)牛顿法、BFGS、L-BFGS 含代码

    目录 写在前面 一.牛顿法 1.看图理解牛顿法 2.公式推导-三角函数 3.公式推导-二阶泰勒展开 二.BFGS公式推导 三.L-BFGS 四.算法迭代过程 五.代码实现 1.torch.optim. ...

  7. 【GAN优化】详解GAN中的一致优化问题

    GAN的训练是一个很难解决的问题,上期其实只介绍了一些基本的动力学概念以及与GAN的结合,并没有进行过多的深入.动力学是一门比较成熟的学科,有很多非常有用的结论,我们将尝试将其用在GAN上,来得到一些 ...

  8. 【GAN优化】详解SNGAN(频谱归一化GAN)

    今天将和大家一起学习具有很高知名度的SNGAN.之前提出的WGAN虽然性能优越,但是留下一个难以解决的1-Lipschitz问题,SNGAN便是解决该问题的一个优秀方案.我们将先花大量精力介绍矩阵的最 ...

  9. Pytorch:优化器

    4.2 优化器 PyTorch将深度学习中常用的优化方法全部封装在torch.optim中,其设计十分灵活,能够很方便的扩展成自定义的优化方法. 所有的优化方法都是继承基类optim.Optimize ...

最新文章

  1. 在CentOS 6.9 x86_64上开启nginx 1.12.2的stub_status模块(ngx_http_stub_status_module)监控
  2. “中文版GPT-3”来了!用64张V100训练了3周
  3. 《.NET与设计模式》学习(一)
  4. 《Python编程从入门到实践》学习笔记3:列表
  5. 眼电、脑电视频课程汇总
  6. Android之内存溢出(Out Of Memory)的总结
  7. 函数的定义,语法,二维数组,几个练习题
  8. finereport字段显示设置_QA | 表单如何设置字段显示逻辑?
  9. vs2010+open244的永久性配置
  10. 【ubuntu操作系统入门】系统安装
  11. dram和nand哪个难生产_终于有人说清楚了什么是DRAM、什么是NAND Flash
  12. Base32 应用与原理解析
  13. 生意经:网店营销要搭强者的便车
  14. 干货|以产品要素设计解读线上小微信贷
  15. 计算机辅助翻译优缺点,计算机辅助翻译优缺点
  16. Python 打地鼠小游戏
  17. 【Python】pandas的describe参数详解
  18. 今日金融词汇--- 高利润模式
  19. 将Discuz!设置到新版应用中心,无需升级Discuz!版本的方法(临时方案)
  20. 51nod 1301 集合异或和

热门文章

  1. php 打开ppt,怎么播放ppt
  2. jupyter notebook报错:ModuleNotFoundError: No module named ‘cufflinks‘
  3. git pull遇到报错:! [rejected]xxx-> xxx (non-fast-forward)
  4. Android HorizontalScrollView左右滑动
  5. R构建逐步回归模型(Stepwise Regression)
  6. google翻译不能用后chrome浏览器如何翻译网页
  7. 分子对接(docking):蛋白质-蛋白质分子对接
  8. python中的self
  9. Matter理论介绍-通用-1-03:桥接器-数据结构
  10. 线性内插interp1函数用法