2学习率调整_学习率衰减
01
—
学习率衰减
首先我们以一个例子,来说明一下我们为什么需要学习率α衰减(learning rate decay)。如果学习率不衰减的话,如下图蓝线所示,由于噪音影响,代价函数更新路径相对不规则,但总体朝着最低点方向移动,但是移动到最低点附近时,由于学习率较大,每一次会移动相对较远的距离,容易直接跨过最低点,导致代价函数在更新完毕后距离最低点仍然相对较远;当我们随着迭代的次数逐渐降低学习率,那么便如绿线所示,一开始学习率较大,前进速度较快,到达最低点附近后,学习率降低到更小的值,于是最终更新完成后代价函数离最低点更近,也就是模型更加优化,预测值与实际值差距更小。
进行学习率衰减的一种方式如下所示,需要设置如下学习率更新规则:
1个epoch是指将所有mini-batch全部迭代一遍,即遍历一遍。假设α0=0.2,decay_rate=1,那么随着epoch增加,学习率α会如下图变化:
在应用这个公式时,我们需要选择合适的超参数α0和decay_rate。除了这种学习率衰减方式,还有一些其他方式来进行学习率衰减:
此外还有离散衰减,经过一段时间衰减一半:
02
—
学习率衰减的pytorch实现
指数衰减
我们首先需要确定需要针对哪个优化器执行学习率动态调整策略,也就是首先定义一个优化器:
optimizer_ExpLR = torch.optim.SGD(net.parameters(), lr=0.1)
定义好优化器以后,就可以给这个优化器绑定一个指数衰减学习率控制器:
ExpLR = torch.optim.lr_scheduler.ExponentialLR(optimizer_ExpLR, gamma=0.98)
参数gamma表示衰减的底数,也就是decay_rate,选择不同的gamma值可以获得幅度不同的衰减曲线。
固定步长衰减
即离散型衰减,学习率每隔一定步数(或者epoch)就减少为原来的gamma分之一,使用固定步长衰减依旧先定义优化器,再给优化器绑定StepLR对象:
optimizer_StepLR = torch.optim.SGD(net.parameters(), lr=0.1)StepLR = torch.optim.lr_scheduler.StepLR(optimizer_StepLR, step_size=step_size, gamma=0.65)
其中gamma参数表示衰减的程度,step_size参数表示每隔多少个step进行一次学习率调整,下面对比了不同gamma值下的学习率变化情况:
多步长衰减
有时我们希望不同的区间采用不同的更新频率,或者是有的区间更新学习率,有的区间不更新学习率,这就需要使用MultiStepLR来实现动态区间长度控制:
optimizer_MultiStepLR = torch.optim.SGD(net.parameters(), lr=0.1)torch.optim.lr_scheduler.MultiStepLR(optimizer_MultiStepLR,\ milestones=[200, 300, 320, 340, 200], gamma=0.8)
其中milestones参数为表示学习率更新的起止区间,在区间[0. 200]内学习率不更新,而在[200, 300]、[300, 320].....[340, 400]的右侧值都进行一次更新;gamma参数表示学习率衰减为上次的gamma分之一。其图示如下:
从图中可以看出,学习率在区间[200, 400]内快速的下降,这就是milestones参数所控制的,在milestones以外的区间学习率始终保持不变。
余弦退火衰减
严格的说,余弦退火策略不应该算是学习率衰减策略,因为它使得学习率按照周期变化,其定义方式如下:
optimizer_CosineLR = torch.optim.SGD(net.parameters(), lr=0.1)CosineLR = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_CosineLR, T_max=150, eta_min=0)
参数T_max表示余弦函数周期;eta_min表示学习率的最小值,默认它是0表示学习率至少为正值。确定一个余弦函数需要知道最值和周期,其中周期就是T_max,最值是初试学习率。下图展示了不同周期下的余弦学习率更新曲线:
为网络的不同层设置不同的学习率
定义一个简单的网络结构:
class net(nn.Module): def __init__(self): super(net, self).__init__() self.conv1 = nn.Conv2d(3, 64, 1) self.conv2 = nn.Conv2d(64, 64, 1) self.conv3 = nn.Conv2d(64, 64, 1) self.conv4 = nn.Conv2d(64, 64, 1) self.conv5 = nn.Conv2d(64, 64, 1) def forward(self, x): out = conv5(conv4(conv3(conv2(conv1(x))))) return out
我们希望conv5学习率是其他层的100倍,我们可以:
net = net()lr = 0.001conv5_params = list(map(id, net.conv5.parameters())) # 1 base_params = filter(lambda p: id(p) not in conv5_params, net.parameters()) # 2,3optimizer = torch.optim.SGD([ {'params': base_params}, {'params': net.conv5.parameters(), 'lr': lr * 100}], lr=lr, momentum=0.9)
1. conv5_params = list(map(id,net.conv5.parameters()))中id()函数用于获取网络参数的内存地址,map()函数用于将id()函数作用于net.conv5.parameters()得到的每个参数上。2.lambda p: id(p)中lamda表达式是python中用于定义匿名函数的方式,其后面定义的是一个函数操作,冒号前的符号是函数的形式参数,用于接收参数,符合的个数表示需要接收的参数个数,冒号右边是具体的函数操作。3.filter()函数的作用是过滤掉不符合条件(False)的元素,返回一个迭代器对象。该函数接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判断,然后返回True或False,最后返回值为True的元素
Reference
深度学习课程 --吴恩达
https://zhuanlan.zhihu.com/p/93624972
2学习率调整_学习率衰减相关推荐
- 【学习率调整】学习率衰减之周期余弦退火 (cyclic cosine annealing learning rate schedule)
1. 概述 在论文<SGDR: Stochastic Gradient Descent with Warm Restarts>中主要介绍了带重启的随机梯度下降算法(SGDR),其中就引入了 ...
- YOLOv5-优化器和学习率调整策略
优化器和学习率调整策略 pytorch-优化器和学习率调整 这个链接关于优化器和学习率的一些基础讲得很细,还有相关实现代码 优化器 前向传播的过程,会得到模型输出与真实标签的差,我们称之为损失, 有了 ...
- Pytorch —— 学习率调整策略
1.为什么要调整学习率 学习率控制梯度更新的快慢,在训练中,开始时的学习率比较大,梯度更新步伐比较大,后期时学习率比较小,梯度更新步伐比较小. 梯度下降:wi+1=wi−g(wi)w_{i+1}=w_ ...
- pytorch优化器学习率调整策略以及正确用法
优化器 optimzier优化器的作用:优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用. 从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西: ...
- PyTorch学习之六个学习率调整策略
PyTorch学习率调整策略通过torch.optim.lr_scheduler接口实现.PyTorch提供的学习率调整策略分为三大类,分别是 a. 有序调整:等间隔调整(Step),按需调整学习率( ...
- 【深度学习】图解 9 种PyTorch中常用的学习率调整策略
learning rate scheduling 学习率调整策略 01 LAMBDA LR 将每个参数组的学习率设置为初始lr乘以给定函数.当last_epoch=-1时,将初始lr设置为初始值. t ...
- PyTorch的六个学习率调整
本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial 文章目录 一.pytorc ...
- PyTorch框架学习十四——学习率调整策略
PyTorch框架学习十四--学习率调整策略 一._LRScheduler类 二.六种常见的学习率调整策略 1.StepLR 2.MultiStepLR 3.ExponentialLR 4.Cosin ...
- 【详解】模型优化技巧之优化器和学习率调整
目录 PyTorch十大优化器 1 torch.optim.SGD 2 torch.optim.ASGD 3 torch.optim.Rprop 4 torch.optim.Adagrad 5 tor ...
最新文章
- pycharm中import呈现灰色原因
- 牛客 - 树上子链(树的直径-处理负权)
- 数据分析与挖掘实战-家用电器用户行为分析与事件识别
- 面试前准备这些,成功率会大大提升!(Java篇)
- 12个必备的JavaScript装逼技巧
- [转]MySQL5.6.22 安装
- 斩获VCR竞赛榜第一,腾讯微视推出BLENDer单模型,超越多模型最好效果
- http请求出现406错误解决方案
- go语言 Accept error: accept tcp [::]:5551: too many open files;
- CNN_原理以及pytorch多分类实践
- 如何快速生成100万不重复的8位随机编号?
- Scratch3.0安装教程
- web前端开发面试题
- RGB-D相机原理与选型
- docker技术简介
- linux ubuntu 播放csf格式视频解决方案
- 写好英语科技论文的诀窍: 主动迎合读者期望,预先回答专家可能质疑--周耀旗教授...
- Qt 网络聊天室项目
- xpath之根据节点获取兄弟节点
- 从session里面取得值为null
热门文章
- java 修饰_Java 修饰符
- visio中公式太小_visio绘图中的数据计算
- mac安装完mysql后关机特别慢_mysql-Mac终端下遇到的问题总结
- mysql group concat_MySQL 的 GROUP_CONCAT 函数详解
- 【分享】 codeReview 的重要性
- 文本分析软件_十大针对机器学习的文本注释工具与服务,你选哪个?
- 【设计模式 01】简单工厂模式(Simple factory pattern)
- linux驱动 cdev,inode结构体
- C#串口SerialPort常用属性方法
- GNU C 、ANSI C、标准C、标准c++区别和联系