torch.optim是一个实现了多种优化算法的包,大多数通用的方法都已支持,提供了丰富的接口调用,未来更多精炼的优化算法也将整合进来。
为了使用torch.optim,需先构造一个优化器对象Optimizer,用来保存当前的状态,并能够根据计算得到的梯度来更新参数。
要构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数(所有的参数必须是变量s)的列表。 然后,您可以指定程序优化特定的选项,例如学习速率,权重衰减等。

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)
self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

Optimizer还支持指定每个参数选项。 只需传递一个可迭代的dict来替换先前可迭代的Variable(变量)。dict中的每一项都可以定义为一个单独的参数组,参数组用一个params键来包含属于它的参数列表。其他键应该与优化器接受的关键字参数相匹配,才能用作此组的优化选项。

optim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)

如上,model.base.parameters()将使用1e-2的学习率,model.classifier.parameters()将使用1e-3的学习率。0.9的momentum作用于所有的parameters。
优化步骤:
所有的优化器Optimizer都实现了step()方法来对所有的参数进行更新,它有两种调用方法:

optimizer.step()

这是大多数优化器都支持的简化版本,使用如下的backward()方法来计算梯度的时候会调用它。

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()
optimizer.step(closure)

一些优化算法,如共轭梯度和LBFGS需要重新评估目标函数多次,所以你必须传递一个closure以重新计算模型。 closure必须清除梯度,计算并返回损失。

for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)

Adam算法:

adam算法来源:Adam: A Method for Stochastic Optimization

Adam(Adaptive Moment Estimation)本质上是带有动量项的RMSprop,它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。它的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。其公式如下:

其中,前两个公式分别是对梯度的一阶矩估计和二阶矩估计,可以看作是对期望E|gt|,E|gt^2|的估计;
公式3,4是对一阶二阶矩估计的校正,这样可以近似为对期望的无偏估计。可以看出,直接对梯度的矩估计对内存没有额外的要求,而且可以根据梯度进行动态调整。最后一项前面部分是对学习率n形成的一个动态约束,而且有明确的范围

class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
params(iterable):可用于迭代优化的参数或者定义参数组的dicts。
lr (float, optional) :学习率(默认: 1e-3)
betas (Tuple[float, float], optional):用于计算梯度的平均和平方的系数(默认: (0.9, 0.999))
eps (float, optional):为了提高数值稳定性而添加到分母的一个项(默认: 1e-8)
weight_decay (float, optional):权重衰减(如L2惩罚)(默认: 0)
step(closure=None)函数:执行单一的优化步骤
closure (callable, optional):用于重新评估模型并返回损失的一个闭包 

torch.optim.adam源码:

import math
from .optimizer import Optimizerclass Adam(Optimizer):def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0):defaults = dict(lr=lr, betas=betas, eps=eps,weight_decay=weight_decay)super(Adam, self).__init__(params, defaults)def step(self, closure=None):loss = Noneif closure is not None:loss = closure()for group in self.param_groups:for p in group['params']:if p.grad is None:continuegrad = p.grad.datastate = self.state[p]# State initializationif len(state) == 0:state['step'] = 0# Exponential moving average of gradient valuesstate['exp_avg'] = grad.new().resize_as_(grad).zero_()# Exponential moving average of squared gradient valuesstate['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']beta1, beta2 = group['betas']state['step'] += 1if group['weight_decay'] != 0:grad = grad.add(group['weight_decay'], p.data)# Decay the first and second moment running average coefficientexp_avg.mul_(beta1).add_(1 - beta1, grad)exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)denom = exp_avg_sq.sqrt().add_(group['eps'])bias_correction1 = 1 - beta1 ** state['step']bias_correction2 = 1 - beta2 ** state['step']step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1p.data.addcdiv_(-step_size, exp_avg, denom)return loss

Adam的特点有:
1、结合了Adagrad善于处理稀疏梯度和RMSprop善于处理非平稳目标的优点;
2、对内存需求较小;
3、为不同的参数计算不同的自适应学习率;
4、也适用于大多非凸优化-适用于大数据集和高维空间。

torch.optim优化算法理解之optim.Adam()相关推荐

  1. 优化算法理解以及举例

    按照搜索策略划分搜索算法 盲目搜索 按照预先设定好的搜索策略进行搜索,并且不根据搜索过程中获得的中间信息更改控制策略.又称无信息搜索,需要大量的时间和空间作为基础 深度优先搜索Depth First ...

  2. 优化算法笔记|飞蛾扑火优化算法理解及实现

    飞蛾扑火算法 一.飞蛾扑火算法背景知识 二.飞蛾扑火算法原理 三.算法流程总结 四.飞蛾扑火算法Python实现 一.飞蛾扑火算法背景知识 飞蛾扑火优化(Moth-flame optimization ...

  3. 深度学习最常用的学习算法:Adam优化算法

    上海站 | 高性能计算之GPU CUDA培训 4月13-15日 三天密集式学习  快速带你晋级 阅读全文 > 正文共6267个字,30张图,预计阅读时间16分钟. 听说你了解深度学习最常用的学习 ...

  4. 【转】听说你了解深度学习最常用的学习算法:Adam优化算法?

    深度学习常常需要大量的时间和机算资源进行训练,这也是困扰深度学习算法开发的重大原因.虽然我们可以采用分布式并行训练加速模型的学习,但所需的计算资源并没有丝毫减少.而唯有需要资源更少.令模型收敛更快的最 ...

  5. 深度学习最常用的算法:Adam优化算法

    深度学习常常需要大量的时间和机算资源进行训练,这也是困扰深度学习算法开发的重大原因.虽然我们可以采用分布式并行训练加速模型的学习,但所需的计算资源并没有丝毫减少.而唯有需要资源更少.令模型收敛更快的最 ...

  6. 优化算法选择:SGD、SGDM、NAG、Adam、AdaGrad、RMSProp、Nadam

    目录 优化算法通用框架 SGD 系列:固定学习率的优化算法 SGD SGD (with Momentum) = SGD-M SGD(with Nesterov Acceleration)= NAG 自 ...

  7. 深度学习优化算法实现(Momentum, Adam)

    目录 Momentum 初始化 更新参数 Adam 初始化 更新参数 除了常见的梯度下降法外,还有几种比较通用的优化算法:表现都优于梯度下降法.本文只记录完成吴恩达深度学习作业时遇到的Momentum ...

  8. Adam优化算法(Adam optimization algorithm)

    Adam优化算法(Adam optimization algorithm) Adam优化算法基本上就是将Momentum和RMSprop结合在一起. 初始化 2.在第t次迭代中,用mini-batch ...

  9. 机器学习:优化算法Optimizer比较和总结(SGD/BGD/MBGD/Momentum/Adadelta/Adam/RMSprop)

    文章目录 梯度下降法(Gradient Descent) 批量梯度下降法BGD 随机梯度下降法SGD 小批量梯度下降法 动量优化法 Momentum NAG(Nesterov accelerated ...

  10. 梯度下降优化算法综述,梯度下降法 神经网络

    梯度下降法是什么? 梯度下降法(英语:Gradientdescent)是一个一阶最优化算法,通常也称为最陡下降法. 要使用梯度下降法找到一个函数的局部极小值,必须向函数上当前点对应梯度(或者是近似梯度 ...

最新文章

  1. 程序员颈椎病康复秘籍
  2. ORACLE12C_ADG删除pdb
  3. the art of java 源代码_请不要再说Java中final方法比非final性能更好了
  4. mysql系列之5--完全备份和增量备份
  5. windows docker 空出C盘 迁移到其他盘
  6. Keil使用PC-Lint
  7. 用java写一个日历_使用JAVA写一个简单的日历
  8. 《Linux内核完全注释》《完全剖析》 » 阅读本书所需的基础知识 -- 再次强调。
  9. 【LeetCode笔记】88. 合并两个有序数组(Java、双指针)
  10. GitHub 热榜第一!这个 Python 项目超 8.4k 标星,网友:太实用!
  11. ip服务器ip地址信息配置,服务器ip地址配置
  12. WPF ListBox(ListView) 自定义 Button 项,获取 ListBox(ListView)的SelectedValue
  13. 使用PostgREST的RestAPI操作之 类型转JSON对象 | 嵌入视图
  14. sed替换字符时 ' /等符号的处理
  15. pycharm怎么改成中文版?
  16. C++超详细五子棋游戏(AI实现人机对弈+双人对弈+EasyX图形化界面+详细介绍)
  17. 【Pytorch】AlexNet图像分类实战
  18. 主成分分析法怎么提取图片中的字_论文中做出CNS高逼格的主成分分析图教程
  19. 20175208 张家华 实验四《Android开发基础》实验报告
  20. 2022年终结——人生中最美好的一站

热门文章

  1. scratch图形化编程操作硬件
  2. 最小二乘法求直线的理解
  3. 苹果8a1660是什么版本_苹果a1780是什么版本
  4. 软件测试技术之项目上线流程
  5. HTML页面布局适配不同分辨率
  6. matlab 脉冲频域压缩,大作业-雷达线性调频脉冲压缩的原理及其MATLAB仿真.doc
  7. MSB3644 找不到 .NETFramework,Version=v4.7 的引用程序集。要解决此问题,请为此框架版本安装......
  8. 【学生个人网页设计作品】使用HMTL制作一个超好看的保护海豚动物网页
  9. 【PPT】极简PPT设计方法
  10. 二年级课程表(4月2日-4月8日)