之前我们的优化,主要是聚焦于对梯度下降运动方向的调整,而在参数迭代更新的过程中,除了梯度,还有一个重要的参数是学习率α,对于学习率的调整也是优化的一个重要方面。

01

学习率衰减

首先我们以一个例子,来说明一下我们为什么需要学习率α衰减(learning rate decay)。如果学习率不衰减的话,如下图蓝线所示,由于噪音影响,代价函数更新路径相对不规则,但总体朝着最低点方向移动,但是移动到最低点附近时,由于学习率较大,每一次会移动相对较远的距离,容易直接跨过最低点,导致代价函数在更新完毕后距离最低点仍然相对较远;当我们随着迭代的次数逐渐降低学习率,那么便如绿线所示,一开始学习率较大,前进速度较快,到达最低点附近后,学习率降低到更小的值,于是最终更新完成后代价函数离最低点更近,也就是模型更加优化,预测值与实际值差距更小。

进行学习率衰减的一种方式如下所示,需要设置如下学习率更新规则:

1个epoch是指将所有mini-batch全部迭代一遍,即遍历一遍。假设α0=0.2,decay_rate=1,那么随着epoch增加,学习率α会如下图变化:

在应用这个公式时,我们需要选择合适的超参数α0和decay_rate。除了这种学习率衰减方式,还有一些其他方式来进行学习率衰减:

此外还有离散衰减,经过一段时间衰减一半:

02

学习率衰减的pytorch实现

指数衰减

我们首先需要确定需要针对哪个优化器执行学习率动态调整策略,也就是首先定义一个优化器:

optimizer_ExpLR = torch.optim.SGD(net.parameters(), lr=0.1)

定义好优化器以后,就可以给这个优化器绑定一个指数衰减学习率控制器:

ExpLR = torch.optim.lr_scheduler.ExponentialLR(optimizer_ExpLR, gamma=0.98)

参数gamma表示衰减的底数,也就是decay_rate,选择不同的gamma值可以获得幅度不同的衰减曲线。

固定步长衰减

即离散型衰减,学习率每隔一定步数(或者epoch)就减少为原来的gamma分之一,使用固定步长衰减依旧先定义优化器,再给优化器绑定StepLR对象:

optimizer_StepLR = torch.optim.SGD(net.parameters(), lr=0.1)StepLR = torch.optim.lr_scheduler.StepLR(optimizer_StepLR, step_size=step_size, gamma=0.65)

其中gamma参数表示衰减的程度,step_size参数表示每隔多少个step进行一次学习率调整,下面对比了不同gamma值下的学习率变化情况:

多步长衰减

有时我们希望不同的区间采用不同的更新频率,或者是有的区间更新学习率,有的区间不更新学习率,这就需要使用MultiStepLR来实现动态区间长度控制:

optimizer_MultiStepLR = torch.optim.SGD(net.parameters(), lr=0.1)torch.optim.lr_scheduler.MultiStepLR(optimizer_MultiStepLR,\                milestones=[200, 300, 320, 340, 200], gamma=0.8)

其中milestones参数为表示学习率更新的起止区间,在区间[0. 200]内学习率不更新,而在[200, 300]、[300, 320].....[340, 400]的右侧值都进行一次更新;gamma参数表示学习率衰减为上次的gamma分之一。其图示如下:

从图中可以看出,学习率在区间[200, 400]内快速的下降,这就是milestones参数所控制的,在milestones以外的区间学习率始终保持不变。

余弦退火衰减

严格的说,余弦退火策略不应该算是学习率衰减策略,因为它使得学习率按照周期变化,其定义方式如下:

optimizer_CosineLR = torch.optim.SGD(net.parameters(), lr=0.1)CosineLR = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_CosineLR, T_max=150, eta_min=0)

参数T_max表示余弦函数周期;eta_min表示学习率的最小值,默认它是0表示学习率至少为正值。确定一个余弦函数需要知道最值和周期,其中周期就是T_max,最值是初试学习率。下图展示了不同周期下的余弦学习率更新曲线:

为网络的不同层设置不同的学习率

定义一个简单的网络结构:

class net(nn.Module):    def __init__(self):        super(net, self).__init__()        self.conv1 = nn.Conv2d(3, 64, 1)        self.conv2 = nn.Conv2d(64, 64, 1)        self.conv3 = nn.Conv2d(64, 64, 1)        self.conv4 = nn.Conv2d(64, 64, 1)        self.conv5 = nn.Conv2d(64, 64, 1)    def forward(self, x):        out = conv5(conv4(conv3(conv2(conv1(x)))))        return out

我们希望conv5学习率是其他层的100倍,我们可以:

net = net()lr = 0.001conv5_params = list(map(id, net.conv5.parameters())) # 1 base_params = filter(lambda p: id(p) not in conv5_params,                     net.parameters()) # 2,3optimizer = torch.optim.SGD([            {'params': base_params},            {'params': net.conv5.parameters(), 'lr': lr * 100}], lr=lr, momentum=0.9)

1. conv5_params = list(map(id,net.conv5.parameters()))中id()函数用于获取网络参数的内存地址,map()函数用于将id()函数作用于net.conv5.parameters()得到的每个参数上。2.lambda p: id(p)中lamda表达式是python中用于定义匿名函数的方式,其后面定义的是一个函数操作,冒号前的符号是函数的形式参数,用于接收参数,符合的个数表示需要接收的参数个数,冒号右边是具体的函数操作。3.filter()函数的作用是过滤掉不符合条件(False)的元素,返回一个迭代器对象。该函数接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判断,然后返回True或False,最后返回值为True的元素

Reference

深度学习课程 --吴恩达

https://zhuanlan.zhihu.com/p/93624972

2学习率调整_学习率衰减相关推荐

  1. 【学习率调整】学习率衰减之周期余弦退火 (cyclic cosine annealing learning rate schedule)

    1. 概述 在论文<SGDR: Stochastic Gradient Descent with Warm Restarts>中主要介绍了带重启的随机梯度下降算法(SGDR),其中就引入了 ...

  2. YOLOv5-优化器和学习率调整策略

    优化器和学习率调整策略 pytorch-优化器和学习率调整 这个链接关于优化器和学习率的一些基础讲得很细,还有相关实现代码 优化器 前向传播的过程,会得到模型输出与真实标签的差,我们称之为损失, 有了 ...

  3. Pytorch —— 学习率调整策略

    1.为什么要调整学习率 学习率控制梯度更新的快慢,在训练中,开始时的学习率比较大,梯度更新步伐比较大,后期时学习率比较小,梯度更新步伐比较小. 梯度下降:wi+1=wi−g(wi)w_{i+1}=w_ ...

  4. pytorch优化器学习率调整策略以及正确用法

    优化器 optimzier优化器的作用:优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用. 从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西: ...

  5. PyTorch学习之六个学习率调整策略

    PyTorch学习率调整策略通过torch.optim.lr_scheduler接口实现.PyTorch提供的学习率调整策略分为三大类,分别是 a. 有序调整:等间隔调整(Step),按需调整学习率( ...

  6. 【深度学习】图解 9 种PyTorch中常用的学习率调整策略

    learning rate scheduling 学习率调整策略 01 LAMBDA LR 将每个参数组的学习率设置为初始lr乘以给定函数.当last_epoch=-1时,将初始lr设置为初始值. t ...

  7. PyTorch的六个学习率调整

    本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial 文章目录 一.pytorc ...

  8. PyTorch框架学习十四——学习率调整策略

    PyTorch框架学习十四--学习率调整策略 一._LRScheduler类 二.六种常见的学习率调整策略 1.StepLR 2.MultiStepLR 3.ExponentialLR 4.Cosin ...

  9. 【详解】模型优化技巧之优化器和学习率调整

    目录 PyTorch十大优化器 1 torch.optim.SGD 2 torch.optim.ASGD 3 torch.optim.Rprop 4 torch.optim.Adagrad 5 tor ...

最新文章

  1. pycharm中import呈现灰色原因
  2. 牛客 - 树上子链(树的直径-处理负权)
  3. 数据分析与挖掘实战-家用电器用户行为分析与事件识别
  4. 面试前准备这些,成功率会大大提升!(Java篇)
  5. 12个必备的JavaScript装逼技巧
  6. [转]MySQL5.6.22 安装
  7. 斩获VCR竞赛榜第一,腾讯微视推出BLENDer单模型,超越多模型最好效果
  8. http请求出现406错误解决方案
  9. go语言 Accept error: accept tcp [::]:5551: too many open files;
  10. CNN_原理以及pytorch多分类实践
  11. 如何快速生成100万不重复的8位随机编号?
  12. Scratch3.0安装教程
  13. web前端开发面试题
  14. RGB-D相机原理与选型
  15. docker技术简介
  16. linux ubuntu 播放csf格式视频解决方案
  17. 写好英语科技论文的诀窍: 主动迎合读者期望,预先回答专家可能质疑--周耀旗教授...
  18. Qt 网络聊天室项目
  19. xpath之根据节点获取兄弟节点
  20. 从session里面取得值为null

热门文章

  1. java 修饰_Java 修饰符
  2. visio中公式太小_visio绘图中的数据计算
  3. mac安装完mysql后关机特别慢_mysql-Mac终端下遇到的问题总结
  4. mysql group concat_MySQL 的 GROUP_CONCAT 函数详解
  5. 【分享】 codeReview 的重要性
  6. 文本分析软件_十大针对机器学习的文本注释工具与服务,你选哪个?
  7. 【设计模式 01】简单工厂模式(Simple factory pattern)
  8. linux驱动 cdev,inode结构体
  9. C#串口SerialPort常用属性方法
  10. GNU C 、ANSI C、标准C、标准c++区别和联系