PyTorch 自动微分示例
autograd 包是 PyTorch 中所有神经网络的核心。首先简要地介绍,然后训练第一个神经网络。autograd 软件包为 Tensors 上的所有算子提供自动微分。这是一个由运行定义的框架,以代码运行方式定义后向传播,并且每次迭代都可以不同。从 tensor 和 gradients 来举一些例子。
1、TENSOR
torch.Tensor 是包的核心类。如果将其属性 .requires_grad 设置为 True,则会开始跟踪针对 tensor 的所有操作。完成计算后,可以调用 .backward() 来自动计算所有梯度。该张量的梯度将累积到 .grad 属性中。
要停止 tensor 历史记录的跟踪,可以调用 .detach(),将其与计算历史记录分离,防止将来的计算被跟踪。
要停止跟踪历史记录(和使用内存),还可以将代码块使用 with torch.no_grad(): 包装起来。在评估模型时,这是特别有用,因为模型在训练阶段具有 requires_grad = True ,可训练参数有利于调参,但在评估阶段不需要梯度。
还有一个类对于 autograd 实现非常重要那就是 Function。Tensor 和 Function 互相连接并构建一个非循环图,保存整个完整的计算过程的历史信息。每个张量都有一个 .grad_fn 属性,保存着创建了张量的 Function 的引用,(如果用户自己创建张量,则g rad_fn 是 None )。
如果想计算导数,可以调用 Tensor.backward()。如果 Tensor 是标量(即包含一个元素数据),则不需要指定任何参数backward(),如果有更多元素,则需要指定一个gradient 参数来指定张量的形状。
import torch
创建一个张量,设置 requires_grad=True 来跟踪相关的计算
x = torch.ones(2, 2, requires_grad=True)
print(x)
输出:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
针对张量做一个操作
y = x + 2
print(y)
输出:
tensor([[3., 3.],
[3., 3.]], grad_fn=)
y 作为操作的结果被创建,所以它有 grad_fn
print(y.grad_fn)

输出:
<AddBackward0 object at 0x7fe1db427470>

针对 y 做更多的操作:

z = y * y * 3
out = z.mean()

print(z, out)
输出:
tensor([[27., 27.],
[27., 27.]], grad_fn=)
tensor(27., grad_fn=)
.requires_grad_( … ) 会改变张量的 requires_grad 标记。输入的标记默认为 False ,如果没有提供相应的参数。

a = torch.randn(2, 2)
a = ((a * 3) / (a - 1))
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b = (a * a).sum()
print(b.grad_fn)
输出:
False
True
<SumBackward0 object at 0x7fe1db427dd8>

梯度:
现在后向传播,因为输出包含了一个标量,out.backward() 等同于out.backward(torch.tensor(1.))。
out.backward()
打印梯度 d(out)/dx
print(x.grad)

输出:
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
原理解释:

看一个雅可比向量积的例子:
x = torch.randn(3, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
y = y * 2

print(y)
输出:
tensor([ -444.6791, 762.9810, -1690.0941], grad_fn=)
在这种情况下,y 不再是一个标量。torch.autograd 不能够直接计算整个雅可比,但是如果想要雅可比向量积,需要简单的传递向量给 backward 作为参数。
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)

print(x.grad)

输出:
tensor([1.0240e+02, 1.0240e+03, 1.0240e-01])

可以通过将代码包裹在 with torch.no_grad(),来停止对从跟踪历史中 的 .requires_grad=True 的张量自动求导。
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
print((x ** 2).requires_grad)
输出:

True
True
False

PyTorch 自动微分示例相关推荐

  1. PyTorch 自动微分

    PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...

  2. 《20天吃透Pytorch》Pytorch自动微分机制学习

    自动微分机制 Pytorch一般通过反向传播 backward 方法 实现这种求梯度计算.该方法求得的梯度将存在对应自变量张量的grad属性下. 除此之外,也能够调用torch.autograd.gr ...

  3. pytorch 自动微分基础原理

    PyTorch 的Autograd功能是 PyTorch 灵活快速地构建机器学习项目的一部分.它允许在复杂计算中快速轻松地计算多个偏导数(也称为 梯度) .该操作是基于反向传播的神经网络学习的核心. ...

  4. pytorch自动微分,反向传播(一)

    1.张量计算补充 2.计算图(Computational Graph) Pytorch中autograd的底层采用了计算图,计算图是一种有向无环图(DAG),用于记录算子与变量之间的关系. 下图为z= ...

  5. 一文详解pytorch的“动态图”与“自动微分”技术

    前言 众所周知,Pytorch是一个非常流行且深受好评的深度学习训练框架.这与它的两大特性"动态图"."自动微分"有非常大的关系."动态图" ...

  6. Pytorch自动求梯度

    求梯度 微分 Pytorch自动微分 微分 通常我们见到的微分方法有两种: ·符号微分法: ·数值微分法: Pytorch自动微分 对于一个Tensor,如果它的属性requires_grad 设置为 ...

  7. 深度学习利器之自动微分(2)

    深度学习利器之自动微分(2) 文章目录 深度学习利器之自动微分(2) 0x00 摘要 0x01 前情回顾 0x02 自动微分 2.1 分解计算 2.2 计算模式 2.3 样例 2.4 前向模式(For ...

  8. 深度学习利器之自动微分(1)

    深度学习利器之自动微分(1) 文章目录 深度学习利器之自动微分(1) 0x00 摘要 0.1 缘起 0.2 自动微分 0x01 基本概念 1.1 机器学习 1.2 深度学习 1.3 损失函数 1.4 ...

  9. 自动微分 ​​​​​​​​​​​​​​Automatic Differentiation

    目录 一.概述 二.原理 2.1 前向模式 2.2 后向模式 2.3 前向 VS 反向 三.Pytorch自动微分举例 四.Ref 记录自动微分的知识点. 一.概述 计算机实现微分功能, 有以下四种方 ...

最新文章

  1. 在ASP.NET中操作文件的例子
  2. 学习html5系列之比较典型的div滥用
  3. C++ QT中的QSound使用方法
  4. 分享一篇关于使用阿里云消息队列中遇到的坑
  5. Anytime项目开发记录0
  6. ElasticSearch常用命令记录
  7. 如何用面对对象来做一个躁动的小球?
  8. java登录界面命令_Java命令行界面(第18部分):JCLAP
  9. eclipse打开文件所在目录
  10. C#LeetCode刷题-并查集
  11. idea插件Lombok
  12. 父亲去年喂猪挣了21万
  13. RP2836 板卡信息标识
  14. js jquery 判断元素是否在数组内
  15. 我设计的目录结构如此清楚,你为什么也会错
  16. 舞台音效控制软件_舞台音乐控制软件下载
  17. 去掉whatsns问答系统页面底部隐藏的官网链接
  18. Python:正则表达式 re.sub()替换功能
  19. 数据中台你想知道的都在这里!
  20. 【网络互联技术】(三) 网络互联基础。

热门文章

  1. 在kotlin companion object中读取spring boot配置文件,静态类使用@Value注解配置
  2. Go 学习笔记(84)— Go 项目目录结构
  3. SpringBoot2.x 不反回空值属性
  4. 【Sql Server】DateBase-事务
  5. 【Docker】Ubuntu18.04国内源安装Docker-准备工作(一)
  6. virtualenv创建虚拟环境为主_多版本
  7. Connecting to (DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=localhost)(PORT=1521))) TNS-12541: TNS:no li
  8. LLVM与Clang局部架构与语法分析
  9. ViewGroup的Touch事件分发(源码分析)
  10. 2021年大数据ELK(二十五):添加Elasticsearch数据源