前言

最近在重温Pytorch基础,然而Pytorch官方文档的各种API是根据字母排列的,并不适合学习阅读。
于是在gayhub上找到了这样一份教程《Pytorch模型训练实用教程》,写得不错,特此根据它来再学习一下Pytorch。
仓库地址:https://github.com/TingsongYu/PyTorch_Tutorial

优化器概念

Pytorch提供了优化器Optimizer,作为基类,它有下面几种方法

param_groups

调用param_groups,可以查看一个优化器的参数组,其包含了每一层的权值,偏置,学习率等参数。

调用实例:

# coding: utf-8import torch
import torch.optim as optimw1 = torch.randn(2, 2)
w1.requires_grad = Truew2 = torch.randn(2, 2)
w2.requires_grad = Truew3 = torch.randn(2, 2)
w3.requires_grad = True# 一个参数组
optimizer_1 = optim.SGD([w1, w3], lr=0.1)
print('len(optimizer.param_groups): ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')# 两个参数组
optimizer_2 = optim.SGD([{'params': w1, 'lr': 0.1},{'params': w2, 'lr': 0.001}])
print('len(optimizer.param_groups): ', len(optimizer_2.param_groups))
print(optimizer_2.param_groups)

zero_grad()

功能:将梯度清零。
调用示例:

# coding: utf-8import torch
import torch.optim as optim# ----------------------------------- zero_gradw1 = torch.randn(2, 2)
w1.requires_grad = Truew2 = torch.randn(2, 2)
w2.requires_grad = Trueoptimizer = optim.SGD([w1, w2], lr=0.001, momentum=0.9)optimizer.param_groups[0]['params'][0].grad = torch.randn(2, 2)print('参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad, '\n')  # 参数组,第一个参数(w1)的梯度optimizer.zero_grad()
print('执行zero_grad()之后,参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad)  # 参数组,第一个参数(w1)的梯度

state_dict()

功能:获取模型当前的参数,以一个有序字典形式返回。
调用示例:

# coding: utf-8import torch.nn as nn
import torch.nn.functional as F# ----------------------------------- state_dict
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 1, 3)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(1 * 3 * 3, 2)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = x.view(-1, 1 * 3 * 3)x = F.relu(self.fc1(x))return xnet = Net()# 获取网络当前参数
net_state_dict = net.state_dict()print('net_state_dict类型:', type(net_state_dict))
print('net_state_dict管理的参数: ', net_state_dict.keys())
for key, value in net_state_dict.items():print('参数名: ', key, '\t大小: ',  value.shape)

load_state_dict(state_dict)

功能:将 state_dict 中的参数加载到当前网络。
调用示例:

# coding: utf-8import torch
import torch.nn as nn
import torch.nn.functional as F# ----------------------------------- load_state_dictclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 1, 3)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(1 * 3 * 3, 2)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = x.view(-1, 1 * 3 * 3)x = F.relu(self.fc1(x))return xdef zero_param(self):for m in self.modules():if isinstance(m, nn.Conv2d):torch.nn.init.constant_(m.weight.data, 0)if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.Linear):torch.nn.init.constant_(m.weight.data, 0)m.bias.data.zero_()
net = Net()# 保存,并加载模型参数(仅保存模型参数)
torch.save(net.state_dict(), 'net_params.pkl')   # 假设训练好了一个模型net
pretrained_dict = torch.load('net_params.pkl')# 将net的参数全部置0,方便对比
net.zero_param()
net_state_dict = net.state_dict()
print('conv1层的权值为:\n', net_state_dict['conv1.weight'], '\n')# 通过load_state_dict 加载参数
net.load_state_dict(pretrained_dict)
print('加载之后,conv1层的权值变为:\n', net_state_dict['conv1.weight'])

add_param_group()

功能:给 optimizer 管理的参数组中增加一组参数,可为该组参数定制 lr, momentum, weight_decay等。
调用示例:

# coding: utf-8import torch
import torch.optim as optim# ----------------------------------- add_param_groupw1 = torch.randn(2, 2)
w1.requires_grad = Truew2 = torch.randn(2, 2)
w2.requires_grad = Truew3 = torch.randn(2, 2)
w3.requires_grad = True# 一个参数组
optimizer_1 = optim.SGD([w1, w2], lr=0.1)
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')# 增加一个参数组
print('增加一组参数 w3\n')
optimizer_1.add_param_group({'params': w3, 'lr': 0.001, 'momentum': 0.8})print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

step()

功能:执行一步权值更新。

优化器汇总

torch.optim.SGD

torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)
功能:
可实现 SGD 优化算法,带动量 SGD 优化算法,带 NAG(Nesterov accelerated gradient)动量 SGD 优化算法。

参数:
params(iterable)- 参数组(参数组的概念请查看 3.2 优化器基类:Optimizer),优化器
要管理的那部分参数。
lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。
momentum(float)- 动量,通常设置为 0.9,0.8
dampening(float)- dampening for momentum ,暂时不了其功能,在源码中是这样用的:buf.mul_(momentum).add_(1 - dampening, d_p),值得注意的是,若采用nesterov,dampening 必须为 0.
weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数
nesterov(bool)- bool 选项,是否使用 NAG(Nesterov accelerated gradient)

torch.optim.ASGD

torch.optim.ASGD(params, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)

功能:
ASGD 也成为 SAG,均表示随机平均梯度下降

参数:
params(iterable)- 参数组(参数组的概念请查看 3.1 优化器基类:Optimizer),优化器要优化的那些参数。
lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。
lambd(float)- 衰减项,默认值 1e-4。
alpha(float)- power for eta update ,默认值 0.75。
t0(float)- point at which to start averaging,默认值 1e6。
weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数。

torch.optim.Rprop

torch.optim.Rprop(params, lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50))
功能:
实现 Rprop 优化方法(弹性反向传播),该优化方法适用于 full-batch,不适用于 mini-batch。

torch.optim.Adagrad

torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)

功能:
实现 Adagrad 优化方法(Adaptive Gradient),Adagrad 是一种自适应优化方法,是自适应的为各个参数分配不同的学习率。这个学习率的变化,会受到梯度的大小和迭代次数的影响。梯度越大,学习率越小;梯度越小,学习率越大。缺点是训练后期,学习率过小,因为 Adagrad 累加之前所有的梯度平方作为分母。

torch.optim.Adadelta

torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
功能:
实现 Adadelta 优化方法。Adadelta 是 Adagrad 的改进。Adadelta 分母中采用距离当前时间点比较近的累计项,这可以避免在训练后期,学习率过小。

torch.optim.RMSprop

torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
功能:
实现 RMSprop 优化方法,RMS 是均方根(root meam square)的意思。RMSprop 采用均方根作为分
母,可缓解 Adagrad 学习率下降较快的问题。

torch.optim.Adam(AMSGrad)

torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

功能:
实现 Adam(Adaptive Moment Estimation)优化方法。Adam 是一种自适应学习率的优化方法,Adam 利用梯度的一阶矩估计和二阶矩估计动态的调整学习率。

torch.optim.Adamax

torch.optim.Adamax(params, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
功能:
实现 Adamax 优化方法。Adamax 是对 Adam 增加了一个学习率上限的概念,所以也称之为 Adamax。

torch.optim.SparseAdam

torch.optim.SparseAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08)
功能:
针对稀疏张量的一种Adam优化方法。

torch.optim.LBFGS

torch.optim.LBFGS(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-05, tolerance_change=1e-09, history_size=100, line_search_fn=None)

功能:
实现 L-BFGS(Limited-memory Broyden–Fletcher–Goldfarb–Shanno)优化方法。L-BFGS 属于拟牛顿算法。L-BFGS 是对 BFGS 的改进,特点就是节省内存。

学习率调整

为了让学习率能够随着模型的训练进行动态调整,Pytorch提供了下列一些学习率调整方法。

lr_scheduler.StepLR

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)
功能:
等间隔调整学习率,调整倍数为 gamma 倍,调整间隔为 step_size。间隔单位是step。
参数:
step_size(int)- 学习率下降间隔数,若为 30,则会在 30、60、90…个 step 时,将学习率调整为 lr*gamma。
gamma(float)- 学习率调整倍数,默认为0.1倍,即下降10倍。
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。

lr_scheduler.MultiStepLR

torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)
功能:
按设定的间隔调整学习率。这个方法适合后期调试使用,观察 loss 曲线,为每个实验定制学习率调整时机。
参数:
milestones(list)- 一个 list,每一个元素代表何时调整学习率,list 元素必须是递增的。如 milestones=[30,80,120]
gamma(float)- 学习率调整倍数,默认为 0.1 倍,即下降 10 倍。
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。

lr_scheduler.ExponentialLR

torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1)
功能:
按指数衰减调整学习率,调整公式: lr = lr * gammaepoch
参数:
gamma- 学习率调整倍数的底,指数为 epoch,即 gamma
epoch
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。

lr_scheduler.CosineAnnealingLR

torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
功能:
以余弦函数为周期,并在每个周期最大值时重新设置学习率。
参数:
T_max(int)- 一次学习率周期的迭代次数,即 T_max 个 epoch 之后重新设置学习率。
eta_min(float)- 最小学习率,即在一个周期中,学习率最小会下降到 eta_min,默认值为 0。

lr_scheduler.ReduceLROnPlateau

torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=‘min’, factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode=‘rel’, cooldown=0, min_lr=0, eps=1e-08)
功能:
当某指标不再变化(下降或升高),调整学习率,这是非常实用的学习率调整策略。

参数:
mode(str)- 模式选择,有 min 和 max 两种模式,min 表示当指标不再降低(如监测loss),max 表示当指标不再升高(如监测 accuracy)。
factor(float)- 学习率调整倍数(等同于其它方法的 gamma),即学习率更新为 lr = lr * factor
patience(int)- 直译——“耐心”,即忍受该指标多少个 step 不变化,当忍无可忍时,调整学习率。
verbose(bool)- 是否打印学习率信息
threshold_mode(str)- 选择判断指标是否达最优的模式,有两种模式,rel 和 abs
cooldown(int)- “冷却时间“,当调整学习率之后,让学习率调整策略冷静一下,让模型再训练一段时间,再重启监测模式。
min_lr(float or list)- 学习率下限,可为 float,或者 list,当有多个参数组时,可用 list 进行设置。
eps(float)- 学习率衰减的最小值,当学习率变化小于 eps 时,则不调整学习率。

lr_scheduler.LambdaLR

torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
功能:
为不同参数组设定不同学习率调整策略。

参数:
lr_lambda(function or list)- 一个计算学习率调整倍数的函数,输入通常为 step,当有多个参数组时,设为 list。
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。

Pytorch模型训练实用教程学习笔记:四、优化器与学习率调整相关推荐

  1. 《PyTorch模型训练实用教程》—学习笔记

    文章目录 前言 数据 Dataset类 DataLoader类 transform 裁剪-Crop 翻转和旋转-Flip and Rotation 图像变换 对transforms操作,使数据增强更灵 ...

  2. PyTorch 模型训练实用教程(附代码)

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx PyTorch 能在短时间内被众多研究人员和工程师接受并推崇是因为其有着诸多优点,如采用 Py ...

  3. 快速上手笔记,PyTorch模型训练实用教程(附代码)

    机器之心发布 作者:余霆嵩 前言 自 2017 年 1 月 PyTorch 推出以来,其热度持续上升,一度有赶超 TensorFlow 的趋势.PyTorch 能在短时间内被众多研究人员和工程师接受并 ...

  4. PyTorch 模型训练实用教程(六):监控模型——可视化

    本章将介绍如何在 PyTorch 中使用 TensorBoardX 对神经网络进行统计可视化,如Loss 曲线.Accuracy 曲线.卷积核可视化.权值直方图及多分位数折线图.特征图可视化.梯度直方 ...

  5. 【深度学习】基于PyTorch的模型训练实用教程之数据处理

    [深度学习]基于PyTorch的模型训练实用教程之数据处理 文章目录 1 transforms 的二十二个方法 2 数据加载和预处理教程 3 torchvision 4 如何用Pytorch进行文本预 ...

  6. 【资源下载】《Pytorch模型训练实现教程》(附下载链接)

    前言 自 2017 年 1 月 PyTorch 推出以来,其热度持续上升,一度有赶超 TensorFlow 的趋势.PyTorch 能在短时间内被众多研究人员和工程师接受并推崇是因为其有着诸多优点,如 ...

  7. 【详解】模型优化技巧之优化器和学习率调整

    目录 PyTorch十大优化器 1 torch.optim.SGD 2 torch.optim.ASGD 3 torch.optim.Rprop 4 torch.optim.Adagrad 5 tor ...

  8. 《深入理解java虚拟机》学习笔记四/垃圾收集器GC学习/一

    Grabage Collection      GC GC要完毕的三件事情: 哪些内存须要回收? 什么时候回收? 怎样回收? 内存运行时区域的各个部分中: 程序计数器.虚拟机栈.本地方法栈这3个区域随 ...

  9. 系统学习Pytorch笔记七:优化器和学习率调整策略

    Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html? Pytorch中文文档:https://pytorch-cn.readthedocs ...

最新文章

  1. 电脑基础操作_学打碟基础技术 - 数码打碟篇
  2. Scala与Java交互
  3. 网站SEO优化之如何提升网站的流量?
  4. [Swift]LeetCode934. 最短的桥 | Shortest Bridge
  5. 并发下HashMap头插会造成死循环情况说明
  6. mysql为什么要转es_MySQL用得好好的,为什么要转ES?
  7. nginx+ssl+pm2 部署 nodejs 服务
  8. c语言调用android surface,Android GUI SurfaceFlinger
  9. global语句(python学习手册422页)
  10. 交换排序之冒泡排序和快速排序
  11. UVA11729 Commando War【贪心】
  12. linux中的__setup的作用
  13. 租房退租时,房东不退押金怎么办?
  14. 机器人最新天赋符文天赋加点图_机器人天赋符文s9
  15. 采样点数和采样频率的区别
  16. 八个常见Java项目,献给初学编程的你!
  17. python查询缺失值所在位置使用scipy_在稀疏lil_matrix(Scipy / Python)中查找最大值及其索引...
  18. 关系网络lbs的应用_LBS中国起步:探索空间关系的商业化
  19. 计算机网络行业规范的主要内容,计算机网络专业论文
  20. Unity Shader 之 简单实现物体被黑洞吸收吞噬(或者从黑洞中出来)的效果

热门文章

  1. 自考总结(运筹学和管经)
  2. PHP快递100的物流接口快递单号查询
  3. slim.fully_connected()
  4. 编程图记(1): 引言
  5. 看图记设计模式【一】,设计模式是什么?设计模式的组成。
  6. 5.10.1 操作查询之生成表查询
  7. oracle设置密码复杂度、设置oracle超时退出的功能
  8. MATLAB判断是不是质数,matlab 如何表示一个数是不是质数,如题 。
  9. windows服务安装完后自动启动
  10. 朱其罡:推进主流芯片代码进主干,完善OpenHarmony芯片生态支撑