在用pytorch训练模型时,通常会在遍历epochs的每一轮batach的过程中依次用到以下三个函数

  • optimizer.zero_grad();
  • loss.backward();
  • optimizer.step()
model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)for epoch in range(1, epochs):for i, (inputs, labels) in enumerate(train_loader):optimizer.zero_grad() # 梯度清零preds = model(inputs) # 利用模型求解预测值loss = criterion(preds, labels) # 求解lossloss.backward() # 反向传播求解梯度optimizer.step() # 更新权重参数

总得来说,这三个函数的作用是:

  • 先将梯度值归零,:optimizer.zero_grad();
  • 然后反向传播计算得到每个参数的梯度值:loss.backward();
  • 最后通过梯度下降执行一步参数更新:optimizer.step();

由于pytorch的动态计算图,当我们使用loss.backward()和opimizer.step()进行梯度下降更新参数的时候,梯度并不会自动清零。并且这两个操作是独立操作。

基于以上几点,正好说明了pytorch的一个特点是每一步都是独立功能的操作,因此也就有需要梯度清零的说法,如若不显示的进行optimizer.zero_grad()这一步操作,backward()的时候就会累加梯度,也就有了梯度累加这种trick。

接下来将通过源码分别理解这三个函数的具体实现过程。在此之前,先简要说明一下函数中常见的参数变量:

  • param_groups:Optimizer类在实例化时会在构造函数中创建一个param_groups列表,列表中有num_groups个长度为6的param_group字典(num_groups取决于你定义optimizer时传入了几组参数),每个param_group包含了 [‘params’, ‘lr’, ‘momentum’, ‘dampening’, ‘weight_decay’, ‘nesterov’] 这6组键值对。

  • param_group[‘params’]:由传入的模型参数组成的列表,即实例化Optimizer类时传入该group的参数,如果参数没有分组,则为整个模型的参数model.parameters(),每个参数是一个torch.nn.parameter.Parameter对象。

一、optimizer.zero_grad()【参数的梯度值归零】

    def zero_grad(self):r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""for group in self.param_groups:for p in group['params']:if p.grad is not None:p.grad.detach_()p.grad.zero_()

optimizer.zero_grad()函数会遍历模型的所有参数,通过p.grad.detach_()方法截断反向传播的梯度流,再通过p.grad.zero_()函数将每个参数的梯度值设为0,即上一次的梯度记录被清空。

因为训练的过程通常使用mini-batch方法,所以如果不将梯度清零的话,梯度会与上一个batch的数据相关,因此该函数要写在反向传播和梯度下降之前。

二、loss.backward()【计算参数的梯度值】

PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。

具体来说,torch.tensor是autograd包的基础类,如果你设置tensor的requires_grads为True,就会开始跟踪这个tensor上面的所有运算,如果你做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去。

更具体地说,损失函数loss是由模型的所有权重 www 经过一系列运算得到的,若某个 www 的 requires_grads 为True,则 www 的所有上层参数(后面层的权重 www)的.grad_fn属性中就保存了对应的运算,然后在使用loss.backward()后,会一层层的反向传播计算每个 www 的梯度值,并保存到该w的.grad属性中

如果没有进行tensor.backward()的话,梯度值将会是None,因此loss.backward()要写在optimizer.step()之前。

三、optimizer.step()【更新参数的值】

以SGD为例,torch.optim.SGD().step()源码如下:

    def step(self, closure=None):"""Performs a single optimization step.Arguments:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:loss = closure()for group in self.param_groups:weight_decay = group['weight_decay']momentum = group['momentum']dampening = group['dampening']nesterov = group['nesterov']for p in group['params']:if p.grad is None:continued_p = p.grad.dataif weight_decay != 0:d_p.add_(weight_decay, p.data)if momentum != 0:param_state = self.state[p]if 'momentum_buffer' not in param_state:buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()else:buf = param_state['momentum_buffer']buf.mul_(momentum).add_(1 - dampening, d_p)if nesterov:d_p = d_p.add(momentum, buf)else:d_p = bufp.data.add_(-group['lr'], d_p)return loss

step()函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step()函数前应先执行loss.backward()函数来计算梯度。

注意:optimizer只负责通过梯度下降进行优化,而不负责产生梯度,梯度是tensor.backward()方法产生的。

四、梯度累加

传统的训练函数,一个batch是这么训练的:

for i, (image, label) in enumerate(train_loader):# 1. input outputpred = model(image)loss = criterion(pred, label)# 2. backwardoptimizer.zero_grad()   # reset gradientloss.backward()optimizer.step()
  • 获取 loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  • optimizer.zero_grad() 清空过往梯度;
  • loss.backward() 反向传播,计算当前梯度;
  • optimizer.step() 根据梯度更新网络参数

简单的说就是进来一个 batch 的数据,计算一次梯度,更新一次网络

使用梯度累加是这么写的:

for i,(image, label) in enumerate(train_loader):# 1. input outputpred = model(image)loss = criterion(pred, label)# 2.1 loss regularizationloss = loss / accumulation_steps  # 2.2 back propagationloss.backward()# 3. update parameters of netif (i+1) % accumulation_steps == 0:# optimizer the netoptimizer.step()        # update parameters of netoptimizer.zero_grad()   # reset gradient
  • 获取 loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  • loss.backward() 反向传播,计算当前梯度;
  • 多次循环步骤 1-2,不清空梯度,使梯度累加在已有梯度上;
  • 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;

总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。

一定条件下,batch_size 越大训练效果越好,梯度累加则实现了 batchsize 的变相扩大,如果accumulation_steps 为 8,则batch_size ‘变相’ 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。




参考资料:
理解optimizer.zero_grad(), loss.backward(), optimizer.step()的作用及原理
PyTorch中在反向传播前为什么要手动将梯度清零?

Pytorch:optim.zero_grad()、pred=model(input)、loss=criterion(pred,tgt)、loss.backward()、optim.step()的作用相关推荐

  1. 为什么pytorch mode = sequential() 为何model(input)这样调用就直接执行了forward

    pytorch mode = sequential() 为何model(input)这样调用就直接执行了forward https://www.cnblogs.com/ailitao/p/117875 ...

  2. 理解optimizer.zero_grad(), loss.backward(), optimizer.step()的作用及原理

    在用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward()和optimizer.step()三个函数,如下所示: ...

  3. 【PyTorch】语言模型/Language model

    1 模型描述 (1)语言模型的定义,来自于维基百科 统计式的语言模型是一个几率分布.语言模型提供上下文来区分听起来相似的单词和短语.例如,短语"再给我两份葱,让我把记忆煎成饼"和& ...

  4. optimizer.zero_grad(), loss.backward(), optimizer.step()的理解及使用

    optimizer.zero_grad,loss.backward,optimizer.step 用法介绍 optimizer.zero_grad(): loss.backward(): optimi ...

  5. pytorch 动态调整学习率,学习率自动下降,根据loss下降

    0 为什么引入学习率衰减? 我们都知道几乎所有的神经网络采取的是梯度下降法来对模型进行最优化,其中标准的权重更新公式: W+=α∗gradient W+=\alpha * \text { gradie ...

  6. loss.backward(),scheduler(), optimizer.step()的作用

    在用pytorch训练模型时(pytorch1.1及以上版本),通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward()和optimizer. ...

  7. 常用损失函数总结(L1 loss、L2 loss、Negative Log-Likelihood loss、Cross-Entropy loss、Hinge Embedding loss、Margi)

    常用损失函数总结(L1 loss.L2 loss.Negative Log-Likelihood loss.Cross-Entropy loss.Hinge Embedding loss.Margi) ...

  8. NLP文本情感分析:测试集loss比训练集loss大很多,训练集效果好测试集效果差的原因

    NLP情感分析:测试集loss比训练集loss大很多 一.前言 二.原因 一.前言 最近在学习神经网络自然语言处理的相关知识,发现运行的之后测试集的loss比训练集的loss大很多,而accuracy ...

  9. ACR Loss: Adaptive Coordinate-based Regression Loss for Face Alignment

    ACR Loss: Adaptive Coordinate-based Regression Loss for Face Alignment Introduction 背景介绍 相关方法 提出的方法 ...

最新文章

  1. RabbitMQ快速入门--介绍和安装
  2. Docker 容器 和 虚拟机 的异同
  3. LeetCode MySQL 1321. 餐馆营业额变化增长(over窗口函数)
  4. StudyTonight 中文系列教程【翻译完成】
  5. HTML中的表格和表单控件详解
  6. mysql 写入性能瓶颈_如何通过性能调优突破MySQL数据库性能瓶颈?
  7. Objective-C与JavaScript交互的那些事
  8. 查询服务器硬件配置的命令
  9. SaasSaaS架构设计之构建Multi-Tenant应用
  10. python第一天环境搭建及基本数据类型与条件语句学习
  11. fat32源码c语言,FAT32文件系统基本原理与数据恢复编程
  12. serv-u 用户使用sftp登录 时间显示不对_宜家中国电商化之路步履蹒跚 或因忽视消费者的使用习惯...
  13. 让 Flutter 在鸿蒙系统上跑起来
  14. ANSNP中线安防 安科瑞 时丽花
  15. C# 通过UDP 远程监控摄像头
  16. 输入一个由小写字母组成的字符串, 按照26个字母表顺序进行排序,打印排序后的字符串;
  17. android境外支付
  18. 洛谷刷题笔记 地球人口承载力估计
  19. ubuntu下docker的lnmp(二) 安装php-fpm之 下载镜像启动容器
  20. CRC16 校验函数

热门文章

  1. 文件搜索工具(Python实现)
  2. 重庆大学计算机在职研究生,重庆大学在职研究生招生学院_重庆大学在职研究生_学校查找_在职研究生教育信息网...
  3. 破译营销最优解,2018E-UP效果营销案例大赛终审完美收官
  4. 四川一度智信:网店养词技巧
  5. Flutter免费视频第二季-常用组件讲解
  6. 怎么在苹果手机计算机上打字,苹果手机怎样设置打字是中文的
  7. c语言编程TLC2543AD采集,51单片机驱动12位AD转换TLC2543电路图+程序
  8. Python模拟Tim登录界面
  9. Ubuntu通过wine安装QQ无法输入账号怎么办?
  10. Python数据分析高薪实战第四天 构建国产电视剧评分数据集