机器之心整理

参与:思、Jamin

一直以来,自动微分都在 DL 框架背后默默地运行着,本文希望探讨它到底是什么,通过 JAX,自动微分又能怎么用。

自动微分现在已经是深度学习框架的标配,我们写的任何模型都需要靠自动微分机制分配模型损失信息,从而更新模型。在广阔的科学世界中,自动微分也是必不可少的。说到底,大多数算法都是由基本数学运算与基本函数组建的。

在 ICLR 2020 的一篇 Oral 论文中(满分 8/8/8),图宾根大学的研究者表示,目前深度学习框架中的自动微分模块只会计算批量数据反传梯度,但批量梯度的方差、海塞矩阵等其它量也很重要,它们可以在计算梯度的过程中快速算出来。

目前自动微分框架只计算出梯度,因此就限定了研究方向只能放在梯度下降变体之上,而不能做更广的探讨。为此,研究者构建了 BACKPACK,它建立在 PyTorch 之上,还扩展了自动微分与反向传播能获得的信息。

选自论文 BACKPACK,arXiv:1912.10985。

除此之外,Julia Computing 团队去年 7 月份也发表了一份论文,提出了可微编程系统,它能将自动微分内嵌于 Julia 语言,从而将其作为第一级的语言特性。由于广泛的科学计算和机器学习领域都需要线性代数的支持,因此这种可微编程能成为更加通用的一种模式。

从这些前沿研究可以清晰地感受到,自动微分越来越重要。

自动微分是什么

在数学与计算代数学中,自动微分也被称为微分算法或数值微分。它是一种数值计算的方式,用来计算因变量对某个自变量的导数。此外,它还是一种计算机程序,与我们手动计算微分的「分析法」不太一样。

自动微分基于一个事实,即每一个计算机程序,不论它有多么复杂,都是在执行加减乘除这一系列基本算数运算,以及指数、对数、三角函数这类初等函数运算。通过将链式求导法则应用到这些运算上,我们能以任意精度自动地计算导数,而且最多只比原始程序多一个常数级的运算。

一般而言会存在两种不同的自动微分模式,即前向累积梯度(前向模式)和反向累计梯度(反向模式)。前向累积会指定从内到外的链式法则遍历路径,即先计算 d_w1/d_x,再计算 d_w2/d_w1,最后计算 dy/dw_2。

反向梯度累积正好相反,它会先计算 dy/dw_2,然后计算 d_w2/d_w1,最后计算 d_w1/d_x。这是我们最为熟悉的反向传播模式,它非常符合「沿模型误差反向传播」这一直观思路。

如图所示,两种自动微分模式都在求 dy/dx,只不过根据链式法则展开的形式不太一样。

来一个实例:误差传播

在统计学上,由于变量含有误差,使得函数也含有误差,我们将其称之为误差传播。阐述这种关系的定律叫做误差传播定律。

先定义一个函数 q(x,y) ,我们想通过 q 传递 x 与 y 的不确定性信息,即 ????_x 与 ????_y。最直接的方式是随机采样 x 与 y,并计算 q 的值,然后查看它的分布。这就是「传播不确定性」这个概念的意义。

误差传播的积分公式可以是一个近似值, q(x,y) 的一般表达式可以写为:

如果我们定义一个特殊案例,即 q(x,y)=x±y,那么总不确定性可以写为:

对于特例 q(x,y)=xy 与 q(x,y)=x/y ,不确定性分别为 (σ_q/q)^2 = (σ_x/x)^2+(σ_y/y)^2 与 σ_q=(x/y)* sqrt((σ_x/x)^2+(σ_y/y)^2)。

我们可以尝试这些方法,并对比根据这些近似公式算出来的反传误差,以及实际发生的反传误差。

实战 JAX 自动微分

Jax 是谷歌开源的一个科学计算库,能对 Python 程序与 NumPy 运算执行自动微分,而且能够在 GPU 和 TPU 上运行,具有很高的性能。

如下先导入 JAX,然后用三行代码就能定义之前给出的反传不确定性度量。

from jax *import* grad, jacfwd
import jax.numpy *as* npdef error_prop_jax_gen(q,x,dx):jac = jacfwd(q)return np.sqrt(np.sum(np.power(jac(x)*dx,2)))

这里计算的反传梯度是根据 jax 完成的,后面的反传误差会直接通过公式计算,并对比两者。

1. 配置两个具有不确定性的观察值

我们需要使用 x 与 y 作为符号推理,但可以把它们都储存在数组 x 中,x[0]=x、x[1]=y。

x_ = np.array([2.,3.])
dx_ = np.array([.1,.1])

2. 加减法

在 ????(????,????)=????±???? 这一特例情况下,误差传播公式可以简化为

上图所示,通过误差传播公式计算出来的值与 JAX 计算出来的是一致地。

3. 乘除法

在 ????(????,????)=???????? 与 ????(????,????)=????/???? 这两种特例中,误差传播公式可以写为:

4. 幂

对于特例 ????(????,????)=????^????*????^????,传播公式可以表示为:

我们可以写成

JAX 的使用非常多样,甚至能直接使用它搭建神经网络。例如 JAXnet 框架,它是一个基于 JAX 的深度学习库,它的 API 提供了便利的模型搭建体验。比如说,以下代码就能建个神经网络:

from jaxnet import *net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), logsoftmax)

此外,不久之前,DeepMind 也发布了两个新库:在 Jax 上进行面向对象开发 的 Haiku 和 Jax 上的强化学习库 RLax。JAX 这样的通用自动微分库也许能在更广泛的领域发挥作用。

自动微分到底是什么?这里有一份自我简述相关推荐

  1. 深度学习利器之自动微分(2)

    深度学习利器之自动微分(2) 文章目录 深度学习利器之自动微分(2) 0x00 摘要 0x01 前情回顾 0x02 自动微分 2.1 分解计算 2.2 计算模式 2.3 样例 2.4 前向模式(For ...

  2. 深度学习利器之自动微分(1)

    深度学习利器之自动微分(1) 文章目录 深度学习利器之自动微分(1) 0x00 摘要 0.1 缘起 0.2 自动微分 0x01 基本概念 1.1 机器学习 1.2 深度学习 1.3 损失函数 1.4 ...

  3. 自动微分(Automatic Differentiation)

    目录 什么是自动微分 手动求解法 数值微分法 符号微分法 自动微分法 自动微分Forward Mode 自动微分Reverse Mode 参考引用 现代深度学习系统中(比如MXNet, TensorF ...

  4. 自动微分(Automatic Differentiation)简介

    现代深度学习系统中(比如MXNet, TensorFlow等)都用到了一种技术--自动微分.在此之前,机器学习社区中很少发挥这个利器,一般都是用Backpropagation进行梯度求解,然后进行SG ...

  5. PyTorch 自动微分示例

    PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...

  6. PyTorch 自动微分

    PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...

  7. MindSpore:自动微分

    MindSpore:自动微分 作为一款「全场景 AI 框架」,MindSpore 是人工智能解决方案的重要组成部分,与 TensorFlow.PyTorch.PaddlePaddle 等流行深度学习框 ...

  8. 只知道TF和PyTorch还不够,快来看看怎么从PyTorch转向自动微分神器JAX

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转载自:机器之心 说到当前的深度学习框架,我们往往绕不开 Ten ...

  9. 反式自动微分autodiff是什么?反向传播(Back Propagation)是什么?它是如何工作的?反向传播与反式自动微分autodiff有什么区别?

    反式自动微分autodiff是什么?反向传播(Back Propagation)是什么?它是如何工作的?反向传播与反式自动微分autodiff有什么区别? 目录

最新文章

  1. 1. spring boot起步之Hello World【从零开始学Spring Boot】
  2. Linux 刻录光盘
  3. random()模块随机函数的用法总结
  4. 1106 Lowest Price in Supply Chain(甲级)
  5. 统计学习方法-李航(3)
  6. 11Linux_vmtools
  7. sql 自定义函数 示例_SQL Server SESSION_CONTEXT()函数与示例
  8. 苹果智能音箱HomePod跳票了,上市日期推迟到明年
  9. lumen mysql 事务_数据库事务不执行回滚?
  10. 下载卫星影像数据流程
  11. 信息学奥赛与大学计算机课程,为什么要学信息学奥赛(NOIP)
  12. echarts 词云图
  13. 怎么冻结表格前几行和前几列_如何冻结表格前几列
  14. 2019 年第 27 周 DApp 影响力排行榜 | TokenInsight
  15. Bootstrap-button btn样式
  16. 【框架思路】python如何读取excel文件内容?如何获取excel文件的路径及sheet名称?
  17. 青龙 金手指教程每天低保保姆安装教程
  18. 知道创宇区块链实验室受邀参加“2021 CCF中国区块链技术大会”
  19. python 计算图像结构张量(Structure_tensor)
  20. 【瑞模网】fbx模型压缩成gltf格式

热门文章

  1. 【复盘】小朋友的奇思妙想
  2. 技术图文:如何实现 DataTable 与模型类 List 的相互转换?
  3. Python 来分析,堪比“唐探系列”!B站9.5分好评如潮!
  4. Q 版老黄带着硬核技术再登场,有点可爱,很有东西
  5. 2021年浅谈多任务学习
  6. 多画面、实时投票,这场上了一晚热搜的超级晚,背后的技术出圈了
  7. CSDN湘苗培优|保持热情,告别平庸
  8. AI安全最全“排雷图”来了!腾讯发布业内首个AI安全攻击矩阵
  9. 无人机巡逻喊话、疫情排查、送药消毒,抗疫战中机器人化身钢铁战士!
  10. 把自己朝九晚五的工作自动化了,有错吗?