torch.optim.lr_scheduler.MultiStepLR()用法研究 台阶/阶梯学习率
torch.optim.lr_scheduler.
MultiStepLR
(optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False)
我自已用代码研究了一遍MultiStepLR()中的last_epoch参数,发现就是个垃圾。
结论:
①last_epoch就是个鸡肋的东西 经过评论区大佬的指点,我现在确定了last_epoch的用法:last_epoch表示已经走了多少个epoch,下一个milestone减去last_epoch就是需要的epoch数
(评论区原话:last_epoch是有用的,简单来说,就是所有学习率都要提前last_epoch开始进行变化。举个例子假如我设置原始lr=0.1,milestones=[5, 15], gamma=0.5,last_epoch=0此时epoch=5时lr才会变为0.05,epoch=15时lr变为0.025。当修改last_epoch=3后,epoch=2,lr就会变为0.05,epoch=12,lr变为0.025.一般默认last_epoch=0)
②会在milestone的时候乘以gamma的平方
实验代码如下:
1、首先是默认配置:
import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6], gamma=0.1)for epoch in range(9):optimizer.step()scheduler.step()# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.1]
1 [0.1]
2 [0.0010000000000000002] # 此处乘的是gamma的平方
3 [0.010000000000000002]
4 [0.010000000000000002]
5 [0.00010000000000000003] # 此处乘的是gamma的平方
6 [0.0010000000000000002]
7 [0.0010000000000000002]
8 [0.0010000000000000002]
2、设置last_epoch=-1。
和前面1、是一样的,因为函数的默认值就是last_epoch=-1
import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6], gamma=0.1)
scheduler.last_epoch = -1for epoch in range(9):optimizer.step()scheduler.step()# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.1]
1 [0.1]
2 [0.1]
3 [0.0010000000000000002]
4 [0.010000000000000002]
5 [0.010000000000000002]
6 [0.00010000000000000003]
7 [0.0010000000000000002]
8 [0.0010000000000000002]
3、设置last_epoch=4。
把第1个epoch的learning_rate设置为0.1,但是按照模型已经更新到了第4个epoch开始执行。
后来理解了一下,感觉就是:在last_epoch处将learning_rate重新设置为初始值,而且也是从last_epoch处继续进行运行;所以就要求你手动把learning_rate设为上一次模型停止的时候对应的learning_rate值,即last_epoch处对应的learning_rate。
import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6], gamma=0.1)
scheduler.last_epoch = 4for epoch in range(9):optimizer.step()scheduler.step()# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.1] # 0相当于第4个eopch
1 [0.0010000000000000002] # 1相当于第5个epoch所以乘以gamma的平方
2 [0.010000000000000002]
3 [0.010000000000000002]
4 [0.010000000000000002]
5 [0.010000000000000002]
6 [0.010000000000000002]
7 [0.010000000000000002]
8 [0.010000000000000002] # 因为4在3的后面,只有一个6这个milestones,所以只更新了一次
4、设置last_epoch=4,并且将scheduler.step()改为schduler.step(epoch)。也是不对
import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6], gamma=0.1)
scheduler.last_epoch = 4for epoch in range(9):optimizer.step()scheduler.step(epoch)# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.0010000000000000002]
1 [0.0010000000000000002]
2 [0.0010000000000000002]
3 [0.00010000000000000003]
4 [0.0010000000000000002]
5 [0.0010000000000000002]
6 [0.00010000000000000003]
7 [0.0010000000000000002]
8 [0.0010000000000000002]
我试过了,无论把last_epoch改为多少,输出都是上面这个。证明scheduler.step()里面是一定不能加epoch的
5、接3,当last_epoch大于milestones的某些值时,会自动跳过这些值
import torch
import torchvisionlearing_rate = 0.1
model = torchvision.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=learing_rate,momentum=0.9,weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 8], gamma=0.1)
scheduler.last_epoch = 4for epoch in range(9):optimizer.step()# scheduler.step(epoch)scheduler.step()# print(optimizer.get_lr())print(epoch, scheduler.get_lr())返回:
0 [0.0010000000000000002]
1 [0.010000000000000002]
2 [0.010000000000000002]
3 [0.00010000000000000003]
4 [0.0010000000000000002]
5 [0.0010000000000000002]
6 [0.0010000000000000002]
7 [0.0010000000000000002]
8 [0.0010000000000000002]
对比第3小节和第5小节的例子可以发现,在3中的例子中,4大于3,所以把3跳过了,直接在milestone=6的时候调整的learning_rate
torch.optim.lr_scheduler.MultiStepLR()用法研究 台阶/阶梯学习率相关推荐
- Pytorch(0)降低学习率torch.optim.lr_scheduler.ReduceLROnPlateau类
当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能.所使用的类 class torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer ...
- pytorch中调整学习率: torch.optim.lr_scheduler
文章翻译自:https://pytorch.org/docs/stable/optim.html torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法. t ...
- torch.optim.lr_scheduler.LambdaLR与OneCycleLR
目录 LambdaLR 输出 OneCycleLR 输出 LambdaLR 函数接口: LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=Fa ...
- class torch.optim.lr_scheduler.ExponentialLR
参考链接: class torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False) 配 ...
- class torch.optim.lr_scheduler.StepLR
参考链接: class torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1, verbose= ...
- class torch.optim.lr_scheduler.LambdaLR
参考链接: class torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False) 配套 ...
- ImportError: cannot import name ‘SAVE_STATE_WARNING‘ from ‘torch.optim.lr_scheduler‘ (/home/jsj/anac
from transformers import BertModel 报错 ImportError: cannot import name 'SAVE_STATE_WARNING' from 't ...
- torch.optim.lr_scheduler.StepLR()函数
1 目的 在训练的开始阶段, 使用的 LR 较大, loss 可以下降的较快, 但是随着训练的轮数越来越多, loss 越来越接近 global min, 若不降低 LR , 就会导致 loss 在最 ...
- torch.load、torch.save、torch.optim.Adam的用法
目录 一.保存模型-torch.save() 1.只保存model的权重 2.保存多项内容 二.加载模型-torch.load() 1.从本地模型中读取数据 2.加载上一步读取的数据 load_sta ...
- pytorch torch.optim.lr_scheduler 各种使用和解释
https://blog.csdn.net/baoxin1100/article/details/107446538
最新文章
- Python,OpenCV中的非局部均值去噪(Non-Local Means Denoising)
- 剑指offer:链表中环的入口结点
- 信道编码之编码理论依据
- 安卓获取手机网络强度_USB调试和USB网络共享,安卓有线投屏究竟选哪个?
- eclipse中的感叹号和x号解决方法
- HackerRank Nimble Game
- nginx php mysql 部署_Linux+Nginx+Mysql+Php运维部署
- python做数据库界面_python数据库界面设计
- 在香蕉派 Banana Pi BPI-M1上使用 开源 OxOffice Impress
- Markdown编辑公式
- python防止sql注入的方法_python解决sql注入以及特殊字符
- python写入Excel时,将路径或链接以超链接的方式写入
- flutter json转对象_在 Flutter 使用 Redux 来共享状态和管理单一数据
- Thread与Runnable的区别
- UVa 10870 - Recurrences 矩阵快速幂
- ai图像处理软件集大成者:Leawo PhotoIns Pro中文版介绍
- 计算机cmp代表什么意思,CMP是什么
- Datawhale组队学习周报(第025周)
- [VB.NET]如何设置随机数的种子
- python实现视频ai换脸_python 实现 AI 换脸
热门文章
- sklearn,SVM 和文本分类
- 《算法图解》第八章之贪婪算法
- 190430每日一句
- 181113每日一句
- Atitit 乔姆斯基分类 语言的分类 目录 1.1. 0 –递归可枚举语法	1 1.2. 1 –上下文相关的语法 自然语言	1 1.3. 2 –上下文无关的语法 gpl编程语言	1 1.4. 3
- Atitit maven 常见类库配置法 maven common lib jar v2 t88 目录 1. Express DSL COMMON	2 1.1. Ognl	2 1.2. veloci
- Atitit 面试问题高难度问题 回答不上来的分析应对法 目录 1. 问题分析法	1 1.1. 判断是否超出自己范围的,直接回复超出自己范围了	1 1.2. 根据生活中的解决方法,大概说下解决模式
- Atititi. naming spec 联系人命名与remark备注指南规范v5 r99.docx
- Atitit 机器视觉图像处理与机器学习概论2017版 attilax著
- Atitit atiuse软件系列