torch.optim 是一个实现各种优化算法的包。 大部分常用的方法都已经支持,接口也足够通用,以后也可以轻松集成更复杂的方法。

How to use an optimizer

要使用 torch.optim,您必须构造一个优化器对象,该对象将保存当前状态并根据计算出的梯度更新参数。

构建它

要构造一个优化器,你必须给它一个包含要优化的参数(都应该是 Variable)的迭代。 然后,您可以指定特定于优化器的选项,例如学习率、权重衰减等。

如果您需要通过 .cuda() 将模型移动到 GPU,请在为其构建优化器之前执行此操作。 .cuda() 之后的模型参数将是与调用之前不同的对象。

通常,在构建和使用优化器时,您应该确保优化的参数位于一致的位置。

例子:

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

每个参数选项

优化器还支持指定每个参数的选项。 为此,不要传递 Variable 的迭代,而是传递 dict 的迭代。 它们中的每一个都将定义一个单独的参数组,并且应该包含一个 params 键,其中包含一个属于它的参数列表。 其他键应与优化器接受的关键字参数匹配,并将用作该组的优化选项。

您仍然可以将选项作为关键字参数传递。 它们将在没有覆盖它们的组中用作默认值。 当您只想改变单个选项,同时在参数组之间保持所有其他选项一致时,这很有用。

例如,当您想要指定每层学习率时,这非常有用:

optim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)

这意味着model.base的参数将使用默认的1e-2学习率,model.classifier的参数将使用1e-3的学习率,所有参数将使用0.9的动量。

采取优化步骤

所有优化器都实现了一个 step() 方法,该方法更新参数。 它可以通过两种方式使用:

optimizer.step()
这是大多数优化器支持的简化版本。 一旦使用例如backward()计算梯度,就可以调用该函数。

例子:

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

optimizer.step(closure)

一些优化算法(例如共轭梯度和 LBFGS)需要多次重新计算函数,因此您必须传入一个闭包,以便它们重新计算您的模型。 闭包应该清除梯度,计算损失并返回它。

例子:

for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)

Base class

CLASS torch.optim.Optimizer(params, defaults)

所有优化器的基类。

warning:

需要将参数指定为具有在运行之间一致的确定性排序的集合。 不满足这些属性的对象的例子是字典值的集合和迭代器。

参数:

1、params (iterable) – 一个 Torch.Tensor 或 dict 的迭代。 指定应该优化哪些张量。

2、defaults – (dict):一个包含优化选项默认值的字典(当参数组没有指定它们时使用)。

Optimizer.add_param_group

向优化器的 param_groups 添加一个参数组。

Optimizer.load_state_dict

加载优化器状态。

Optimizer.state_dict

以字典的形式返回优化器的状态。

Optimizer.step

执行单个优化步骤(参数更新)。

Optimizer.zero_grad

将所有优化的 torch.Tensor 的梯度设置为零。

Algorithms

Adadelta

实现 Adadelta 算法。

Adagrad

实现 Adagrad 算法。

Adam

实现Adam算法。

AdamW

实现 AdamW 算法。

SparseAdam

实现适用于稀疏张量的 Adam 算法的惰性版本。

Adamax

实现 Adamax 算法(基于无穷范数的 Adam 变体)。

ASGD

实现平均随机梯度下降。

LBFGS

实现 L-BFGS 算法,深受 minFunc 的启发。

RMSprop

实现 RMSprop 算法。

Rprop

实现弹性反向传播算法。

SGD

实现随机梯度下降(可选动量)。

How to adjust learning rate

torch.optim.lr_scheduler 提供了多种方法来根据 epoch 数调整学习率。 torch.optim.lr_scheduler.ReduceLROnPlateau 允许基于一些验证测量降低动态学习率。

学习率调度应在优化器更新后应用; 例如,你应该这样写你的代码:

例子:

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler.step()

大多数学习率调度器都可以称为back-to-back(也称为链式调度器)。 结果是每个调度器都被一个接一个地应用于前一个调度器获得的学习率。

例子:

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler1.step()scheduler2.step()

在文档的许多地方,我们将使用以下模板来引用调度程序算法。

>>> scheduler = ...
>>> for epoch in range(100):
>>>     train(...)
>>>     validate(...)
>>>     scheduler.step()

warning
在 PyTorch 1.1.0 之前,学习率调度器预计会在优化器更新之前被调用; 1.1.0 以一种打破 BC 的方式改变了这种行为。 如果在优化器更新(调用 optimizer.step())之前使用学习率调度器(调用 scheduler.step()),这将跳过学习率调度的第一个值。 如果升级到 PyTorch 1.1.0 后无法重现结果,请检查您是否在错误的时间调用了 scheduler.step()。

lr_scheduler.LambdaLR

Sets the learning rate of each parameter group to the initial lr times a given function.

lr_scheduler.MultiplicativeLR

Multiply the learning rate of each parameter group by the factor given in the specified function.

lr_scheduler.StepLR

Decays the learning rate of each parameter group by gamma every step_size epochs.

lr_scheduler.MultiStepLR

Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones.

lr_scheduler.ExponentialLR

Decays the learning rate of each parameter group by gamma every epoch.

lr_scheduler.CosineAnnealingLR

Set the learning rate of each parameter group using a cosine annealing schedule, where \eta_{max}ηmax​ is set to the initial lr and T_{cur}Tcur​ is the number of epochs since the last restart in SGDR:

lr_scheduler.ReduceLROnPlateau

Reduce learning rate when a metric has stopped improving.

lr_scheduler.CyclicLR

Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR).

lr_scheduler.OneCycleLR

Sets the learning rate of each parameter group according to the 1cycle learning rate policy.

lr_scheduler.CosineAnnealingWarmRestarts

Set the learning rate of each parameter group using a cosine annealing schedule, where \eta_{max}ηmax​ is set to the initial lr, T_{cur}Tcur​ is the number of epochs since the last restart and T_{i}Ti​ is the number of epochs between two warm restarts in SGDR:

Stochastic Weight Averaging

torch.optim.swa_utils 实现随机权重平均 (SWA)。 特别是,torch.optim.swa_utils.AveragedModel 类实现了 SWA 模型,torch.optim.swa_utils.SWALR 实现了 SWA 学习率调度程序,torch.optim.swa_utils.update_bn() 是一个实用函数,用于在 训练结束。

SWA 已在 Averaging Weights Leads to Wide Optima and Better Generalization 中提出。

Constructing averaged models

AveragedModel 类用于计算 SWA 模型的权重。 您可以通过运行来创建平均模型:

>>> swa_model = AveragedModel(model)

这里模型可以是任意的 torch.nn.Module 对象。 swa_model 将跟踪模型参数的运行平均值。 要更新这些平均值,您可以使用 update_parameters() 函数:

>>> swa_model.update_parameters(model)

SWA learning rate schedules

通常,在 SWA 中,学习率设置为较高的恒定值。 SWALR 是一种学习率调度器,它将学习率退火到一个固定值,然后保持不变。 例如,以下代码创建了一个调度程序,它在每个参数组内的 5 个时期内将学习率从其初始值线性退火到 0.05:

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

您还可以通过设置anneal_strategy="cos" 将余弦退火用于固定值而不是线性退火。

Taking care of batch normalization

update_bn() 是一个实用函数,它允许在训练结束时在给定的数据加载器上计算 SWA 模型的batchnorm 统计信息:

>>> torch.optim.swa_utils.update_bn(loader, swa_model)

update_bn() 将 swa_model 应用于数据加载器中的每个元素,并计算模型中每个batch normalization层的激活统计信息。

warning

update_bn() 假设数据加载器加载器中的每个批次都是张量或张量列表,其中第一个元素是应应用网络 swa_model 的张量。 如果您的数据加载器具有不同的结构,您可以通过在数据集的每个元素上使用 swa_model 进行前向传递来更新 swa_model 的批量标准化统计信息。

Custom averaging strategies

默认情况下,torch.optim.swa_utils.AveragedModel 计算您提供的参数的运行相等平均值,但您也可以使用带有 avg_fn 参数的自定义平均函数。 在以下示例中,ema_model 计算指数移动平均线。

例子:

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.1 * averaged_model_parameter + 0.9 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

Putting it all together

在下面的示例中,swa_model 是累积权重平均值的 SWA 模型。 我们总共训练了 300 个时期的模型,然后切换到 SWA 学习率计划并开始收集第 160 个时期参数的 SWA 平均值:

>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>       if epoch > swa_start:
>>>           swa_model.update_parameters(model)
>>>           swa_scheduler.step()
>>>       else:
>>>           scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)

13、TORCH.OPTIM相关推荐

  1. torch.load、torch.save、torch.optim.Adam的用法

    目录 一.保存模型-torch.save() 1.只保存model的权重 2.保存多项内容 二.加载模型-torch.load() 1.从本地模型中读取数据 2.加载上一步读取的数据 load_sta ...

  2. torch的拼接函数_从零开始深度学习Pytorch笔记(13)—— torch.optim

    前文传送门: 从零开始深度学习Pytorch笔记(1)--安装Pytorch 从零开始深度学习Pytorch笔记(2)--张量的创建(上) 从零开始深度学习Pytorch笔记(3)--张量的创建(下) ...

  3. PyTorch 笔记(18)— torch.optim 优化器的使用

    到目前为止,代码中的神经网络权重的参数优化和更新还没有实现自动化,并且目前使用的优化方法都有固定的学习速率,所以优化函数相对简单,如果我们自己实现一些高级的参数优化算法,则优化函数部分的代码会变得较为 ...

  4. PyTorch官方中文文档:torch.optim 优化器参数

    内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...

  5. PyTorch 笔记(08)— Tensor 比较运算(torch.gt、lt、ge、le、eq、ne、torch.topk、torch.sort、torch.max、torch.min)

    1. 常用函数 比较函数中有一些是逐元素比较,操作类似逐元素操作,还有一些类似归并操作,常用的比较函数如下表所示. 表中第一行的比较操作已经实现了运算符重载,因此可以使用 a>=b,a>b ...

  6. PyTorch: torch.optim 的6种优化器及优化算法介绍

    import torch import torch.nn.functional as F import torch.utils.data as Data import matplotlib.pyplo ...

  7. 【torch.optim】优化器的使用 / 学习率的调整 / SWA策略

    torch.optim torch.optim是实现各种优化算法的包.大多数常用的方法都已得到支持,而且接口足够通用,因此将来还可以轻松集成更复杂的方法. 优化器 使用优化器 为了使用一个优化器,必须 ...

  8. class torch.optim.lr_scheduler.StepLR

    参考链接: class torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1, verbose= ...

  9. pytorch每日一学22(torch.empty()、torch.empty_like()、torch.empty_strided())创建未初始化数据的tensor

    第22个方法 由于下面这三个方法比较相似,都是创建未初始化的tensor(第三个是创建一个tensor的视图),所以就放到一起来讲: torch.empty(*size, *, out=None, d ...

  10. Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

    Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...

最新文章

  1. 大连理工IIAU Lab提出SSLSOD:自监督预训练的RGB-D显著性目标检测模型(AAAI 22)
  2. 大厂php怎么做前端,大厂前端经典面试问题精选(附答案)
  3. 15个顶级Java多线程面试题及回答(高级java工程师)
  4. hive 插入数据映射到hbase_大数据基础知识:Hadoop分布式系统介绍
  5. C#中TransactionScope的使用方法和原理
  6. 满二叉树及完全二叉树的相关性质证明
  7. Linux常用命令—文件处理命令—文件处理命令
  8. linux 服务管理
  9. 系统学习深度学习(二十三)--SqueezeNet
  10. [转]Java 强引用、 软引用、 弱引用、虚引用
  11. 知行之桥EDI系统30天试用导航
  12. 创建JSON集合使用JSONArray.fromObject 转化后得到的jsonArray集为空?
  13. 计算机考研408复习路线,不再让你头大啦
  14. Go语言十大排序算法
  15. 计算机视觉文献综述选题,综述论文2021-计算机视觉十大领域最新综述文章分类大盘点...
  16. 苹果手机在哪搜索测试版软件,如何在 beta 版软件上测试你的 App
  17. 蓝桥杯 算法训练 合集1 C++
  18. JAVA基础——关键字与保留字——标识符——进制转换
  19. android数字转汉字大写字母,将数字金额转成汉字大写的
  20. 常见物联网近距离无线通信技术解析

热门文章

  1. mfc word转pdf
  2. 判断推理---逻辑判断
  3. qt 获取本地文件夹、文件路径
  4. 微信小程序:王者荣耀吃鸡气泡等等头像框DIY在线生成N种风格
  5. flask框架可以做什么?
  6. 方差分析ANOVA、单因素方差分析、协变量方差分析ANCOVA、重复测量方差分析、双因素方差分析( two-way ANOVA)、多元方差分析MANOVA、多元协方差分析MANCOVA
  7. Excel如何批量删除工作表中的所有空列
  8. openwrt网关服务器性能,单一ipv6地址做网关的三种方法之openwrt篇
  9. Git 版本控制系统的安装与使用
  10. 通过游戏编程学Python(番外篇)— 乱序成语、猜单词