【pytorch】optimizer(优化器)的使用详解
目录
- 1 创建一个 Optimizer
- 一个简单的例子:求目标函数的最小值
- Per-parameter 的优化器
- 2 Taking an optimization step 开始优化
- optimizer.step(closure)
- 常见的几种优化器
- 如何调整 lr?
- 优化器的保存和读取
本文介绍 torch.optim
包常见的使用方法和使用技巧。
1 创建一个 Optimizer
要构造一个Optimizer,你必须给它一个包含参数(所有参数都应该是 Variable s
)的可迭代对象来优化。然后,您可以指定特定于优化器的选项,如学习率、权值衰减等。
from torch.autograd import Variable
import torch.optim as optim# Variable 的创建
tensor = torch.FloatTensor([[1,2],[3,4]]) # build a tensor
var1 = Variable(tensor, requires_grad=True) # build a variable, usually for compute gradients
var2 = Variable(tensor+1, requires_grad=True)
model = model()
# 构造 Optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001) # 还可以对 Variable 进行优化哦~
一个简单的例子:求目标函数的最小值
假设 x = 4 为起点,求 y = (x-5)^2
的最小值:
from torch.autograd import Variable
import torch.optim as optim
# Variable 的创建
tensor = torch.FloatTensor([[4]]) # build a tensor
x = Variable(tensor, requires_grad=True) # build a variable, usually for compute gradientsoptimizer = optim.Adam([x], lr=0.1) # 还可以对 Variable 进行优化哦~
for i in range(100):optimizer.zero_grad()y = (x - 5)*(x - 5) # 因为 x 的值不断在优化,所以 y 的定义式要放在这里y.backward()optimizer.step()print(x)
Per-parameter 的优化器
有时候,我们会使用例如 pre-trained model 这样的模型,用其特征提取模块并连接自己设计的 classifier
层。这时候需要对不同的层使用不同的 lr,具体操作如下:
首先,模型的设计可以采用这样的结构,*layers是一个列表。
optim.SGD([ {'params': model.features.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9) # model.features 的lr 是 1e-2, model.classifier 是 1e-3,momentum=0.9针对所有层都有效。
2 Taking an optimization step 开始优化
所有优化器都实现一个step()方法,该方法更新参数。它有两种用法:
optimizer.step()
这是大多数优化器支持的简化版本。该函数可以在梯度计算完成后调用,例如使用 backward()
。
for input, target in dataset:optimizer.zero_grad() # 这一步很重要output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()
optimizer.step(closure)
一些优化算法,如共轭梯度和LBFGS需要多次重新计算函数,所以您必须传入一个闭包,允许它们重新计算您的函数。闭包应该清除梯度,计算损失,并返回它。
常见的几种优化器
具体参数设定请参阅:https://pytorch.org/docs/stable/optim.html#algorithms
如何调整 lr?
torch.optim.lr_scheduler
提供几种方法,以调整学习速率的基础上的时间数。torch.optim.lr_scheduler.ReduceLROnPlateau
允许根据评估指标,动态降低学习率(这里不做介绍)。
import torch.nn as nn
from torch.utils.data import DataLoader,TensorDataset
model = nn.Parameter(torch.randn(2, 1, requires_grad=True))
optimizer = optim.SGD([model], 0.1)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) # 指数衰减,每一轮变为上一轮的 0.9
x = torch.randn(10,2)
y = torch.randn(10,1)
dataset = TensorDataset(x, y)
dataset = DataLoader(dataset)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = input * modelloss = (output - target).sqrt().mean()loss.backward()optimizer.step()scheduler.step()
验证 lr 降低的效果:
当然,也可以手动地在每一轮中设置 lr 并创建新的优化器。
来看另一种 lr 衰减的方法,使用2个lr衰减方法的叠加。
model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler1.step()scheduler2.step()
优化器的保存和读取
有时候训练到一半,需要建立 checkpoint ,随时保存模型和优化器状态,和模型的读取、保存一样,优化器的使用方法如下:
para_dict = optimizer.state_dict()
optimizer.load_state_dict(para_dict)
参考:
https://pytorch.org/docs/stable/optim.html
【pytorch】optimizer(优化器)的使用详解相关推荐
- 【pytorch 优化器】ReduceLROnPlateau详解
说明 torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbos ...
- pytorch 7 optimizer 优化器 加速训练
pytorch 7 optimizer 优化器 加速训练 import torch import torch.utils.data as Data import torch.nn.functional ...
- PyTorch 实现批训练和 Optimizer 优化器
批训练 import torch import torch.utils.data as DataBATCH_SIZE = 5x = torch.linspace(1, 10, 10) # this i ...
- Pytorch autograd.grad与autograd.backward详解
Pytorch autograd.grad与autograd.backward详解 引言 平时在写 Pytorch 训练脚本时,都是下面这种无脑按步骤走: outputs = model(inputs ...
- [源码解析] PyTorch分布式优化器(1)----基石篇
[源码解析] PyTorch分布式优化器(1)----基石篇 文章目录 [源码解析] PyTorch分布式优化器(1)----基石篇 0x00 摘要 0x01 从问题出发 1.1 示例 1.2 问题点 ...
- PyTorch中的torch.nn.Parameter() 详解
PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...
- python数据集的预处理_关于Pytorch的MNIST数据集的预处理详解
关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu ...
- pytorch教程之nn.Module类详解——使用Module类来自定义网络层
前言:前面介绍了如何自定义一个模型--通过继承nn.Module类来实现,在__init__构造函数中申明各个层的定义,在forward中实现层之间的连接关系,实际上就是前向传播的过程. 事实上,在p ...
- 【PyTorch教程】PyTorch分布式并行模块DistributedDataParallel(DDP)详解
本期目录 DDP简介 1. 单卡训练回顾 2. 与DataParallel比较 1)DataParallel 2)DistributedDataParallel 3. 多卡DDP训练 本章的重点是学习 ...
- 加速神经网络训练方法及不同Optimizer优化器性能比较
本篇博客主要介绍几种加速神经网络训练的方法. 我们知道,在训练样本非常多的情况下,如果一次性把所有的样本送入神经网络,每迭代一次更新网络参数,这样的效率是很低的.为什么?因为梯度下降法参数更新的公式一 ...
最新文章
- 12 Java面向对象之多态
- 开源生态也难逃“卡脖子”危机?中国AI开发者的警醒和突围
- 是谁“偷吃”了硬盘中的3GB空间
- linux 统计日志最多的ip,统计nginx日志里访问次数最多的前十个IP
- 解密华为云原生媒体网络如何保障实时音视频服务质量
- python操作excel_使用Python操作Excel时必学的3个库
- oracle中存储过程 =,oracle中的存储过程使用
- typora用什么文档管理_会展经济与管理专业自考本科毕业后有什么用
- 2017.8.7 GT考试 思考记录
- 未来教育计算机二级答案19,2019年3月计算机二级MSOffice提分试题及答案019
- 小米温控配置不见了_小米11值得买吗?目前看来功耗很高啊?
- Spring Boot 集成 WebSocket,轻松实现信息推送!
- C++/QT控制通过VISA控制硬件设备,超级容易学会的控制硬件方法
- 基于Thinkphp6+Element的插件化后台管理系统
- 智能开关如何实现双控
- 同一网段两台电脑共享文件
- less css 视频教程
- 理解实时音视频聊天中的延时问题一篇就够
- 做小红书推广快速涨粉的技巧_云媒易
- 巧用千寻位置GNSS软件|CAD功能全解析