​Aurograd自动求导机制总结

PyTorch中,所有神经网络的核心是 autograd 包。autograd 包为tensor上的所有操作提供了自动求导机制。它是一个在运行时定义(define-by-run)的框架,这意味着反向传播是根据代码如何运行来决定的,并且每次迭代可以是不同的.
理解自动求导机制可以帮助我们编写更高效、简洁的程序,并且可以方便我们进行调试。
Aurigrad如何实现自动求导?
Autograd是一个反向自动微分系统。autograd 会记录一个图表,记录在执行操作时创建tensor数据的所有操作,并提供一个有向无环图,其叶子是输入tensor,根是输出tensor。通过从根到叶跟踪此图,就可以使用链式法则自动计算梯度。

Tensor

torch.Tensor 是这个包的核心类。如果设置它的属性 requires_grad 为 True,那么它将会追踪对于该tensor的所有操作。当完成计算后可以通过调用 .backward(),来自动计算所有的梯度。这个tensor的所有梯度将会自动累加到grad属性。
阻止一个tensor被跟踪历史,可以调用 .detach() 方法将其与计算历史分离,并阻止它未来的计算记录被跟踪。
Tensor 和 Function 互相连接生成了一个无圈图(acyclic graph),它编码了完整的计算历史。每个tensor都有一个 grad_fn 属性,该属性引用了创建 Tensor 自身的Function(除非这个tensor是手动创建的, grad_fn 是 None )。

import torch
x = torch.ones(2, 3, requires_grad=True)
print(x)

输出:

tensor([[1., 1., 1.],[1., 1., 1.]], requires_grad=True)

对这个tensor添加运算:

y = x + 2
print(y)
print(y.grad_fn)

输出:

tensor([[3., 3.],[3., 3.]], grad_fn=<AddBackward0>)
<AddBackward0 object at 0x0000021F4129C850>

y是计算的结果,有grad_fn属性。
对y进行更多操作

z = y *  2
out = z.mean()print(z, out)

输出:

tensor([[6., 6.],[6., 6.]], grad_fn=<MulBackward0>) tensor(6., grad_fn=<MeanBackward0>)

.requires_grad_() 原地改变了现有tensor的 requires_grad 标志。如果没有指定的话,默认输入的这个标志是 False。

a = torch.randn(2, 2)
a = ((a * 2) / (a - 2))
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b = (a * a).sum()
print(b.grad_fn)

输出:

False
True
<SumBackward0 object at 0x0000021F46A2AA00>

梯度

下面进行反向传播,因为out是一个标量,因此 out.backward()out.backward(torch.tensor(1.)) 等价。

out.backward()

输出导数 d(out)/dx

print(x.grad)

输出:

tensor([[0.3333, 0.3333, 0.3333],[0.3333, 0.3333, 0.3333]])

torch.autograd 是计算雅可比向量积的一个“引擎”。雅可比向量积的特性使外部梯度输入到具有非标量输出的模型中变得非常方便。
现在我们来看一个雅可比向量积的例子:

x = torch.randn(2, requires_grad=True)
print(x)
y = x * 2
while y.data.norm() < 1000:y = y * 2
print(y)

输出:

tensor([-1.2877, -0.5659], requires_grad=True)
tensor([-1318.5631,  -579.4386], grad_fn=<MulBackward0>)

在这种情况下,y 不再是标量。torch.autograd 不能直接计算完整的雅可比矩阵,但是如果我们只想要雅可比向量积,只需将这个向量作为参数传给 backward:

v = torch.tensor([1, 1], dtype=torch.float)
y.backward(v)
print(x.grad)

输出:

tensor([1024., 1024.])

还可以通过将代码块包装在 with torch.no_grad(): 中,来阻止autograd跟踪设置了 .requires_grad=True 的tensor的历史记录。

print(x)
print(x.requires_grad)
print((x ** 2).requires_grad)
with torch.no_grad():print((x ** 2).requires_grad)

输出:

tensor([-1.2877, -0.5659], requires_grad=True)
True
True
False

保存Tensor

一些操作需要在前向传递期间保存中间结果,以便执行反向传递。例如,函数y = x**2 ,需要保存输入的梯度。
可以使用 save_for_backward()在前向传播期间保存tensor并使用 saved_tensors在后向传递期间检索它们。

对于 PyTorch 定义的操作(例如torch.pow()),tensor会根据需要自动保存。grad_fn可以通过查找以前缀开头的属性来探索某个out保存了哪些tensor_saved。

import torch
x = torch.randn(2, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self))
print(x is y.grad_fn._saved_self)

输出:

True
True

在前面的代码中,引用与x y.grad_fn._saved_self相同的 Tensor 对象。但有一些情况不太相同。例如:

x = torch.randn(2, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result))  # True
print(y is y.grad_fn._saved_result)  # False

输出:

True
False

在后台,为了防止引用循环,PyTorch在保存时打包了张量,并将其解包到不同的张量中以供读取。在这里,您从访问中获得的张量y.grad_fn._saved_result是一个不同的张量对象y(但它们仍然共享相同的存储)。
一个tensor是否会被打包到一个不同的张量对象中取决于它是否是它自己的grad_fn的输出。

局部禁用梯度计算

Python 有几种机制可以在本地禁用梯度计算:
要在整个代码块中禁用梯度,有no-grad和inference mode。为了从梯度计算中更细粒度地排除子图,可以设置requires_grad tensor的字段。
evaluation mode ( nn.Module.eval())方法不是用于禁用梯度计算,但经常因为名字与禁止梯度计算相混淆。

设置requires_grad

requires_grad
​requires_grad是一个标志,默认为 false除非包含在 nn.Parameter中,它允许从梯度计算中细粒度地排除子图。它在向前和向后传播中都生效:
在前向传播过程中,只有至少一个输入tensor需要 grad 时,才会将操作记录在后向图中。在后向传播.backward()过程中,只有requires_grad=True的叶tensors才会有梯度积累到它们的.grad字段中。

多线程 Autograd

autograd 引擎负责运行计算反向传递所需的所有反向操作。

#定义一个在不同线程中使用的train函数
def train_fn():x = torch.ones(2, 2, requires_grad=True)# forwardy = (x + 5) * (x + 5) * 0.1# backwardy.sum().backward()# 优化器更新# 编写线程代码来驱动train_fn
threads = []
for _ in range(10):p = threading.Thread(target=train_fn, args=())p.start()threads.append(p)for p in threads:p.join()
print(threads)

输出:

[<Thread(Thread-5, stopped 11328)>, <Thread(Thread-6, stopped 13548)>, <Thread(Thread-7, stopped 14440)>, <Thread(Thread-8, stopped 12720)>, <Thread(Thread-9, stopped 2416)>, <Thread(Thread-10, stopped 3820)>, <Thread(Thread-11, stopped 10688)>, <Thread(Thread-12, stopped 4620)>, <Thread(Thread-13, stopped 2200)>, <Thread(Thread-14, stopped 14916)>]

Aurograd自动求导机制还有许多强大的功能,可参考官方文档torch.autograd。

【PyTorch学习(三)】Aurograd自动求导机制总结相关推荐

  1. PyTorch的计算图和自动求导机制

    文章目录 PyTorch的计算图和自动求导机制 自动求导机制简介 自动求导机制实例 梯度函数的使用 计算图构建的启用和禁用 总结 PyTorch的计算图和自动求导机制 自动求导机制简介 PyTorch ...

  2. 深度学习修炼(三)——自动求导机制

    文章目录 致谢 3 自动求导机制 3.1 传播机制与计算图 3.1.1 前向传播 3.1.2 反向传播 3.2 自动求导 3.3 再来做一次 3.4 线性回归 3.4.1 回归 3.4.2 线性回归的 ...

  3. Pytorch学习(一)—— 自动求导机制

    现在对 CNN 有了一定的了解,同时在 GitHub 上找了几个 examples 来学习,对网络的搭建有了笼统地认识,但是发现有好多基础 pytorch 的知识需要补习,所以慢慢从官网 API 进行 ...

  4. 【PyTorch基础教程2】自动求导机制(学不会来打我啊)

    文章目录 第一部分:深度学习和机器学习 一.机器学习任务 二.ML和DL区别 (1)数据加载 (2)模型实现 (3)训练过程 第二部分:Pytorch部分 一.学习资源 二.自动求导机制 2.1 to ...

  5. PyTorch 笔记Ⅱ——PyTorch 自动求导机制

    文章目录 Autograd: 自动求导机制 张量(Tensor) 梯度 使用PyTorch计算梯度数值 Autograd 简单的自动求导 复杂的自动求导 Autograd 过程解析 扩展Autogra ...

  6. 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...

  7. pytorch如何计算导数_Pytorch的自动求导机制与使用方法(一)

    本文以线性模型为例,讲解线性模型的求解的pytorch梯度实现方法. 要注意几个问题:在PyTorch 0.4.0版本之后,Variable类已经被禁用了,所有的torch.Tensor与torch. ...

  8. Pytorch教程入门系列4----Autograd自动求导机制

    系列文章目录 文章目录 系列文章目录 前言 一.Autograd是什么? 二.Autograd的使用方法 1.在tensor中指定 2.重要属性 三.Autograd的进阶知识 1.动态计算图 2.梯 ...

  9. Pytorch Autograd (自动求导机制)

    Introduce Pytorch Autograd库 (自动求导机制) 是训练神经网络时,反向误差传播(BP)算法的核心. 本文通过logistic回归模型来介绍Pytorch的自动求导机制.首先, ...

最新文章

  1. seq2seq与Attention机制
  2. Android内存泄漏简介
  3. 如何分析案件的性质_律师如何综合分析一个案件
  4. sessionId与cookie 的关系(百度文库)
  5. APICloud App定制平台的操作指南
  6. 设计自己的基于Selenium 的自动化测试框架-Java版(1) - 为什么selenium还需要测试框架?...
  7. weak和assign的区别
  8. 一场农业“人机”对战,能否凿开农村致富新门路呢?
  9. 微软产品界面配色方案分析
  10. 小案例:王者荣耀战力查询系统(免费调用外部接口
  11. RTSP协议视频安防综合管理平台EasyNVR与海康萤石云平台运行机制差异对比说明
  12. day04 java学习
  13. java字符串去重复_java去除重复的字符串和移除不想要的字符串
  14. EasyRecovery最新版本Photo16电脑数据恢复软件下载
  15. Linux之文件/目录搜索
  16. int, long, long long类型的范围
  17. 15秒视频播放量超5500万,如何抢占涨粉又爆赞的流量密码?
  18. wps文字下载 wps2019怎么关掉内置浏览器?关闭内置浏览器步骤一览
  19. PCIe传输速率、吞吐量、PCLK计算方式
  20. Electron开发环境的搭建

热门文章

  1. 进化算法-人工蜂群(ABC)
  2. 下一代数据存储OneStorage闪亮登场,华为打造全场景智能的基石
  3. 注册表禁用和启用USB端口
  4. mapbox symbols 层级设置_Mapbox 地图样式规范
  5. 局域网实现PC、Pad、Android互联
  6. C#“调用的目标发生了异常”之终极解决办法
  7. vs2017 配置WTL10 出现“调用目标发生异常”
  8. 【电子取证篇】文件签名
  9. android 开机画面定制
  10. 对未来几年嵌入式行业发展的预测