点击上方,选择星标置顶,每天给你送干货

阅读大概需要10分钟

跟随小博主,每天进步一丢丢

作者:limzero

地址:https://www.zhihu.com/people/lim0-34

编辑:人工智能前沿讲习

最近深入了解了下pytorch下面余弦退火学习率的使用.网络上大部分教程都是翻译的pytorch官方文档,并未给出一个很详细的介绍,由于官方文档也只是给了一个数学公式,对参数虽然有解释,但是解释得不够明了,这样一来导致我们在调参过程中不能合理的根据自己的数据设置合适的参数.这里作一个笔记,并且给出一些定性和定量的解释和结论.说到pytorch自带的余弦学习率调整方法,通常指下面这两个

CosineAnnealingLR

CosineAnnealingWarmRestarts

CosineAnnealingLR

这个比较简单,只对其中的最关键的Tmax参数作一个说明,这个可以理解为余弦函数的半周期.如果max_epoch=50次,那么设置T_max=5则会让学习率余弦周期性变化5次.

max_opoch=50, T_max=5

CosineAnnealingWarmRestarts

这个最主要的参数有两个:

  • T_0:学习率第一次回到初始值的epoch位置

  • T_mult:这个控制了学习率变化的速度

    • 如果T_mult=1,则学习率在T_0,2T_0,3T_0,....,i*T_0,....处回到最大值(初始学习率)

      • 5,10,15,20,25,.......处回到最大值

    • 如果T_mult>1,则学习率在T_0,(1+T_mult)T_0,(1+T_mult+T_mult**2)T_0,.....,(1+T_mult+T_mult2+...+T_0i)*T0,处回到最大值

      • 5,15,35,75,155,.......处回到最大值

T_0=5, T_mult=1

T_0=5, T_mult=2

所以可以看到,在调节参数的时候,一定要根据自己总的epoch合理的设置参数,不然很可能达不到预期的效果,经过我自己的试验发现,如果是用那种等间隔的退火策略(CosineAnnealingLR和Tmult=1的CosineAnnealingWarmRestarts),验证准确率总是会在学习率的最低点达到一个很好的效果,而随着学习率回升,验证精度会有所下降.所以为了能最终得到一个更好的收敛点,设置T_mult>1是很有必要的,这样到了训练后期,学习率不会再有一个回升的过程,而且一直下降直到训练结束。

下面是使用示例和画图的代码:

import torchfrom torch.optim.lr_scheduler import CosineAnnealingLR,CosineAnnealingWarmRestarts,StepLRimport torch.nn as nnfrom torchvision.models import resnet18import matplotlib.pyplot as plt#model=resnet18(pretrained=False)optimizer = torch.optim.SGD(model.parameters(), lr=0.1)mode='cosineAnnWarm'if mode=='cosineAnn':    scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0)elif mode=='cosineAnnWarm':    scheduler = CosineAnnealingWarmRestarts(optimizer,T_0=5,T_mult=1)    '''    以T_0=5, T_mult=1为例:    T_0:学习率第一次回到初始值的epoch位置.    T_mult:这个控制了学习率回升的速度        - 如果T_mult=1,则学习率在T_0,2*T_0,3*T_0,....,i*T_0,....处回到最大值(初始学习率)            - 5,10,15,20,25,.......处回到最大值        - 如果T_mult>1,则学习率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,处回到最大值            - 5,15,35,75,155,.......处回到最大值    example:        T_0=5, T_mult=1    '''plt.figure()max_epoch=50iters=200cur_lr_list = []for epoch in range(max_epoch):    for batch in range(iters):        '''        这里scheduler.step(epoch + batch / iters)的理解如下,如果是一个epoch结束后再.step        那么一个epoch内所有batch使用的都是同一个学习率,为了使得不同batch也使用不同的学习率        则可以在这里进行.step        '''        #scheduler.step(epoch + batch / iters)        optimizer.step()    scheduler.step()    cur_lr=optimizer.param_groups[-1]['lr']    cur_lr_list.append(cur_lr)    print('cur_lr:',cur_lr)x_list = list(range(len(cur_lr_list)))plt.plot(x_list, cur_lr_list)plt.show()

最后,对 scheduler.step(epoch + batch / iters)的一个说明,这里的个人理解:一个epoch结束后再.step, 那么一个epoch内所有batch使用的都是同一个学习率,为了使得不同batch也使用不同的学习率 ,则可以在这里进行.step(将离散连续化,或者说使得采样得更加的密集),下图是以20个epoch,每个epoch5个batch,T0=2,Tmul=2画的学习率变化图

代码:

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR,CosineAnnealingWarmRestarts,StepLR
import torch.nn as nn
from torchvision.models import resnet18
import matplotlib.pyplot as plt
#
model=resnet18(pretrained=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
mode='cosineAnnWarm'
if mode=='cosineAnn':scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0)
elif mode=='cosineAnnWarm':scheduler = CosineAnnealingWarmRestarts(optimizer,T_0=2,T_mult=2)'''以T_0=5, T_mult=1为例:T_0:学习率第一次回到初始值的epoch位置.T_mult:这个控制了学习率回升的速度- 如果T_mult=1,则学习率在T_0,2*T_0,3*T_0,....,i*T_0,....处回到最大值(初始学习率)- 5,10,15,20,25,.......处回到最大值- 如果T_mult>1,则学习率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,处回到最大值- 5,15,35,75,155,.......处回到最大值example:T_0=5, T_mult=1'''
plt.figure()
max_epoch=20
iters=5
cur_lr_list = []
for epoch in range(max_epoch):print('epoch_{}'.format(epoch))for batch in range(iters):scheduler.step(epoch + batch / iters)optimizer.step()#scheduler.step()cur_lr=optimizer.param_groups[-1]['lr']cur_lr_list.append(cur_lr)print('cur_lr:',cur_lr)print('epoch_{}_end'.format(epoch))
x_list = list(range(len(cur_lr_list)))
plt.plot(x_list, cur_lr_list)
plt.show()

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。

下载一:中文版!学习TensorFlow、PyTorch、机器学习、深度学习和数据结构五件套!后台回复【五件套】
下载二:南大模式识别PPT后台回复【南大模式识别】

说个正事哈

由于微信平台算法改版,公号内容将不再以时间排序展示,如果大家想第一时间看到我们的推送,强烈建议星标我们和给我们多点点【在看】。星标具体步骤为:

(1)点击页面最上方深度学习自然语言处理”,进入公众号主页。

(2)点击右上角的小点点,在弹出页面点击“设为星标”,就可以啦。

感谢支持,比心

投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。

方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。

记得备注呦

推荐两个专辑给大家:

专辑 | 李宏毅人类语言处理2020笔记

专辑 | NLP论文解读

专辑 | 情感分析


整理不易,还望给个在看!

pytorch的余弦退火学习率相关推荐

  1. 垃圾分类、EfficientNet模型B0~B7、Rectified Adam(RAdam)、Warmup、带有Warmup的余弦退火学习率衰减

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 垃圾分类.EfficientNet模型.数据增强(ImageD ...

  2. PyTorch学习率 warmup + 余弦退火

    PyTorch学习率 warmup + 余弦退火 Pytorch 余弦退火 PyTorch内置了很多学习率策略,详情请参考torch.optim - PyTorch 1.10.1 documentat ...

  3. 【深度学习】(11) 学习率衰减策略(余弦退火衰减,多项式衰减),附TensorFlow完整代码

    大家好,今天和各位分享一下如何使用 TensorFlow 构建 多项式学习率衰减策略.单周期余弦退火学习率衰减策略.多周期余弦退火学习率衰减策略,并使用Mnist数据集来验证构建的方法是否可行. 在上 ...

  4. pytorch优化器,学习率衰减学习笔记

    目录 LAMB优化器 AdaBelief 优化器 Adam和SGD的结合体 lookahead Ranger RAdam和LookAhead合二为一 余弦退火学习率衰减

  5. PyTorch学习率衰减策略:指数衰减(ExponentialLR)、固定步长衰减(StepLR)、多步长衰减(MultiStepLR)、余弦退火衰减(CosineAnnealingLR)

    梯度下降算法需要我们指定一个学习率作为权重更新步幅的控制因子,常用的学习率有0.01.0.001以及0.0001等,学习率越大则权重更新.一般来说,我们希望在训练初期学习率大一些,使得网络收敛迅速,在 ...

  6. Pytorch 学习率衰减 之 余弦退火与余弦warmup 自定义学习率衰减scheduler

    学习率衰减,通常我们英文也叫做scheduler.本文学习率衰减自定义,通过2种方法实现自定义,一是利用lambda,另外一个是继承pytorch的lr_scheduler import math i ...

  7. 【学习率调整】学习率衰减之周期余弦退火 (cyclic cosine annealing learning rate schedule)

    1. 概述 在论文<SGDR: Stochastic Gradient Descent with Warm Restarts>中主要介绍了带重启的随机梯度下降算法(SGDR),其中就引入了 ...

  8. 学习率衰减之余弦退火(CosineAnnealing)

    1 引言 当我们使用梯度下降算法来优化目标函数的时候,当越来越接近Loss值的全局最小值时,学习率应该变得更小来使得模型尽可能接近这一点,而余弦退火(Cosine annealing)可以通过余弦函数 ...

  9. PyTorch学习之六个学习率调整策略

    PyTorch学习率调整策略通过torch.optim.lr_scheduler接口实现.PyTorch提供的学习率调整策略分为三大类,分别是 a. 有序调整:等间隔调整(Step),按需调整学习率( ...

  10. 【深度学习】图解 9 种PyTorch中常用的学习率调整策略

    learning rate scheduling 学习率调整策略 01 LAMBDA LR 将每个参数组的学习率设置为初始lr乘以给定函数.当last_epoch=-1时,将初始lr设置为初始值. t ...

最新文章

  1. python具有一些突出优点_Python具有一些突出优点,它们是:()
  2. 2013 多校联合4 1011 Fliping game (hdu 4642)
  3. MATLAB GPU编程基础
  4. 整合Flex和Java(上)
  5. iptables 防火墙为什么不占用端口?
  6. c语言的加法和平均值程序,编写求一组整数的和与平均值的程序
  7. teighax是什么_cut up,cut in,cut off,cut down有什么区别?
  8. [LAMP兄弟连李明老师讲Linux].课件Shell编程
  9. LTE IDLE DRX和CDRX
  10. 小学计算机考核,小学信息技术学科考核评价方案.docx
  11. 搜狐股票接口获取数据方法
  12. (OC) interface
  13. 2022年上半年中国企业员工主动离职率大幅下降至6%;三成以上中国企业大部分高管岗位没有后备人选 | 美通社头条...
  14. 拼图游戏java(三)实现鼠标点击图片上下左右移动
  15. Vijos - 想越狱的小杉(最短路)
  16. Unity地图分割组合时出现接缝的处理办法
  17. 机房动环监控系统应用意义
  18. 快速自建Web安全测试环境
  19. Codeforces Round #705 (Div. 2) A-D
  20. 小白兔写话_小学二年级期末写话片段练习 可爱的小白兔

热门文章

  1. linux系统快捷键使用
  2. spark sql cache
  3. JavaScript 函数参数是传值(byVal)还是传址(byRef)?
  4. Altium Designer(五):布板技巧
  5. 任务21 :了解ASP.NET Core 依赖注入,看这篇就够了
  6. 深度学习网络架构(三):VGG
  7. PyCharm导入selenium的webdirver模块出错
  8. /etc/profile、~/.bash_profile、~/.bashrc和/etc/bashrc
  9. jdbc:initialize-database标签的研究
  10. EF直接更新数据(不需查询)