点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达

作者丨机器学习入坑者@知乎(已授权)

来源丨https://zhuanlan.zhihu.com/p/93624972

编辑丨极市平台

导读

本文介绍了四种衰减类型:指数衰减、固定步长的衰减、多步长衰、余弦退火衰减并逐一介绍其性质,及pytorch对应的使用方式。

梯度下降算法需要我们指定一个学习率作为权重更新步幅的控制因子,常用的学习率有0.01、0.001以及0.0001等,学习率越大则权重更新。一般来说,我们希望在训练初期学习率大一些,使得网络收敛迅速,在训练后期学习率小一些,使得网络更好的收敛到最优解。下图展示了随着迭代的进行动态调整学习率的4种策略曲线:

上述4种策略为自己根据资料整理得到的衰减类型:指数衰减、固定步长的衰减、多步长衰、余弦退火衰减。下面逐一介绍其性质,及pytorch对应的使用方式,需要注意学习率衰减策略很大程度上是依赖于经验与具体问题的,不能照搬参数。

1、指数衰减

学习率按照指数的形式衰减是比较常用的策略,我们首先需要确定需要针对哪个优化器执行学习率动态调整策略,也就是首先定义一个优化器:

optimizer_ExpLR = torch.optim.SGD(net.parameters(), lr=0.1)

定义好优化器以后,就可以给这个优化器绑定一个指数衰减学习率控制器:

ExpLR = torch.optim.lr_scheduler.ExponentialLR(optimizer_ExpLR, gamma=0.98)

其中参数gamma表示衰减的底数,选择不同的gamma值可以获得幅度不同的衰减曲线,如下:

2、固定步长衰减

有时我们希望学习率每隔一定步数(或者epoch)就减少为原来的gamma分之一,使用固定步长衰减依旧先定义优化器,再给优化器绑定StepLR对象:

optimizer_StepLR = torch.optim.SGD(net.parameters(), lr=0.1)
StepLR = torch.optim.lr_scheduler.StepLR(optimizer_StepLR, step_size=step_size, gamma=0.65)

其中gamma参数表示衰减的程度,step_size参数表示每隔多少个step进行一次学习率调整,下面对比了不同gamma值下的学习率变化情况:

3、多步长衰减

上述固定步长的衰减的虽然能够按照固定的区间长度进行学习率更新 但是有时我们希望不同的区间采用不同的更新频率,或者是有的区间更新学习率,有的区间不更新学习率,这就需要使用MultiStepLR来实现动态区间长度控制:

optimizer_MultiStepLR = torch.optim.SGD(net.parameters(), lr=0.1)
torch.optim.lr_scheduler.MultiStepLR(optimizer_MultiStepLR,milestones=[200, 300, 320, 340, 200], gamma=0.8)

其中milestones参数为表示学习率更新的起止区间,在区间[0. 200]内学习率不更新,而在[200, 300]、[300, 320].....[340, 400]的右侧值都进行一次更新;gamma参数表示学习率衰减为上次的gamma分之一。其图示如下:

从图中可以看出,学习率在区间[200, 400]内快速的下降,这就是milestones参数所控制的,在milestones以外的区间学习率始终保持不变。

4、余弦退火衰减

严格的说,余弦退火策略不应该算是学习率衰减策略,因为它使得学习率按照周期变化,其定义方式如下:

optimizer_CosineLR = torch.optim.SGD(net.parameters(), lr=0.1)
CosineLR = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_CosineLR, T_max=150, eta_min=0)

其包含的参数和余弦知识一致,参数T_max表示余弦函数周期;eta_min表示学习率的最小值,默认它是0表示学习率至少为正值。确定一个余弦函数需要知道最值和周期,其中周期就是T_max,最值是初试学习率。下图展示了不同周期下的余弦学习率更新曲线:

5、上述4种学习率动态更新策略的说明

4个负责学习率调整的类:StepLR、ExponentialLR、MultiStepLR和CosineAnnealingLR,其完整对学习率的更新都是在其step()函数被调用以后完成的,这个step表达的含义可以是一次迭代,当然更多情况下应该是一个epoch以后进行一次scheduler.step(),这根据具体问题来确定。此外,根据pytorch官网上给出的说明,scheduler.step()函数的调用应该在训练代码以后:

scheduler = ...
>>> for epoch in range(100):
>>>     train(...)
>>>     validate(...)
>>>     scheduler.step()

参考:

https://www.jianshu.com/p/125fe2ab085b

https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

如果觉得有用,就请分享到朋友圈吧!

点个在看 paper不断!

干货|pytorch必须掌握的的4种学习率衰减策略相关推荐

  1. pytorch必须掌握的的4种学习率衰减策略

    原文: pytorch必须掌握的的4种学习率衰减策略 1.指数衰减 2. 固定步长衰减 3. 多步长衰减 4. 余弦退火衰减 5. 上述4种学习率动态更新策略的说明 梯度下降算法需要我们指定一个学习率 ...

  2. polyrate使用方法_pytorch必须掌握的的4种学习率衰减策略

    梯度下降算法需要我们指定一个学习率作为权重更新步幅的控制因子,常用的学习率有0.01.0.001以及0.0001等,学习率越大则权重更新.一般来说,我们希望在训练初期学习率大一些,使得网络收敛迅速,在 ...

  3. PyTorch学习率衰减策略:指数衰减(ExponentialLR)、固定步长衰减(StepLR)、多步长衰减(MultiStepLR)、余弦退火衰减(CosineAnnealingLR)

    梯度下降算法需要我们指定一个学习率作为权重更新步幅的控制因子,常用的学习率有0.01.0.001以及0.0001等,学习率越大则权重更新.一般来说,我们希望在训练初期学习率大一些,使得网络收敛迅速,在 ...

  4. 系统学习Pytorch笔记七:优化器和学习率调整策略

    Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html? Pytorch中文文档:https://pytorch-cn.readthedocs ...

  5. paddle 12种学习率调度器

    目录 文本框检测的Cosine学习率调度器: 13种调度器 文本框检测的Cosine学习率调度器: 学习率 0.001 效果好像比较好,推荐使用 configs/det/ch_ppocr_v2.0/c ...

  6. PyTorch | 优化神经网络训练的17种方法

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | LORENZ KUHN 来源 | 人工智能前沿讲习 编辑 ...

  7. (转) 干货 | 图解LSTM神经网络架构及其11种变体(附论文)

    干货 | 图解LSTM神经网络架构及其11种变体(附论文) 2016-10-02 机器之心 选自FastML 作者:Zygmunt Z. 机器之心编译  参与:老红.李亚洲 就像雨季后非洲大草原许多野 ...

  8. PyTorch: torch.optim 的6种优化器及优化算法介绍

    import torch import torch.nn.functional as F import torch.utils.data as Data import matplotlib.pyplo ...

  9. 【深度学习】图解 9 种PyTorch中常用的学习率调整策略

    learning rate scheduling 学习率调整策略 01 LAMBDA LR 将每个参数组的学习率设置为初始lr乘以给定函数.当last_epoch=-1时,将初始lr设置为初始值. t ...

最新文章

  1. 电阻(4)之上/下拉电阻
  2. 第六章 计算机网络与i教案,大学计算机基础教案第6章计算机网络基础与应用.docx...
  3. Qt学习:QDomDocument
  4. python库安装错误 in _error_catcher解决之镜像安装
  5. Zuul转发请求时HttpHostConnectException can‘t cast to ZuulException问题解决方法
  6. CentOS8 同步时间chrony ntpdate已无法使用
  7. 【基础】算法的时间复杂度分析
  8. MySQL提取字符串中数字(自定义函数)
  9. 【微信小程序】解决代码上传超过大小限制,小程序分包
  10. 计算机博士专业学位,计算机博士
  11. python帮你获取王者荣耀金币
  12. cloudchat苹果如何下载只能通过ipa吗
  13. 力扣、github网站登不上
  14. python怎样分析文献综述_教你如何做文献综述
  15. mysql 从第几个字符串开始截取_mysql字符串截取
  16. 【转】情牵牛仔裤 情色一生
  17. 形容等待时间长的句子_形容等待时间长的诗句
  18. 计算机远程病理会诊准确率,数字病理远程诊断
  19. 苹果cmsv10自适应模板自带后台系统原创多功能漂亮主题
  20. linux 系统cpu查看

热门文章

  1. vs2005什么时候能出正式版
  2. 谢文睿:西瓜书 + 南瓜书 吃瓜系列 6. 神经网络
  3. 如何制作风格迁移图?
  4. python解析json
  5. 【牛客】CSL 的字符串 (stack map)
  6. 性能超越最新序列推荐模型,华为诺亚方舟提出记忆增强的图神经网络
  7. 微众银行殷磊:AI+卫星,从上帝视角洞察资产管理|BDTC 2019
  8. 天哪!我的十一假期被AI操控了
  9. 《使女的故事》大火,AI是背后最大推手?
  10. 升级人脸识别,小鱼易连要打通企业与个人微信,重塑视频会议3.0!