文章目录

  • 引言
  • 一、什么是优化器?
  • 二、optimizer的基本属性
  • 三、optimizer的基本方法
  • 四、方法实例
    • 1.optimizer.step()
    • 2. optimizer.zero_grad()
    • 3. optimizer.add_param_group()
    • 4. optimizer.state_dict()
    • 5. optimizer.load_state_dict()
  • 五、优化器中的常用参数
    • 1.learning rate 学习率
    • 2.momentum 动量
  • 六、Pytorch十种优化器简介

引言

  本文学习优化器optimizer的基本属性、基本方法和作用

一、什么是优化器?

  pytorch的优化器:管理更新模型中可学习参数的值,使得模型输出更接近真实标签。通俗一点,就是采样梯度更新模型的可学习参数,使得损失减小。

二、optimizer的基本属性

class Optimizer(object):def __init__(self, params, defaults):self.defaults = defaultsself.state = defaultdict(dict)self.param_groups = []...param_groups = [{'params': param_groups}]
  • defaults:优化器超参数
  • state:参数的缓存,如momentum的缓存
  • params_groups:管理的参数组
  • _step_count:记录更新次数,学习率调整中使用

三、optimizer的基本方法

class Optimizer(object):def __init__(self, params, defaults):self.defaults = defaultsself.state = defaultdict(dict)self.param_groups = []...param_groups = [{'params': param_groups}]def zero_grad(self):for group in self.param_groups:for p in group['params']:if p.grad is not None:p.grad.detach_()# 清零p.grad.zero_()def add_param_group(self, param_group):for group in self.param_groups:param_set.update(set(group['params’]))...        self.param_groups.append(param_group)def state_dict(self):...return {'state': packed_state,'param_groups': param_groups, }def load_state_dict(self, state_dict):...
  • zero_grad():清空所管理参数的梯度
    pytorch特性:张量梯度不自动清零,会将张量梯度累加;因此,需要在使用完梯度之后,或者在反向传播前,将梯度自动清零
  • step():执行一步更新
  • add_param_group():添加参数组,例如:可以为特征提取层与全连接层设置不同的学习率或者别的超参数
  • state_dict():获取优化器当前状态信息字典
    长时间的训练,会隔一段时间保存当前的状态信息,用来在断点的时候恢复训练,避免由于意外的原因导致模型的终止
  • load_state_dict() :加载状态信息字典

四、方法实例

1.optimizer.step()

import torch
import random
import numpy as np
import torch.optim as optimdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed(1)  # 设置随机种子weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("weight before step:{}".format(weight.data))
# 梯度一步更新
optimizer.step()
print("weight after step:{}".format(weight.data))
weight before step:tensor([[0.6614, 0.2669],[0.0617, 0.6213]])
weight after step:tensor([[ 0.5614,  0.1669],[-0.0383,  0.5213]])

2. optimizer.zero_grad()

import torch
import random
import numpy as np
import torch.optim as optimdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed(1)  # 设置随机种子weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("weight before step:{}".format(weight.data))
# 梯度一步更新
optimizer.step()
print("weight after step:{}".format(weight.data))
# 地址相同
print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))print("weight.grad is {}\n".format(weight.grad))
# 将梯度清零
optimizer.zero_grad()
print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))
weight before step:tensor([[0.6614, 0.2669],[0.0617, 0.6213]])
weight after step:tensor([[ 0.5614,  0.1669],[-0.0383,  0.5213]])
weight in optimizer:2063731163904
weight in weight:2063731163904
weight.grad is tensor([[1., 1.],[1., 1.]])
after optimizer.zero_grad(), weight.grad is
tensor([[0., 0.],[0., 0.]])

3. optimizer.add_param_group()

import torch
import random
import numpy as np
import torch.optim as optimdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed(1)  # 设置随机种子weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)print("optimizer.param_groups is\n{}".format(optimizer.param_groups))w2 = torch.randn((3, 3), requires_grad=True)
# 添加参数组,设置不同参数有不同的学习率
optimizer.add_param_group({"params": w2, 'lr': 0.0001})print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
optimizer.param_groups is
[{'params': [tensor([[0.6614, 0.2669],[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]
optimizer.param_groups is
[{'params': [tensor([[0.6614, 0.2669],[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, {'params': [tensor([[-0.4519, -0.1661, -1.5228],[ 0.3817, -1.0276, -0.5631],[-0.8923, -0.0583, -0.1955]], requires_grad=True)], 'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

4. optimizer.state_dict()

import torch
import random
import numpy as np
import torch.optim as optimdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed(1)  # 设置随机种子weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
# 用于保存优化器的状态信息,通常用于断点的续训练
opt_state_dict = optimizer.state_dict()print("state_dict before step:\n", opt_state_dict)for i in range(10):optimizer.step()
# 获取优化器当前状态信息字典
print("state_dict after step:\n", optimizer.state_dict())
# 将状态信息保存下来
torch.save(optimizer.state_dict(), os.path.join('..', "optimizer_state_dict.pkl"))
state_dict before step:{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0]}]}
state_dict after step:{'state': {0: {'momentum_buffer': tensor([[6.5132, 6.5132],[6.5132, 6.5132]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0]}]}

5. optimizer.load_state_dict()

import torch
import random
import numpy as np
import torch.optim as optimdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed(1)  # 设置随机种子weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))optimizer = optim.SGD([weight], lr=0.1)# 加载文件
state_dict = torch.load(os.path.join('..', "optimizer_state_dict.pkl"))print("state_dict before load state:\n", optimizer.state_dict())
# 加载状态信息字典
optimizer.load_state_dict(state_dict)
print("state_dict after load state:\n", optimizer.state_dict())
state_dict before load state:{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0]}]}
state_dict after load state:{'state': {0: {'momentum_buffer': tensor([[6.5132, 6.5132],[6.5132, 6.5132]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0]}]}

五、优化器中的常用参数

1.learning rate 学习率

梯度下降:

PyTorch学习—13.优化器optimizer的概念及常用优化器相关推荐

  1. Python基础学习(13)—面向对象2(特殊方法,分装及装饰器)

    面向对象(2) 1.1 特殊方法(魔术方法) 1 特殊方法例如__init__的都是以__开头__结尾的方法, 特殊方法会在特定的时候自动调用,init会在对象创建以后立即执行并且init会对新创建的 ...

  2. MySQL中级优化教程(一)——SQL常用优化工具及explain语句的使用

    序言: 说来惭愧,java学了两年,期间虽在博客上记了一些东西,可也不曾写过什么系统的教程,前一段时间开始学习MySQL数据库优化相关的知识,就想着趁着这个机会好好整理一份电子档出来,即方便自己之后回 ...

  3. Java常用的垃圾收集器_JVM垃圾收集算法及常用垃圾收集器

    程序计数器.虚拟机栈.本地方法栈随线程而生随线程而灭,栈帧分配多少内存在类结构确定后就确定了.垃圾回收针对的是Java堆和方法区. 一:对象已死吗 在垃圾收集器进行回收前,第一件事就是确定这些对象哪些 ...

  4. 妈耶,讲得好详细,十分钟彻底看懂深度学习常用优化器SGD、RMSProp、Adam详解分析

    深度学习常用优化器学习总结 常用优化器 SGD RMS Prop Adam 常用优化器 SGD 基本思想:通过当前梯度和历史梯度共同调节梯度的方向和大小 我们首先根据pytorch官方文档上的这个流程 ...

  5. 陷波器介绍_50Hz工频信号陷波器设计

    文章目录 学习目标: 基本概念: 基本原理: 参数的具体计算及选择: 具体计算 # 陷波器的意义 学习目标: 了解陷波器的基本概念 掌握50HZ工频陷波器的基本电路图 基本概念: 提示:这里可以添加要 ...

  6. python装饰器类-基于类的python装饰器

    python装饰器函数其实是这样一个接口约束,它必须接受一个callable对象作为参数,然后返回一个callable对象.在Python中一般callable对象都是函数,但也有例外.只要某个对象重 ...

  7. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  8. Pytorch优化器Optimizer

    优化器Optimizer 什么是优化器 pytorch的优化器:管理并更新模型中可学习参数的值,使得模型输出更接近真实标签 导数:函数在指定坐标轴上的变化率 方向导数:指定方向上的变化率(二元及以上函 ...

  9. 【PyTorch基础教程9】优化器optimizer和训练过程

    学习总结 (1)每个优化器都是一个类,一定要进行实例化才能使用,比如: class Net(nn.Moddule):··· net = Net() optim = torch.optim.SGD(ne ...

  10. pytorch学习笔记十二:优化器

    前言 机器学习中的五个步骤:数据 --> 模型 --> 损失函数 --> 优化器 --> 迭代训练,通过前向传播,得到模型的输出和真实标签之间的差异,也就是损失函数,有了损失函 ...

最新文章

  1. keras 的 example 文件 babi_rnn.py 解析
  2. Kali Linux 安全渗透教程第七更 大学霸1.4.3安装至VMware Workstation
  3. Charles使用1
  4. 光纤激光切机计算机无法启动,激光切割机不出光,如何解决?
  5. C++学习笔记7[指针]
  6. [html] 你认为写出什么样的html代码才是好代码呢?
  7. java httpcomponents_java – 如何使用Apache httpcomponents从NHttpRequ...
  8. this.$modal.confirm 自定义按钮关闭_【勤哲资料】7.6 自定义打印
  9. 快速排序方法——python实现
  10. 解决实例化Servlet类[com.mu.servlet.HelloServlet]异常
  11. 机器学习、深度神经网络的认识与结论
  12. kylin cube 增量和全量
  13. 5星评价,1位小数显示
  14. android硬件加速默认,Android的硬件加速
  15. layui爬坑之路——input value属性绑定函数返回值
  16. BZOJ3202 [Sdoi2013]项链
  17. 奇异值分解的几何原理
  18. 程序员白piao服务器。大派送
  19. IGBT学习记录(一)
  20. ROM制作---原生安卓国内适配部分修改点

热门文章

  1. 功能选中jquery实现全选反选功能
  2. Silverlight.XNA(C#)跨平台3D游戏研发手记:(七)向Windows Phone移植之双向交互
  3. C++中-运算符与.运算符的具体使用
  4. android studio connot resolve
  5. python+request+Excel做接口自动化测试
  6. 前两篇转载别人的精彩文章,自己也总结一下python split的用法吧!
  7. cocos2dx阴影层的实现
  8. WPF依赖属性(续)(4)依赖属性与数据绑定
  9. (解决MATLAB的使用问题)避免每次双击.m文件都会自动打开一个matlab程序
  10. 使用hiredis接口(Synchronous API)编写redis客户端