PyTorch对Optimizer类的实现大部分都在Python上,只有计算用到了C++的部分,所以还是可以继续分析的。

总览

Optimizer类是所有具体优化器类的一个基类。下面一幅图表示一下。

这里我以SGD类为例自下而上地介绍一下。

Optimizer类中重要的成员变量只有两个,self.param_groups和self.state。

self.param_groups用于存储模型参数和优化器本身的一些参数(如学习率等)。

self.state则用于存储更新过程中模型参数对应的各种临时状态,如MSGD中每个参数需要对应一个动量。而每个参数可能不止需要对应一个临时状态。因此self.state是一个键值对类型为parameter:dict的有序字典。

Optimizer类中重要的方法只有一个 add_param_group,它是用来初始化self.param_groups的。

而self.state的初始化需要在某个具体的优化器类中进行。

self.param_groups如何初始化?

self.param_groups在optimizer类的__init__方法中初始化完成。

这里可以先看一下SGD类的初始化方法,它将lr,momentum等优化器参数打包成字典defaults,然后和模型参数params一起传入optimizer类的初始化方法中。

class SGD(Optimizer):def __init__(self, params, lr=required, momentum=0, dampening=0,weight_decay=0, nesterov=False):if lr is not required and lr < 0.0:raise ValueError("Invalid learning rate: {}".format(lr))if momentum < 0.0:raise ValueError("Invalid momentum value: {}".format(momentum))if weight_decay < 0.0:raise ValueError("Invalid weight_decay value: {}".format(weight_decay))#打包优化器参数defaults = dict(lr=lr, momentum=momentum, dampening=dampening,weight_decay=weight_decay, nesterov=nesterov)if nesterov and (momentum <= 0 or dampening != 0):raise ValueError("Nesterov momentum requires a momentum and zero dampening")super(SGD, self).__init__(params, defaults)

而在Optimizer类的初始化方法中,对于defaults,它只是将defaults存储起来。对于params则是先转换成列表形式,之后转换成一个由列表封装的字典。然后对这个字典执行self.add_param_group。

至此我们还是没有看到self.param_groups到底是怎么初始化的,所以需要继续看self.add_param_group这个方法。

注意区分这里的self.param_groups和param_groups。

class Optimizer(object):def __init__(self, params, defaults):torch._C._log_api_usage_once("python.optimizer")self.defaults = defaultsself._hook_for_profile()if isinstance(params, torch.Tensor):raise TypeError("params argument given to the optimizer should be ""an iterable of Tensors or dicts, but got " +torch.typename(params))#self.state初始化self.state = defaultdict(dict)self.param_groups = []param_groups = list(params)if len(param_groups) == 0:raise ValueError("optimizer got an empty parameter list")#一般情况下param_groups[0]是一个parameters类#这里其实是在判断param_groups之前有没有被封装过。if not isinstance(param_groups[0], dict):param_groups = [{'params': param_groups}]#虽然是遍历操作,但是其实并不是遍历所有参数。for param_group in param_groups:#这里的param_group等价于{'params': param_groups}self.add_param_group(param_group)

add_param_group源码:

def add_param_group(self, param_group):r"""Add a param group to the :class:`Optimizer` s `param_groups`.This can be useful when fine tuning a pre-trained network as frozen layers can be madetrainable and added to the :class:`Optimizer` as training progresses.Args:param_group (dict): Specifies what Tensors should be optimized along with groupspecific optimization options."""#前面都不重要,都是一些边界条件的判断assert isinstance(param_group, dict), "param group must be a dict"params = param_group['params']if isinstance(params, torch.Tensor):param_group['params'] = [params]elif isinstance(params, set):raise TypeError('optimizer parameters need to be organized in ordered collections, but ''the ordering of tensors in sets will change between runs. Please use a list instead.')else:param_group['params'] = list(params)for param in param_group['params']:if not isinstance(param, torch.Tensor):raise TypeError("optimizer can only optimize Tensors, ""but one of the params is " + torch.typename(param))if not param.is_leaf:raise ValueError("can't optimize a non-leaf Tensor")#这里开始就是self.param_groups的初始化了#defaults在这里加入param_groupfor name, default in self.defaults.items():if default is required and name not in param_group:raise ValueError("parameter group didn't specify a value of required optimization parameter " +name)else:param_group.setdefault(name, default)params = param_group['params']if len(params) != len(set(params)):warnings.warn("optimizer contains a parameter group with duplicate parameters; ""in future, this will cause an error; ""see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)param_set = set()for group in self.param_groups:param_set.update(set(group['params']))if not param_set.isdisjoint(set(param_group['params'])):raise ValueError("some parameters appear in more than one parameter group")#把param_group这个字典放入self.param_groups这个空列表里#这样初始化就完成了self.param_groups.append(param_group)

add_param_group中对param_group的操作其实很简单,就是将之前的优化器参数self.defaults也放入param_group里。然后再把param_group存到self.param_groups里。

接下来从一个实际的例子看看是不是这样:

import torch
X = torch.tensor([1.0],requires_grad = True)
Y = torch.tensor([2.0],requires_grad = True)
optimizer = torch.optim.SGD([X,Y],lr =0.001)
print(optimizer.param_groups)
"""
输出结果:
[
{
'params': [tensor([1.], requires_grad=True), tensor([2.], requires_grad=True)],
'lr': 0.001,
'momentum': 0,
'dampening': 0,
'weight_decay': 0,
'nesterov': False
}
]
"""

self.state如何更新?

上面self.param_groups初始化过程介绍的差不多了,接下来考虑self.state的初始化和更新问题,因为之前说过self.state每一次迭代都会更新。而优化器的更新操作是放在step这个方法里的,但是optimizer基类并不会实现step这个方法,需要每一个子类自己去实现。所以我这里以SGD为例介绍一下优化器的更新流程。

传统的SGD肯定是不需要使用self.state的,PyTorch这里的SGD只有在带动量的情况下会需要使用self.state。动量的意思简单来说就是存储过去的梯度信息。这样相比于SGD只基于当前梯度进行更新,带动量的SGD可以基于当前+过去的梯度进行更新,收敛更快。

class SGD(Optimizer):
@torch.no_grad()
def step(self, closure=None):"""Performs a single optimization step.Args:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:with torch.enable_grad():loss = closure()for group in self.param_groups:#存储有梯度的参数params_with_grad = []#存储参数对应的梯度d_p_list = []#存储动量momentum_buffer_list = []#正则化系数weight_decay = group['weight_decay']#动量系数momentum = group['momentum']#忘了dampening = group['dampening']nesterov = group['nesterov']#学习率lr = group['lr']#从self.state中取出momentum_buffer#初始化momentum_buffer_list#注意此时的momentum_buffer只包含过去的梯度信息for p in group['params']:if p.grad is not None:params_with_grad.append(p)d_p_list.append(p.grad)#自动初始化为空字典state = self.state[p]if 'momentum_buffer' not in state:momentum_buffer_list.append(None)else:momentum_buffer_list.append(state['momentum_buffer'])#对参数更新#并且更新momentum_buffer#该函数执行完后momentum_buffer将包含过去+现在的梯度信息F.sgd(params_with_grad,d_p_list,momentum_buffer_list,weight_decay=weight_decay,momentum=momentum,lr=lr,dampening=dampening,nesterov=nesterov)#momentum_buffer_list是通过append复制操作得到state里的momentum_buffer的#所以虽然momentum_buffer_list已经更新了,但是state里的momentum_buffer还没更新#所以需要同步一下,便于下一次迭代继续从state里取momentum_buffer。# update momentum_buffers in statefor p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):#每个参数都对应一个state字典。#在这里,每个参数都对应一个动量state = self.state[p]state['momentum_buffer'] = momentum_bufferreturn loss

这里看到step方法被@torch.no_grad()装饰器修饰,是因为需要对叶子节点做inplace操作,我之前有这一部分的介绍,这里就不赘述了。

step方法对参数的更新主要分为三步:

第一步是从self.state中取出momentum_buffer转换成列表形式momentum_buffer_list。

第二步是对参数进行更新,同时momentum_buffer_list也会得到更新。

第三步是利用更新后的momentum_buffer_list对state中的momentum_buffer进行更新。

真正的更新操作都被放在了F.sgd里,这里我做了注释,大家有兴趣可以看一下。

def sgd(params: List[Tensor],d_p_list: List[Tensor],momentum_buffer_list: List[Optional[Tensor]],*,weight_decay: float,momentum: float,lr: float,dampening: float,nesterov: bool):r"""Functional API that performs SGD algorithm computation.See :class:`~torch.optim.SGD` for details."""#遍历所有参数for i, param in enumerate(params):#取出参数对应的梯度d_p = d_p_list[i]#梯度加上正则项if weight_decay != 0:d_p = d_p.add(param, alpha=weight_decay)#取出上一次迭代得到的动量(准备更新动量)if momentum != 0:buf = momentum_buffer_list[i]#第一次迭代时的动量初始化if buf is None:buf = torch.clone(d_p).detach()momentum_buffer_list[i] = buf#动量更新,都是inplace操作#mb*momentum_factor+(1-dampening)*gradelse:buf.mul_(momentum).add_(d_p, alpha=1 - dampening)# nesterov的更新方式if nesterov:d_p = d_p.add(buf, alpha=momentum)# 常规动量的更新方式else:d_p = buf#参数更新param.add_(d_p, alpha=-lr)

总结

这里只是对optimizer中和更新相关的源码进行了介绍,不过optimizer类中还有很多其他的方法,我目前都用不到,所以就暂时不看了。

PyTorch 源码分析:Optimizer类相关推荐

  1. spring Quartz 源码分析--触发器类CronTriggerBean源码剖析

    前面我们讲到了Quartz框架在项目中的实现,在Quartz中的重要API有两个重要的触发器类:CronTrigger 和SimpleTrigger 在Quartz框架中这两个触发器都继承了一个抽象基 ...

  2. 【SLAM学习笔记】11-ORB_SLAM3关键源码分析⑨ Optimizer(六)地图回环优化

    2021SC@SDUSC 目录 1.前言 2.代码分析 1.前言 这一部分代码量巨大,查阅了很多资料结合来看的代码,将分为以下部分进行分析 单帧优化 局部地图优化 全局优化 尺度与重力优化 sim3优 ...

  3. 【SLAM学习笔记】12-ORB_SLAM3关键源码分析⑩ Optimizer(七)地图融合优化

    2021SC@SDUSC 目录 1.前言 2.代码分析 1.前言 这一部分代码量巨大,查阅了很多资料结合来看的代码,将分为以下部分进行分析 单帧优化 局部地图优化 全局优化 尺度与重力优化 sim3优 ...

  4. 【SLAM学习笔记】6-ORB_SLAM3关键源码分析④ Optimizer(一)单帧优化

    2021SC@SDUSC 目录 1.前言 2.代码分析 1.前言 Optimizer是非常重要的代码文件!! 这一部分代码量巨大,查阅了很多资料结合来看的代码,将分为以下部分进行分析 1. 单帧优化 ...

  5. Pytorch源码分析

    目录 命名空间/类/方法/函数/变量 torch.autograd.Function中的ctx参数 DDP(DistributedDataParallel)的构造函数 torch.floor(inpu ...

  6. springSecurity源码分析——DelegatingFilterProxy类的作用

    http://www.cnblogs.com/hzhuxin/archive/2011/12/19/2293730.html 使用过springSecurity的朋友都知道,首先需要在web.xml进 ...

  7. MariaDB源码分析——CONNECT类

    当主线程accept新连接之后,会调用handle_accepted_socket函数,申请CONNECT类对象,调用create_new_thread函数,该函数为CONNECT类对象的thread ...

  8. Hadoop3.2.1 【 HDFS 】源码分析 :FSDirectory类解析

    Table of Contents 一.前言. 二.构造方法 三.常量 四.方法 一.前言. Namenode最重要的两个功能之一就是维护整个文件系统的目录树(即命名空间namesystem) . H ...

  9. 【Groovy】闭包 Closure ( 闭包类 Closure 简介 | this、owner、delegate 成员赋值及源码分析 )

    文章目录 总结 一.闭包类 Closure 简介 二.闭包类 Closure 中 this.owner.delegate 成员 源码分析 三.分析编译后的字节码文件内容 总结 在闭包中 , 打印 th ...

最新文章

  1. module r8169
  2. Es6 generator浅入浅出
  3. 北京科技计算机与通信工程学院,北京科技大学计算机与通信工程学院-任超
  4. python 表达式求值数据结构_python 数据结构与算法
  5. burp的intruder报错Payload set 1: Invalid number settings
  6. 【AI视野·今日NLP 自然语言处理论文速览 第九期】Thu, 17 Jun 2021
  7. qq html消息,类似于QQ新消息提醒-前端
  8. 什么是validationQuery
  9. windows服务初识
  10. 纯HTML的个人简历,真的超简单,有源码
  11. jrtplib 编译安装配置
  12. 实用供热空调设计手册_暖通空调设计与施工数据图表手册
  13. 文档类型字符集即HTML标签的语义化
  14. 中文的括号和英文的括号区别_工具推荐 含笔顺及英文的汉字书写练习纸
  15. 小白兔写话_小白兔写话二年级作文
  16. 怎样用word制作标准格式公文操作实例
  17. 解决机械+固态的电脑无法安装window10系统的问题
  18. 宽带猫各指示灯的含义
  19. GATOR齿轮---凯利讯半导体
  20. ESP32驱动震动传感器、MAX4466(感知层)

热门文章

  1. [转载]健康养肾的最佳动作(图)
  2. Loadrunner License
  3. VS2013运行OpenGL例子提示找不到GL/glew.h
  4. 如何理解“页式存储管理方案”
  5. HihoCoder 1835 K-Dimensional Foil II ICPC2018 北京网络赛
  6. Java实现网上招聘系统(Servlet+Jsp+Mybatis+Oracle 个人用户简历操作+企业用户简历筛查)
  7. html input dropdown,选择下拉插件-Dropdown.js
  8. Mastering ROS for Robotics Programming第二版阅读笔记
  9. [已解决] Adding visible gpu devecies:
  10. 强制删除|病毒清除好帮手Unlocker