模型训练技巧——warm up
1. pytorch 中学习率的调节策略
(1)等间隔调整学习率 StepLR
(2)按需调整学习率 MultiStepLR
(3)指数衰减调整学习率 ExponentialLR
(4)余弦退火调整学习率 CosineAnnealingLR
(5)自适应调整学习率 ReduceLROnPlateau
(6)自定义调整学习率 LambdaLR
每种学习率的参数详解,见博文:pytorch 学习率参数详解
2. 论文中和比赛中学习率的调节策略
然而在顶会论文和知名比赛中,作者一般都不会直接使用上述学习率调整策略,而是先预热模型(warm up), 即以一个很小的学习率逐步上升到设定的学习率,这样做会使模型的最终收敛效果更好。
下面,小编以warm up + CosineAnnealingLR来实现学习率的调整。训练过程中学习率的变化过程如图中红色曲线所示:
![](/assets/blank.gif)
3. 代码实现
首先,写一个warm up的类,重写get_lr方法。
import torch
from torch.optim.lr_scheduler import _LRSchedulerclass WarmUpLR(_LRScheduler):"""warmup_training learning rate schedulerArgs:optimizer: optimzier(e.g. SGD)total_iters: totoal_iters of warmup phase"""def __init__(self, optimizer, total_iters, last_epoch=-1):self.total_iters = total_iterssuper().__init__(optimizer, last_epoch)def get_lr(self):"""we will use the first m batches, and set the learningrate to base_lr * m / total_iters"""return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
在训练代码中使用:
criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)warmup_epoch = 5scheduler = CosineAnnealingLR(optimizer, 100 - warmup_epoch)iter_per_epoch = len(train_dataset)warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * warmup_epoch)for epoch in range(1, max_epoch+1):if epoch >= warmup_epoch:scheduler.step()learn_rate = scheduler.get_lr()[0]print("Learn_rate:%s" % learn_rate)test(epoch, net, valloader, criterion)train(epoch, net, trainloader, optimizer, criterion, warmup_scheduler)
在train函数中的修改:
for (inputs, targets) in tqdm(trainloader):if epoch < 5:warmup_scheduler.step()warm_lr = warmup_scheduler.get_lr()print("warm_lr:%s" % warm_lr)inputs, targets = inputs.to(device), targets.to(device)
4. 总结
在论文中和比赛中一般都会用到warm up技巧,特别是在模型难收敛的任务中。在论文中,MultiStepLR和CosineAnnealingLR两种学习率调节策略用得较多。在知名竞赛中,ReduceLROnPlateau学习率调整策略用得较多。小编在工程项目中是怎么用的呢?一般用warm up结合上述三种调节策略都尝试一遍,最终哪个模型的精度高就用哪个模型。很多情况下,三个模型的精度差不多,精度差距在±0.5%以内。
模型训练技巧——warm up相关推荐
- 大模型训练技巧|单卡多卡|训练性能评测
原视频:[单卡.多卡 BERT.GPT2 训练性能[100亿模型计划]] 此笔记主要参考了李沐老师的视频,感兴趣的同学也可以去看视频- 视频较长,这里放上笔记,与大家分享- 大模型对于计算资源的要求越 ...
- 高效又稳定的ChatGPT大模型训练技巧总结,让训练事半功倍!
文|python 前言 近期,ChatGPT成为了全网热议的话题.ChatGPT是一种基于大规模语言模型技术(LLM, large language model)实现的人机对话工具.现在主流的大规模语 ...
- 李宏毅老师《机器学习》课程笔记-2.1模型训练技巧
注:本文是我学习李宏毅老师<机器学习>课程 2021/2022 的笔记(课程网站 ),文中图片除了两幅是我自己绘制外,其余图片均来自课程 PPT.欢迎交流和多多指教,谢谢! 文章目录 Le ...
- 【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧
1 散度在无监督学习中的应用 在神经网络的损失计算中,最大化和最小化两个数据分布间散度的方法,已经成为无监督模型中有效的训练方法之一. 在无监督模型训练中,不但可以使用K散度JS散度,而且可以使用其他 ...
- ML(10) - 模型训练技巧
模型技巧 交叉验证 Pipeline 网格搜索 偏差(Bias)和方差(Variance) 模型正则化(Regularization) 正则化基本概念 正则化种类(scikit-learn) 交叉验证 ...
- 深度学习模型训练技巧
博主以前都是拿别人的模型别人的数据做做分类啊,做做目标检测,搞搞学习,最近由于导师的工程需求,自己构造网络,用自己的数据来跑网络,才发现模型训练真的是很有讲究,很有技巧在里面,最直接的几个超参数的设置 ...
- 计算机视觉中的数据预处理与模型训练技巧总结
来源丨机器学习小王子,转载自丨极市平台 导读 针对图像分类任务提升准确率的方法主要有两条:一个是模型的修改,另一个是各种数据处理和训练的技巧.本文在精读论文的基础上,总结了图像分类任务的11个tric ...
- 【干货】计算机视觉中的数据预处理与模型训练技巧总结
来源丨机器学习小王子 编辑丨极市平台 针对图像分类任务提升准确率的方法主要有两条:一个是模型的修改,另一个是各种数据处理和训练的技巧.本文在精读论文的基础上,总结了图像分类任务的11个tricks. ...
- 模型训练技巧:warmup学习率策略
1.什么是warmup 学习率的设置 - 不同阶段不同值:上升 -> 平稳 -> 下降 由于神经网络在刚开始训练的时候是非常不稳定的,因此刚开始的学习率应当设置得很低很低,这样可以保证网络 ...
最新文章
- 微信小程序发送模板消息,php发送模板消息
- Loadrunner11如何使用非IE浏览器录制脚本
- 使用LeNet对于旋转数字进行识别:合并数字集合
- 第三篇:属性_第二节:控件属性在页面及源码中的表示方式
- struts深入理解之登录示例的源码跟踪
- 转载:浏览器开发系列第一篇:如何获取最新chromium源码
- as it exceeds the max of 500KB._IT狂人第一季 | 如何考察员工
- python入门经典27版_【python】编程语言入门经典100例--27
- 1.5 本地库与中央库
- Ubuntu18.04忘记密码解决
- 鲲鹏920的服务器芯片,鲲鹏920芯片是什么芯片
- 计算机上显示找不到无线网络连接,电脑怎么找不到无线网络? 笔记本找不到无线网络如何解决?...
- 航空航天行业工作站应用---EDA仿真计算工作站
- Altium Designer中PCB画多层板(4、6、8...层)
- k30pro杀进程严重怎么解决_命运2掉帧严重怎么解决?GoLink免费加速器助力玩家稳定畅玩...
- Web应用程序 [/XXX_war_exploded] 注册了JDBC驱动程序 [com.mysql.cj.jdbc.Driver],但在Web应用程序停止时无法注销它。
- Selective Search算法-候选框生成
- 4、弱电工程FTTH网络的分光建设及分光比设计
- MGN:Learning Discriminative Features with Multiple Granularities for Person Re-Identification阅读笔记
- 服务器内存不足导致程序(tomcat)崩溃
热门文章
- Java操作pdf的工具类itextpdf
- 计算机组成原理实验 ram_如何对计算机的RAM超频
- 第30章 MySQL 序列使用教程
- java swing 聊天气泡_Java Swing中的聊天气泡
- cpout引脚是干什么的_单片机引脚的定义与功能详解
- 中控,I/O端口,继电器,红外接口,编码器,解码器,主机,名词解释
- 详细分析大型web系统各个子系统架构图 纯干货!
- 我的第一个web开发环境:基于eclipse java EE 的java web系统搭建
- SeqTrack: Sequence to Sequence Learning for Visual Object Tracking
- Kafka数据导入导出