点击关注我哦

autograd和动态计算图可以说是pytorch中非常核心的部分,我们在之前的文章中提到:autograd其实就是反向求偏导的过程,而在求偏导的过程中,链式求导法则和雅克比矩阵是其实现的数学基础;Tensor构成的动态计算图是使用pytorch的实现的结构。

backward()函数

backward()是通过将参数(默认为1x1单位张量)通过反向图追踪所有对于该张量的操作,使用链式求导法则从根张量追溯到每个叶子节点以计算梯度。下图描述了pytorch对于函数z = (a + b)(b - c)构建的计算图,以及从根节点z到叶子节点a,b,c的求导过程:

注意:计算图已经在前向传递过程中已经被动态创建了,反向传播仅使用已存在的计算图计算梯度并将其存储在叶子节点中。

为了节约内存,在每一轮迭代完成后,计算图就会被释放,若需要多次调用backward()方法,则需要在使用时添加retain_graph=True,否则会报如下错误:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

若我们在使用过程中,仅仅想求得某个节点的梯度,而非整个图的梯度,则需要用到Tensor的.grad属性,如下列代码所示:

import torch# 创建计算图x = torch.tensor(1.0, requires_grad = True)z = x ** 3# 计算梯度z.backward() print(x.grad.data)

需要注意的是:当调用z.backward()时,将自动计算z.backward(torch.tensor(1.0)),其中 torch.tensor(1.0)是用于终止连式法则梯度乘法的外部梯度。可以将此作为输入传递给MulBackward函数,以进一步计算x的梯度。

在上述的示例中,我们给出了标量对向量的求导过程,那么当向量对向量进行求导时呢?例如,需要计算梯度的张量x和y如下:

x = torch.tensor([0.0, 2.0, 8.0], requires_grad = True)y = torch.tensor([5.0 , 1.0 , 7.0], requires_grad = True)z = x * y

此时调用z.backward()函数将会报如下错误:

RuntimeError: grad can be implicitly created only for scalar outputs

错误提示我们只能应用于标量输出。若我们想对向量z进行梯度计算,先了解一下Jacobian矩阵。

Jacobian矩阵和向量

从数学角度上来讲:雅克比矩阵是基于函数对所有变量一阶偏导数的数值矩阵,当输入个数等于输出个数时又称为雅克比行列式。

而autograd类在实际运用的过程中也是通过计算雅克比向量积实现对向量梯度的计算。简单来说,雅可比矩阵是代表两个向量的所有可能偏导数的矩阵,可以用于求一个向量相对于另一个向量的梯度。

注:在此过程中,PyTorch不会显式构造整个Jacobian矩阵,而是直接计算Jacobian矢量积,这种计算方式更为简便。

如果向量X = [x1,x2,… xn]通过函数f(X)= [f1,f2,… fn]计算其他向量,假设f对于x的每个一阶偏导数都存在,则f(X)相对于X的梯度矩阵为:

假设待计算梯度的张量X为:X = [x1,x2,… xn](机器学习模型的权重),X可以进行一些运算以形成向量Y:Y = f(X)= [y1,y2,… ym]。然后,使用Y来计算标量损失l。假设向量v恰好是标量损失l相对于向量Y的梯度,则:

此时,向量v则被称为grad_tensor,即梯度张量。并将其作为参数传递给backward()函数。为了获得损失l相对于权重X的梯度,将Jacobian矩阵J与向量v相乘,得到最终梯度:

综上所述,pytorch在使用计算图求导的过程中整体可以分为以下两种情况:

1. 若标量对向量求导,则可以直接调用backward()函数;

2. 若向量A对向量B求导,则先求得向量A对于向量B的Jacobian矩阵,并将其与grad_tensors对应的矩阵进行点乘计算得到最终梯度。

·  END  ·

RECOMMEND推荐阅读

1. 效率提升的软件大礼包

2. 深度学习——入门PyTorch(一)

3. 深度学习——入门PyTorch(二)

4. PyTorch入门——autograd(一)

5. PyTorch入门——autograd(二)

pytorch 矩阵相乘_深度学习 — — PyTorch入门(三)相关推荐

  1. pytorch 训练过程acc_深度学习Pytorch实现分类模型

    今天将介绍深度学习中的分类模型,以下主要介绍Softmax的基本概念.神经网络模型.交叉熵损失函数.准确率以及Pytorch实现图像分类.01Softmax基本概念 在分类问题中,通常标签都为类别,可 ...

  2. pytorch 矩阵相乘_编译PyTorch静态库

    背景 众所周知,PyTorch项目作为一个C++工程,是基于CMake进行构建的.然而当你想基于CMake来构建PyTorch静态库时,你会发现: 静态编译相关的文档不全: CMake文件bug太多, ...

  3. pytorch 矩阵相乘_深入浅出PyTorch(算子篇)

    Tensor 自从张量(Tensor)计算这个概念出现后,神经网络的算法就可以看作是一系列的张量计算.所谓的张量,它原本是个数学概念,表示各种向量或者数值之间的关系.PyTorch的张量(torch. ...

  4. python多分类混淆矩阵代码_深度学习自学记录(3)——两种多分类混淆矩阵的Python实现(含代码)...

    深度学习自学记录(3)--两种多分类混淆矩阵的Python实现(含代码),矩阵,样本,模型,类别,真实 深度学习自学记录(3)--两种多分类混淆矩阵的Python实现(含代码) 深度学习自学记录(3) ...

  5. 深度学习 --- 优化入门三(梯度消失和激活函数ReLU)

    前两篇的优化主要是针对梯度的存在的问题,如鞍点,局部最优,梯度悬崖这些问题的优化,本节将详细探讨梯度消失问题,梯度消失问题在BP的网络里详细的介绍过(兴趣有请的查看我的这篇文章),然后主要精力介绍Ru ...

  6. unet是残差网络吗_深度学习系列(三)卷积神经网络模型(ResNet、ResNeXt、DenseNet、DenceUnet)...

    深度学习系列(三)卷积神经网络模型(ResNet.ResNeXt.DenseNet.Dence Unet) 内容目录 1.ResNet2.ResNeXt3.DenseNet4.Dence Unet 1 ...

  7. nin神经网络_深度学习基础(三)NIN_Network In Network

    该论文提出了一种新颖的深度网络结构,称为"Network In Network"(NIN),以增强模型对感受野内local patches的辨别能力.与传统的CNNs相比,NIN主 ...

  8. torch的拼接函数_从零开始深度学习Pytorch笔记(13)—— torch.optim

    前文传送门: 从零开始深度学习Pytorch笔记(1)--安装Pytorch 从零开始深度学习Pytorch笔记(2)--张量的创建(上) 从零开始深度学习Pytorch笔记(3)--张量的创建(下) ...

  9. 深度学习Pytorch框架

    深度学习Pytorch框架 文章目录 深度学习Pytorch框架 前言 1. Pytorch命令之``nn.Sequential`` 2. Pytorch命令之``nn.Conv2d`` 3. Pyt ...

最新文章

  1. AJAX实用教程——获取博客园博文列表
  2. 华为LTE 模块AT 命令拨号上网流程
  3. mysql半同步复制
  4. mysql索引优化实际例子_MySQL索引优化的实际案例分析
  5. 解决 idea 复制jsp 文件过来页面报404
  6. 小程序日历插件的使用
  7. MFC学习笔记(1)
  8. Ubuntu环境下下载Android-SDK-Linux之后使用adb连接设备报错
  9. linux设置进程开机启动,Linux应用程序开机自动启动设置方法
  10. python单例模式学习
  11. ES6模板字符串中使用变量
  12. TRUNK理论与配置实验
  13. DaZeng:Vue全家桶实现小米商城(二)
  14. 腾讯云数据库TDSQL-C(原CynosDB)的外网访问配置
  15. OCR:ECCV 2020 论文了解
  16. 纪念丹尼斯——C语言之父
  17. 经济数据预测 | Python实现CNN-LSTM股票价格预测时间序列预测
  18. Linux mount 命令
  19. Unix 文件系统的核心目录总结
  20. 小波卷积网络Multi-level Wavelet-CNN for Image Restoration论文阅读笔记

热门文章

  1. java搭建线程池框架,JAVA线程池管理及分布式HADOOP调度框架搭建
  2. mysql 8.0认证失败_Node.js无法对MySQL 8.0进行身份验证
  3. idea中event log_【JavaScript 教程】事件——Event 对象
  4. python魔法方法与函数_在Python中画图(基于Jupyter notebook的魔法函数)
  5. html异形轮播,异形滚动
  6. c++二维数组指针详解
  7. matplotlib使用GridSpec自定义子图位置 (非对称的子图)
  8. OpenCV绘图和注释
  9. 从‘一边拉琴,一边哭’,看什么是真正的兴趣
  10. MySQL流浪记(七)—— MySQL删除表数据