找了一晚上warmup资料,有用的很少,基本都是互相转载,要不就是讲的很空泛,代码没有可使用的价值。但是最后我还是解决了,于是写一个warmup教程造福大家,这里抛砖引玉了。

一、介绍GradualWarmupScheduler

GradualWarmupScheduler(optimizer, multiplier, total_epoch, after_scheduler)

参数解释
optimizer:优化器
multiplier:当multiplier=1.0时,学习率lr从0开始增到base_lr为止,当multiplier大于1.0时,学习率lr从base_lr开始增到base_lr*multiplier为止。multiplier不能小于1.0。
【那么base_lr又是什么?】
【就是传入优化器的lr,例如optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay),base_lr就是learning_rate】
total_epoch:在total_epoch个epoch后达到目标学习率,也就是warmup持续的代数
after_scheduler:在经过total_epoch个epoch以后,所使用的学习率策略。

如果想了解更多细节,比如每个epoch的步长计算方式,可以查看源代码实现链接::pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py

二、如何使用GradualWarmupScheduler

这里以我的代码为例,简单讲解下如何使用GradualWarmupScheduler。

  1. 如果没有warmup_scheduler包的话,需要安装:pip install warmup_scheduler

  2. 我先实现了optimizer,schedular_r(这里我定义的策略是如果测试准确率【mode=‘max’】连续三代不上升【patience=3】,则学习率变为原学习率的0.1倍【factor=0.1】),最后再实现schedular。

  3. schedular的含义是经过10代【total_epoch=10】warm up,学习率由0.01(base_lr)逐渐上升至0.1【multiplier=10】,从第11代开始学习率策略将按照schedular_r进行衰减【after_scheduler=schedular_r】,也就是我第二点中介绍的。

  4. 值得注意的是schedular.step(metrics=test_acc)是在每个epoch进行迭代,且由于我后续使用的策略是ReduceLROnPlateau,所以这里需要传入一个参数metrics=test_acc。

from warmup_scheduler import GradualWarmupSchedulerdef train(net, device, epochs, learning_rate, weight_decay):optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)# if loss do not change for 5 epochs, change lr*0.1schedular_r = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True, eps=1e-5)schedular = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=10, after_scheduler=schedular_r)#schedular = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-5, T_max=5)initepoch = 0loss = nn.CrossEntropyLoss()best_test_acc = 0for epoch in range(initepoch, epochs):  # loop over the dataset multiple timesnet.train()timestart = time.time()running_loss = 0.0total = 0correct = 0#print(optimizer.param_groups[0]['lr'])for i, data in tqdm(enumerate(train_iter)):# get the inputsinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)l = loss(outputs, labels)l.backward()optimizer.step()running_loss += l.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()train_acc = 100.0 * correct / totalprint('epoch %d, loss: %.4f,tran Acc: %.3f%%,time:%3f sec, lr: %.7f'% (epoch+1, running_loss, train_acc, time.time() - timestart, optimizer.param_groups[0]['lr']))print(schedular.last_epoch)# testnet.eval()total = 0correct = 0with torch.no_grad():for data in tqdm(test_iter):images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)# print(outputs.shape)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_acc = 100.0 * correct / totalprint('test Acc: %.3f%%' % (test_acc))# if epoch > 30:#     torch.save(net.state_dict(), '/root/Desktop/cifar-100/checkpoint_512_512_100/' + str(test_acc) + '_Resnet18.pth')if test_acc > best_test_acc:print('find best! save at checkpoint/cnn_best.pth')best_test_acc = test_accbest_epoch = epochtorch.save(net.state_dict(),'/root/Desktop/exps/cifar100/model_512_100_IF10/best_' + str(best_test_acc) + '_' + str(train_acc) + '_resnet34.pth')schedular.step(metrics=test_acc)print('Finished Training')print('best test acc epoch: %d' % epoch+1)

【介绍+代码实现】使用GradualWarmupScheduler进行学习率预热相关推荐

  1. 【学习率预热】Warm up

    1.什么是warm up   Warmup是在ResNet论文中提到的一种学习率预热的方法,它在训练开始的时候先选择使用一个较小的学习率,训练了一些epoches或者steps(比如4个epoches ...

  2. MATLAB 画三维长方体 介绍+代码

    MATLAB 画三维长方体 介绍+代码 在做机械臂三维避障仿真时可能用到对空间障碍物进行描述,一般用长方体,圆柱体等描述,以下是两种画长方体的程序,第一种是指定长方体的八个顶点坐标,第二种是指定长方体 ...

  3. 十九.激光和惯导LIO-SLAM框架学习之项目工程代码介绍---代码框架和一些文件解释

    专栏系列文章如下: 一:Tixiao Shan最新力作LVI-SAM(Lio-SAM+Vins-Mono),基于视觉-激光-惯导里程计的SLAM框架,环境搭建和跑通过程_goldqiu的博客-CSDN ...

  4. Web中html个人介绍代码,web开发工程师自我介绍示例

    web开发工程师自我介绍示例 Web前端开发工程师,主要职责是利用(X)HTML/CSS/JavaScript/Flash等各种Web技术进行客户端产品的开发.小编整理了web开发工程师自我介绍示例, ...

  5. 23.Flink-高级特性-新特性-Streaming Flie Sink\介绍\代码演示\Flink-高级特性-新特性-FlinkSQL整合Hive\添加依赖和jar包和配置

    23.Flink-高级特性-新特性-Streaming Flie Sink 23.1.介绍 23.2.代码演示 24.Flink-高级特性-新特性-FlinkSQL整合Hive 24.1.介绍 24. ...

  6. GaitPart学习笔记(主干网络论文介绍+代码讲解)

    论文 1 Introduction ​ 先前的研究都是将人体的整个步态作为网络输入进行特征提取,而本文最大的亮点在于发现人体步态的不同部分在形状以及行走时的移动模式上具有显著的区别,并且这些信息将为网 ...

  7. 时序预测工具库(Prophet)介绍+代码

    时序预测工具库(Prophet) 一.Prophet 简介 二.Prophet 适用场景 三.Prophet 算法的输入输出 四.Prophet 算法原理 五.与机器学习算法的对比 六.代码 6.1 ...

  8. java个人介绍代码_个人项目WC(Java)

    1.WC项目要求 wc.exe 是一个常见的工具,它能统计文本文件的字符数.单词数和行数.这个项目要求写一个命令行程序,模仿已有wc.exe 的功能,并加以扩充,给出某程序设计语言源文件的字符数.单词 ...

  9. 学习率预热warmup

    学习率衰减 学习率:权重更新的控制因子 训练神经网络的常规策略是初始时使用较大的学习率(好处是能使网络收敛迅速),随着训练的进行,学习率衰减: warmup是什么? 在训练初期,loss很大,因此计算 ...

最新文章

  1. Python之闭包、装饰器及相关习题练习
  2. nessus国内用户不让免费使用了!
  3. ffmpeg 声音参数_ffmpeg转换参数和压缩输出大小的比率
  4. fedora 安装Linux源码,如何在 Fedora 29/30 上安装 VS Code
  5. 《程序员面试宝典》精华 面向对象部分
  6. python禁用警告
  7. Python:for的多种写法
  8. 被全球 iPhone 用户讨伐 49 天后,苹果终于为 iOS 带来手动关闭降频功能!
  9. Pikachu实验重现2(Sql的注入)
  10. 微信红包c语言程序,微信抢红包软件的C语言原理
  11. tp6 api请求返回参数统一配置方法
  12. img居中以及等比缩放
  13. python不定积分教学_python使用sympy不定积分入门及求解
  14. Android API19 设置Alarm闹钟
  15. 一文带你看透通知短信
  16. funcode小游戏暑假大作业,开源,新颖,游戏名:凿空,免费。
  17. 以太网PHY接口直连设计
  18. webrtc 视频编码格式及参数配置
  19. 手机浏览计算机以查找驱动程序,win7手机驱动安装失败怎么办
  20. PywebIO 轻松制作一个数据大屏,代码只需100行

热门文章

  1. android指南针Demo,谁有安卓简易指南针的DEmo
  2. 基于单目视觉的平面目标定位和坐标测量 (下) - 相机姿态估计和目标测量
  3. doctor技术基础
  4. 淘丞相将微博链接转为淘宝直达是怎么实现的?
  5. 他们为什么离开微软? 创业热情驱动
  6. 家具生产设备_家具生产线
  7. 嵌入式Linux工程师的成长经历
  8. ansible中的加密
  9. pta 计算机通信(并查集)
  10. 科普|AGV自动运输车的不同导航方式以及优缺点