pytorch——计算图与动态图机制
1、计算图
计算图是用来描述运算的有向无环图;
计算图有两个主要元素:结点(Node)和边(Edge);
结点表示数据,如向量、矩阵、张量,边表示运算,如加减乘除卷积等;
用计算图表示:y=(x+w)∗(w+1)y = (x + w) * (w + 1)y=(x+w)∗(w+1)
令a=x+wa=x+wa=x+w,b=w+1b=w+1b=w+1,y=a∗by=a*by=a∗b,那么得到的计算图如下所示:
采用计算图来描述运算的好处不仅仅是让运算更加简洁,还有一个更加重要的作用是使梯度求导更加方便。举个例子,看一下y对w求导的一个过程。
计算图与梯度求导
y=(x+w)∗(w+1)y = (x + w) * (w + 1)y=(x+w)∗(w+1)
a=x+wa=x+wa=x+w b=w+1b=w+1b=w+1
y=a∗by=a*by=a∗b
∂y∂w=∂y∂a∂a∂w+∂y∂b∂b∂w\frac{\partial y}{\partial w}=\frac{\partial y}{\partial a}\frac{\partial a}{\partial w}+\frac{\partial y}{\partial b}\frac{\partial b}{\partial w}∂w∂y=∂a∂y∂w∂a+∂b∂y∂w∂b
=b∗1+a∗1=b*1+a*1=b∗1+a∗1
=b+a=b+a=b+a
=(w+1)+(x+w)=(w+1)+(x+w)=(w+1)+(x+w)
=2∗w+x+1=2*w+x+1=2∗w+x+1
=2∗1+2+1=2*1+2+1=2∗1+2+1
=5=5=5
通过链式求导可以知道,利用计算图推导得到的推导结果如下图所示:
通过分析可以知道,y对w求导就是在计算图中找到所有y到w的路径,把路径上的导数进行求和。
利用代码看一下y对w求导之后w的梯度是否是上面计算得到的。具体的代码如下所示:
import torchw = torch.tensor([1.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为True
x = torch.tensor([2.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为Truea = torch.add(w, x) # a = w + x
b = torch.add(w, 1) # b = w + 1
y = torch.mul(a, b) # y = a * by.backward() #对y进行反向传播
print(w.grad) #输出w的梯度
得到的结果为5,证明了上面的结论。
在第一篇博文中讲张量的属性的时候,讲到与梯度相关的四个属性的时候,有一个is_leaf,也就是叶子节点,叶子节点的功能是指示张量是否是叶子节点。
叶子节点:用户创建的结点称为叶子结点,如X与W;
is_leaf:指示张量是否为叶子节点;
叶子节点是整个计算图的根基,例如前面求导的计算图,在前向传导中的a、b和y都要依据创建的叶子节点x和w进行计算的。同样,在反向传播过程中,所有梯度的计算都要依赖叶子节点。
设置叶子节点主要是为了节省内存,在梯度反向传播结束之后,非叶子节点的梯度都会被释放掉。可以根据代码分析一下非叶子节点a、b和y的梯度情况。
import torchw = torch.tensor([1.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为True
x = torch.tensor([2.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为Truea = torch.add(w, x) # a = w + x
b = torch.add(w, 1) # b = w + 1
y = torch.mul(a, b) # y = a * by.backward() #对y进行反向传播
print(w.grad) #输出w的梯度#查看叶子结点
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf) #输出为True True False False False,只有前面两个是叶子节点#查看梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad) #输出为tensor([5.]) tensor([2.]) None None None,因为非叶子节点都被释放掉了
如果想使用非叶子结点梯度,可以使用pytorch中的retain_grad()。例如对上面代码中的a执行相关操作a.retain_grad(),则a的梯度会被保留下来,具体的代码如下所示:
import torchw = torch.tensor([1.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为True
x = torch.tensor([2.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为Truea = torch.add(w, x) # a = w + x
a.retain_grad() #保存非叶子结点a的梯度,输出为tensor([5.]) tensor([2.]) tensor([2.]) None None
b = torch.add(w, 1) # b = w + 1
y = torch.mul(a, b) # y = a * by.backward() #对y进行反向传播
print(w.grad) #输出w的梯度#查看叶子结点
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)#查看梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)
torch.Tensor中还有一个属性为grad_fn,grad_fn的作用是记录创建该张量时所用的方法(函数),该属性在梯度反向传播的时候用到。例如在上面提到的例子中,y.grad_fn = ,y在反向传播的时候会记录y是用乘法得到的,所用在求解a和b的梯度的时候就会用到乘法的求导法则去求解a和b的梯度。同样,对于a有a.grad_fn=,对于b有b.grad_fn=,由于a和b是通过加法得到的,所以grad_fn都是AddBackword0。可以通过代码观看各个变量的属性。
import torchw = torch.tensor([1.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为True
x = torch.tensor([2.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为Truea = torch.add(w, x) # a = w + x
a.retain_grad()
b = torch.add(w, 1) # b = w + 1
y = torch.mul(a, b) # y = a * by.backward() #对y进行反向传播
print(w.grad) #输出w的梯度# 查看 grad_fn
print("grad_fn:\n", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)#上面代码的输出结果为
grad_fn:None None <AddBackward0 object at 0x000001EEAA829308> <AddBackward0 object at 0x000001EE9C051548> <MulBackward0 object at 0x000001EE9C29F948>
可以看到w和x的grad_fn都是None,因为w和x都是用户创建的,没有通过任何方法任何函数去生成这两个张量,所以两个叶子节点的属性为None,这些属性都是在梯度求导中用到的。
2、pytorch的动态图机制
动态图 vs 静态图
动态图:pytorch使用的,运算与搭建同时进行;灵活,易调节。
静态图:tensorflow使用的,先搭建图,后运算;高效,不灵活。
根据计算图搭建方式,可将计算图分为动态图和静态图。
为了尽快理解静态图和动态图的区别,这里举一个例子。假如我们去新马泰旅游,如果我们是跟团的话就是静态图,如果是自驾游的话就是动态图。跟团的意思是路线都已经计划好了,也就是先建图后运算。自驾游的话可以根据实际情况实际调整。下面分别列举tensorflow的静态图例子和pytorch的动态图实例进行简单理解。
在上面这个图中,框框代表的就是节点,带箭头的线代表边。tensorflow使用的是静态图,是先将图搭建好之后,再input数据进去。
pytorch使用的是动态图,具体的操作如下代码:
W_h = torch.randn(20, 20, requires_grad=True) #先创建四个张量
W_x = torch.randn(20, 10, requires_grad=True)
x = torch.randn(1, 10)
prev_h = torch.randn(1, 20)h2h = torch.mm(W_h, prev_h.t()) #将W_h和prev_h进行相乘,得到一个新张量h2h
i2h = torch.mm(W_x, x.t()) #将W_x和x进行相乘,等到一个新张量i2h
next_h = h2h + i2h #创建加法操作
next_h = next_h.tanh() #使用激活函数loss = next_h.sum() #计算损失函数
loss.backward() #梯度反向传播
上面代码对应的动态图过程就是下面的图
动态图的搭建是根据每一步的计算搭建的,而tensorflow是先搭建所有的计算图之后,再把数据输入进去。这就是动态图和静态图的区别。
pytorch——计算图与动态图机制相关推荐
- PyTorch框架学习四——计算图与动态图机制
PyTorch框架学习四--计算图与动态图机制 一.计算图 二.动态图与静态图 三.torch.autograd 1.torch.autograd.backward() 2.torch.autogra ...
- 【PyTorch 】静态图与动态图机制
[PyTorch 学习笔记] 1.4 静态图与动态图机制 - 知乎 PyTorch 的动态图机制 PyTorch 采用的是动态图机制 (Dynamic Computational Graph),而 T ...
- 【Torch笔记】计算图与动态图
[Torch笔记]计算图与动态图 1 什么是计算图? 计算图(Computational Graph)是用来 描述运算 的有向无环图,主要由节点和边组成.节点表示数据,如向量.矩阵.张量,边表示运算, ...
- pytorch入门学习(四)-----计算图与动态图
计算图: 用来描述运算的有向无环图有两个主要元素,结点note 边edge结点表示数据,如向量,矩阵,张量边表示运算,如加减乘除使用计算图主要是为了求导方便, 只需要沿着计算图的方向找到需要求导对象的 ...
- PyTorch 的 Autograd、计算图、叶子张量、inplace 操作、动态图,静态图(来自知乎)
本博文来自:https://zhuanlan.zhihu.com/p/69294347 非常感谢此博主! PyTorch 作为一个深度学习平台,在深度学习任务中比 NumPy 这个科学计算库强在哪里呢 ...
- 一文详解pytorch的“动态图”与“自动微分”技术
前言 众所周知,Pytorch是一个非常流行且深受好评的深度学习训练框架.这与它的两大特性"动态图"."自动微分"有非常大的关系."动态图" ...
- 【深度学习】村通网之——谈谈Tensorflow Eager Execution机制之静态图和动态图的区别(一)
文章目录 前言 介绍 搭建静态图 搭建动态图 前言 随着TensorFlow 1.4 Eager Execution的出现,TensorFlow的使用出现了革命性的变化. 介绍 我很早就听说过这样一句 ...
- 动态图 vs 静态图
动态图 动态图意味着计算图的构建和计算同时发生(define by run).这种机制由于能够实时得到中间结果的值,使得调试更加容易,同时我们将大脑中的想法转化为代码方案也变得更加容易,对于编程实现来 ...
- 基于pytorch实现图像分类——理解自动求导、计算图、静态图、动态图、pytorch入门
1. pytorch入门 什么是PYTORCH? 这是一个基于Python的科学计算软件包,针对两组受众: 替代NumPy以使用GPU的功能 提供最大灵活性和速度的深度学习研究平台 1.1 开发环境 ...
最新文章
- OpenAI新发现:GPT-3做小学数学题能得55分,验证胜过微调!
- (Oracle学习笔记) Oracle概述
- window的onresize执行多次的解决方法
- 20200817-Mysql 底层数据结构及Explain详解
- 【Servlet】Listener监听器
- Happy Necklace
- 博文视点大讲堂第20期——Windows 7来了
- 外部库依赖以及 编译
- SAXReader解析xml文件
- 日记侠:要赚钱千万别多想立刻开干
- kubernetes架构及核心概念
- 9.2 多元微分学及应用——偏导数
- 大厂Java研发岗位要求你清楚吗?
- Jenkins容器由于虚拟内存不足导致的异常退出
- 创新认知 基于LPC1114单片机的传感器使用
- 英语心理测试脸型软件,心理测试:脸型分析自己
- 基于C语言的学生选课系统
- 个人网站引入B站视频播放,个人博客播放B站视频。【1080P】
- 自己做量化交易软件(44)小白量化实战17--利用小白量化金融模块在迅投QMT极速策略交易系统上仿大智慧指标回测及实战交易设计
- CF1139C Edgy TreesDFS求连通块大小、思维
热门文章
- [Java] Scanner(new File( )) 从文件输入内容
- 企业文件服务器(samba)配置案例一
- Linux/Windows/MacOS各个操作系统下推荐应用集合
- 你解决的问题比你编写的代码更重要! 1
- 让开发人员变平庸的八个习惯,看看你中了几条
- 国内最火5款Java微服务开源项目
- HomeBrew 更换为国内源--提高brew命令操作速度
- CCF 201503-2 数字排序
- 【Python】基本统计值计算
- C#LeetCode刷题之#557-反转字符串中的单词 III(Reverse Words in a String III)