LR_scheduler

LR_scheduler是用于调节学习率lr的,在代码中,我们经常看到这样的一行代码

 scheduler.step()

通过这行代码来实现lr的更新的,那么其中的底层原理是什么呢?我们就进去看看


在pytorch代码中,各种类型scheduler大多基于_LRScheduler

我们就看看这个类的step()函数到底干了什么

     def step(self, epoch=None):# Raise a warning if old pattern is detected# https://github.com/pytorch/pytorch/issues/20124if self._step_count == 1:if not hasattr(self.optimizer.step, "_with_counter"):warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler ""initialization. Please, make sure to call `optimizer.step()` before ""`lr_scheduler.step()`. See more details at ""https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)​# Just check if there were two first lr_scheduler.step() calls before optimizer.step()elif self.optimizer._step_count < 1:warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. ""In PyTorch 1.1.0 and later, you should call them in the opposite order: ""`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this ""will result in PyTorch skipping the first value of the learning rate schedule. ""See more details at ""https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)self._step_count += 1​class _enable_get_lr_call:​def __init__(self, o):self.o = o​def __enter__(self):self.o._get_lr_called_within_step = Truereturn self​def __exit__(self, type, value, traceback):self.o._get_lr_called_within_step = False​with _enable_get_lr_call(self):if epoch is None:self.last_epoch += 1  # 表示上一个epochvalues = self.get_lr()  # 计算学习率lrelse:warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)self.last_epoch = epoch  # 直接跳转到参数epochif hasattr(self, "_get_closed_form_lr"):values = self._get_closed_form_lr()else:values = self.get_lr()​# 对所有参数权重对应的lr进行修改for i, data in enumerate(zip(self.optimizer.param_groups, values)):param_group, lr = dataparam_group['lr'] = lr # 修改学习率self.print_lr(self.verbose, i, lr, epoch)​self._last_lr = [group['lr'] for group in self.optimizer.param_groups]​

由上代码可知,step()的目的是计算计算新的学习率并对旧学习率进行修改,其中最重要的函数是get_lr(),我们接下来对这个函数进行分析

     def get_lr(self):# Compute learning rate using chainable form of the schedulerraise NotImplementedError

由于_LRScheduler类是一个基类,不表示任何学习率策略,我们选择最简单的StepLR学习策略(学习率阶梯式下降)来分析

     def get_lr(self):if not self._get_lr_called_within_step:warnings.warn("To get the last learning rate computed by the scheduler, ""please use `get_last_lr()`.", UserWarning)​if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): # 表示在一个阶梯上,不改变学习率return [group['lr'] for group in self.optimizer.param_groups]return [group['lr'] * self.gamma # 对所有学习率乘以一个小于1的小数,减小学习率for group in self.optimizer.param_groups]​

如果step()函数中有epoch参数,需要直接跳转到指定epoch,那么直接乘以固定的小数就不对了,这时候就需要函数_get_closed_form_lr()

     def _get_closed_form_lr(self):return [base_lr * self.gamma ** (self.last_epoch // self.step_size)  for base_lr in self.base_lrs]

其中self.last_epoch之前在基类_LRScheduler中已经被赋值了self.last_epoch = epoch ,所以直接根据学习率变化公式计算处理

由上可知,get_lr()和_get_closed_form_lr()就是具体的学习率计算方法

这样,我们就可以根据不同的学习率计算方式设计自己的scheduler类了。


warmup

初始训练阶段,直接使用较大学习率会导致权重变化较大,出现振荡现象,使得模型不稳定,加大训练难度。而使用Warmup预热学习率,在开始的几个epoch,逐步增大学习率,如下图所示,使得模型逐渐趋于稳定,等模型相对稳定后再选择预先设置的基础学习率进行训练,使得模型收敛速度变得更快,模型效果更佳

上图中的0-10epoch阶段就是一个warmup操作,学习率缓慢增加,10之后就是常规的学习率递减算法

原理上很简单,接下来从代码上进行分析,warmup可以有两种构成方式:

  1. 对已有的scheduler类进行包装重构

  2. 直接编写新的类

对于第一种情况,我们以CosineAnnealingLR类为例

         scheduler = CosineAnnealingLR( # pytorch自带的类optimizer=optimizer,eta_min=0.000001,T_max=(epochs - warmup_epoch) * n_iter_per_epoch)  scheduler = GradualWarmupScheduler( # 重构的类optimizer,multiplier=args.warmup_multiplier,after_scheduler=scheduler,warmup_epoch=warmup_epoch * n_iter_per_epoch)​

其中,GradualWarmupScheduler就是基于CosineAnnealingLR重构的类,我们首先查看类中step()函数

     def step(self, epoch=None):if epoch is None:epoch = self.last_epoch + 1self.last_epoch = epochif epoch > self.warmup_epoch: # 超过warmup范围,使用自带的类,也就是CosineAnnealingLRself.after_scheduler.step(epoch - self.warmup_epoch) # 注意CosineAnnealingLR要从0epoch开始,所以需要减去else:super(GradualWarmupScheduler, self).step(epoch)  # warmup范围,使用当前重构类的()​

对于超过warmup范围,直接使用CosineAnnealingLR类,比较简单

对于warmup范围类,使用当前重构类的step()函数,因为也是继承于_LRScheduler类,所以step()同样是运用到get_lr()

     def get_lr(self):if self.last_epoch > self.warmup_epoch:  # 超过warmup范围,使用CosineAnnealingLR类的get_lr()return self.after_scheduler.get_lr()else:  # warmup范围,编写线性变化,也就是上图中0-10区间内的直线return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.)for base_lr in self.base_lrs]​

对于第二种情况,step()无需构造,直接继承_LRScheduler,需要构造的是get_lr()函数,其中warmup范围外的代码与自带的CosineAnnealingLR类中get_lr()代码一样。

LR_scheduler及warmup底层原理和代码分析相关推荐

  1. 对dpdk的rte_ring实现原理和代码分析

    对dpdk的rte_ring实现原理和代码分析 前言 dpdk的rte_ring是借鉴了linux内核的kfifo实现原理,这里统称为无锁环形缓冲队列. 环形缓冲区通常有一个读指针和一个写指针.读指针 ...

  2. TrueCrypt 6.2a原理及代码分析

    TrueCrypt 6.2a原理及代码分析 3 comments 25th Apr 10 rafa 1 项目物理布局 Project     |____ Boot /* MBR部分的代码 */     ...

  3. 免费的Lucene 原理与代码分析完整版下载

    Lucene是一个基于Java的高效的全文检索库. 那么什么是全文检索,为什么需要全文检索? 目前人们生活中出现的数据总的来说分为两类:结构化数据和非结构化数据.很容易理解,结构化数据是有固定格式和结 ...

  4. Lucene 原理与代码分析完整版

    原文地址为: Lucene 原理与代码分析完整版 Lucene 原理与代码分析系列文章已经基本告一段落,可能问题篇还会有新的更新. 完整版pdf可由以下链接下载. Lucene 原理与代码分析完整版 ...

  5. Lucene原理与代码分析(高手博客备忘)

    2019独角兽企业重金招聘Python工程师标准>>> 随笔 - 69  文章 - 77  评论 - 687 随笔分类 - Lucene原理与代码分析 Lucene 4.X 倒排索引 ...

  6. Runtime底层原理总结--反汇编分析消息转发

    消息转发:发送一个消息,也就是sel查找imp,当没有找到imp,接下来进入动态方法解析,如果开发者并没有处理,会进入消息转发. 消息转发 前几篇文章介绍了Runtime底层原理和动态方法解析总结 , ...

  7. OpenStack 虚拟机冷/热迁移的实现原理与代码分析

    目录 文章目录 目录 前文列表 冷迁移代码分析(基于 Newton) Nova 冷迁移实现原理 热迁移代码分析 Nova 热迁移实现原理 向 libvirtd 发出 Live Migration 指令 ...

  8. stm32-通用定时器原理及代码分析

    目录 定时器:基本,通用 一,基本定时器: 作用: 结构图: 二.通用定时器: 作用: 结构图: 三.代码分析: 1.选择时钟 2.配置时基单元 3.产生中断 4.使用定时器 定时器:基本,通用 一, ...

  9. ROS机器人操作系统底层原理及代码剖析

    0 目的 本文介绍ROS机器人操作系统(Robot Operating System)的实现原理,从最底层分析ROS代码是如何实现的. 1 序列化 把通信的内容(也就是消息message)序列化是通信 ...

最新文章

  1. HASHMAP(JDK1.7)最详细原理分析(二)
  2. ajax无刷新方式对form表单进行赋值!
  3. elasticsearch给IK分词器添加自定义词汇
  4. 【解题报告+感想感言】2019年第十届蓝桥杯【C++省赛B组】【第五题:迷宫】
  5. JSON-RPC、XML-RPC、SOAP三者的关系
  6. 仅靠一杯奶茶钱8.8元,你就能转到人工智能专业?
  7. 番茄花园win11 32位官方纯净版镜像v2021.07
  8. 华强北二手手机卖不出去,闲鱼砸一亿现金帮扶
  9. Spring Boot + Spring Cloud 实现权限管理系统 配置中心(Config、Bus)
  10. 矩池云上cifar10使用说明
  11. NOIP2015题解
  12. php如何用if函数算出最大值,在Excel中根据条件用Max函数和IF函数实现求其他数据表的最大值...
  13. 利用爬虫大量抓取网页图片
  14. c语言电脑写程序的软件,c语言编程软件下载电脑版
  15. MATLAB 求导diff
  16. Controller中使用swagger注解的正确姿势
  17. Windows设置程序开机自启动_设置程序开机自启动的几种方法_添加启动项
  18. 贪心法--->1.会议安排问题
  19. xbox手柄映射_如何在Windows 10中重新映射Xbox One控制器的按钮
  20. Intel VT-d(1)- 简介

热门文章

  1. 母亲节倒计时,选礼物救急指南
  2. Anaconda安装Tensorflow-GPU
  3. 组织结构 - 职能型,矩阵型和项目型
  4. 又一个网友放生后拍出红莲瓣!
  5. Unity 游戏设计模式 — 策略模式(Strategy)
  6. php基础知识速记,php基础速记
  7. 裸辞2个月找不到工作,我慌了
  8. 游戏辅助原理与制作02-植物大战僵尸00-概述
  9. 【历史上的今天】8 月 4 日:第一位图灵奖女性得主;NVIDIA 收购 MediaQ;首届网络安全挑战大赛完成
  10. python 图片转换成py文件