文章目录

  • 1.叶节点、中间节点、梯度计算
  • 2.叶子张量 leaf tensor (叶子节点) (detach)
    • 2.1 为什么需要叶子节点?
    • 2.2 detach()将节点剥离成叶子节点
    • 2.3 什么样节点会是叶子节点
    • 2.3 detach(),detach_() 的作用和区别
    • 2.4 clone()与detach()的区别
  • 5.optimizer.zero_grad()
  • 3.loss.backward()
  • 4.optimizer.step()

1.叶节点、中间节点、梯度计算

  • 所有属性requires_grad=False的张量是叶子节点(即:叶子张量、叶子节点张量)。
  • 对于属性requires_grad=True的张量可能是叶子节点张量,也可能不是叶子节点张量而是中间节点(中间节点张量)。如果该张量的属性requires_grad=True,而且是用于直接创建的,也即它的属性grad_fn=None,那么它就是叶子节点。如果该张量的属性requires_grad=True,但是它不是用户直接创建的,而是由其他张量经过某些运算操作产生的,那么它就不是叶子张量,而是中间节点张量,并且它的属性grad_fn不是None,比如:grad_fn=,这表示该张量是通过torch.mean()运算操作产生的,是中间结果,所以是中间节点张量,所以不是叶子节点张量。
  • 判断一个张量是不是叶子节点,可以通过它的属性is_leaf来查看。
  • 一个张量的属性requires_grad用来指示在反向传播时,是否需要为这个张量计算梯度。如果这个张量的属性requires_grad=False,那么就不需要为这个张量计算梯度,也就不需要为这个张量进行优化学习。
  • 在PyTorch的运算操作中,如果参加这个运算操作的所有输入张量的属性requires_grad都是False的话,那么这个运算操作产生的结果,即输出张量的属性requires_grad也是False,否则是True.。即输入的张量只要有一个需要求梯度(属性requires_grad=True),那么得到的结果张量也是需要求梯度的(属性requires_grad=True)。只有当所有的输入张量都不需要求梯度时,得到的结果张量才会不需要求梯度。
  • 对于属性requires_grad=True的张量,在反向传播时,会为该张量计算梯度.。但是pytorch的自动梯度机制不会为中间结果保存梯度,即只会为叶子节点计算的梯度保存起来,保存到该叶子节点张量的属性grad中,不会在中间节点张量的属性grad中保存这个张量的梯度,这是出于对效率的考虑,中间节点张量的属性grad是None。如果用户需要为中间节点保存梯度的话,可以让这个中间节点调用方法retain_grad(),这样梯度就会保存在这个中间节点的grad属性中。
  • 只有叶子节点有梯度值grad,非叶节点为None。只有非叶节点有grad_fn,叶节点为None。

2.叶子张量 leaf tensor (叶子节点) (detach)

  • 在Pytorch中,默认情况下,非叶节点的梯度值在反向传播过程中使用完后就会被清除,不会被保留。只有叶节点的梯度值能够被保留下来。
  • 在Pytorch神经网络中,我们反向传播backward()就是为了求叶子节点的梯度。在pytorch中,神经网络层中的权值w的tensor均为叶子节点。它们的require_grad都是True,但它们都属于用户创建的,所以都是叶子节点。而反向传播backward()也就是为了求它们的梯度。
  • 在调用backward()时,只有当requires_grad和is_leaf同时为真时,才会计算节点的梯度值。

2.1 为什么需要叶子节点?

那些非叶子节点,是通过用户所定义的叶子节点的一系列运算生成的,也就是这些非叶子节点都是中间变量,一般情况下,用户不回去使用这些中间变量的导数,所以为了节省内存,它们在用完之后就被释放了。

在Pytorch的autograd机制中,当tensor的requires_grad值为True时,在backward()反向传播计算梯度时才会被计算。在所有的require_grad=True中,默认情况下,非叶子节点的梯度值在反向传播过程中使用完后就会被清除,不会被保留(即调用loss.backward() 会将计算图的隐藏变量梯度清除)。默认情况下,只有叶子节点的梯度值能够被保留下来。被保留下来的叶子节点的梯度值会存入tensor的grad属性中,在 optimizer.step()过程中会更新叶子节点的data属性值,从而实现参数的更新。

2.2 detach()将节点剥离成叶子节点

如果需要使得某一个节点成为叶子节点,只需使用detach()即可将它从创建它的计算图中分离开来。即detach()函数的作用就是把一个节点从计算图中剥离,使其成为叶子节点。

2.3 什么样节点会是叶子节点

①所有requires_grad为False的张量,都约定俗成地归结为叶子张量。 就像我们训练模型的input,它们都是require_grad=False,因为他们不需要计算梯度(我们训练网络训练的是网络模型的权重,而不需要训练输入)。它们是一个计算图都是起始点。

②requires_grad为True的张量, 如果他们是由用户创建的,则它们是叶张量(leaf Tensor)。例如各种网络层,nn.Linear(), nn.Conv2d()等, 他们是用户创建的,而且其网络参数也需要训练,所以requires_grad=True这意味着它们不是运算的结果,因此gra_fn为None。

2.3 detach(),detach_() 的作用和区别

detach()

返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。即使之后重新将它的requires_grad置为true,它也不会具有梯度grad。这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播。

使用detach返回的tensor和原始的tensor共同一个内存,即一个修改另一个也会跟着改变。

detach_()

将一个tensor从创建它的图中分离,并把它设置成叶子tensor。其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子tensor是x,但是这个时候对m进行了m.detach_()操作,其实就是进行了两个操作:

将m的grad_fn的值设置为None,这样m就不会再与前一个节点x关联,这里的关系就会变成x,m -> y,此时的m就变成了叶子结点。然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度。

detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的tensor

比如x -> m -> y中如果对m进行detach(),后面如果反悔想还是对原来的计算图进行操作还是可以的。但是如果是进行了detach_(),那么原来的计算图也发生了变化,就不能反悔了

2.4 clone()与detach()的区别

clone与原tensor不共享内存,detach与原tensor共享内存。

clone支持梯度回传,detach不支持梯度回传。

如果想要非叶节点也保留梯度的话,可以用retain_grad()。

5.optimizer.zero_grad()

optimizer.zero_grad()清除了优化器中所有 x x x的 x . g r a d x.grad x.grad,在每次loss.backward()之前,不要忘记使用,否则之前的梯度将会累积,这通常不是我们所期望的(也不排除也有人需要利用这个功能)。

3.loss.backward()

损失函数loss定义了模型优劣的标准,loss越小,模型越好,常见的损失函数比如均方差MSE(Mean Square Error),MAE (Mean Absolute Error),交叉熵CE(Cross-entropy) 等。

loss.backward()故名思义,就是将损失loss 向输入侧进行反向传播,同时对于需要进行梯度计算的所有变量 x x x(requires_grad=True),计算梯度 d d x l o s s {\frac{d}{d x}}l o s s dxd​loss ,并将其累积到梯度 x . g r a d x.grad x.grad 中备用,即:

x . g r a d = x . g r a d + d d x l o s s x.grad=x.grad + {\frac{d}{d x}}l o s s x.grad=x.grad+dxd​loss

4.optimizer.step()

optimizer.step()是优化器对 x x x的值进行更新,以随机梯度下降SGD为例:学习率(learning rate, lr)来控制步幅,即: x = x − l r ∗ x . g r a d x = x - lr*x.grad x=x−lr∗x.grad,减号是由于要沿着梯度的反方向调整变量值以减少Cost。

Pytorch中的梯度知识总结相关推荐

  1. 更新fielddata为true_在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新...

    在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新 2020/4/11 FesianXu 前言 在现在的深度模型软件框架中,如TensorFlow和PyTorch等等,都是实现了自动求导 ...

  2. PyTorch中的梯度累积

    我们在训练神经网络的时候,超参数batch_size的大小会对模型最终效果产生很大的影响,通常的经验是,batch_size越小效果越差:batch_size越大模型越稳定.理想很丰满,现实很骨感,很 ...

  3. Pytorch中的梯度回传

    来源:知乎-歪杠小胀 作者:https://zhuanlan.zhihu.com/p/451441329 01 记录写这篇文章的初衷 最近在复现一篇论文的训练代码时,发现原论文中的总loss由多个lo ...

  4. PyTorch中的梯度计算1

    主要用pytorch,对其他的框架用的很少,而且也没经过系统的学习,对动态图和梯度求导没有个准确的认识.今天根据代码仔细的看一下. import torch from torch.autograd i ...

  5. 【深度学习】PyTorch 中的线性回归和梯度下降

    作者 | JNK789   编译 | Flin  来源 | analyticsvidhya 我们正在使用 Jupyter notebook 来运行我们的代码.我们建议在Google Colaborat ...

  6. backward()和zero_grad()在PyTorch中代表什么意思

    文章目录 问:`backward()`和`zero_grad()`是什么意思? backward() zero_grad() 问:求导和梯度什么关系 问:backward不是求导吗,和梯度有什么关系( ...

  7. Pytorch中的序列化容器-度消失和梯度爆炸-nn.Sequential-nn.BatchNorm1d-nn.Dropout

    Pytorch中的序列化容器-度消失和梯度爆炸-nn.Sequential-nn.BatchNorm1d-nn.Dropout 1. 梯度消失和梯度爆炸 在使用pytorch中的序列化 容器之前,我们 ...

  8. Pytorch中的向前计算(autograd)、梯度计算以及实现线性回归操作

    在整个Pytorch框架中, 所有的神经网络本质上都是一个autograd package(自动求导工具包) autograd package提供了一个对Tensors上所有的操作进行自动微分的功能. ...

  9. PyTorch中的叶节点、中间节点、梯度计算等知识点总结

    总结: 按照惯例,所有属性requires_grad=False的张量是叶子节点(即:叶子张量. 叶子节点张量). 对于属性requires_grad=True的张量可能是叶子节点张量也可能不是叶 子 ...

最新文章

  1. mvc ajax提交html标签,asp.net-mvc – 如何使用ajax get或post在带有参数的mvc中将数据从View传递到Controller...
  2. 百度资源管理平台 站长工具 批量添加主站域名 子站域名 域名主动推送
  3. java导出类_java导出excel工具类
  4. Codeforces 1109F. Sasha and Algorithm of Silence's Sounds
  5. 2022.2.28集成电子开关电路TWH8778
  6. 更换ubuntu的root的默认python版本
  7. 默认HotSpot最大直接内存大小
  8. wap移动网页开发rem用法
  9. NSUserDefaults 添加与删除
  10. django 集成个推_Django动态添加定时任务之djangocelery的使用
  11. office2019 使用
  12. mysql join与where_mysql中left join设置条件在on与where时的用法区别分析
  13. java中猜字母_Java有大神会写 猜字母游戏
  14. 服务器如何设置内网IP地址
  15. DNSPod-免费智能DNS解析服务商
  16. 实验二 —— 串口通信
  17. 介绍一个牛逼的Github项目
  18. jupyter notebook 中运行from scipy import stats之后报错FutureWarning:
  19. word里公式后面标号怎么对齐,Word里面公式后面的编号如何与公式最后一行对齐?...
  20. 无敌哥-创新设计思维

热门文章

  1. Android Handle用法
  2. Vi编辑器的常用命令1(文件内操作)
  3. 小程序---微信本地存储的方法
  4. xen(三)xl 工具使用
  5. B4A +GoLang 实现手机端webserver
  6. input框的输入事件
  7. pdf文件转换ppt可编辑_创建,转换和编辑PDF文件的免费工具
  8. Android MVVM封装,MVVM: 这是一个android MVVM 框架,基于谷歌dataBinding技术实现
  9. [2021.10.14][Android P]OpenCamera详细分析(Camera2+Hal3)
  10. 微信小程序实现手机号隐藏,用****代替