LR_scheduler及warmup底层原理和代码分析
LR_scheduler
LR_scheduler是用于调节学习率lr的,在代码中,我们经常看到这样的一行代码
scheduler.step()
通过这行代码来实现lr的更新的,那么其中的底层原理是什么呢?我们就进去看看
在pytorch代码中,各种类型scheduler大多基于_LRScheduler类
我们就看看这个类的step()函数到底干了什么
def step(self, epoch=None):# Raise a warning if old pattern is detected# https://github.com/pytorch/pytorch/issues/20124if self._step_count == 1:if not hasattr(self.optimizer.step, "_with_counter"):warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler ""initialization. Please, make sure to call `optimizer.step()` before ""`lr_scheduler.step()`. See more details at ""https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)# Just check if there were two first lr_scheduler.step() calls before optimizer.step()elif self.optimizer._step_count < 1:warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. ""In PyTorch 1.1.0 and later, you should call them in the opposite order: ""`optimizer.step()` before `lr_scheduler.step()`. Failure to do this ""will result in PyTorch skipping the first value of the learning rate schedule. ""See more details at ""https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)self._step_count += 1class _enable_get_lr_call:def __init__(self, o):self.o = odef __enter__(self):self.o._get_lr_called_within_step = Truereturn selfdef __exit__(self, type, value, traceback):self.o._get_lr_called_within_step = Falsewith _enable_get_lr_call(self):if epoch is None:self.last_epoch += 1 # 表示上一个epochvalues = self.get_lr() # 计算学习率lrelse:warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)self.last_epoch = epoch # 直接跳转到参数epochif hasattr(self, "_get_closed_form_lr"):values = self._get_closed_form_lr()else:values = self.get_lr()# 对所有参数权重对应的lr进行修改for i, data in enumerate(zip(self.optimizer.param_groups, values)):param_group, lr = dataparam_group['lr'] = lr # 修改学习率self.print_lr(self.verbose, i, lr, epoch)self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
由上代码可知,step()的目的是计算计算新的学习率并对旧学习率进行修改,其中最重要的函数是get_lr(),我们接下来对这个函数进行分析
def get_lr(self):# Compute learning rate using chainable form of the schedulerraise NotImplementedError
由于_LRScheduler类是一个基类,不表示任何学习率策略,我们选择最简单的StepLR学习策略(学习率阶梯式下降)来分析
def get_lr(self):if not self._get_lr_called_within_step:warnings.warn("To get the last learning rate computed by the scheduler, ""please use `get_last_lr()`.", UserWarning)if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): # 表示在一个阶梯上,不改变学习率return [group['lr'] for group in self.optimizer.param_groups]return [group['lr'] * self.gamma # 对所有学习率乘以一个小于1的小数,减小学习率for group in self.optimizer.param_groups]
如果step()函数中有epoch参数,需要直接跳转到指定epoch,那么直接乘以固定的小数就不对了,这时候就需要函数_get_closed_form_lr()
def _get_closed_form_lr(self):return [base_lr * self.gamma ** (self.last_epoch // self.step_size) for base_lr in self.base_lrs]
其中self.last_epoch之前在基类_LRScheduler中已经被赋值了self.last_epoch = epoch ,所以直接根据学习率变化公式计算处理
由上可知,get_lr()和_get_closed_form_lr()就是具体的学习率计算方法
这样,我们就可以根据不同的学习率计算方式设计自己的scheduler类了。
warmup
初始训练阶段,直接使用较大学习率会导致权重变化较大,出现振荡现象,使得模型不稳定,加大训练难度。而使用Warmup预热学习率,在开始的几个epoch,逐步增大学习率,如下图所示,使得模型逐渐趋于稳定,等模型相对稳定后再选择预先设置的基础学习率进行训练,使得模型收敛速度变得更快,模型效果更佳
上图中的0-10epoch阶段就是一个warmup操作,学习率缓慢增加,10之后就是常规的学习率递减算法
原理上很简单,接下来从代码上进行分析,warmup可以有两种构成方式:
对已有的scheduler类进行包装重构
直接编写新的类
对于第一种情况,我们以CosineAnnealingLR类为例
scheduler = CosineAnnealingLR( # pytorch自带的类optimizer=optimizer,eta_min=0.000001,T_max=(epochs - warmup_epoch) * n_iter_per_epoch) scheduler = GradualWarmupScheduler( # 重构的类optimizer,multiplier=args.warmup_multiplier,after_scheduler=scheduler,warmup_epoch=warmup_epoch * n_iter_per_epoch)
其中,GradualWarmupScheduler就是基于CosineAnnealingLR重构的类,我们首先查看类中step()函数
def step(self, epoch=None):if epoch is None:epoch = self.last_epoch + 1self.last_epoch = epochif epoch > self.warmup_epoch: # 超过warmup范围,使用自带的类,也就是CosineAnnealingLRself.after_scheduler.step(epoch - self.warmup_epoch) # 注意CosineAnnealingLR要从0epoch开始,所以需要减去else:super(GradualWarmupScheduler, self).step(epoch) # warmup范围,使用当前重构类的()
对于超过warmup范围,直接使用CosineAnnealingLR类,比较简单
对于warmup范围类,使用当前重构类的step()函数,因为也是继承于_LRScheduler类,所以step()同样是运用到get_lr()
def get_lr(self):if self.last_epoch > self.warmup_epoch: # 超过warmup范围,使用CosineAnnealingLR类的get_lr()return self.after_scheduler.get_lr()else: # warmup范围,编写线性变化,也就是上图中0-10区间内的直线return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.)for base_lr in self.base_lrs]
对于第二种情况,step()无需构造,直接继承_LRScheduler,需要构造的是get_lr()函数,其中warmup范围外的代码与自带的CosineAnnealingLR类中get_lr()代码一样。
LR_scheduler及warmup底层原理和代码分析相关推荐
- 对dpdk的rte_ring实现原理和代码分析
对dpdk的rte_ring实现原理和代码分析 前言 dpdk的rte_ring是借鉴了linux内核的kfifo实现原理,这里统称为无锁环形缓冲队列. 环形缓冲区通常有一个读指针和一个写指针.读指针 ...
- TrueCrypt 6.2a原理及代码分析
TrueCrypt 6.2a原理及代码分析 3 comments 25th Apr 10 rafa 1 项目物理布局 Project |____ Boot /* MBR部分的代码 */ ...
- 免费的Lucene 原理与代码分析完整版下载
Lucene是一个基于Java的高效的全文检索库. 那么什么是全文检索,为什么需要全文检索? 目前人们生活中出现的数据总的来说分为两类:结构化数据和非结构化数据.很容易理解,结构化数据是有固定格式和结 ...
- Lucene 原理与代码分析完整版
原文地址为: Lucene 原理与代码分析完整版 Lucene 原理与代码分析系列文章已经基本告一段落,可能问题篇还会有新的更新. 完整版pdf可由以下链接下载. Lucene 原理与代码分析完整版 ...
- Lucene原理与代码分析(高手博客备忘)
2019独角兽企业重金招聘Python工程师标准>>> 随笔 - 69 文章 - 77 评论 - 687 随笔分类 - Lucene原理与代码分析 Lucene 4.X 倒排索引 ...
- Runtime底层原理总结--反汇编分析消息转发
消息转发:发送一个消息,也就是sel查找imp,当没有找到imp,接下来进入动态方法解析,如果开发者并没有处理,会进入消息转发. 消息转发 前几篇文章介绍了Runtime底层原理和动态方法解析总结 , ...
- OpenStack 虚拟机冷/热迁移的实现原理与代码分析
目录 文章目录 目录 前文列表 冷迁移代码分析(基于 Newton) Nova 冷迁移实现原理 热迁移代码分析 Nova 热迁移实现原理 向 libvirtd 发出 Live Migration 指令 ...
- stm32-通用定时器原理及代码分析
目录 定时器:基本,通用 一,基本定时器: 作用: 结构图: 二.通用定时器: 作用: 结构图: 三.代码分析: 1.选择时钟 2.配置时基单元 3.产生中断 4.使用定时器 定时器:基本,通用 一, ...
- ROS机器人操作系统底层原理及代码剖析
0 目的 本文介绍ROS机器人操作系统(Robot Operating System)的实现原理,从最底层分析ROS代码是如何实现的. 1 序列化 把通信的内容(也就是消息message)序列化是通信 ...
最新文章
- HASHMAP(JDK1.7)最详细原理分析(二)
- ajax无刷新方式对form表单进行赋值!
- elasticsearch给IK分词器添加自定义词汇
- 【解题报告+感想感言】2019年第十届蓝桥杯【C++省赛B组】【第五题:迷宫】
- JSON-RPC、XML-RPC、SOAP三者的关系
- 仅靠一杯奶茶钱8.8元,你就能转到人工智能专业?
- 番茄花园win11 32位官方纯净版镜像v2021.07
- 华强北二手手机卖不出去,闲鱼砸一亿现金帮扶
- Spring Boot + Spring Cloud 实现权限管理系统 配置中心(Config、Bus)
- 矩池云上cifar10使用说明
- NOIP2015题解
- php如何用if函数算出最大值,在Excel中根据条件用Max函数和IF函数实现求其他数据表的最大值...
- 利用爬虫大量抓取网页图片
- c语言电脑写程序的软件,c语言编程软件下载电脑版
- MATLAB 求导diff
- Controller中使用swagger注解的正确姿势
- Windows设置程序开机自启动_设置程序开机自启动的几种方法_添加启动项
- 贪心法--->1.会议安排问题
- xbox手柄映射_如何在Windows 10中重新映射Xbox One控制器的按钮
- Intel VT-d(1)- 简介