pytorch-backword函数的理解

函数:\(tensor.backward(params)\)

这个params的维度一定要和tensor的一致,因为tensor如果是一个向量y = [y1,y2,y3],那么传入的params=[a1,a2,a3],这三个值是系数,那么是什么的系数呢?
假定对x =[ x1,x2]求导,那么我们知道,
\(dy/dx\) 为:
第一列: \(dy1/dx1,dy2/dx1,dy3/dx1\)
第二列:\(dy1/dx2, dy2/dx2,dy3/dx2\)
从而 \(dy/dx\)是一个3行2列的矩阵,每一列对应了对x1的导数,每一列也就是\(x1\)的梯度向量
而反向计算的时候,并不是返回这个矩阵,而是返回这个矩阵每列的和作为梯度,也就是:\(dy1/dx1+dy2/dx1+dy3/dx1\) 是y对x1的梯度
这就好理解了,系数为\(params=[a1,a2,a3]\)就对应了这加和的三项!也就是,对\(x1\)的梯度实际上是\(a1*dy1/dx1+a2*dy2/dx1+a3*dy3/dx1\)
而输出y是标量的时候,就不需要了,默认的就是\(1.\)

自己重写backward函数时,要写上一个grad_output参数,这个参数就是上面提到的params

这个grad_output参数究竟是什么呢?下面作出解释:
是这样的,假如网络有两层, h = h(x),y = y(h)
你可以计算\(dy/dx\),这样,y.backward(),因为\(dy/dy=1\),那么,backward的参数就可以省略
如果计算h.backward(),因为你想求的是\(dy/dx\),(这才是输出对于输入的梯度),那么,计算图中的y = y(h)就没有考虑到
因为\(dy/dx = dy/dh * dh/dx\),h.backward()求得是\(dh/dx\),那么你必须传入之前的梯度\(dy/dh\)才行,也就是说,h.backward(params=dy/dh)这里面的参数就是\(dy/dh\)

这就好理解了,如果我们自己实现了一层,继承自Function,自己实现静态方法forwardbackward时,backward必须有个grad_output参数,这个参数就是计算图中输出对该自定义层的梯度,这样才能求出对输入的梯度。

另外,假设定义的层计算出的是y,调用的就是y.backward(grad_output),这个里面的参数的维度必须和y是相同的。这也就是为什么前面提到对于输出是多维的,会有个“系数”的原因,这个系数就是后向传播时,该层之前的梯度的累积,这样与本层再累积,才实现了完整的链式法则,最终求出outinput的梯度。

另外,自定义实现forwardbackward时,两函数的输入输出是有要求的,即forward的输入必须和~的return相对应,如forwardinput有个w参数,那么backwardreturn就必须在对应的位置返回grad_w,因为只有这样,才能够对相应的输入参数梯度下降。

转载于:https://www.cnblogs.com/duye/p/9913602.html

【pytorch】pytorch-backward()的理解相关推荐

  1. Pytorch的backward()相关理解

    Pytorch的backward()相关理解 最近一直在用pytorch做GAN相关的实验,pytorch 框架灵活易用,很适合学术界开展研究工作.  这两天遇到了一些模型参数寻优的问题,才发现自己对 ...

  2. 【Pytorch】backward()简单理解

    backward()是反向传播求梯度,具体实现过程如下 import torchx=torch.tensor([1,2,3],requires_grad=True,dtype=torch.double ...

  3. Pytorch的backward()与optim.setp()的理解

    @Xia Pytorch的backward()与optim.setp()的理解 backward()与optim.setp()一直对这两个函数他们之间的调用不是很清楚,花点时间应该是搞明白了. 先看最 ...

  4. pytorch中repeat()函数理解

    pytorch中repeat()函数理解 最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解. 情况1:repeat参数个数与tensor维数一致时 a = torch ...

  5. pytorch 中 contiguous() 函数理解

    pytorch 中 contiguous() 函数理解 文章目录 pytorch 中 contiguous() 函数理解 引言 使用 contiguous() 后记 文章抄自 Pytorch中cont ...

  6. [pytorch] Pytorch入门

    Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...

  7. 深度理解Pytorch中backward()

    转自https://blog.csdn.net/douhaoexia/article/details/78821428 接触pytorch很久了,也自认为对 backward 方法有一定了解,但看了这 ...

  8. 高效理解pytorch的backward需要scalar outputs

    利用backward时 , 可能经常遇到错误 RuntimeError: grad can be implicitly created only for scalar outputs 理解的最好方式就 ...

  9. pytorch lstm crf 代码理解 重点

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  10. pytorch lstm crf 代码理解

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

最新文章

  1. Data Lake Analytics + OSS数据文件格式处理大全
  2. Linux下添加DB2用户
  3. STL 之find,find_if,find_end,find_first_of
  4. 移动端点击屏幕按钮闪现的灰色底框
  5. 【结论】取石子游戏(jzoj 1211)
  6. Redis配置文件redis.config详解以及关闭Redis服务
  7. 错误的艺术!20个创意的404错误页面设计
  8. python切割图片文字_Python+opencv 实现图片文字的分割的方法示例
  9. NHibernate Inheritance Mapping 继承映射
  10. 如何更改static控件的字体大小
  11. 【QT】简单易学的QT安装教程
  12. 基于python的表情识别_python表情识别
  13. 用flash做古诗动画_Flash制作跟我学 用遮罩技术制作古诗动画-FLASH课件制作(FLASH课件制作教程)-flash课件吧(湖北金鹰)...
  14. java.gg_JAVA公文管理系统
  15. 阿里云来担保商标注册申请,担保有哪些程序(详细教程)
  16. adpcb 添加差分对_在AD中PCB设计常用规则——差分规则设置?
  17. 深度linux deepin 内存,【转载】深度Deepin国产操作系统使用体验报告!
  18. c语言csp字符串,骇人听闻的 CSP
  19. [c语言编程入门]迭代法求平方根
  20. 根据单头价格清单(核价单),更新单身出货明细的单价

热门文章

  1. cocos2d-x在win7下的android交叉编译环境
  2. 安装Macports遇到的问题和PATH设置
  3. Android程序如何在代码中改变图片原有的颜色
  4. Tempdb数据库详细介绍
  5. RunnableException与CheckedException
  6. A Strange Bitcoin Transaction
  7. 如何使用JavaScript Math.floor生成范围内的随机整数-已解决
  8. 背景图自适应屏幕居中显示,且不变形
  9. 离线安装k8s 1.9.0
  10. java高并发编程(二)