torch.optim

torch.optim是实现各种优化算法的包。大多数常用的方法都已得到支持,而且接口足够通用,因此将来还可以轻松集成更复杂的方法。

优化器

使用优化器

为了使用一个优化器,必须构造一个优化器对象,它将保存当前状态,并将根据计算的梯度更新参数。

  • 构建优化器

要构造一个优化器,必须给它一个包含参数(所有的参数应该是Variable)的可迭代对象来优化。然后,可以指定特定于优化器的选项,如学习率、权重衰减等。

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
  • Per-parameter

优化器还支持指定per-parameter options。要做到这一点,不是传递一个Variable的迭代对象,而是传递一个dict 的迭代对象。它们每个都将定义一个单独的参数组,并且应该包含一个params键,包含属于它的参数列表。其他键应该与优化器接受的关键字参数匹配,并将用作该组的优化选项。

例如,当你想指定每层的学习率时:

optim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)

这意味**model.base**的参数将使用默认的学习率1e-2,模型。model.classifier的参数将使用1e-3的学习率,所有参数将使用0.9momentum

  • optimization step

所有优化器都实现了一个step()方法,用于更新参数。它有两种用法:

(1)optimizer.step() : 这是大多数优化器支持的简化版本。该函数可以在梯度计算完成后调用,例如backward()

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

(2)optimizer.step(closure)

一些优化算法(如Conjugate GradientLBFGS)需要多次重新计算函数,所以你必须传入一个闭包,允许它们重新计算你的模型。闭包应该清除梯度,计算损失并返回。

for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)

优化器基类 Optimizer

  • Optimizer
torch.optim.Optimizer(params, defaults)
  • 参数

    • params (*iterable*) :指定应当被优化的张量,是torch.Tensordict的迭代器类型。
    • defaults(dict) : ****包含优化选项默认值的字典(当参数组没有指定这些值时使用)。
  • 方法
Optimizer.add_param_group 在优化器的param_groups中添加一个参数组。
Optimizer.load_state_dict 加载优化器状态。
Optimizer.state_dict 以字典的形式返回优化器的状态。
Optimizer.step 执行单个优化步骤(参数更新)。
Optimizer.zero_grad 设置所有优化的 torch.Tensor 的梯度为零

优化算法

Adadelta Adadelta算法的实现
Adagrad Adagrad 算法的实现
Adam Adam算法的实现
AdamW AdamW 算法的实现
SparseAdam 适合稀疏张量的Adam算法的实现
Adamax Adamax algorithm 算法的实现
ASGD Averaged Stochastic Gradient Descent. 算法的实现
LBFGS L-BFGS 算法的实现
NAdam NAdam 算法的实现
RAdam RAdam 算法的实现
RMSprop RMSprop 算法的实现
Rprop resilient backpropagation 算法的实现
SGD stochastic gradient descent (optionally with momentum). 算法的实现

学习率

调整学习率的基类

**torch.optim.lr_scheduler**提供了几种基于epoch数量调整学习率的方法。

**torch.optim.lr_scheduler.ReduceLROnPlateau** 允许基于某些验证度量动态地降低学习率。

学习率调度应在优化器更新后应用 :

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler.step()

大多数学习率调度器可以称为back-to-back调度器(也称为链式调度器)。每个调度器一个接一个地应用于前一个调度器获得的学习率。

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler1.step()scheduler2.step()

我们将使用以下模板来引用调度器算法。

scheduler = ...
for epoch in range(100):train(...)validate(...)scheduler.step()

调整学习率的算法

lr_scheduler.LambdaLR 将每个参数组的学习率设置为初始lr乘以给定函数。
lr_scheduler.MultiplicativeLR 将每个参数组的学习率乘以指定函数中给定的因子。
lr_scheduler.StepLR 每个step_size epoch将每个参数组的学习率衰减gamma。
lr_scheduler.MultiStepLR 一旦epoch的数量达到一个里程碑,将每个参数组的学习率衰减为gamma。
lr_scheduler.ConstantLR 将每个参数组的学习率衰减一个小的常数因子,直到epoch的数量达到预定义的里程碑:total_iters。
lr_scheduler.LinearLR 通过线性改变小的乘法因子衰减每个参数组的学习率,直到epoch的数量达到预定义的里程碑:total_iters。
lr_scheduler.ExponentialLR 每一个epoch衰减每个参数组的学习率gamma。
lr_scheduler.PolynomialLR 在给定的total_iters中使用多项式函数衰减每个参数组的学习率。
lr_scheduler.CosineAnnealingLR 使用余弦退火计划设置每个参数组的学习率
lr_scheduler.ChainedScheduler 学习速率调度器的链表。
lr_scheduler.SequentialLR 接收预计在优化过程和里程碑点期间顺序调用的调度器列表,该列表提供精确的间隔,以反映在给定时间段里应该调用哪个调度器。
lr_scheduler.ReduceLROnPlateau 当metric停止改进时,降低学习率。
lr_scheduler.CyclicLR 根据CLR (cycle learning rate policy)策略设置各参数组的学习率。
lr_scheduler.OneCycleLR 按照1cycle学习率策略设置各参数组的学习率。
lr_scheduler.CosineAnnealingWarmRestarts 使用余弦退火计划设置每个参数组的学习率

随机加权平均(Stochastic Weight Averaging)

torch.optim.swa_utils 实现随机加权平均(SWA)。特别是, torch.optim.swa_utils.AveragedModel 类实现SWA模型,torch.optim.swa_utils.SWALR 实现了SWA学习率调度器和 torch.optim.swa_utils.update_bn() 是一个效用函数,用于在训练结束时更新SWA批归一化统计信息。

构造 Averaged modeles

AveragedModel 类用于计算SWA模型的权重。可以通过运行以下命令创建一个averaged model:

swa_model = AveragedModel(model)

这里的模型Model可以是任意的torch.nn.Module对象。swa_model将跟踪模型参数的运行平均值。要更新这些平均值,你可以使用update_parameters()函数:

swa_model.update_parameters(model)

SWA学习速率策略

通常,在SWA中,学习率被设置为一个高常数值。SWALR是一个学习率调度器,它将学习率退火到一个固定的值,然后保持它不变。例如,下面的代码创建了一个调度器,它在每个参数组内的5个周期内将学习率从初始值线性退火到0.05 :

swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

你也可以使用余弦退火到一个固定的值,而不是线性退火通过设置**anneal_strategy="cos"**

batch normalization

update_bn()允许在训练结束时计算给定数据加载器loader上SWA模型的batchnorm统计数据:

torch.optim.swa_utils.update_bn(loader, swa_model)

update_bn()swa_model应用于数据加载器中的每个元素,并计算模型中每个批处理规范化层的激活统计信息。

自定义averaging策略

默认情况下,**torch.optim.swa_utils.AveragedModel**计算您提供的参数的平均值,但是您也可以使用带有avg_fn参数的自定义平均函数。在下面的例子中,ema_model计算一个指数移动平均。

ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\0.1 * averaged_model_parameter + 0.9 * model_parameter
ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

汇总

loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
for epoch in range(300):for input, target in loader:optimizer.zero_grad()loss_fn(model(input), target).backward()optimizer.step()if epoch > swa_start:swa_model.update_parameters(model)swa_scheduler.step()else:scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data
preds = swa_model(test_input)

【torch.optim】优化器的使用 / 学习率的调整 / SWA策略相关推荐

  1. PyTorch 中 torch.optim优化器的使用

    一.优化器基本使用方法 建立优化器实例 循环: 清空梯度 向前传播 计算Loss 反向传播 更新参数 示例: from torch import optim input = ..... optimiz ...

  2. PyTorch官方中文文档:torch.optim 优化器参数

    内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...

  3. PyTorch 笔记(18)— torch.optim 优化器的使用

    到目前为止,代码中的神经网络权重的参数优化和更新还没有实现自动化,并且目前使用的优化方法都有固定的学习速率,所以优化函数相对简单,如果我们自己实现一些高级的参数优化算法,则优化函数部分的代码会变得较为 ...

  4. 【PyTorch】Optim 优化器

    文章目录 五.Optim 优化器 1.SGD 五.Optim 优化器 参考文档:https://pytorch.org/docs/stable/optim.html 1.SGD 参考文档:https: ...

  5. 『PyTorch』第十一弹_torch.optim优化器 每层定制参数

    一.简化前馈网络LeNet 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 im ...

  6. 【Pytorch神经网络理论篇】 10 优化器模块+退化学习率

    1 优化器模块的作用 1.1 反向传播的核心思想 反向传播的意义在于告诉模型我们需要将权重修改到什么数值可以得到最优解,在开始探索合适权重的过程中,正向传播所生成的结果与实际标签的目标值存在误差,反向 ...

  7. 深度学习调参:优化算法,优化器optimizer,学习率learning rate

    在优化模型的过程中,有高原.高峰.洼地,我们的目的是找到最低的那个洼地. 选择不同的学习率和优化器,可能进入不同的洼地,或者在洼地附近震荡,无法收敛. 1 优化器的选择 Adam那么棒,为什么还对SG ...

  8. 优化器:torch.optimizer

    优化器:对参数进行调整,降低误差减小损失 torch.optimizer parameters(模型参数),lr(float)–>(学习速率) 代码: import torch import t ...

  9. 关于SGD优化器的学习率衰减的问题

    由于学术需要,这段时间再训练一个分类器,但其效果不太好,loss下降不明显.便考虑是不是学习率的问题,由于使用的是SGD,其中一个参数为decay,借鉴别人的参设默认值,decay 一般设为1x10- ...

最新文章

  1. WebForm开发常用代码
  2. 微软开发x86模拟器,让Windows for ARM能运行x86应用
  3. OpenCV数字图像处理(5) 像素访问之添加椒盐实例 通道分离与合并
  4. Windows内核新手上路3——挂钩KeUserModeCallBack
  5. 从需求的角度去理解Linux系列:总线、设备和驱动
  6. matlab mobile中文版,MATLAB Mobile
  7. oracle vm 实施图解
  8. android可点击的列表,如何在Android的可扩展列表视图中的子点击...
  9. 【python】用正则表达式进行文字局部替换
  10. 有一种小说叫“纯爱”:为“纯爱小说系列写的序言
  11. [css] img标签是行内元素,为什么却能设置宽高
  12. 懒加载Lazy Loading
  13. linux下用mail发送邮件
  14. Python的BoundedSemaphore对象和Pool对象实例
  15. 蓝桥杯 ALGO-110 算法训练 字符串的展开
  16. html5 新标签xss,HTML5 localStorageXSS漏洞
  17. U盘修复工具哪个好?7款U盘低格工具详解
  18. 如何拯救拖延症??11大招总有一招帮你搞定
  19. dfuse 新版 client-js 库发布
  20. 深度:老年旅游市场迎来转折点,50/60新老年消费升级带来结构性创新机会—营销/产品/运营

热门文章

  1. formality形式验证里的案件分析
  2. java计算机毕业设计HTML5“守护萌宠”网站设计与实现MyBatis+系统+LW文档+源码+调试部署
  3. 【Python学习教程】Python变量类型和运算符
  4. html中盒子的定位,css盒子的定位有哪些方法?
  5. 贝壳网webpack案例
  6. ubuntu安装nvidia 750ti显卡驱动
  7. 联合证券|“仰望”概念爆发,多股涨停!人气龙头股罕见“炸板”
  8. [乡土民间故事_徐苟三传奇]第廿九回_蠢财主落水知上当
  9. 拉钩作业:Bikeshare数据集 预测共享单车骑行量
  10. 西安交大计算机学院 栾佳锡,史椸-西安交通大学-自动化科学与工程学院