目录

回答一:

回答二:

回答三:

传统的训练函数,一个batch是这么训练的:

使用梯度累加是这么写的:


回答一:

一句话,用来更新和计算影响模型训练和模型输出的网络参数,使其逼近或达到最优值,从而最小化(或最大化)损失函数E(x)

这种算法使用各参数的梯度值来最小化或最大化损失函数E(x)。最常用的一阶优化算法是梯度下降。


回答二:

各位答主都答到了zero_grad()的好处。那我试着从code的角度解释解释,一般来说是如下模板

optimizer.zero_grad()             ## 梯度清零
preds = model(inputs)             ## inference
loss = criterion(preds, targets)  ## 求解loss
loss.backward()                   ## 反向传播求解梯度
optimizer.step()                  ## 更新权重参数
  1. 由于pytorch的动态计算图,当我们使用loss.backward()和opimizer.step()进行梯度下降更新参数的时候,梯度并不会自动清零。并且这两个操作是独立操作
  2. backward():反向传播求解梯度。
  3. step():更新权重参数。

基于以上几点,正好说明了pytorch的一个特点是每一步都是独立功能的操作,因此也就有需要梯度清零的说法,如若不显示的进行optimizer.zero_grad()这一步操作,backward()的时候就会累加梯度,也就有了各位答主所说到的梯度累加这种trick。(当你GPU显存较少时,你又想要调大batch-size,此时你就可以利用PyTorch的这个性质进行梯度的累加来进行backward。

作者:Gary
链接:https://www.zhihu.com/question/303070254/answer/573504133
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。


回答三:

这种模式可以让梯度玩出更多花样,比如说梯度累加(gradient accumulation)

传统的训练函数,一个batch是这么训练的:

for i,(images,target) in enumerate(train_loader):# 1. input outputimages = images.cuda(non_blocking=True)target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)outputs = model(images)loss = criterion(outputs,target)# 2. backwardoptimizer.zero_grad()   # reset gradientloss.backward()optimizer.step()            
  1. 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  2. optimizer.zero_grad() 清空过往梯度;
  3. loss.backward() 反向传播,计算当前梯度
  4. optimizer.step() 根据梯度更新网络参数

简单的说就是进来一个batch的数据,计算一次梯度,更新一次网络

使用梯度累加是这么写的:

for i,(images,target) in enumerate(train_loader):# 1. input outputimages = images.cuda(non_blocking=True)target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)outputs = model(images)loss = criterion(outputs,target)# 2.1 loss regularizationloss = loss/accumulation_steps   # 2.2 back propagationloss.backward()# 3. update parameters of netif((i+1)%accumulation_steps)==0:# optimizer the netoptimizer.step()        # update parameters of netoptimizer.zero_grad()   # reset gradient
  1. 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  2. loss.backward() 反向传播,计算当前梯度;
  3. 多次循环步骤1-2,不清空梯度,使梯度累加在已有梯度上;
  4. 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;

总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。

一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize '变相' 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。

作者:Pascal
链接:https://www.zhihu.com/question/303070254/answer/573037166
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

损失函数与优化器理解+【PyTorch】在反向传播前为什么要手动将梯度清零?optimizer.zero_grad()相关推荐

  1. PyTorch的损失函数和优化器

    文章目录 PyTorch的损失函数和优化器 损失函数 优化器 总结 PyTorch的损失函数和优化器 损失函数 一般来说,PyTorch的损失函数有两种形式:函数形式和模块形式.前者调用的是torch ...

  2. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

  3. 系统总结深度学习的主要的损失函数和优化器

    本次博客是深度学习的作业:系统总结了深度学习中监督学习的主要的损失函数,并指出各自的适应条件和优缺点.觉得很有意义,就记录下来供大家参考. 文章目录 损失函数 损失函数的作用 损失函数的分类 基于距离 ...

  4. 自定义----损失函数与优化器

    损失函数与优化器 1.相关知识点 内置的计算loss的值,小写直接计算交叉熵,大写返回一个可调用的对象 #内置的计算loss的值,小写直接计算交叉熵,大写返回一个可调用的对象 ls=tf.keras. ...

  5. 【深度学习】pytorch自动求导机制的理解 | tensor.backward() 反向传播 | tensor.detach()梯度截断函数 | with torch.no_grad()函数

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一.pytorch里自动求导的基础概念 1.1.自动求导 requires_grad=True 1.2.求导 requ ...

  6. pytorch学习 -- 反向传播backward

    pytorch学习 – 反向传播backward 入门学习pytorch,在查看pytorch代码时对autograd的backard函数的使用和具体原理有些疑惑,在查看相关文章和学习后,简单说下我自 ...

  7. 人工智能-作业1:PyTorch实现反向传播

    人工智能-作业1:PyTorch实现反向传播 人工智能-作业1:PyTorch实现反向传播 环境配置: 计算过程 反向传播 PyTorch Autograd自动求导 人工智能-作业1:PyTorch实 ...

  8. ​通俗理解神经网络BP反向传播算法

    转载自  ​通俗理解神经网络BP反向传播算法 通俗理解神经网络BP反向传播算法 在学习深度学习相关知识,无疑都是从神经网络开始入手,在神经网络对参数的学习算法bp算法,接触了很多次,每一次查找资料学习 ...

  9. 【2021-2022 春学期】人工智能-作业1:PyTorch实现反向传播

    1 安装pycharm,配置好python开发环境 PyCharm 安装教程(Windows) | 菜鸟教程 (runoob.com) 2 安装pytorch PyTorch 3 使用pytorch实 ...

最新文章

  1. [数据结构] 时间复杂度计算
  2. 博客页面运行代码demo测试
  3. map端join和reduce端join的区别
  4. gdc服务器维修公司,gdc服务器阵列架坏了
  5. 圣思园java.doc_北京圣思园java课堂笔记.doc
  6. Spring-Cloud-Config消息总线和高可用
  7. 《JavaScript高级程序设计2》学习笔记——BOM
  8. Atitit.常用语言的常用内部api 以及API兼容性对源码级别可移植的重要性 总结
  9. 这么多年的土豆都白吃了!土豆还能这么做,太香了
  10. 互联网公司的“江湖绰号”盘点,你知道几个?
  11. linux刻录光盘空间不足,解决Linux下刻录光盘问题
  12. Mysql中date和datetime的区别
  13. 三消游戏死局算法的解析
  14. Kobe -接小球游戏
  15. PS不能直接拖入图片的解决办法
  16. 全景krpano相关问题解答
  17. 网上银行系统1:系统分析
  18. OpenCV--直线拟合fitLine及求两直线对称轴
  19. cropper.js 实现HTML5 裁剪上传头像
  20. 情感日记:想念从未断绝

热门文章

  1. 结构与算法(02):队列和栈结构
  2. scrapy框架架构
  3. Leetcode PHP题解--D84 371. Sum of Two Integers
  4. 常用的前端跨域的几种方式
  5. 好程序员分享居中一个float元素
  6. View、Text、Button的drawableLeft左侧图片自定义大小
  7. redis源码阅读--hashTable
  8. 基于spring自动注入及AOP的表单二次提交验证
  9. Oracle 客户端工具
  10. 思杰VDI十终极结构图及总结