在用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward()optimizer.step()三个函数,如下所示:

model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)for epoch in range(1, epochs):for i, (inputs, labels) in enumerate(train_loader):output= model(inputs)loss = criterion(output, labels)# compute gradient and do SGD stepoptimizer.zero_grad()loss.backward()optimizer.step()

总得来说,这三个函数的作用是先将梯度归零(optimizer.zero_grad()),然后反向传播计算得到每个参数的梯度值(loss.backward()),最后通过梯度下降执行一步参数更新(optimizer.step())

接下来将通过源码分别理解这三个函数的具体实现过程。在此之前,先简要说明一下函数中常见的参数变量:

param_groups:Optimizer类在实例化时会在构造函数中创建一个param_groups列表,列表中有num_groups个长度为6的param_group字典(num_groups取决于你定义optimizer时传入了几组参数),每个param_group包含了 ['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'] 这6组键值对。

param_group['params']:由传入的模型参数组成的列表,即实例化Optimizer类时传入该group的参数,如果参数没有分组,则为整个模型的参数model.parameters(),每个参数是一个torch.nn.parameter.Parameter对象。

一、optimizer.zero_grad():

    def zero_grad(self):r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""for group in self.param_groups:for p in group['params']:if p.grad is not None:p.grad.detach_()p.grad.zero_()

optimizer.zero_grad()函数会遍历模型的所有参数,通过p.grad.detach_()方法截断反向传播的梯度流,再通过p.grad.zero_()函数将每个参数的梯度值设为0,即上一次的梯度记录被清空。

因为训练的过程通常使用mini-batch方法,所以如果不将梯度清零的话,梯度会与上一个batch的数据相关,因此该函数要写在反向传播和梯度下降之前。

二、loss.backward():

PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。

具体来说,torch.tensor是autograd包的基础类,如果你设置tensor的requires_grads为True,就会开始跟踪这个tensor上面的所有运算,如果你做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去。

更具体地说,损失函数loss是由模型的所有权重w经过一系列运算得到的,若某个w的requires_grads为True,则w的所有上层参数(后面层的权重w)的.grad_fn属性中就保存了对应的运算,然后在使用loss.backward()后,会一层层的反向传播计算每个w的梯度值,并保存到该w的.grad属性中。

如果没有进行tensor.backward()的话,梯度值将会是None,因此loss.backward()要写在optimizer.step()之前。

三、optimizer.step():

以SGD为例,torch.optim.SGD().step()源码如下:

    def step(self, closure=None):"""Performs a single optimization step.Arguments:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:loss = closure()for group in self.param_groups:weight_decay = group['weight_decay']momentum = group['momentum']dampening = group['dampening']nesterov = group['nesterov']for p in group['params']:if p.grad is None:continued_p = p.grad.dataif weight_decay != 0:d_p.add_(weight_decay, p.data)if momentum != 0:param_state = self.state[p]if 'momentum_buffer' not in param_state:buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()else:buf = param_state['momentum_buffer']buf.mul_(momentum).add_(1 - dampening, d_p)if nesterov:d_p = d_p.add(momentum, buf)else:d_p = bufp.data.add_(-group['lr'], d_p)return loss

step()函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step()函数前应先执行loss.backward()函数来计算梯度。

注意:optimizer只负责通过梯度下降进行优化,而不负责产生梯度,梯度是tensor.backward()方法产生的。

参考:https://www.cnblogs.com/Thinker-pcw/p/9630367.html

理解optimizer.zero_grad(), loss.backward(), optimizer.step()的作用及原理相关推荐

  1. optimizer.zero_grad(), loss.backward(), optimizer.step()的理解及使用

    optimizer.zero_grad,loss.backward,optimizer.step 用法介绍 optimizer.zero_grad(): loss.backward(): optimi ...

  2. 梯度值与参数更新optimizer.zero_grad(),loss.backward、和optimizer.step()、lr_scheduler.step原理解析

    在用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward.和optimizer.step().lr_schedule ...

  3. Pytorch:optim.zero_grad()、pred=model(input)、loss=criterion(pred,tgt)、loss.backward()、optim.step()的作用

    在用pytorch训练模型时,通常会在遍历epochs的每一轮batach的过程中依次用到以下三个函数 optimizer.zero_grad(): loss.backward(): optimize ...

  4. Loss.backward()

    若只有一个loss outputs = model(inputs) loss = criterrion(outputs,target)optimizer.zero_grad() loss.backwa ...

  5. optimizer.zero_grad(),loss.backward(),optimizer.step()的作用原理

    目录 前言 一.optimizer.zero_grad() 二. loss.backward() 三.optimizer.step() 前言 在用pytorch训练模型时,通常会在遍历epochs的过 ...

  6. pytorch-->optimizer.zero_grad()、loss.backward()、optimizer.step()和scheduler.step()

    优化器optimizer的作用 优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数值的作用. 一般来说,以下三个函数的使用顺序如下: # compute gradient ...

  7. loss.backward(),scheduler(), optimizer.step()的作用

    在用pytorch训练模型时(pytorch1.1及以上版本),通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward()和optimizer. ...

  8. Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

    引言 一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.s ...

  9. model.train()、model.eval()、optimizer.zero_grad()、loss.backward()、optimizer.step作用及原理详解【Pytorch入门手册】

    1. model.train() model.train()的作用是启用 Batch Normalization 和 Dropout. 如果模型中有BN层(Batch Normalization)和D ...

最新文章

  1. Android studio 查看sha1
  2. 【C 语言】文件操作 ( 配置文件读写 | 读取配置文件 | 函数接口形参 | 读取配置文件的逐行遍历操作 | 读取一行文本 | 查找字符 | 删除字符串前后空格 )
  3. 静态方法-应用场景和定义方式
  4. 统计字符串中每种字符类型的个数demo
  5. [密码学基础][每个信息安全博士生应该知道的52件事][Bristol Cryptography][第30篇]大致简述密钥协商中的BR安全定义
  6. 【转】带你玩转Visual Studio——02.带你新建一个工程
  7. 利用Excel进行成对(配对)T检验
  8. 草稿 断开式绑定combobox
  9. 2019-04-02
  10. static和const关键字
  11. 循环链表简单操作 C++
  12. 读取xml节点的数据总结(.net 2.0)
  13. 51nod1264线段相交
  14. Centos 安装配置 Dynamips
  15. mc冒险者传说java_我的世界冒险者传说1.9
  16. 廊坊金彩教育:店铺详情页设计要点
  17. Promise学习:基础入门
  18. 免费图片转pdf的方法?学会图片转pdf很重要
  19. Zend框架:Zend_Nosql_Mong​​o组件建议
  20. node.js学习笔记3 express基本使用、托管静态资源、express中间件

热门文章

  1. 线性稳压电源和开关电源到底有什么区别
  2. 仓库温度湿度控制措施_仓库温湿度管理规定_仓库温湿度监测管理制度
  3. Http 调用netty 服务,服务调用客户端,伪同步响应.ProtoBuf 解决粘包,半包问题.
  4. flutter中好用的Widget-CupertinoPicker
  5. mysql面试题1313
  6. ftp服务器密码为空,ftp服务器设置为无账号密码
  7. 利用7z程序压缩、解压
  8. 用 JS 进行 Base64 编码、解码
  9. 软件测试 之Web项目实战
  10. matlab_颜色矩阵三原色