PyTorch框架学习十三——优化器

  • 一、优化器
  • 二、Optimizer类
    • 1.基本属性
    • 2.基本方法
  • 三、学习率与动量
    • 1.学习率learning rate
    • 2.动量、冲量Momentum
  • 四、十种常见的优化器(简单罗列)

上次笔记简单介绍了一下损失函数的概念以及18种常用的损失函数,这次笔记介绍优化器的相关知识以及PyTorch中的使用。

一、优化器

PyTorch中的优化器:管理并更新模型中可学习参数的值,使得模型输出更接近真实标签。

导数:函数在指定坐标轴上的变化率。
方向导数:指定方向上的变化率。
梯度:一个向量,方向为方向导数取得最大值的方向。

二、Optimizer类

1.基本属性

  1. defaults:优化器超参数,包含优化选项的默认值的dict(当参数组没有指定这些值时使用)。
  2. state:参数的缓存,如momentum的缓存。
  3. param_groups:管理的参数组,形式上是列表,每个元素都是一个字典。

2.基本方法

(1)zero_grad():清空所管理的参数的梯度。因为PyTorch中张量梯度不会自动清零。

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("weight before step:{}".format(weight.data))
optimizer.step()        # 修改lr=1 0.1观察结果
print("weight after step:{}".format(weight.data))print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))print("weight.grad is {}\n".format(weight.grad))
optimizer.zero_grad()
print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))

结果如下:

weight before step:tensor([[0.6614, 0.2669],[0.0617, 0.6213]])
weight after step:tensor([[ 0.5614,  0.1669],[-0.0383,  0.5213]])
weight in optimizer:1314236528344
weight in weight:1314236528344weight.grad is tensor([[1., 1.],[1., 1.]])after optimizer.zero_grad(), weight.grad is
tensor([[0., 0.],[0., 0.]])

(2) step():执行一步优化更新。

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("weight before step:{}".format(weight.data))
optimizer.step()        # 修改lr=1 0.1观察结果
print("weight after step:{}".format(weight.data))

结果如下:

weight before step:tensor([[0.6614, 0.2669],[0.0617, 0.6213]])
weight after step:tensor([[ 0.5614,  0.1669],[-0.0383,  0.5213]])

(3) add_param_group():添加参数组。

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("optimizer.param_groups is\n{}".format(optimizer.param_groups))w2 = torch.randn((3, 3), requires_grad=True)optimizer.add_param_group({"params": w2, 'lr': 0.0001})print("optimizer.param_groups is\n{}".format(optimizer.param_groups))

结果如下:

optimizer.param_groups is
[{'params': [tensor([[0.6614, 0.2669],[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]
optimizer.param_groups is
[{'params': [tensor([[0.6614, 0.2669],[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, {'params': [tensor([[-0.4519, -0.1661, -1.5228],[ 0.3817, -1.0276, -0.5631],[-0.8923, -0.0583, -0.1955]], requires_grad=True)], 'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

(4)state_dict():获取优化器当前状态信息字典。

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
opt_state_dict = optimizer.state_dict()print("state_dict before step:\n", opt_state_dict)for i in range(10):optimizer.step()print("state_dict after step:\n", optimizer.state_dict())torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))

结果如下:

state_dict before step:{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2872948098296]}]}
state_dict after step:{'state': {2872948098296: {'momentum_buffer': tensor([[6.5132, 6.5132],[6.5132, 6.5132]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2872948098296]}]}

获取到了优化器当前状态的信息字典,其中那个2872948098296是存放权重的地址,并将这些参数信息保存为一个pkl文件:

(5)load_state_dict():加载状态信息字典。

optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))print("state_dict before load state:\n", optimizer.state_dict())
optimizer.load_state_dict(state_dict)
print("state_dict after load state:\n", optimizer.state_dict())

从刚刚保存参数的pkl文件中读取参数赋给一个新的空的优化器,结果为:

state_dict before load state:{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [1838346925624]}]}
state_dict after load state:{'state': {1838346925624: {'momentum_buffer': tensor([[6.5132, 6.5132],[6.5132, 6.5132]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [1838346925624]}]}

注:state_dict()与load_state_dict()一般经常用于模型训练中的保存和读取模型参数,防止断电等突发情况导致模型训练强行中断而前功尽弃。

三、学习率与动量

1.学习率learning rate

梯度下降:
其中LR就是学习率,作用是控制更新的步伐,如果太大可能导致模型无法收敛或者是梯度爆炸,如果太小可能使得训练时间过长,需要调节。

2.动量、冲量Momentum

结合当前梯度与上一次更新信息,用于当前更新。
PyTorch中梯度下降的更新公式为:

其中:

  • Wi:第i次更新的参数。
  • lr:学习率。
  • Vi:更新量。
  • m:momentum系数。
  • g(Wi):Wi的梯度。

举个例子:

100这个时刻的更新量不仅与当前梯度有关,还与之前的梯度有关,只是越以前的对当前时刻的影响就越小。

momentum的作用主要是可以加速收敛。

四、十种常见的优化器(简单罗列)

目前对优化器的了解还不多,以后会继续跟进,这里就简单罗列一下:

  1. optim.SGD:随机梯度下降法
  2. optim.Adagrad:自适应学习率梯度下降法
  3. optim.RMSprop:Adagrad的改进
  4. optim.Adadelta:Adagrad的改进
  5. optim.Adam:RMSprop结合Momentum
  6. optim.Adamax:Adam增加学习率上限
  7. optim.SparseAdam:稀疏版的Adam
  8. optim.ASGD:随机平均梯度下降法
  9. optim.Rprop:弹性反向传播
  10. optim.LBFGS :BFGS的改进

PyTorch框架学习十三——优化器相关推荐

  1. pytorch学习十 ---- 优化器

    1.什么是优化器? 首先我们回忆一下机器学习的五大模块:数据.模型.损失函数.优化器.迭代训练 在损失函数中我们会得到一个loss值,即真实标签与预测标签的差异值,对于loss我们通常会采用pytor ...

  2. PyTorch框架学习二十——模型微调(Finetune)

    PyTorch框架学习二十--模型微调(Finetune) 一.Transfer Learning:迁移学习 二.Model Finetune:模型的迁移学习 三.看个例子:用ResNet18预训练模 ...

  3. PyTorch框架学习十九——模型加载与保存

    PyTorch框架学习十九--模型加载与保存 一.序列化与反序列化 二.PyTorch中的序列化与反序列化 1.torch.save 2.torch.load 三.模型的保存 1.方法一:保存整个Mo ...

  4. PyTorch框架学习十七——Batch Normalization

    PyTorch框架学习十七--Batch Normalization 一.BN的概念 二.Internal Covariate Shift(ICS) 三.BN的一个应用案例 四.PyTorch中BN的 ...

  5. PyTorch框架学习十六——正则化与Dropout

    PyTorch框架学习十六--正则化与Dropout 一.泛化误差 二.L2正则化与权值衰减 三.正则化之Dropout 补充: 这次笔记主要关注防止模型过拟合的两种方法:正则化与Dropout. 一 ...

  6. PyTorch框架学习十五——可视化工具TensorBoard

    PyTorch框架学习十五--可视化工具TensorBoard 一.TensorBoard简介 二.TensorBoard安装及测试 三.TensorBoard的使用 1.add_scalar() 2 ...

  7. PyTorch框架学习十四——学习率调整策略

    PyTorch框架学习十四--学习率调整策略 一._LRScheduler类 二.六种常见的学习率调整策略 1.StepLR 2.MultiStepLR 3.ExponentialLR 4.Cosin ...

  8. PyTorch框架学习八——PyTorch数据读取机制(简述)

    PyTorch框架学习八--PyTorch数据读取机制(简述) 一.数据 二.DataLoader与Dataset 1.torch.utils.data.DataLoader 2.torch.util ...

  9. PyTorch的损失函数和优化器

    文章目录 PyTorch的损失函数和优化器 损失函数 优化器 总结 PyTorch的损失函数和优化器 损失函数 一般来说,PyTorch的损失函数有两种形式:函数形式和模块形式.前者调用的是torch ...

最新文章

  1. 超市账单管理系统之-------登录
  2. php 小炒花生,炒花生的做法_炒花生怎么做_炒花生的家常做法
  3. StringUtils之equals
  4. 65. Leetcode 153. 寻找旋转排序数组中的最小值 (二分查找-局部有序)
  5. HoloLens开发手记-全息Hologram
  6. 怎么覆盖默认样式_图形设计软件cdr教程:设置默认字体
  7. 分布式数据库中间件概念
  8. 写论文参考文献,如何查看一些书籍的随书光盘?如何查找一些书籍的原文阅读?如何高效合理的 运用高等学校数字图书馆、大学图书馆? 这里将给你答案
  9. wacom win10 未连接任何设备 驱动的问题 解决影拓3手绘板等老设备驱动无法在win10使用的问题
  10. 逃脱只会部署集群系列 —— k8s集群的网络模型与跨主机通信
  11. 按照斗地主的规则,完成洗牌发牌的动作
  12. RxJava Observer与Subscriber的关系
  13. git学习笔记(三)—— 远程仓库
  14. QT报错:“pure virtual method called; terminate called without an active exception“
  15. 拒绝访问html,为什么IE常常出现拒绝访问 ie拒绝访问的原因及解决方法
  16. IP地址分类 三类IP地址 IPV4
  17. python读取docx文件_Python 实现docx文件的读写操作
  18. IP、子网、等如何计算?
  19. Android 获取设备各种信息以及其它
  20. html设置input圆角矩形_html5中关于input用法的改变

热门文章

  1. reg类型变量综合电路_Verilog中reg型变量的综合效果(待补充)
  2. 多线程 流水线 java_Java Lock锁多线程中实现流水线任务
  3. mysql slow log 分析工具_mysql slow log分析工具的比较
  4. 数据缺失、混乱、重复怎么办?最全数据清洗指南让你所向披靡
  5. 【公益】开放一台Eureka注册中心给各位Spring Cloud爱好者
  6. 论文浅尝 | 利用Lattice LSTM的最优中文命名实体识别方法
  7. 论文浅尝 | 从 6 篇顶会论文看「知识图谱」领域最新研究进展 | 解读 代码
  8. 知识图谱最新权威综述论文解读:实体发现
  9. 【HTML5】Server-Sent服务器发送事件
  10. 将单词的首字母改为大写