引言

本着“凡我不能创造的,我就不能理解”的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导。

要深入理解深度学习,从零开始创建的经验非常重要,从自己可以理解的角度出发,尽量不适用外部完备的框架前提下,实现我们想要的模型。本系列文章的宗旨就是通过这样的过程,让大家切实掌握深度学习底层实现,而不是仅做一个调包侠。
本系列文章首发于微信公众号:JavaNLP

本文介绍自动求导的基础知识——计算图。

计算图

我们知道,反向传播是模型训练的途径。而反向传播是基于求导的,有没有想过像Keras和PyTorch这种工具是如何做到自动求导的。答案就是计算图,只要掌握了计算图的知识,我们就能自己开发一个自动求导工具。

计算图是一种描述函数的工具,可以可视化为有向图结构。其中节点为Tensor(向量/张量),有向边为操作。

在深度学习中比较常见的例子是类似
y=f(g(h(x)))u=h(x)v=g(u)y=f(v)y = f (g(h(x))) \\ u = h(x) \quad v= g(u) \quad y=f(v) y=f(g(h(x)))u=h(x)v=g(u)y=f(v)

x可以看成是输入,y可以看成是输出,中间经过了3次变换。

有时我们的函数有多个参数(比如乘法就有两个参数),假设我们要计算e=(a+b)∗(b+1)e = (a+b) * (b+1)e=(a+b)∗(b+1)​​,它的计算图如下:

这里令c=a+b;d=b+1c = a+ b; \quad d = b + 1c=a+b;d=b+1,为了完整性,也画出了常量111。

有了这个计算图,我们就就可以很容易的计算出eee的值。比如令a=2,b=1a=2,b=1a=2,b=1

当然,我们这么辛苦的画出这个图,主要不是为了沿着箭头方向进行计算的。而是为了求导,也就是计算梯度。

计算图上的梯度

回顾一下链式法则

我们重点来看下多路径的链式法则,即上面说的全导数。

我们要计算e=(a+b)∗(b+1)e = (a+b) * (b+1)e=(a+b)∗(b+1)中∂e/∂b\partial e/ \partial b∂e/∂b。

c=a+b;d=b+1c = a+ b; \quad d = b + 1c=a+b;d=b+1。

类似上图中的ttt,bbb也影响了两个因子。因此有
∂e∂b=∂e∂c⋅∂c∂b+∂e∂d⋅∂d∂b\frac{\partial e}{\partial b} = \frac{\partial e}{\partial c} \cdot \frac{\partial c}{\partial b} + \frac{\partial e}{\partial d} \cdot \frac{\partial d}{\partial b} ∂b∂e​=∂c∂e​⋅∂b∂c​+∂d∂e​⋅∂b∂d​
要计算偏导数,我们先把每个箭头的偏导数计算出来。

我们先填入计算出来的式子:

根据e=c∗dc=a+bd=b+1e = c * d \quad c = a+ b \quad d = b+ 1e=c∗dc=a+bd=b+1以及求导公式不难得到上面的结果。

此时,要计算∂e/∂b\partial e/ \partial b∂e/∂b,只需要找出所有从bbb到eee​到路径,然后把相同路径上的值相乘,不同路径上的值相加(连线相乘,分线相加)。

就可以得到:∂e/∂b=1∗(b+1)+1∗(a+b)\partial e/ \partial b = 1*(b+1) + 1*(a+b)∂e/∂b=1∗(b+1)+1∗(a+b)

此时,代入a=2,b=1a=2,b=1a=2,b=1。

先计算出b+1=2b+1=2b+1=2,再计算a+b=3a+b=3a+b=3,所以∂e/∂b=1∗2+1∗3=5\partial e/ \partial b = 1*2 + 1 *3=5∂e/∂b=1∗2+1∗3=5

反向模式

如果要同时计算eee对aaa和bbb​的偏导数,我们需要反向模式(Reverse mode)来同时计算它们。

反向就是从顶点开始,这里从eee开始,也从梯度等于111开始。

从顶点到bbb,通过有两条路径,如山古同橙色箭头所示。达到ccc时的梯度为1∗2=21 * 2 =21∗2=2;到达ddd时的梯度为1∗3=31 * 3=31∗3=3。

ccc和ddd到bbb的梯度都是111。根据相同路径相乘,不同路径相加。到bbb到梯度为2+3=52+3=52+3=5。

此时,计算eee到aaa的就简单了,我们已经知道了eee到ccc到梯度为222,由于eee到aaa只有一条路径,因此直接相乘得∂e/∂a=2∗1=2\partial e / \partial a =2 * 1=2∂e/∂a=2∗1=2。

如果你的函数只有一个输出,由需要同时计算大量的不同值的偏导数时,用反向模式就比较快。

而这恰恰非常适合于我们计算损失函数的梯度,因为损失函数的输出就是一个标量。

总结

我们已经了解了计算图的基础知识,下篇文章就来看一下常见操作的计算图。

从零实现深度学习框架——自动求导神器计算图相关推荐

  1. python学习框架图-从零搭建深度学习框架(二)用Python实现计算图和自动微分

    我们在上一篇文章<从零搭建深度学习框架(一)用NumPy实现GAN>中用Python+NumPy实现了一个简单的GAN模型,并大致设想了一下深度学习框架需要实现的主要功能.其中,不确定性最 ...

  2. 从零实现深度学习框架——GloVe从理论到实战

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  3. 从零实现深度学习框架——Seq2Seq从理论到实战【实战】

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  4. 从零实现深度学习框架——RNN从理论到实战【理论】

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  5. 从零实现深度学习框架——深入浅出Word2vec(下)

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导. 要深入理解深度学 ...

  6. 从零实现深度学习框架——从共现矩阵到点互信息

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  7. 从零实现深度学习框架——LSTM从理论到实战【理论】

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  8. 基于pytorch实现图像分类——理解自动求导、计算图、静态图、动态图、pytorch入门

    1. pytorch入门 什么是PYTORCH? 这是一个基于Python的科学计算软件包,针对两组受众: 替代NumPy以使用GPU的功能 提供最大灵活性和速度的深度学习研究平台 1.1 开发环境 ...

  9. 饮水思源--浅析深度学习框架设计中的关键技术

    点击上方"深度学习大讲堂"可订阅哦! 编者按:如果把深度学习比作一座城,框架则是这座城中的水路系统,而基于拓扑图的计算恰似城中水的流动,这种流动赋予了这座城以生命.一个优雅的框架在 ...

  10. np实现sigmoid_使用numpy实现一个深度学习框架

    为了理解深度学习框架的大致机理,决定使用numpy实现一个简单的神经网络框架 深度学习框架我觉得最重要的是实现了链式求导法则,而计算图就是建立在链式求导法则之上的,目前大多数深度学习是基于反向传播思想 ...

最新文章

  1. Java项目:学生管理系统(无库版)(java+打印控制台)
  2. 当你用钥匙开不开门时
  3. 2016-2017 ACM-ICPC Pacific Northwest Regional Contest (Div. 2) 【部分题解】
  4. 服务器mysql数据库安装教程视频教程_MySQL数据库管理系统安装实际操作_MySQL教程视频 - 动力节点...
  5. java ceilingentry_java.util.TreeMap.ceilingKey()
  6. C++局部变量和全局变量的初始化
  7. NAND FLASH读写原理
  8. 谷歌浏览器整个网页截图方法
  9. 关于如何把用手机查看原型
  10. 计算机格式化命令符号,格式化c盘命令是什么 格式化c盘会怎么样【图文】
  11. 安装kubernetes k8s v1.16.0 国内环境
  12. MySQL数据库创建表一系列操作
  13. 7月18日自助装机配置专家点评
  14. 查看mysql是否区分大小写
  15. Android高级工程师进阶学习,分享PDF高清版
  16. 腾讯云发布php项目,利用腾讯云服务器进行微校开放平台开发
  17. 全新整理:微软、谷歌、百度等公司经典面试100题[第1-60题]
  18. 网校系统是怎样搭建的?
  19. IntelliJ IDEA 之 配置JDK 的 4种方式
  20. 输出100-200之间所有的素数(素数:只能被1和自己本身整除的数)

热门文章

  1. 高并发模拟( 测试 )
  2. 邀请您加入移动开发专家联盟
  3. 多核cpu的特殊中断
  4. 白鹭引擎 - 事件机制 ( Event, addEventListener, dispatchEvent )
  5. 基于Wi-Fi的HID注射器,利用WHID攻击实验
  6. 利用python获取nginx服务的ip以及流量统计信息
  7. vmware实现小型局域网实验环境
  8. js可以选择时间的日历控件
  9. Access导入SQL2005
  10. (Origin教程)在图片和表格中插入Latex公式