学习率衰减,通常我们英文也叫做scheduler。本文学习率衰减自定义,通过2种方法实现自定义,一是利用lambda,另外一个是继承pytorch的lr_scheduler

import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.optim import *
from torchvision import models
class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.fc = nn.Linear(1, 10)def forward(self,x):return self.fc(x)

余弦退火

  1. 当T_max=20
lrs = []
model = Net()
LR = 0.01
epochs = 100
optimizer = Adam(model.parameters(),lr = LR)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-9)
for epoch in range(epochs): optimizer.step()lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])scheduler.step()plt.figure(figsize=(10, 6))
plt.plot(lrs, color='r')
plt.text(0, lrs[0], str(lrs[0]))
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()

  1. 当T_max = epochs,这就是我们经常用到的弦退火的 scheduler,下面再来看看带Warm-up的
lrs = []
model = Net()
LR = 0.01
epochs = 100
optimizer = Adam(model.parameters(),lr = LR)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-9)
for epoch in range(epochs): optimizer.step()lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])scheduler.step()plt.figure(figsize=(10, 6))
plt.plot(lrs, color='r')
plt.text(0, lrs[0], str(lrs[0]))
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()

WarmUp

下面来看看 Pytorch定义的余弦退货的公式如下
ηt=ηmin+12(ηmax−ηmin)(1+cos⁡(TcurTmaxπ)),Tcur≠(2k+1)Tmax;ηt+1=ηt+12(ηmax−ηmin)(1−cos⁡(1Tmaxπ)),Tcur=(2k+1)Tmax.\begin{aligned} \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), & T_{cur} \neq (2k+1)T_{max}; \\ \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), & T_{cur} = (2k+1)T_{max}. \end{aligned}ηt​ηt+1​​=ηmin​+21​(ηmax​−ηmin​)(1+cos(Tmax​Tcur​​π)),=ηt​+21​(ηmax​−ηmin​)(1−cos(Tmax​1​π)),​Tcur​​=(2k+1)Tmax​;Tcur​=(2k+1)Tmax​.​

实际上是用下面的公式做为更新的, 当Tcur=TmaxT_{cur} = T_{max}Tcur​=Tmax​是,coscoscos部分为0,所以就等于ηmin\eta_{min}ηmin​

ηt=ηmin+12(ηmax−ηmin)(1+cos⁡(TcurTmaxπ))\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)ηt​=ηmin​+21​(ηmax​−ηmin​)(1+cos(Tmax​Tcur​​π))

这里直接根据公式的定义来画个图看看

etas = []
epochs = 100
eta_max = 1e-4
eta_min = 1e-9
t_max = epochs / 1
for i in range(epoch):t_cur = ieta = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * t_cur / t_max))etas.append(eta)plt.figure(figsize=(10, 6))
plt.plot(range(len(etas)), etas, color='r')
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()


从图上来看,跟上面的余弦退化是一样的,眼尖的都会发现lr_min 不等于eta_min=1e-9

利用Lambda来定义的

有个较小的bug(也不算,在description里有指出)

def warm_up_cosine_lr_scheduler(optimizer, epochs=100, warm_up_epochs=5, eta_min=1e-9):"""Description:- Warm up cosin learning rate scheduler, first epoch lr is too smallArguments:- optimizer: input optimizer for the training- epochs: int, total epochs for your training, default is 100. NOTE: you should pass correct epochs for your training- warm_up_epochs: int, default is 5, which mean the lr will be warm up for 5 epochs. if warm_up_epochs=0, means no needto warn up, will be as cosine lr scheduler- eta_min: float, setup ConsinAnnealingLR eta_min while warm_up_epochs = 0Returns:- scheduler"""if warm_up_epochs <= 0:scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=eta_min)else:warm_up_with_cosine_lr = lambda epoch: eta_min + (epoch / warm_up_epochs) if epoch <= warm_up_epochs else 0.5 * (np.cos((epoch - warm_up_epochs) / (epochs - warm_up_epochs) * np.pi) + 1)scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_cosine_lr)return scheduler
# warm up consin lr scheduler
lrs = []
model = Net()
LR = 1e-4
warm_up_epochs = 30
epochs = 100
optimizer = SGD(model.parameters(), lr=LR)scheduler = warm_up_cosine_lr_scheduler(optimizer, warm_up_epochs=warm_up_epochs, eta_min=1e-9)for epoch in range(epochs):optimizer.step()lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])scheduler.step()plt.figure(figsize=(10, 6))  ![请添加图片描述](https://img-blog.csdnimg.cn/566b2c036b4a44598ae2a5a0548f2550.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBAamFzbmVpaw==,size_20,color_FFFFFF,t_70,g_se,x_16)plt.plot(lrs, color='r')
plt.text(0, lrs[0], str(lrs[0]))
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()


从图上看,第一个lr非常非常小,导致训练时的,第一个epoch基本上不更新

继承lr_scheduler的类

class WarmupCosineLR(lr_scheduler._LRScheduler):def __init__(self, optimizer, lr_min, lr_max, warm_up=0, T_max=10, start_ratio=0.1):"""Description:- get warmup consine lr schedulerArguments:- optimizer: (torch.optim.*), torch optimizer- lr_min: (float), minimum learning rate- lr_max: (float), maximum learning rate- warm_up: (int),  warm_up epoch or iteration- T_max: (int), maximum epoch or iteration- start_ratio: (float), to control epoch 0 lr, if ratio=0, then epoch 0 lr is lr_minExample:<<< epochs = 100<<< warm_up = 5<<< cosine_lr = WarmupCosineLR(optimizer, 1e-9, 1e-3, warm_up, epochs)<<< lrs = []<<< for epoch in range(epochs):<<<     optimizer.step()<<<     lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])<<<     cosine_lr.step()<<< plt.plot(lrs, color='r')<<< plt.show()"""self.lr_min = lr_minself.lr_max = lr_maxself.warm_up = warm_upself.T_max = T_maxself.start_ratio = start_ratioself.cur = 0    # current epoch or iterationsuper().__init__(optimizer, -1)def get_lr(self):if (self.warm_up == 0) & (self.cur == 0):lr = self.lr_maxelif (self.warm_up != 0) & (self.cur <= self.warm_up):if self.cur == 0:lr = self.lr_min + (self.lr_max - self.lr_min) * (self.cur + self.start_ratio) / self.warm_upelse:lr = self.lr_min + (self.lr_max - self.lr_min) * (self.cur) / self.warm_up# print(f'{self.cur} -> {lr}')else:            # this works finelr = self.lr_min + (self.lr_max - self.lr_min) * 0.5 *\(np.cos((self.cur - self.warm_up) / (self.T_max - self.warm_up) * np.pi) + 1)self.cur += 1return [lr for base_lr in self.base_lrs]
# class
epochs = 100
warm_up = 5
cosine_lr = WarmupCosineLR(optimizer, 1e-9, 1e-3, warm_up, epochs, 0.1)
lrs = []
for epoch in range(epochs):optimizer.step()lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])cosine_lr.step()plt.figure(figsize=(10, 6))
plt.plot(lrs, color='r')
plt.text(0, lrs[0], str(lrs[0]))
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()


从图上看出,第一个epoch的lr也不至于非常非常小了,达到了所需预期,当然,如果你说first epoch的lr,你也需要非常非常小(<1e-8),你也可以自己尝试其它值。

Pytorch 学习率衰减 之 余弦退火与余弦warmup 自定义学习率衰减scheduler相关推荐

  1. 学习率衰减之余弦退火(CosineAnnealing)

    1 引言 当我们使用梯度下降算法来优化目标函数的时候,当越来越接近Loss值的全局最小值时,学习率应该变得更小来使得模型尽可能接近这一点,而余弦退火(Cosine annealing)可以通过余弦函数 ...

  2. PyTorch学习率衰减策略:指数衰减(ExponentialLR)、固定步长衰减(StepLR)、多步长衰减(MultiStepLR)、余弦退火衰减(CosineAnnealingLR)

    梯度下降算法需要我们指定一个学习率作为权重更新步幅的控制因子,常用的学习率有0.01.0.001以及0.0001等,学习率越大则权重更新.一般来说,我们希望在训练初期学习率大一些,使得网络收敛迅速,在 ...

  3. PyTorch学习率 warmup + 余弦退火

    PyTorch学习率 warmup + 余弦退火 Pytorch 余弦退火 PyTorch内置了很多学习率策略,详情请参考torch.optim - PyTorch 1.10.1 documentat ...

  4. 垃圾分类、EfficientNet模型B0~B7、Rectified Adam(RAdam)、Warmup、带有Warmup的余弦退火学习率衰减

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 垃圾分类.EfficientNet模型.数据增强(ImageD ...

  5. 【学习率调整】学习率衰减之周期余弦退火 (cyclic cosine annealing learning rate schedule)

    1. 概述 在论文<SGDR: Stochastic Gradient Descent with Warm Restarts>中主要介绍了带重启的随机梯度下降算法(SGDR),其中就引入了 ...

  6. 【深度学习】(11) 学习率衰减策略(余弦退火衰减,多项式衰减),附TensorFlow完整代码

    大家好,今天和各位分享一下如何使用 TensorFlow 构建 多项式学习率衰减策略.单周期余弦退火学习率衰减策略.多周期余弦退火学习率衰减策略,并使用Mnist数据集来验证构建的方法是否可行. 在上 ...

  7. 【深度学习】(10) 自定义学习率衰减策略(指数、分段、余弦),附TensorFlow完整代码

    大家好,今天和大家分享一下如何使用 TensorFlow 自定义 指数学习率下降.阶梯学习率下降.余弦学习率下降 方法,并使用 Mnist数据集验证自定义的学习率下降策略. 创建的自定义学习率类方法, ...

  8. 【Pytorch神经网络理论篇】 10 优化器模块+退化学习率

    1 优化器模块的作用 1.1 反向传播的核心思想 反向传播的意义在于告诉模型我们需要将权重修改到什么数值可以得到最优解,在开始探索合适权重的过程中,正向传播所生成的结果与实际标签的目标值存在误差,反向 ...

  9. 百面机器学习 #2 模型评估:03 余弦距离和余弦相似度、欧氏距离

    文章目录 余弦相似度 余弦相似度和余弦距离 和欧式距离的比较和关系 余弦距离不是一个严格定义的距离 在模型训练过程中,我们也在不断地评估着样本间的距离,如何评估样本距离也是定义优化目标和训练方法的基础 ...

最新文章

  1. wps电脑版_WPS的前前前前身,是一根绳子?懂点历史没坏处
  2. CNCC 2019 | 计算领域年度盛会—中国计算机大会10月将在苏州举行
  3. Codeforces-868C. Qualification Rounds(状压)
  4. java的System.getProperty()方法能够获取的值
  5. 深度学习之循环神经网络(7)梯度裁剪
  6. php数据库可转java数据库,php转java 系列2 Spring boo 链接数据库jdbc
  7. AndroidStudio_安卓原生开发_保存全局数据---Android原生开发工作笔记141
  8. mysql5.7.11升级_MySQL升级从5.6.18到5.7.11
  9. 哪些Mac快捷键可以精准定位光标位置
  10. 如何配置luogu,codeforces的spj(special judge)
  11. 八条佛曰 66句震撼人心的禅语
  12. tableau server在centos7.6上安装记录
  13. Win7系统如何卸载残留无用驱动设备
  14. 【MW】Drop Materialized View Hangs with 'Enq: JI - Contention'
  15. Javascript特效之删除内容效果
  16. 商品清单计算总和(购物车)
  17. Linux学习笔记(22.2)——基于IIC + Regmap + IIO的AP3216C的设备驱动
  18. 关于Sql语句中的模糊查询like关键字详解
  19. 客服充值退款小套路你们会识别吗?
  20. 如何使用OpenCV进行Delaunay三角剖分和Voronoi图

热门文章

  1. 手机变速齿轮_变速齿轮神途官方版下载-变速齿轮神途手游官方版下载 v2.20190828-114手机乐园...
  2. 同质化游戏做出不同点在于背景音乐
  3. unity游戏场景设计
  4. 人工智能可预测阿茨海默症病情演变
  5. 部署DNS从服务失败,nslookup访问www.linuxprobe.com失败
  6. 一个老程序员写给换行业的朋友的信
  7. 微服务治理之分布式链路追踪--3.zipkin实战
  8. 支付宝小程序状态栏显示图片
  9. python分类汇总_数据分析番外篇13_利用Python实现分类汇总
  10. 三维空间中的旋转--旋转向量