Gradual warmup lr schedule--pytorch
Gradually warm-up(increasing) learning rate for pytorch’s optimizer. Proposed in ‘Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour’.
# from:https://github.com/ildoonet/pytorch-gradual-warmup-lr
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateauclass GradualWarmupScheduler(_LRScheduler):""" Gradually warm-up(increasing) learning rate in optimizer.Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.Args:optimizer (Optimizer): Wrapped optimizer.multiplier: target learning rate = base lr * multipliertotal_epoch: target learning rate is reached at total_epoch, graduallyafter_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)"""def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):self.multiplier = multiplierif self.multiplier <= 1.:raise ValueError('multiplier should be greater than 1.')self.total_epoch = total_epochself.after_scheduler = after_schedulerself.finished = Falsesuper().__init__(optimizer)def get_lr(self):if self.last_epoch > self.total_epoch:if self.after_scheduler:if not self.finished:self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]self.finished = Truereturn self.after_scheduler.get_lr()return [base_lr * self.multiplier for base_lr in self.base_lrs]return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]def step_ReduceLROnPlateau(self, metrics, epoch=None):if epoch is None:epoch = self.last_epoch + 1self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginningif self.last_epoch <= self.total_epoch:warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):param_group['lr'] = lrelse:if epoch is None:self.after_scheduler.step(metrics, None)else:self.after_scheduler.step(metrics, epoch - self.total_epoch)def step(self, epoch=None, metrics=None):if type(self.after_scheduler) != ReduceLROnPlateau:if self.finished and self.after_scheduler:if epoch is None:self.after_scheduler.step(None)else:self.after_scheduler.step(epoch - self.total_epoch)else:return super(GradualWarmupScheduler, self).step(epoch)else:self.step_ReduceLROnPlateau(metrics, epoch)if __name__ == '__main__':scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_epoch)scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine)for epoch in range(train_epoch):scheduler_warmup.step() # 10 epoch warmup, after that schedule as scheduler_plateau...
Gradual warmup lr schedule--pytorch相关推荐
- Gradual Warmup Scheduler
Gradual Warmup Scheduler 1.介绍 从最开始的小学习率开始,每个iteration增大一点,直到最初设置的比较大的学习率. 其学习率如上图变化所示. 2.实现 这将使用到tor ...
- warmup lr+CosineAnnealingLR策略
warmup lr策略就是在网络训练初期用比较小的学习率,线性增长到初始设定的学习率. 大概就是下面这个趋势,从0上升到0.01,再按照正常的学习率调整策略训练. import torch from ...
- 【trick 5】warmup —— 一种学习率调优方法
目录 一.warmup定义 二.为什么使用warmup 2.1.理性分析 2.2.感性分析 三.常用的warmup 3.1.Constant Warmup 3.2.Linner Warmup 3.3. ...
- PyTorch学习率 warmup + 余弦退火
PyTorch学习率 warmup + 余弦退火 Pytorch 余弦退火 PyTorch内置了很多学习率策略,详情请参考torch.optim - PyTorch 1.10.1 documentat ...
- Warm-up pytorch代码
train_cfg = dict(warmup = 5,lr = [0.004, 0.002, 0.0004, 0.00004, 0.000004],gamma = 0.1,end_lr = 1e-6 ...
- pytorch apex +ddp 分布式训练+数据优化
1.DDP代码添加讲解 https://blog.csdn.net/cdknight_happy/article/details/108262595 2.apex 官网 apex + ddp + 数据 ...
- PyTorch—计算机视觉目标检测 mmdetection
一.前言 商汤和港中文联合开源了 mmdetection-基于 PyTorch 的开源目标检测工具包. 工具包支持 Mask RCNN 等多种流行的检测框架,读者可在 PyTorch 环境下测试不同的 ...
- 目标检测的Tricks | 【Trick5】学习率调优方法——warmup
如有错误,恳请指出. 文章目录 1. warmup理论概要 2. warmup实现代码 1. warmup理论概要 warmup定义: 在模型训练之初选用较小的学习率,训练一段时间之后(如:10epo ...
- Warmup 原理与实现
背景介绍 在神经网络训练过程中,学习率是一个很重要的超参数,学习率的选择对于网络的训练结果有较大的影响. 理论上,如果学习率设置过小,则会出现收敛速度过慢的情况:如果学习率设置过大,则容易错过局部最优 ...
最新文章
- 【LeetCode 剑指offer刷题】树题16:Kth Smallest Element in a BST
- C++:cin、cin.getline()、getline()的用法
- 【PP物料】物料主档备忘录
- 测试结果表明开车打手机比酒后开车更危险
- 电子商务概论_走进经管优质线上课堂(二)之电子商务概论
- django的admin
- C# 静态变量及静态函数
- 配置Becon frame中的Carplay的Interworking和Vendor Specific字段信息
- 电脑使用小常识(2):新手装软件指南,防止流氓软件
- 小程序scroll-view,滚动到最低_小程序滚动到底部
- 全闪存存储的数据库加速场景应用
- WPF x:Key标签
- 纸张折叠多少次能够和珠穆朗玛峰峰一样高
- 计算机二级Python考试心得分享体会总结
- 小度计算机模式,小度机器人怎么用 小度机器人使用教程-电脑教程
- iOS第三方直播的集成
- 如何在arxiv上面发论文
- 我的世界java什么村民卖地图_1.11/1.11.2关于探险地图及制图师村民的一些机制介绍...
- 揭秘:中国企业家十大顶级圈子
- 织梦dede乐信短信插件
热门文章
- 数字化转型顶层设计怎么做?建筑央企数字化转型给出答案
- 2007软件英雄会暨CSDN社区英雄榜颁奖典礼邀请人员名单
- 银河麒麟 安装PL2303GC USB转串口驱动
- 使用多线程实现多客户端的连接(通过Socket实现TCP编程)
- c盘空间不足的一些删除办法
- firewalld防火墙(二)实验案例:ip地址伪装,端口转发
- Vue项目中出现Loading chunk {n} failed问题的解决方法
- python高端实现各国GDP动态轮换图
- 武汉市征集人工智能领域技术成果等通知-2022年申请时间及条件
- C++ 输入输出(cin cout)加速/效率优化