pytorch会自动求导,但是当遇到无法自动求导的时候,需要自己认为定义求导过程,这个时候就涉及到要定义自己的forward和backward函数。

举例如下:

看到这里,大家应该会有很多疑问,比如:

1:ctx.save_for_backward和ctx.saved_tensors的含义

2:backward中各个计算函数的意义,以及backward的输入参数grad_out是什么,以及grad_out包含哪些数据。

针对以上问题,我们一个个解答:

第一个问题:百度吧,答案很多!!!!

第二个问题:拿上面这个例子来看,我们定义了一个类似于线性层的东西,但注意这不是线性层,因为我们是直接把输入和weight用*来做点对点的乘法的,所以这不是我们通常情况下的线性层。

但是这么看也费劲,我们写一个网络,把这个函数加到网络中去,再完整的跑一遍看吧!

测试代码:

结果如下:

现在,来进行解答:

首先,backward函数的返回值,就是对应着forward里面的参数的梯度,也就是说,forward函数里面有几个输入参数,那么backward函数的输出就要有几个!为什么是这样?

我们首先要理解backward的输入grad_out,为什么backward的参数就是一个,因为这是根据链式法则来的

比如,我们定义三个函数H(对应上面网络中linear1),F(自定义函数xjm_inter),D(对应上面网络中linear2),定义一个输入x(对应上面输入a),定义一个输出y(对应上面输出b):

y = D(F(H(X)))

现在,我们求y对x的偏导,那么:

dy/dx = dy/dD * dD/dF * dF/dH * dH/dx

好吧看到这里你可能还是不懂,为什么backward的参数就是一个grad_out!!

我们韩式以上面则个函数为例子,但是,我们现在不求y对x的导数,我们假设F函数有一个叶子节点(或者说requires_grad=True)的参数w1,现在我们要求y对w1的导数:

所以dy/dw1 = dy/dD *dD/dF * dF/dw1。

那么此时,F就是我们上面代码中自定义的xjm_inter函数,则 grad_out = dy/dD *dD/dF。

怎么理解呢,根据链式法则,我们呢所定义的网络中的每一层都是一个单独的函数,所以函数中的变量的最终求导其实只取决于该函数本身,链式法则求导传递过来的其实永远都知识一个值,这就是为什么backward函数的输出只有一个。

那么我们的backward要实现什么样的功能呢?说到这里,大家应该大概能明白了,就是实现当前层那的梯度计算,并进行返回,所以,这也是为什么backward的返回值要和forward的输入值一一对应,否则会报错。

pytorch自定义forward和backward函数相关推荐

  1. pytorch 1.9.0 backward函数解释以及报错(RuntimeError: grad can be implicitly created only for scalar outputs)

    文章目录 官方文档 简单示例 示例1 示例2(报错(RuntimeError: grad can be implicitly created only for scalar outputs)解决方法) ...

  2. torch 的 forward 和 backward

    Criterions有其forward和backward函数 https://github.com/torch/nn/blob/master/doc/criterion.md Module也有其for ...

  3. pytorch自定义函数实现自动梯度

    Motivation 构建模型有时需要使用自定义的函数,为了不影响模型的反向传播,需要实现自动梯度计算(即把自定义函数嵌入计算图). 实现 要点: 将函数定义为类,需继承自torch.autograd ...

  4. Pytorch中backward函数

    backward函数是反向求导数,使用链式法则求导,如果对非标量y求导,函数需要额外指定grad_tensors,grad_tensors的shape必须和y的相同. import torch fro ...

  5. pytorch自定义模型执行过程

    使用pytorch定义自己的模型是继承nn.Module实现的.在__init__方法中定义需要初始化的参数,一般把网络中具有可学习参数的层放在这里定义.forward方法实现模型的功能,实现各个层之 ...

  6. pytorch自定义算子 native_functions.yaml

    pytorch自定义算子 native_functions.yaml 在pytorch的文件夹中搜索native_functions.yaml,可以看到,所有pytorch原生的函数都是在这里注册的 ...

  7. java如何给一个链表定义和传值_如何在CUDA中为Transformer编写一个PyTorch自定义层...

    如今,深度学习模型处于持续的演进中,它们正变得庞大而复杂.研究者们通常通过组合现有的 TensorFlow 或 PyTorch 操作符来发现新的架构.然而,有时候,我们可能需要通过自定义的操作符来实现 ...

  8. Pytorch自定义数据集

    简述 Pytorch自定义数据集方法,应该是用pytorch做算法的最基本的东西. 往往网络上给的demo都是基于torch自带的MNIST的相关类.所以,为了解决使用其他的数据集,在查阅了torch ...

  9. [转]一文解释PyTorch求导相关 (backward, autograd.grad)

    PyTorch是动态图,即计算图的搭建和运算是同时的,随时可以输出结果:而TensorFlow是静态图. 在pytorch的计算图里只有两种元素:数据(tensor)和 运算(operation) 运 ...

最新文章

  1. 2017年html5行业报告,云适配发布2017 HTML5开发者生态报告 期待更多行业标准
  2. D3.tsv与D3.csv加载数据
  3. Solr集群搭建,zookeeper集群搭建,Solr分片管理,Solr集群下的DataImport,分词配置。...
  4. python数据分析知识点_Python数据分析--Pandas知识点(三)
  5. HTTP 如何传输大文件
  6. ORACLE查看某个表的索引状态
  7. python线程按照顺序执行_Python3多线程之间的执行顺序问题
  8. differential forms
  9. ff7重制版青魔法_狂父重制版发布+妖精的尾巴首次打折¥244+最终幻想4解锁国区新增中文...
  10. mysql decimal被四舍五入_MySQL之ROUND函数四舍五入的陷阱
  11. Android 为View实现双击效果
  12. python自动化办公入门-[Python] 自动化办公 docx操作Word基础代码
  13. linux下搭建python机器学习环境
  14. “康园圈--互联网+校园平台“项目之Sprint3
  15. mysql 修改字段为主键自增_给MySQL中某表增加一个新字段,设为主键值为自动增长。...
  16. cad转dxf格式文件太大_如何将DWG DXF互转,一招教你解决难题
  17. ubuntu18.04系统无法正常连接网络解决办法
  18. 几何分布及其期望计算
  19. IT资讯精选(2022-09-11)
  20. 自己电脑中安装黑群辉NAS

热门文章

  1. 基于Robotics Toolbox的机械臂工作空间求解
  2. JAVA开发基础之使用IDEA导出JAR包
  3. 10月18---10月20号第一周总结
  4. Spark 学习入门教程
  5. 企业信息化建设都包括哪些方面?
  6. 对称加密与非对称加密的区别
  7. solr DIH 设置定时索引
  8. 字符串的初始化(详解)
  9. 分布式与集群的区别?
  10. 使用MATLAB搭建用于时间序列分类的1DCNN模型