torch.optim.lr_scheduler是PyTorch中负责调整学习率的模块,常和torch.optim.Optimizer配合使用。
optimizer模块的源码学习可参见:torch.optim.optimizer源码阅读和灵活使用

class _LRScheduler(object):def __init__(self, optimizer, last_epoch=-1):# 读取相应的Optimizerif not isinstance(optimizer, Optimizer):raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__))self.optimizer = optimizer# last_epoch表示上一轮epoch的序号;若为-1,表示当前训练是从头训练if last_epoch == -1:for group in optimizer.param_groups:group.setdefault('initial_lr', group['lr'])else: # last_epoch不为-1,表示当前训练是断点训练,必须有初始学习率for i, group in enumerate(optimizer.param_groups):if 'initial_lr' not in group:raise KeyError("param 'initial_lr' is not specified in param_groups[{}] when resuming an optimizer".format(i))# 读取每一组待优化变量的初始学习率self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))self.last_epoch = last_epoch# Following https://github.com/pytorch/pytorch/issues/20124# We would like to ensure that `lr_scheduler.step()` is called after# `optimizer.step()`def with_counter(method):if getattr(method, '_with_counter', False):# `optimizer.step()` has already been replaced, return.return method# 建立一个method的弱引用。弱引用不增加对象的引用计数,只存在弱引用的对象是可被垃圾回收的;# 弱引用可以解决循环引用的问题。instance_ref = weakref.ref(method.__self__)# Get the unbound method for the same purpose.func = method.__func__  # __func__是method的底层实现,不跟具体的实例绑定cls = instance_ref().__class__  # method的所属类del method@wraps(func)def wrapper(*args, **kwargs):instance = instance_ref()instance._step_count += 1wrapped = func.__get__(instance, cls)return wrapped(*args, **kwargs)# Note that the returned function here is no longer a bound method,# so attributes like `__func__` and `__self__` no longer exist.wrapper._with_counter = Truereturn wrapper# 通过装饰器来为optimizer.step添加计数功能,并初始化计数器self.optimizer.step = with_counter(self.optimizer.step)self.optimizer._step_count = 0self._step_count = 0self.step() # 更新学习率

lr_scheduler在构造函数中主要是获取optimizer并向其添加step计数功能,然后更新一次学习率。

弱引用相关:Python弱引用的使用

__func__相关:Python(类)实例方法的特殊属性

装饰器和functools.wraps相关:探究functools模块wraps装饰器的用途

step函数:

def step(self, epoch=None):# 由于lr_scheduler在构造函数中已经step过一次,故lr_scheduler.step()一定要在optimizer.step()之后。if self._step_count == 1:if not hasattr(self.optimizer.step, "_with_counter"):warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)     elif self.optimizer._step_count < 1:warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)self._step_count += 1 # lr_scheduler的step计数# 支持上下文管理器协议的类class _enable_get_lr_call:def __init__(self, o):self.o = odef __enter__(self):self.o._get_lr_called_within_step = Truereturn selfdef __exit__(self, type, value, traceback):self.o._get_lr_called_within_step = Falsewith _enable_get_lr_call(self):if epoch is None:  # 未指明从某个具体的epoch开始训练self.last_epoch += 1   # 更新epochvalues = self.get_lr() # 计算新的lr,与具体的lr_scheduler类型有关else:  # 指定epoch# EPOCH_DEPRECATION_WARNING是一个提示信息:epoch参数即将被移除warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)self.last_epoch = epochif hasattr(self, "_get_closed_form_lr"):values = self._get_closed_form_lr() # 正式移除epoch之前的lr计算方法else:values = self.get_lr()# 更新optimizer中保存的lrfor param_group, lr in zip(self.optimizer.param_groups, values):param_group['lr'] = lr# _last_lr记录上一轮次更新的lr值self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

step函数主要进行lr的实时计算以及相关参数的更新,包括epoch、lr和optimizer中保存的实时lr。

上下文管理器相关:python中的__enter__和__exit

学习率计算:get_lr

# 计算当前更新轮次的学习率,与具体的lr更新策略有关,由子类实现
def get_lr(self):# Compute learning rate using chainable form of the schedulerraise NotImplementedError

获得上一轮次训练的lr值:

def get_last_lr(self):""" Return last computed learning rate by current scheduler."""return self._last_lr

获取lr_scheduler的相关参数:

def state_dict(self):"""Returns the state of the scheduler as a :class:`dict`.It contains an entry for every variable in self.__dict__ which is not the optimizer."""# lr_scheduler中虽然有optimizer属性来记录与其相对应的优化器,但state_dict中并不包括优化器参数return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

需要注意的是,lr_scheduler的state_dict返回的是scheduler的所有属性,所有不同的scheduler返回的参数各不相同。
加载已有的lr_scheduler参数:

def load_state_dict(self, state_dict):"""Loads the schedulers state.Arguments:state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`."""self.__dict__.update(state_dict) # 用state_dict更新当前lr_scheduler的参数

__dict__相关:Python ____dict__与dir()区别

以cosine学习率调整策略来具体学习lr_scheduler:

lr计算公式:

ηt+1=ηmin+12(ηt−ηmin)(1+cos⁡(Tcur+1Tmaxπ)1+cos⁡(TcurTmaxπ)),Tcur≠(2k+1)Tmax\eta_{t+1} = \eta_{min} + \frac{1}{2}(\eta_{t} - \eta_{min})\left(\frac{1 + \cos\left(\frac{T_{cur}+1}{T_{max}}\pi\right)}{1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)}\right), T_{cur} \neq (2k+1)T_{max}ηt+1​=ηmin​+21​(ηt​−ηmin​)⎝⎛​1+cos(Tmax​Tcur​​π)1+cos(Tmax​Tcur​+1​π)​⎠⎞​,Tcur​​=(2k+1)Tmax​

ηt+1=ηt+12(ηmax−ηmin)(1−cos⁡(1Tmaxπ)),Tcur=(2k+1)Tmax\eta_{t+1} = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), T_{cur} = (2k+1)T_{max}ηt+1​=ηt​+21​(ηmax​−ηmin​)(1−cos(Tmax​1​π)),Tcur​=(2k+1)Tmax​

由于上述公式是递归式,所以lr可以在get_lr之外被修改,若lr尽在get_lr中计算,则公式可统一为:

ηt=ηmin+12(ηmax−ηmin)(1+cos⁡(TcurTmaxπ))\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)ηt​=ηmin​+21​(ηmax​−ηmin​)(1+cos(Tmax​Tcur​​π))

class CosineAnnealingLR(_LRScheduler):r"""Set the learning rate of each parameter group using a cosine annealingschedule, where :math:`\eta_{max}` is set to the initial lr and:math:`T_{cur}` is the number of epochs since the last restart in SGDR:SGDR: Stochastic Gradient Descent with Warm RestartsArgs:optimizer (Optimizer): Wrapped optimizer.T_max (int): Maximum number of iterations.eta_min (float): Minimum learning rate. Default: 0.last_epoch (int): The index of last epoch. Default: -1."""def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):self.T_max = T_maxself.eta_min = eta_minsuper(CosineAnnealingLR, self).__init__(optimizer, last_epoch)def get_lr(self):if not self._get_lr_called_within_step:warnings.warn("To get the last learning rate computed by the scheduler, ""please use `get_last_lr()`.", UserWarning)# 在lr_scheduler的step函数中,last_epoch+1发生在get_lr之前,故get_lr中的last_epoch是当前更新轮次if self.last_epoch == 0: # step只执行过一次,即当前轮此为0,对应的学习率是初始学习率return self.base_lrselif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:# T_{cur} = (2k+1)T_{max}:return [group['lr'] + (base_lr - self.eta_min) *(1 - math.cos(math.pi / self.T_max)) / 2for base_lr, group inzip(self.base_lrs, self.optimizer.param_groups)]# base_lr是初始学习率,group['lr']是上一轮的学习率# T_{cur} \neq (2k+1)T_{max}:return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *(group['lr'] - self.eta_min) + self.eta_minfor group in self.optimizer.param_groups]# step的辅助函数def _get_closed_form_lr(self):return [self.eta_min + (base_lr - self.eta_min) *(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2for base_lr in self.base_lrs]

torch.optim.lr_scheduler源码和cosine学习率策略学习相关推荐

  1. Pytorch(0)降低学习率torch.optim.lr_scheduler.ReduceLROnPlateau类

    当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能.所使用的类 class torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer ...

  2. pytorch中调整学习率: torch.optim.lr_scheduler

    文章翻译自:https://pytorch.org/docs/stable/optim.html torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法. t ...

  3. torch.optim.lr_scheduler.LambdaLR与OneCycleLR

    目录 LambdaLR 输出 OneCycleLR 输出 LambdaLR 函数接口: LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=Fa ...

  4. 【torch.optim】优化器的使用 / 学习率的调整 / SWA策略

    torch.optim torch.optim是实现各种优化算法的包.大多数常用的方法都已得到支持,而且接口足够通用,因此将来还可以轻松集成更复杂的方法. 优化器 使用优化器 为了使用一个优化器,必须 ...

  5. class torch.optim.lr_scheduler.ExponentialLR

    参考链接: class torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False) 配 ...

  6. class torch.optim.lr_scheduler.StepLR

    参考链接: class torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1, verbose= ...

  7. class torch.optim.lr_scheduler.LambdaLR

    参考链接: class torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False) 配套 ...

  8. ImportError: cannot import name ‘SAVE_STATE_WARNING‘ from ‘torch.optim.lr_scheduler‘ (/home/jsj/anac

    from transformers import BertModel 报错   ImportError: cannot import name 'SAVE_STATE_WARNING' from 't ...

  9. java实验项目代码_java web 期末项目实验源码20套,自用学习非常不错!

    分享java web 期末项目实验源码20套,自用学习非常不错! 我自己也从里面学习到了很多东西! 1.BBS论坛系统(jsp+sql) 2.ERP管理系统(jsp+servlet) 3.OA办公自动 ...

  10. Spring源码深度解析(郝佳)-学习-源码解析-基于注解bean定义(一)

    我们在之前的博客 Spring源码深度解析(郝佳)-学习-ASM 类字节码解析 简单的对字节码结构进行了分析,今天我们站在前面的基础上对Spring中类注解的读取,并创建BeanDefinition做 ...

最新文章

  1. Android 追加写入文件的三种方法
  2. js mysql 住宿系统_[源码和文档分享]基于JavaScript和MySQL实现的酒店管理系统
  3. 企鹅帝国的疯狂反扑!
  4. jQuery仿淘宝商城天猫鼠标移动过去,透明度降低
  5. Discuz1.5 密码错误次数过多,请 15 分钟后重新登录
  6. Kotlin when 流程判断
  7. IDT系列:(二)中断处理过程,使用bochs调试IDT中的中断服务程序
  8. tensorflow.python.framework.errors_impl.NotFoundError: libnvinfer.so.5: cannot open shared object fi
  9. win32程序测试键盘钩子
  10. 轻轻松松,一键获取3000个外链
  11. 我是如何用JSP在网络上架构一个网上招标系统,以推进网站无纸化,过程电子化,管理智能化的发展
  12. java 蓝桥杯算法训练 瓷砖铺放(题解)
  13. [APIO2011] 方格染色
  14. 摄像机、投影、3D旋转、缩放
  15. python 读取npy文件
  16. java strut2通配符_Struts2的通配符
  17. 小学生可以学java编程吗_小学生学编程都要学习哪些内容 家长们知道吗
  18. string容器模拟实现及使用——C++
  19. Unity模拟经营类游戏Demo部分代码及技术总结
  20. 年轻人的第一台挂灯:米家显示器挂灯

热门文章

  1. 盘点≠走过场,哪些功能可以进行高效库存盘点?
  2. openstack controller ha测试环境搭建记录(二)——配置corosync和pacemaker
  3. 使用机器学习进行语言翻译:神经网络和seq2seq为何效果非凡?
  4. 虚幻引擎4(UE4)的基本操作Actor的操作
  5. 好多粉数据上报之360点睛平台ocpc API上报数据方法
  6. 蓝凌OA SSRF+JNDI远程命令执行
  7. 怎样进行结构化思维思考?
  8. 班得瑞[Bandari]音乐介绍
  9. 微信公众号开发工具类
  10. 谷歌浏览器翻译插件推荐——Google Chrome 插件推荐