torch.optim.lr_scheduler.MultiStepLR(optimizermilestonesgamma=0.1last_epoch=-1verbose=False)

我自已用代码研究了一遍MultiStepLR()中的last_epoch参数,发现就是个垃圾。

结论:

①last_epoch就是个鸡肋的东西   经过评论区大佬的指点,我现在确定了last_epoch的用法:last_epoch表示已经走了多少个epoch,下一个milestone减去last_epoch就是需要的epoch数

(评论区原话:last_epoch是有用的,简单来说,就是所有学习率都要提前last_epoch开始进行变化。举个例子假如我设置原始lr=0.1,milestones=[5, 15], gamma=0.5,last_epoch=0此时epoch=5时lr才会变为0.05,epoch=15时lr变为0.025。当修改last_epoch=3后,epoch=2,lr就会变为0.05,epoch=12,lr变为0.025.一般默认last_epoch=0)

②会在milestone的时候乘以gamma的平方

实验代码如下:

1、首先是默认配置:

import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6], gamma=0.1)for epoch in range(9):optimizer.step()scheduler.step()# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.1]
1 [0.1]
2 [0.0010000000000000002]   # 此处乘的是gamma的平方
3 [0.010000000000000002]
4 [0.010000000000000002]
5 [0.00010000000000000003]   # 此处乘的是gamma的平方
6 [0.0010000000000000002]
7 [0.0010000000000000002]
8 [0.0010000000000000002]

2、设置last_epoch=-1。
和前面1、是一样的,因为函数的默认值就是last_epoch=-1

import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6], gamma=0.1)
scheduler.last_epoch = -1for epoch in range(9):optimizer.step()scheduler.step()# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.1]
1 [0.1]
2 [0.1]
3 [0.0010000000000000002]
4 [0.010000000000000002]
5 [0.010000000000000002]
6 [0.00010000000000000003]
7 [0.0010000000000000002]
8 [0.0010000000000000002]

3、设置last_epoch=4。
把第1个epoch的learning_rate设置为0.1,但是按照模型已经更新到了第4个epoch开始执行。

后来理解了一下,感觉就是:在last_epoch处将learning_rate重新设置为初始值,而且也是从last_epoch处继续进行运行;所以就要求你手动把learning_rate设为上一次模型停止的时候对应的learning_rate值,即last_epoch处对应的learning_rate

import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6], gamma=0.1)
scheduler.last_epoch = 4for epoch in range(9):optimizer.step()scheduler.step()# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.1] # 0相当于第4个eopch
1 [0.0010000000000000002]  # 1相当于第5个epoch所以乘以gamma的平方
2 [0.010000000000000002]
3 [0.010000000000000002]
4 [0.010000000000000002]
5 [0.010000000000000002]
6 [0.010000000000000002]
7 [0.010000000000000002]
8 [0.010000000000000002]  # 因为4在3的后面,只有一个6这个milestones,所以只更新了一次

4、设置last_epoch=4,并且将scheduler.step()改为schduler.step(epoch)。也是不对

import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6], gamma=0.1)
scheduler.last_epoch = 4for epoch in range(9):optimizer.step()scheduler.step(epoch)# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.0010000000000000002]
1 [0.0010000000000000002]
2 [0.0010000000000000002]
3 [0.00010000000000000003]
4 [0.0010000000000000002]
5 [0.0010000000000000002]
6 [0.00010000000000000003]
7 [0.0010000000000000002]
8 [0.0010000000000000002]

我试过了,无论把last_epoch改为多少,输出都是上面这个。证明scheduler.step()里面是一定不能加epoch的

5、接3,当last_epoch大于milestones的某些值时,会自动跳过这些值

import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 8], gamma=0.1)
scheduler.last_epoch = 4for epoch in range(9):optimizer.step()# scheduler.step(epoch)scheduler.step()# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.0010000000000000002]
1 [0.010000000000000002]
2 [0.010000000000000002]
3 [0.00010000000000000003]
4 [0.0010000000000000002]
5 [0.0010000000000000002]
6 [0.0010000000000000002]
7 [0.0010000000000000002]
8 [0.0010000000000000002]

对比第3小节和第5小节的例子可以发现,在3中的例子中,4大于3,所以把3跳过了,直接在milestone=6的时候调整的learning_rate

torch.optim.lr_scheduler.MultiStepLR()用法研究 台阶/阶梯学习率相关推荐

  1. Pytorch(0)降低学习率torch.optim.lr_scheduler.ReduceLROnPlateau类

    当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能.所使用的类 class torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer ...

  2. pytorch中调整学习率: torch.optim.lr_scheduler

    文章翻译自:https://pytorch.org/docs/stable/optim.html torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法. t ...

  3. torch.optim.lr_scheduler.LambdaLR与OneCycleLR

    目录 LambdaLR 输出 OneCycleLR 输出 LambdaLR 函数接口: LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=Fa ...

  4. class torch.optim.lr_scheduler.ExponentialLR

    参考链接: class torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False) 配 ...

  5. class torch.optim.lr_scheduler.StepLR

    参考链接: class torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1, verbose= ...

  6. class torch.optim.lr_scheduler.LambdaLR

    参考链接: class torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False) 配套 ...

  7. ImportError: cannot import name ‘SAVE_STATE_WARNING‘ from ‘torch.optim.lr_scheduler‘ (/home/jsj/anac

    from transformers import BertModel 报错   ImportError: cannot import name 'SAVE_STATE_WARNING' from 't ...

  8. torch.optim.lr_scheduler.StepLR()函数

    1 目的 在训练的开始阶段, 使用的 LR 较大, loss 可以下降的较快, 但是随着训练的轮数越来越多, loss 越来越接近 global min, 若不降低 LR , 就会导致 loss 在最 ...

  9. torch.load、torch.save、torch.optim.Adam的用法

    目录 一.保存模型-torch.save() 1.只保存model的权重 2.保存多项内容 二.加载模型-torch.load() 1.从本地模型中读取数据 2.加载上一步读取的数据 load_sta ...

  10. pytorch torch.optim.lr_scheduler 各种使用和解释

    https://blog.csdn.net/baoxin1100/article/details/107446538

最新文章

  1. Python,OpenCV中的非局部均值去噪(Non-Local Means Denoising)
  2. 剑指offer:链表中环的入口结点
  3. 信道编码之编码理论依据
  4. 安卓获取手机网络强度_USB调试和USB网络共享,安卓有线投屏究竟选哪个?
  5. eclipse中的感叹号和x号解决方法
  6. HackerRank Nimble Game
  7. nginx php mysql 部署_Linux+Nginx+Mysql+Php运维部署
  8. python做数据库界面_python数据库界面设计
  9. 在香蕉派 Banana Pi BPI-M1上使用 开源 OxOffice Impress
  10. Markdown编辑公式
  11. python防止sql注入的方法_python解决sql注入以及特殊字符
  12. python写入Excel时,将路径或链接以超链接的方式写入
  13. flutter json转对象_在 Flutter 使用 Redux 来共享状态和管理单一数据
  14. Thread与Runnable的区别
  15. UVa 10870 - Recurrences 矩阵快速幂
  16. ai图像处理软件集大成者:Leawo PhotoIns Pro中文版介绍
  17. 计算机cmp代表什么意思,CMP是什么
  18. Datawhale组队学习周报(第025周)
  19. [VB.NET]如何设置随机数的种子
  20. python实现视频ai换脸_python 实现 AI 换脸

热门文章

  1. sklearn,SVM 和文本分类
  2. 《算法图解》第八章之贪婪算法
  3. 190430每日一句
  4. 181113每日一句
  5. Atitit 乔姆斯基分类 语言的分类 目录 1.1. 0 –递归可枚举语法 1 1.2. 1 –上下文相关的语法 自然语言 1 1.3. 2 –上下文无关的语法 gpl编程语言 1 1.4. 3
  6. Atitit maven 常见类库配置法 maven common lib jar v2 t88 目录 1. Express DSL COMMON 2 1.1. Ognl 2 1.2. veloci
  7. Atitit 面试问题高难度问题 回答不上来的分析应对法 目录 1. 问题分析法 1 1.1. 判断是否超出自己范围的,直接回复超出自己范围了 1 1.2. 根据生活中的解决方法,大概说下解决模式
  8. Atititi. naming spec 联系人命名与remark备注指南规范v5 r99.docx
  9. Atitit 机器视觉图像处理与机器学习概论2017版 attilax著
  10. Atitit atiuse软件系列