损失函数与优化器理解+【PyTorch】在反向传播前为什么要手动将梯度清零?optimizer.zero_grad()
目录
回答一:
回答二:
回答三:
传统的训练函数,一个batch是这么训练的:
使用梯度累加是这么写的:
回答一:
一句话,用来更新和计算影响模型训练和模型输出的网络参数,使其逼近或达到最优值,从而最小化(或最大化)损失函数E(x)
这种算法使用各参数的梯度值来最小化或最大化损失函数E(x)。最常用的一阶优化算法是梯度下降。
回答二:
各位答主都答到了zero_grad()的好处。那我试着从code的角度解释解释,一般来说是如下模板
optimizer.zero_grad() ## 梯度清零
preds = model(inputs) ## inference
loss = criterion(preds, targets) ## 求解loss
loss.backward() ## 反向传播求解梯度
optimizer.step() ## 更新权重参数
- 由于pytorch的动态计算图,当我们使用loss.backward()和opimizer.step()进行梯度下降更新参数的时候,梯度并不会自动清零。并且这两个操作是独立操作。
- backward():反向传播求解梯度。
- step():更新权重参数。
基于以上几点,正好说明了pytorch的一个特点是每一步都是独立功能的操作,因此也就有需要梯度清零的说法,如若不显示的进行optimizer.zero_grad()这一步操作,backward()的时候就会累加梯度,也就有了各位答主所说到的梯度累加这种trick。(当你GPU显存较少时,你又想要调大batch-size,此时你就可以利用PyTorch的这个性质进行梯度的累加来进行backward。)
作者:Gary
链接:https://www.zhihu.com/question/303070254/answer/573504133
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
回答三:
这种模式可以让梯度玩出更多花样,比如说梯度累加(gradient accumulation)
传统的训练函数,一个batch是这么训练的:
for i,(images,target) in enumerate(train_loader):# 1. input outputimages = images.cuda(non_blocking=True)target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)outputs = model(images)loss = criterion(outputs,target)# 2. backwardoptimizer.zero_grad() # reset gradientloss.backward()optimizer.step()
- 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
- optimizer.zero_grad() 清空过往梯度;
- loss.backward() 反向传播,计算当前梯度;
- optimizer.step() 根据梯度更新网络参数
简单的说就是进来一个batch的数据,计算一次梯度,更新一次网络
使用梯度累加是这么写的:
for i,(images,target) in enumerate(train_loader):# 1. input outputimages = images.cuda(non_blocking=True)target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)outputs = model(images)loss = criterion(outputs,target)# 2.1 loss regularizationloss = loss/accumulation_steps # 2.2 back propagationloss.backward()# 3. update parameters of netif((i+1)%accumulation_steps)==0:# optimizer the netoptimizer.step() # update parameters of netoptimizer.zero_grad() # reset gradient
- 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
- loss.backward() 反向传播,计算当前梯度;
- 多次循环步骤1-2,不清空梯度,使梯度累加在已有梯度上;
- 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;
总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。
一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize '变相' 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。
作者:Pascal
链接:https://www.zhihu.com/question/303070254/answer/573037166
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
损失函数与优化器理解+【PyTorch】在反向传播前为什么要手动将梯度清零?optimizer.zero_grad()相关推荐
- PyTorch的损失函数和优化器
文章目录 PyTorch的损失函数和优化器 损失函数 优化器 总结 PyTorch的损失函数和优化器 损失函数 一般来说,PyTorch的损失函数有两种形式:函数形式和模块形式.前者调用的是torch ...
- Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用
Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...
- 系统总结深度学习的主要的损失函数和优化器
本次博客是深度学习的作业:系统总结了深度学习中监督学习的主要的损失函数,并指出各自的适应条件和优缺点.觉得很有意义,就记录下来供大家参考. 文章目录 损失函数 损失函数的作用 损失函数的分类 基于距离 ...
- 自定义----损失函数与优化器
损失函数与优化器 1.相关知识点 内置的计算loss的值,小写直接计算交叉熵,大写返回一个可调用的对象 #内置的计算loss的值,小写直接计算交叉熵,大写返回一个可调用的对象 ls=tf.keras. ...
- 【深度学习】pytorch自动求导机制的理解 | tensor.backward() 反向传播 | tensor.detach()梯度截断函数 | with torch.no_grad()函数
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一.pytorch里自动求导的基础概念 1.1.自动求导 requires_grad=True 1.2.求导 requ ...
- pytorch学习 -- 反向传播backward
pytorch学习 – 反向传播backward 入门学习pytorch,在查看pytorch代码时对autograd的backard函数的使用和具体原理有些疑惑,在查看相关文章和学习后,简单说下我自 ...
- 人工智能-作业1:PyTorch实现反向传播
人工智能-作业1:PyTorch实现反向传播 人工智能-作业1:PyTorch实现反向传播 环境配置: 计算过程 反向传播 PyTorch Autograd自动求导 人工智能-作业1:PyTorch实 ...
- 通俗理解神经网络BP反向传播算法
转载自 通俗理解神经网络BP反向传播算法 通俗理解神经网络BP反向传播算法 在学习深度学习相关知识,无疑都是从神经网络开始入手,在神经网络对参数的学习算法bp算法,接触了很多次,每一次查找资料学习 ...
- 【2021-2022 春学期】人工智能-作业1:PyTorch实现反向传播
1 安装pycharm,配置好python开发环境 PyCharm 安装教程(Windows) | 菜鸟教程 (runoob.com) 2 安装pytorch PyTorch 3 使用pytorch实 ...
最新文章
- [数据结构] 时间复杂度计算
- 博客页面运行代码demo测试
- map端join和reduce端join的区别
- gdc服务器维修公司,gdc服务器阵列架坏了
- 圣思园java.doc_北京圣思园java课堂笔记.doc
- Spring-Cloud-Config消息总线和高可用
- 《JavaScript高级程序设计2》学习笔记——BOM
- Atitit.常用语言的常用内部api 以及API兼容性对源码级别可移植的重要性 总结
- 这么多年的土豆都白吃了!土豆还能这么做,太香了
- 互联网公司的“江湖绰号”盘点,你知道几个?
- linux刻录光盘空间不足,解决Linux下刻录光盘问题
- Mysql中date和datetime的区别
- 三消游戏死局算法的解析
- Kobe -接小球游戏
- PS不能直接拖入图片的解决办法
- 全景krpano相关问题解答
- 网上银行系统1:系统分析
- OpenCV--直线拟合fitLine及求两直线对称轴
- cropper.js 实现HTML5 裁剪上传头像
- 情感日记:想念从未断绝