pytorch中lr_scheduler的使用
torch.optim.lr_scheduler.StepLR
- 代码
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import AlexNet
import matplotlib.pyplot as pltmodel = AlexNet(num_classes=2)
optimizer = optim.SGD(params=model.parameters(), lr=0.05)# lr_scheduler.StepLR()
# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.05 if epoch < 30
# lr = 0.005 if 30 <= epoch < 60
# lr = 0.0005 if 60 <= epoch < 90scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
plt.figure()
x = list(range(100))
y = []
for epoch in range(100):scheduler.step()lr = scheduler.get_lr()print(epoch, scheduler.get_lr()[0])y.append(scheduler.get_lr()[0])plt.plot(x, y)
0<epoch<30, lr = 0.05
30<=epoch<60, lr = 0.005
60<=epoch<90, lr = 0.0005
torch.optim.lr_scheduler.MultiStepLR
与StepLR
相比,MultiStepLR
可以设置指定的区间
- 代码
# ---------------------------------------------------------------
# 可以指定区间
# lr_scheduler.MultiStepLR()
# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.05 if epoch < 30
# lr = 0.005 if 30 <= epoch < 80
# lr = 0.0005 if epoch >= 80
print()
plt.figure()
y.clear()
scheduler = lr_scheduler.MultiStepLR(optimizer, [30, 80], 0.1)
for epoch in range(100):scheduler.step()print(epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))y.append(scheduler.get_lr()[0])plt.plot(x, y)
plt.show()
torch.optim.lr_scheduler.ExponentialLR
指数衰减
- 代码
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
print()
plt.figure()
y.clear()
for epoch in range(100):scheduler.step()print(epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))y.append(scheduler.get_lr()[0])plt.plot(x, y)
plt.show()
pytorch中lr_scheduler的使用相关推荐
- pytorch中调整学习率的lr_scheduler机制
pytorch中调整学习率的lr_scheduler机制 </h1><div class="clear"></div><div class ...
- Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau
Pytorch中的学习率调整:lr_scheduler,ReduceLROnPlateau torch.optim.lr_scheduler:该方法中提供了多种基于epoch训练次数进行学习率调整的方 ...
- pytorch中的学习率与优化器【lr_scheduler与optimizer】
pytorch中优化器的使用流程大致为: for input, target in dataset:optimizer.zero_grad()output = model(input)loss = l ...
- 利用 AssemblyAI 在 PyTorch 中建立端到端的语音识别模型
作者 | Comet 译者 | 天道酬勤,责编 | Carol 出品 | AI 科技大本营(ID:rgznai100) 这篇文章是由AssemblyAI的机器学习研究工程师Michael Nguyen ...
- Lesson 15.2 学习率调度在PyTorch中的实现方法
Lesson 15.2 学习率调度在PyTorch中的实现方法 学习率调度作为模型优化的重要方法,也集成在了PyTorch的optim模块中.我们可以通过下述代码将学习率调度模块进行导入. fro ...
- PyTorch 中 torch.optim优化器的使用
一.优化器基本使用方法 建立优化器实例 循环: 清空梯度 向前传播 计算Loss 反向传播 更新参数 示例: from torch import optim input = ..... optimiz ...
- Pytorch中的学习率调整方法
在梯度下降更新参数的时,我们往往需要定义一个学习率来控制参数更新的步幅大小,常用的学习率有0.01.0.001以及0.0001等,学习率越大则参数更新越大.一般来说,我们希望在训练初期学习率大一些,使 ...
- pytorch中如何处理RNN输入变长序列padding
一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...
- PyTorch中的MIT ADE20K数据集的语义分割
PyTorch中的MIT ADE20K数据集的语义分割 代码地址:https://github.com/CSAILVision/semantic-segmentation-pytorch Semant ...
最新文章
- PHP设计模式——迭代模式
- PHP Warning: date(): It is not safe to rely on the system's timezone settings
- 无聊博文之:用同余的语言阐述欧几里德算法
- 用户登陆_「python学习笔记」用户登陆需求实现(for/if/str知识点)
- Vue2+VueRouter2+webpack 构建项目实战(四)接通api,先渲染个列表
- Android 根证书管理与证书验证
- .NET 6新特性试用 | 模式匹配之Extended Property Patterns
- mysql update field_mysql-更新表与另一个选择,但字段是SUM(someField)
- CSS清除默认样式,聪明人已经收藏了!
- 前端学习(3004):vue+element今日头条管理--使用form表单
- arcgis for android sdk下载地址,Arcgis Runtime sdk for android 授权
- c语言根据变量作用域不同分为,C语言中不同变量的访问方式
- tomcat 请求超时_高并发环境下如何优化Tomcat性能?看完我懂了!
- Java设计模式-设计模式概述
- Intel vt-d技术资料收集
- 一元(多元)线性回归分析之R语言实现
- linux下重装显卡驱动
- 学术论文写作之引言(Introduction)怎么写
- mysql时间自动填充_Mysql自动设置时间(自动获取时间,填充时间)
- 网络间谍:你的共享文件夹网络监视器
热门文章
- 语音识别系统报告_2018-2024年中国语音识别系统行业市场发展格局及投资价值评估研究报告_中国产业信息网...
- 语音支持英语_语音识别英语_英语语音评分 - 云+社区 - 腾讯云
- jvm压缩指针原理以及32g内存压缩指针失效详解
- dio设置自定义post请求_Flutter中的http网络请求
- matlab simout,每日学习Matlab(2)
- flex实现水平垂直居中
- 保存自动修复_CAD小技巧:怎样将自动保存的图形复原
- html读取json换行无效,前端Json换行显示
- Sublime Text3搭建go运行环境
- Eclipse中修改SVN地址