1 用法介绍

 pytorch中torch.autograd.grad函数主要用于计算并返回输出相对于输入的梯度总和,具体的参数作用如下所示:

torch.tril(input, diagonal=0, *, out=None) ⟶\longrightarrow⟶Tensor

  • outputs(sequence of Tensor):表示微分函数的输出
  • inputs (sequence of Tensor):表示微分函数的输入
  • grad_outputs (sequence of Tensor):表示“向量-雅克比矩阵”的向量
  • retain_graph (bool, optional):表示是否需要将计算图释放掉,当计算二阶导数时需要设置为True
  • create_graph (bool, optional):表示是否需要将梯度将会加入到计算图中,当计算高阶导数或者其他计算时会将其设置为需要设置为True
  • allow_unused (bool, optional):表示是否只返回输入的梯度,而不返回其他叶子节点的梯度

2 实例讲解

 以下给出了具体的二阶导数解析解的数学实例

给定一个向量x=(x1,x2)⊤{\bf{x}}=(x_1,x_2)^{\top}x=(x1​,x2​)⊤,可以得到向量y=(y1,y2)⊤=(x12,x22)⊤{\bf{y}}=(y_1,y_2)^{\top}=(x^2_1,x^2_2)^{\top}y=(y1​,y2​)⊤=(x12​,x22​)⊤。对向量y{\bf{y}}y的元素求平均可以得到损失函数loss1\mathrm{loss}_1loss1​为:loss1(x)=mean(y)=x12+x222\mathrm{loss}_1({\bf{x}})=\mathrm{mean}({\bf{y}})=\frac{x_1^2+x^2_2}{2}loss1​(x)=mean(y)=2x12​+x22​​向量y{\bf{y}}y元素的分量分别对x{\bf{x}}x求偏导,然后相加求平均得到损失函数loss2\mathrm{loss}_2loss2​为{h1(x)=∂y1∂x=(2x1,0)⊤h2(x)=∂y2∂x=(0,2x2)⊤,loss2(x)=mean(h1(x1)−h2(x2))=x1−x2\left\{\begin{aligned}h_1({\bf{x}})&=\frac{\partial y_1}{\partial {\bf{x}}}=(2x_1,0)^{\top}\\h_2({\bf{x}})&=\frac{\partial y_2}{\partial {\bf{x}}}=(0,2x_2)^{\top}\end{aligned}\right.,\quad \mathrm{loss}_2({\bf{x}})=\mathrm{mean}(h_1({\bf{x}}_1)-h_2({\bf{x}}_2))=x_1-x_2⎩⎨⎧​h1​(x)h2​(x)​=∂x∂y1​​=(2x1​,0)⊤=∂x∂y2​​=(0,2x2​)⊤​,loss2​(x)=mean(h1​(x1​)−h2​(x2​))=x1​−x2​将损失函数loss1\mathrm{loss}_1loss1​与损失函数loss2\mathrm{loss}_2loss2​相加可以得到loss(x)=loss1(x)+loss2(x)=x12+x222+x1−x2\mathrm{loss}({\bf{x}})=\mathrm{loss}_1({\bf{x}})+\mathrm{loss}_2({\bf{x}})=\frac{x_1^2+x_2^2}{2}+x_1-x_2loss(x)=loss1​(x)+loss2​(x)=2x12​+x22​​+x1​−x2​最终损失函数loss\mathrm{loss}loss对向量x{\bf{x}}x的偏导数为∂loss∂x=(x1+1,x2−1)⊤\frac{\partial {\mathrm{loss}}}{\partial{{\bf{x}}}}=(x_1+1,x_2-1)^{\top}∂x∂loss​=(x1​+1,x2​−1)⊤

以下为用pytorch实现二阶导数相对应的代码实例:

import torchx = torch.tensor([5.0, 7.0], requires_grad=True)
y = x**2loss1 = torch.mean(y)h1 = torch.autograd.grad(y[0], x, retain_graph = True, create_graph=True)
h2 = torch.autograd.grad(y[1], x, retain_graph = True, create_graph=True)
loss2 = torch.mean(h1[0] - h2[0])loss = loss1 + loss2result = torch.autograd.grad(loss, x)
print(result)

当向量x{\bf{x}}x取值为(5,7)⊤(5,7)^{\top}(5,7)⊤时,根据数学解析解得到的二阶导数为(6,6)⊤(6,6)^{\top}(6,6)⊤,对应的代码运行的实验结果也为(6,6)(6,6)(6,6)。

torch.autograd.grad求二阶导数相关推荐

  1. torch.autograd学习系列之torch.autograd.grad()函数学习

    前言:上一次我们学习了torch.autograd.backward()方法,这是一个计算反向过程的核心方法,没看过的小伙伴可以去看看 传送门:https://blog.csdn.net/Li7819 ...

  2. [转]一文解释PyTorch求导相关 (backward, autograd.grad)

    PyTorch是动态图,即计算图的搭建和运算是同时的,随时可以输出结果:而TensorFlow是静态图. 在pytorch的计算图里只有两种元素:数据(tensor)和 运算(operation) 运 ...

  3. 【Torch笔记】autograd自动求导系统

    [Torch笔记]autograd自动求导系统 Pytorch 提供的自动求导系统 autograd,我们不需要手动地去计算梯度,只需要搭建好前向传播的计算图,然后使用 autograd 计算梯度即可 ...

  4. python grad_torch.autograd.grad()函数用法示例

    目录 一.函数解释 如果输入x,输出是y,则求y关于x的导数(梯度): def grad(outputs, inputs, grad_outputs=None, retain_graph=None, ...

  5. Pytorch autograd.grad与autograd.backward详解

    Pytorch autograd.grad与autograd.backward详解 引言 平时在写 Pytorch 训练脚本时,都是下面这种无脑按步骤走: outputs = model(inputs ...

  6. PyTorch 1.0 中文文档:torch.autograd

    译者:gfjiangly torch.autograd 提供类和函数,实现任意标量值函数的自动微分. 它要求对已有代码的最小改变-你仅需要用requires_grad=True关键字为需要计算梯度的声 ...

  7. pytorch求导总结(torch.autograd)

    1.Autograd 求导机制 我们在用神经网络求解PDE时, 经常要用到输出值对输入变量(不是Weights和Biases)求导: 例如在训练WGAN-GP 时, 也会用到网络对输入变量的求导,py ...

  8. 使用torch.autograd.function解决dist.all_gather不能反向传播问题

    1. 问题来源 最近在用mmcv复现Partial FC模型,看到源码中,有单独写的前向反向传播,甚是疑惑- 源码: # Features all-gather total_features = to ...

  9. Pytorch的自定义拓展:torch.nn.Module和torch.autograd.Function

    参考链接:pytorch的自定义拓展之(一)--torch.nn.Module和torch.autograd.Function_LoveMIss-Y的博客-CSDN博客_pytorch自定义backw ...

最新文章

  1. postmaster.c 中的 ListenAddresses
  2. linux内核分两种,Linux内核版本
  3. 压缩可以卸载吗_不可错过!螺杆压缩机故障分析详解(2)
  4. android通知栏半透明,Android开发实现透明通知栏
  5. Xcode的一些控制台命令
  6. vue --- [全家桶]vue-router
  7. letsencrypt 自动续期不关闭nginx
  8. BZOJ1010 [HNOI2008]玩具装箱
  9. Atom 编辑器安装 linter-eslint 插件,并配置使其支持 vue 文件中的 js 格式校验
  10. Android Studio报错:Could not download kotlin-reflect.jar (org.jetbrains.kotlin:kotlin-reflect:1.3.61)
  11. 使用Python下载电视剧(二):下载ts片段
  12. Yalmip最优化求解器+matlab | 教程(一)
  13. 2022 Google翻译修复工具 V1.3 【谷歌浏览器无法翻译网页问题解决】
  14. (模拟)HDU - 5857 Median
  15. 快速在网站跳转支付宝付款链接
  16. html颜色代码错误,HTML颜色代码表
  17. 批量执行ABAQUS的inp文件——整理
  18. Azure媒体服务的Apple FairPlay流功能正式上线
  19. 75 道 JavaScript 面试题
  20. 转:solr 从数据库导入数据,全量索引和增量索引(实例配置原理)

热门文章

  1. 【代码质量】-阿里巴巴java开发手册(代码质量提升神器)学习笔记
  2. 吟诵,不为吟诵 - 徐健顺
  3. 自学Python:快速查找文件或文件夹
  4. lammps一对一课程学习大纲
  5. MLB的选秀会有哪些规定和流程·棒球6号位
  6. 鸿蒙招聘店铺主是真的吗,为什么很多店铺门口贴着招聘,然而进去问都说招满了,但是招聘的内容还是放在那不收走?...
  7. as 贪食蛇小游戏(一)
  8. 程序员眼中的编程语言和操作系统
  9. Google:我能把文本变成音乐,但这个 AI 模型不能对外发布!
  10. 微信小程序的详细登录(上)