1. pytorch 中学习率的调节策略

(1)等间隔调整学习率 StepLR

(2)按需调整学习率 MultiStepLR

(3)指数衰减调整学习率 ExponentialLR

(4)余弦退火调整学习率 CosineAnnealingLR

(5)自适应调整学习率 ReduceLROnPlateau

(6)自定义调整学习率 LambdaLR

每种学习率的参数详解,见博文:pytorch 学习率参数详解

2. 论文中和比赛中学习率的调节策略

然而在顶会论文和知名比赛中,作者一般都不会直接使用上述学习率调整策略,而是先预热模型(warm up), 即以一个很小的学习率逐步上升到设定的学习率,这样做会使模型的最终收敛效果更好。

下面,小编以warm up + CosineAnnealingLR来实现学习率的调整。训练过程中学习率的变化过程如图中红色曲线所示:

Caption

3. 代码实现

首先,写一个warm up的类,重写get_lr方法。

import torch
from torch.optim.lr_scheduler import _LRSchedulerclass WarmUpLR(_LRScheduler):"""warmup_training learning rate schedulerArgs:optimizer: optimzier(e.g. SGD)total_iters: totoal_iters of warmup phase"""def __init__(self, optimizer, total_iters, last_epoch=-1):self.total_iters = total_iterssuper().__init__(optimizer, last_epoch)def get_lr(self):"""we will use the first m batches, and set the learningrate to base_lr * m / total_iters"""return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]

在训练代码中使用:

    criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)warmup_epoch = 5scheduler = CosineAnnealingLR(optimizer, 100 - warmup_epoch)iter_per_epoch = len(train_dataset)warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * warmup_epoch)for epoch in range(1, max_epoch+1):if epoch >= warmup_epoch:scheduler.step()learn_rate = scheduler.get_lr()[0]print("Learn_rate:%s" % learn_rate)test(epoch, net, valloader, criterion)train(epoch, net, trainloader, optimizer, criterion, warmup_scheduler)

在train函数中的修改:

for (inputs, targets) in tqdm(trainloader):if epoch < 5:warmup_scheduler.step()warm_lr = warmup_scheduler.get_lr()print("warm_lr:%s" % warm_lr)inputs, targets = inputs.to(device), targets.to(device)

4. 总结

在论文中和比赛中一般都会用到warm up技巧,特别是在模型难收敛的任务中。在论文中,MultiStepLR和CosineAnnealingLR两种学习率调节策略用得较多。在知名竞赛中,ReduceLROnPlateau学习率调整策略用得较多。小编在工程项目中是怎么用的呢?一般用warm up结合上述三种调节策略都尝试一遍,最终哪个模型的精度高就用哪个模型。很多情况下,三个模型的精度差不多,精度差距在±0.5%以内。

模型训练技巧——warm up相关推荐

  1. 大模型训练技巧|单卡多卡|训练性能评测

    原视频:[单卡.多卡 BERT.GPT2 训练性能[100亿模型计划]] 此笔记主要参考了李沐老师的视频,感兴趣的同学也可以去看视频- 视频较长,这里放上笔记,与大家分享- 大模型对于计算资源的要求越 ...

  2. 高效又稳定的ChatGPT大模型训练技巧总结,让训练事半功倍!

    文|python 前言 近期,ChatGPT成为了全网热议的话题.ChatGPT是一种基于大规模语言模型技术(LLM, large language model)实现的人机对话工具.现在主流的大规模语 ...

  3. 李宏毅老师《机器学习》课程笔记-2.1模型训练技巧

    注:本文是我学习李宏毅老师<机器学习>课程 2021/2022 的笔记(课程网站 ),文中图片除了两幅是我自己绘制外,其余图片均来自课程 PPT.欢迎交流和多多指教,谢谢! 文章目录 Le ...

  4. 【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧

    1 散度在无监督学习中的应用 在神经网络的损失计算中,最大化和最小化两个数据分布间散度的方法,已经成为无监督模型中有效的训练方法之一. 在无监督模型训练中,不但可以使用K散度JS散度,而且可以使用其他 ...

  5. ML(10) - 模型训练技巧

    模型技巧 交叉验证 Pipeline 网格搜索 偏差(Bias)和方差(Variance) 模型正则化(Regularization) 正则化基本概念 正则化种类(scikit-learn) 交叉验证 ...

  6. 深度学习模型训练技巧

    博主以前都是拿别人的模型别人的数据做做分类啊,做做目标检测,搞搞学习,最近由于导师的工程需求,自己构造网络,用自己的数据来跑网络,才发现模型训练真的是很有讲究,很有技巧在里面,最直接的几个超参数的设置 ...

  7. 计算机视觉中的数据预处理与模型训练技巧总结

    来源丨机器学习小王子,转载自丨极市平台 导读 针对图像分类任务提升准确率的方法主要有两条:一个是模型的修改,另一个是各种数据处理和训练的技巧.本文在精读论文的基础上,总结了图像分类任务的11个tric ...

  8. 【干货】计算机视觉中的数据预处理与模型训练技巧总结

    来源丨机器学习小王子 编辑丨极市平台 针对图像分类任务提升准确率的方法主要有两条:一个是模型的修改,另一个是各种数据处理和训练的技巧.本文在精读论文的基础上,总结了图像分类任务的11个tricks. ...

  9. 模型训练技巧:warmup学习率策略

    1.什么是warmup 学习率的设置 - 不同阶段不同值:上升 -> 平稳 -> 下降 由于神经网络在刚开始训练的时候是非常不稳定的,因此刚开始的学习率应当设置得很低很低,这样可以保证网络 ...

最新文章

  1. 微信小程序发送模板消息,php发送模板消息
  2. Loadrunner11如何使用非IE浏览器录制脚本
  3. 使用LeNet对于旋转数字进行识别:合并数字集合
  4. 第三篇:属性_第二节:控件属性在页面及源码中的表示方式
  5. struts深入理解之登录示例的源码跟踪
  6. 转载:浏览器开发系列第一篇:如何获取最新chromium源码
  7. as it exceeds the max of 500KB._IT狂人第一季 | 如何考察员工
  8. python入门经典27版_【python】编程语言入门经典100例--27
  9. 1.5 本地库与中央库
  10. Ubuntu18.04忘记密码解决
  11. 鲲鹏920的服务器芯片,鲲鹏920芯片是什么芯片
  12. 计算机上显示找不到无线网络连接,电脑怎么找不到无线网络? 笔记本找不到无线网络如何解决?...
  13. 航空航天行业工作站应用---EDA仿真计算工作站
  14. Altium Designer中PCB画多层板(4、6、8...层)
  15. k30pro杀进程严重怎么解决_命运2掉帧严重怎么解决?GoLink免费加速器助力玩家稳定畅玩...
  16. Web应用程序 [/XXX_war_exploded] 注册了JDBC驱动程序 [com.mysql.cj.jdbc.Driver],但在Web应用程序停止时无法注销它。
  17. Selective Search算法-候选框生成
  18. 4、弱电工程FTTH网络的分光建设及分光比设计
  19. MGN:Learning Discriminative Features with Multiple Granularities for Person Re-Identification阅读笔记
  20. 服务器内存不足导致程序(tomcat)崩溃

热门文章

  1. Java操作pdf的工具类itextpdf
  2. 计算机组成原理实验 ram_如何对计算机的RAM超频
  3. 第30章 MySQL 序列使用教程
  4. java swing 聊天气泡_Java Swing中的聊天气泡
  5. cpout引脚是干什么的_单片机引脚的定义与功能详解
  6. 中控,I/O端口,继电器,红外接口,编码器,解码器,主机,名词解释
  7. 详细分析大型web系统各个子系统架构图 纯干货!
  8. 我的第一个web开发环境:基于eclipse java EE 的java web系统搭建
  9. SeqTrack: Sequence to Sequence Learning for Visual Object Tracking
  10. Kafka数据导入导出