torch.optim.lr_scheduler源码和cosine学习率策略学习
torch.optim.lr_scheduler是PyTorch中负责调整学习率的模块,常和torch.optim.Optimizer配合使用。
optimizer模块的源码学习可参见:torch.optim.optimizer源码阅读和灵活使用
class _LRScheduler(object):def __init__(self, optimizer, last_epoch=-1):# 读取相应的Optimizerif not isinstance(optimizer, Optimizer):raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__))self.optimizer = optimizer# last_epoch表示上一轮epoch的序号;若为-1,表示当前训练是从头训练if last_epoch == -1:for group in optimizer.param_groups:group.setdefault('initial_lr', group['lr'])else: # last_epoch不为-1,表示当前训练是断点训练,必须有初始学习率for i, group in enumerate(optimizer.param_groups):if 'initial_lr' not in group:raise KeyError("param 'initial_lr' is not specified in param_groups[{}] when resuming an optimizer".format(i))# 读取每一组待优化变量的初始学习率self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))self.last_epoch = last_epoch# Following https://github.com/pytorch/pytorch/issues/20124# We would like to ensure that `lr_scheduler.step()` is called after# `optimizer.step()`def with_counter(method):if getattr(method, '_with_counter', False):# `optimizer.step()` has already been replaced, return.return method# 建立一个method的弱引用。弱引用不增加对象的引用计数,只存在弱引用的对象是可被垃圾回收的;# 弱引用可以解决循环引用的问题。instance_ref = weakref.ref(method.__self__)# Get the unbound method for the same purpose.func = method.__func__ # __func__是method的底层实现,不跟具体的实例绑定cls = instance_ref().__class__ # method的所属类del method@wraps(func)def wrapper(*args, **kwargs):instance = instance_ref()instance._step_count += 1wrapped = func.__get__(instance, cls)return wrapped(*args, **kwargs)# Note that the returned function here is no longer a bound method,# so attributes like `__func__` and `__self__` no longer exist.wrapper._with_counter = Truereturn wrapper# 通过装饰器来为optimizer.step添加计数功能,并初始化计数器self.optimizer.step = with_counter(self.optimizer.step)self.optimizer._step_count = 0self._step_count = 0self.step() # 更新学习率
lr_scheduler在构造函数中主要是获取optimizer并向其添加step计数功能,然后更新一次学习率。
弱引用相关:Python弱引用的使用
__func__相关:Python(类)实例方法的特殊属性
装饰器和functools.wraps相关:探究functools模块wraps装饰器的用途
step函数:
def step(self, epoch=None):# 由于lr_scheduler在构造函数中已经step过一次,故lr_scheduler.step()一定要在optimizer.step()之后。if self._step_count == 1:if not hasattr(self.optimizer.step, "_with_counter"):warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) elif self.optimizer._step_count < 1:warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)self._step_count += 1 # lr_scheduler的step计数# 支持上下文管理器协议的类class _enable_get_lr_call:def __init__(self, o):self.o = odef __enter__(self):self.o._get_lr_called_within_step = Truereturn selfdef __exit__(self, type, value, traceback):self.o._get_lr_called_within_step = Falsewith _enable_get_lr_call(self):if epoch is None: # 未指明从某个具体的epoch开始训练self.last_epoch += 1 # 更新epochvalues = self.get_lr() # 计算新的lr,与具体的lr_scheduler类型有关else: # 指定epoch# EPOCH_DEPRECATION_WARNING是一个提示信息:epoch参数即将被移除warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)self.last_epoch = epochif hasattr(self, "_get_closed_form_lr"):values = self._get_closed_form_lr() # 正式移除epoch之前的lr计算方法else:values = self.get_lr()# 更新optimizer中保存的lrfor param_group, lr in zip(self.optimizer.param_groups, values):param_group['lr'] = lr# _last_lr记录上一轮次更新的lr值self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
step函数主要进行lr的实时计算以及相关参数的更新,包括epoch、lr和optimizer中保存的实时lr。
上下文管理器相关:python中的__enter__和__exit
学习率计算:get_lr
# 计算当前更新轮次的学习率,与具体的lr更新策略有关,由子类实现
def get_lr(self):# Compute learning rate using chainable form of the schedulerraise NotImplementedError
获得上一轮次训练的lr值:
def get_last_lr(self):""" Return last computed learning rate by current scheduler."""return self._last_lr
获取lr_scheduler的相关参数:
def state_dict(self):"""Returns the state of the scheduler as a :class:`dict`.It contains an entry for every variable in self.__dict__ which is not the optimizer."""# lr_scheduler中虽然有optimizer属性来记录与其相对应的优化器,但state_dict中并不包括优化器参数return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
需要注意的是,lr_scheduler的state_dict返回的是scheduler的所有属性,所有不同的scheduler返回的参数各不相同。
加载已有的lr_scheduler参数:
def load_state_dict(self, state_dict):"""Loads the schedulers state.Arguments:state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`."""self.__dict__.update(state_dict) # 用state_dict更新当前lr_scheduler的参数
__dict__相关:Python ____dict__与dir()区别
以cosine学习率调整策略来具体学习lr_scheduler:
lr计算公式:
ηt+1=ηmin+12(ηt−ηmin)(1+cos(Tcur+1Tmaxπ)1+cos(TcurTmaxπ)),Tcur≠(2k+1)Tmax\eta_{t+1} = \eta_{min} + \frac{1}{2}(\eta_{t} - \eta_{min})\left(\frac{1 + \cos\left(\frac{T_{cur}+1}{T_{max}}\pi\right)}{1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)}\right), T_{cur} \neq (2k+1)T_{max}ηt+1=ηmin+21(ηt−ηmin)⎝⎛1+cos(TmaxTcurπ)1+cos(TmaxTcur+1π)⎠⎞,Tcur=(2k+1)Tmax
ηt+1=ηt+12(ηmax−ηmin)(1−cos(1Tmaxπ)),Tcur=(2k+1)Tmax\eta_{t+1} = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), T_{cur} = (2k+1)T_{max}ηt+1=ηt+21(ηmax−ηmin)(1−cos(Tmax1π)),Tcur=(2k+1)Tmax
由于上述公式是递归式,所以lr可以在get_lr之外被修改,若lr尽在get_lr中计算,则公式可统一为:
ηt=ηmin+12(ηmax−ηmin)(1+cos(TcurTmaxπ))\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)ηt=ηmin+21(ηmax−ηmin)(1+cos(TmaxTcurπ))
class CosineAnnealingLR(_LRScheduler):r"""Set the learning rate of each parameter group using a cosine annealingschedule, where :math:`\eta_{max}` is set to the initial lr and:math:`T_{cur}` is the number of epochs since the last restart in SGDR:SGDR: Stochastic Gradient Descent with Warm RestartsArgs:optimizer (Optimizer): Wrapped optimizer.T_max (int): Maximum number of iterations.eta_min (float): Minimum learning rate. Default: 0.last_epoch (int): The index of last epoch. Default: -1."""def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):self.T_max = T_maxself.eta_min = eta_minsuper(CosineAnnealingLR, self).__init__(optimizer, last_epoch)def get_lr(self):if not self._get_lr_called_within_step:warnings.warn("To get the last learning rate computed by the scheduler, ""please use `get_last_lr()`.", UserWarning)# 在lr_scheduler的step函数中,last_epoch+1发生在get_lr之前,故get_lr中的last_epoch是当前更新轮次if self.last_epoch == 0: # step只执行过一次,即当前轮此为0,对应的学习率是初始学习率return self.base_lrselif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:# T_{cur} = (2k+1)T_{max}:return [group['lr'] + (base_lr - self.eta_min) *(1 - math.cos(math.pi / self.T_max)) / 2for base_lr, group inzip(self.base_lrs, self.optimizer.param_groups)]# base_lr是初始学习率,group['lr']是上一轮的学习率# T_{cur} \neq (2k+1)T_{max}:return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *(group['lr'] - self.eta_min) + self.eta_minfor group in self.optimizer.param_groups]# step的辅助函数def _get_closed_form_lr(self):return [self.eta_min + (base_lr - self.eta_min) *(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2for base_lr in self.base_lrs]
torch.optim.lr_scheduler源码和cosine学习率策略学习相关推荐
- Pytorch(0)降低学习率torch.optim.lr_scheduler.ReduceLROnPlateau类
当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能.所使用的类 class torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer ...
- pytorch中调整学习率: torch.optim.lr_scheduler
文章翻译自:https://pytorch.org/docs/stable/optim.html torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法. t ...
- torch.optim.lr_scheduler.LambdaLR与OneCycleLR
目录 LambdaLR 输出 OneCycleLR 输出 LambdaLR 函数接口: LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=Fa ...
- 【torch.optim】优化器的使用 / 学习率的调整 / SWA策略
torch.optim torch.optim是实现各种优化算法的包.大多数常用的方法都已得到支持,而且接口足够通用,因此将来还可以轻松集成更复杂的方法. 优化器 使用优化器 为了使用一个优化器,必须 ...
- class torch.optim.lr_scheduler.ExponentialLR
参考链接: class torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False) 配 ...
- class torch.optim.lr_scheduler.StepLR
参考链接: class torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1, verbose= ...
- class torch.optim.lr_scheduler.LambdaLR
参考链接: class torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False) 配套 ...
- ImportError: cannot import name ‘SAVE_STATE_WARNING‘ from ‘torch.optim.lr_scheduler‘ (/home/jsj/anac
from transformers import BertModel 报错 ImportError: cannot import name 'SAVE_STATE_WARNING' from 't ...
- java实验项目代码_java web 期末项目实验源码20套,自用学习非常不错!
分享java web 期末项目实验源码20套,自用学习非常不错! 我自己也从里面学习到了很多东西! 1.BBS论坛系统(jsp+sql) 2.ERP管理系统(jsp+servlet) 3.OA办公自动 ...
- Spring源码深度解析(郝佳)-学习-源码解析-基于注解bean定义(一)
我们在之前的博客 Spring源码深度解析(郝佳)-学习-ASM 类字节码解析 简单的对字节码结构进行了分析,今天我们站在前面的基础上对Spring中类注解的读取,并创建BeanDefinition做 ...
最新文章
- Android 追加写入文件的三种方法
- js mysql 住宿系统_[源码和文档分享]基于JavaScript和MySQL实现的酒店管理系统
- 企鹅帝国的疯狂反扑!
- jQuery仿淘宝商城天猫鼠标移动过去,透明度降低
- Discuz1.5 密码错误次数过多,请 15 分钟后重新登录
- Kotlin when 流程判断
- IDT系列:(二)中断处理过程,使用bochs调试IDT中的中断服务程序
- tensorflow.python.framework.errors_impl.NotFoundError: libnvinfer.so.5: cannot open shared object fi
- win32程序测试键盘钩子
- 轻轻松松,一键获取3000个外链
- 我是如何用JSP在网络上架构一个网上招标系统,以推进网站无纸化,过程电子化,管理智能化的发展
- java 蓝桥杯算法训练 瓷砖铺放(题解)
- [APIO2011] 方格染色
- 摄像机、投影、3D旋转、缩放
- python 读取npy文件
- java strut2通配符_Struts2的通配符
- 小学生可以学java编程吗_小学生学编程都要学习哪些内容 家长们知道吗
- string容器模拟实现及使用——C++
- Unity模拟经营类游戏Demo部分代码及技术总结
- 年轻人的第一台挂灯:米家显示器挂灯