目录

# 使用介绍 #

# 自动微分

# vmap 和 pmap

# JIT 编译

# 内部实现 #

# Trace 变换

# Jaxpr:JAX 中间表达式

# 总结 #

参考


JAX [1] 是 Google 推出的可以对 NumPy 和 Python 代码进行自动微分并跑到 GPU/TPU(Google 自研张量加速器)加速的机器学习库。Numpy [2] 是 Python 著名的数组运算库,官方版本只支持 CPU 运行(后面 Nvidia 推出的 CuPy 支持 GPU 加速,这里按住不表)。JAX 前身是 AutoGrad [3],2015 年哈佛大学来自物理系和 SEAS(工程与应用科学学院)的师生发表论文推出的支持 NumPy 程序自动求导的机器学习库。AutoGrad 提供和 NumPy 库一致的编程接口,用户导入 AutoGrad 就可以让原来写的 NumPy 程序拿来求导。JAX 在 2018 年将 XLA [4](Tensorflow 线性代数领域编译器)引入进来,使得 Python 程序可以通过 XLA 编译跑到 GPU/TPU 加速器上。简单地理解 JAX = NumPy + AutoGrad + XLA。可以说,XLA 加持下的 JAX,才真正具备了实施深度学习训练的基础和能力。

JAX Github: https://github.com/google/jax

JAX API Docs: https://jax.readthedocs.io/en/latest/


# 使用介绍 #

# 自动微分

JAX 提供兼容 NumPy 风格的接口,照顾用户原 NumPy 编程习惯。JAX 面向 Python 用户提供自动微分接口,包括生成梯度函数、求导等。

例 1:使用 jax.grad() 求导

from jax import graddef f(x):return x * x * xD_f = grad(f) # 3x^2
D2_f = grad(D_f) # 6x
D3_f = grad(D2_f) # 6f(1.0) # 1.0
D_f(1.0) # 3.0
D2_f(1.0) # 6.0
D3_f(0.0) # 6.0 (always)

jax.grad:只接受输出标量的原始函数 f,生成对应的梯度函数 ▼f▼f 接受和原始函数一样的入参 x,输出为参数梯度 dx▼f 亦可被 grad(),相当于对原始函数计算高阶梯度,但需满足一样的要求:输出为标量。如果被求导的函数计算结果不止一个数值,不能直接传给 grad()。需要先 reduce 成一个标量。

例 2:对数组函数求导

from jax import numpy as np
from jax import grad
import matplotlib.pyplot as pltdef f(x):return x * x * xD_f = grad(lambda x: np.sum(f(x)))
D2_f = grad(lambda x: np.sum(D_f(x)))
D3_f = grad(lambda x: np.sum(D2_f(x)))x = np.linspace(-1, 1, 200)
plt.plot(x, f(x), x, D_f(x), x, D2_f(x), x, D3_f(x))
plt.show()

和例 1 相比主要区别在于例 2 分别对函数(fD_fD2_f)结果进行求和(sum)再求导。函数 f 和它的一阶、二阶、三阶导函数曲线如下图所示。

JAX 支持不同模式自动微分。grad() 默认采取反向模式自动微分。另外显式指定模式的微分接口有 jax.vjp 和 jax.jvp

  • jax.vjp:反向模式自动微分。根据原始函数 f、输入 x 计算函数结果  y 并生成梯度函数 ▼f▼f 输入是 dy,输出是 dxgrad() 实现上底层调用 vjp(),可看做 vjp() 的一种特例。

  • jax.jvp:前向模式自动微分,根据原始函数 f、输入 x 和 dx 计算结果 y 和 dy。在函数输入参数数量少于或持平输出参数数量的情况下前向模式自动微分比反向模式更省内存,内存利用效率上更具优势 [5]。

jvp() 中微分计算和原始函数计算是同时完成的。多次调用可能存在对原始函数重复计算。JAX 提供前向模式自动微分的缓存优化接口 jax.linearize。该接口根据原始函数 f 和输入 x,计算函数结果 y 并生成导函数 f'。导函数 f' 输入是 dx,输出是 dylinearize() 实现上为前向模式 jvp() 加上 partial evaluation(缓存了原始函数计算过程数据),在内存占用方面更接近于反向模式自动微分,相对于前向模式来说还是比较耗内存的。

为方便对照列出这几个微分接口的形式化信息:

微分接口 类型签名
grad() (a -> b) -> a -> T a
value_and_grad() (a -> b) -> a -> (b, T a)
jvp() ((a -> b), a, T a) -> (b, T b)
vjp() ((a -> b), a) -> (b, (T b -> T a))
linearize() ((a -> b), a) -> (b, (T a -> T b))

vmap 和 pmap

jax.vmap 负责对函数进行向量化,可指定向量化维度。假设要对原始函数 f 进行自动微分,输入输出参数数量持平,并且要多次执行导函数 f'

方式 1:使用 jvp()。前向模式自动微分省内存。但,多次执行 jvp() 意味着多次计算原始函数。慢!

for in_tangent in in_tangents:y, out_tangent = jax.jvp(f, (x,), (in_tangent,))

方式 2:使用 linearize(),只计算一次原始函数,比方式 1 的计算效率高。但是,如前面所说,linearize() 内存占用接近反向模式自动微分。耗内存!

y, f_jvp = jax.linearize(f, x)
for in_tangent in in_tangents:out_tangent = f_jvp(in_tangent)

方式 3:使用 jvp() 加 vmap()。利用前向模式自动微分省内存的特点优化方式 2 的内存占用问题,同时通过向量化相比方式 1 提高了计算效率。前提是提前已知导函数执行需要的这些输入数据。

pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))

jax.pmap 帮助实现 SPMD(即 single program, multiple data)编程,比如在 GPU 多个卡上并行计算,对用户屏蔽底层通信操作。调用 pmap() 之后,经过 JAX 编译可将数据和计算任务分布到多个设备上执行。

JIT 编译

JAX 通过 XLA 后端对 Python 函数进行 JIT 编译得到优化后的函数。被 JIT 的函数必须是纯函数,其中副作用代码只会执行一次。输入输出参数类型必须满足:数组、标量、容器(tuple、list、dict)的一种。

def f(x):return x * x * x
print(jax.jit(f)(1.0)) # 1.0
print(jax.jit(grad(f))(1.0)) # 3.0

JAX 要求参与微分、JIT 的必须是 Python 纯函数。使用 JAX 表达神经网络的计算过程都是由 Python 函数组成。假设 N 个全连接层串行构成神经网络结构。JAX 代码示意如下。

def loss(params, batch):inputs, targets = batchpreds = predict(params, inputs)# 损失函数return -np.mean(np.sum(preds * targets, axis=1))def predict(params, inputs):activations = inputsfor w, b in params[:-1]:# 前 (N-1) 层分别由带 bias 的矩阵乘 + tanh 激活函数组成outputs = np.dot(activations, w) + bactivations = np.tanh(outputs)# 第 N 层由带 bias 的矩阵乘组成final_w, final_b = params[-1]logits = np.dot(activations, final_w) + final_b# 归一层return logits - logsumexp(logits, axis=1, keepdims=True)@jit
def update(params, batch):# 求 loss 函数的参数梯度grads = grad(loss)(params, batch)# 依次分别更新权重和 biasreturn [(w - 0.001 * dw, b - 0.001 * db)for (w, b), (dw, db) in zip(params, grads)]params = ... # 初始化 N 层参数(权重和 bias)
for epoch in range(num_epochs):for _ in range(num_batches):# 参数随着 epoch 和 batch 迭代变化params = update(params, next(batches))# 基于最新参数数据评估精度(需要跑一遍前向 predict)test_acc = accuracy(params, (test_images, test_labels))

# 内部实现 #

JAX 提供的 grad()jvp()vmap()pmap() 等接口指定原始函数用于变换。最后用于真正执行计算求值的是变换生成的新函数,比如各种微分相关的函数、映射优化的函数。JAX 实现里,这些参与变换的原始函数以动态方式被记录并生成中间表达式。这一过程叫做 trace。JAX 生成的中间表达式叫 Jaxpr。Jaxpr 经过内部解释器执行变换。过程如下图所示 [1]。

前向模式自动微分、vmap 等情况下,trace 只需要携带少部分上下文信息即可变换生成新函数。但反向模式自动微分等情况下,trace 需要生成 Jaxpr 以记录更多信息,再变换生成新函数。JIT 编译例外,JIT 会生成 Jaxpr,但底层通过 XLA 编译生成二进制代码并运行,不再回到 Python 代码。

# Trace 变换

JAX 提供多种 tracer,包括 jvp tracer、vjp tracer、vmap tracer、jaxpr builder tracer 等。这些 tracer 是在 JAX 代码运行过程中工作的。做 trace 的过程同时是 Python 代码特例化过程。能够被 trace 的都是 JAX 要求导的变量信息和操作信息,其他信息包括 Python 原生控制流、自定义类型变量 & 操作、打印语句等不会被 trace。

JAX 被 trace 的 Python 数组可看做抽象的符号表示,只有类型和 shape 信息,没有具体元素数值。比如调用 JAX 提供的 numpy.sum(),不会立刻触发 sum 计算。直到需要访问 Python 数组数值时才会真正求值,相当于惰性求值(Lazy Evaluation)。

JAX 变换负责对 trace 结果执行求值,求值后得到的 JAX 数组包含具体元素数值。JAX 数组通过 to_py() 可主动转成 NumPy 数组。

# Jaxpr:JAX 中间表达式

Jaxpr 全称 JAX Program Representation,用于待变换的函数的内部表示。Jaxpr 是强类型的,函数式的,定义形式符合 ANF form。引入 Jaxpr 主要有两方面考虑:

  1. JIT 需要对 Python 代码建立这样的中间表示来完成动态编译和计算;

  2. 反向模式自动微分对原始函数进行反向传播。Jaxpr 程序表示也可以帮助实现这一点。

Jaxpr 生成时机分为两种方式,对应的 trace 变换方法亦不同:

  • 求值前生成(偏静态)。通过 trace 生成 Jaxpr。而 trace 和 Jaxpr 执行分别位于不同阶段(类似多阶段计算)。不支持依赖数据的控制流。JIT 适用于这种。

  • 求值时生成(偏动态)。trace 不急于生成 Jaxpr,直到最后变换时刻生成 Jaxpr。变换过程就像正常调用 Python 函数一样。允许依赖数据的控制流。定制 jvp 适用于这种。

Jaxpr 函数是强类型的、纯函数的表示,输入、输出都带有类型信息。函数输出只依赖输入,不依赖全局变量。Jaxpr 变量类型只能是数组、标量、容器(tuple、list、dict)的一种。Jaxpr 定义形式比较简单。

jaxpr ::= { lambda Var* ; Var+.let Eqn*in  [Expr+] }Eqn ::= let Var+ = Primitive [Param*] Expr+

JAX 内置常规数学原语和微分规则。

Primitive := add | sub | sin | mul | ...

如想查看函数对应的 Jaxpr,比如自动微分生成的新函数内部形式,可以使用 jax.make_jaxpr

def f(x):return x * xjax.make_jaxpr(f)(1.0)
""" { lambda  ; a.let b = mul a ain (b,) }
"""
jax.make_jaxpr(jax.grad(f))(1.0)
""" { lambda  ; a.let _ = mul a ab = mul 1.0 ac = mul 1.0 ad = add_any b cin (d,) }
"""
jax.make_jaxpr(jax.linearize(f, x)[1])(x)
"""
{ lambda a ; b.let c = mul b ad = mul a be = add_any c din (e,) }
"""

注意 linearize() 生成的函数捕获原始函数 f 的输入 x(亦即 linearize() 函数第二个参数),反映在 Jaxpr 就是 lambda a; 的变量捕获。


# 总结 #

JAX 在机器学习开发上以 NumPy 库 API 作为切入点,提供 AI 需要的自动微分和 JIT 编译加速功能。JAX 编程上贴近 Python 原生语法,体验类似 PyTorch。但与 PyTorch 明显不同的是,JAX 把 Python 代码限制到纯函数。PyTorch 对 tensor 对象求导,而 JAX 选择对函数求导,包括 NumPy 函数和其他 Python 函数。JAX 自动微分除了反向模式,还提供前向模式以及高阶混合模式微分。为了平衡自动微分计算和内存墙的问题,支持按需 checkpoint。

JAX 内部表示是纯函数式的,但考虑到 Python 语言高度动态性特点,对用户使用上有一些编程限制。比如 JAX 自动微分的 Python 函数只支持纯函数,要求用户自行保证这一点。如用户代码写了副作用,可能经过 JAX 变换生成的函数执行结果不符合期望。因 JAX trace 函数为纯函数,当全局变量、配置信息发生变化,可能需要重新 trace。JAX 只能对固定类型进行微分求导,不支持自定义类型如 class 变量等分析和微分求导。JAX 最新版本是 v0.2。相对于其他 Python AI 框架,JAX 用户数和受到的关注度偏少。最近大热的 Alpha Fold2 开源项目里有 JAX 的身影,有希望给 JAX 带一波热度。

参考

[1] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake Vander-Plas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transfor-mations of Python+NumPy programs, 2018.

[2] NumPy - https://numpy.org/

[3] HIPS/Autograd - https://github.com/HIPS/autograd

[4] XLA: Optimizing Compiler for Machine Learning  |  TensorFlow - https://www.tensorflow.org/xla

[5] Atilim Gunes Baydin, Barak A. Pearlmutter, and Alexey AndreyevichRadul. Automatic differentiation in machine learning: a survey. CoRR,abs/1502.05767, 2015.


技术分享 | 能微分会加速的 NumPy —— JAX相关推荐

  1. 技术分享 | 从自动微分到可微编程语言设计(三)

    摘要 自动微分(Automatic Differentiation,AD)是一种对计算机程序进行高效准确求导的技术,一直被广泛应用于计算流体力学.大气科学.工业设计仿真优化等领域.而近年来,机器学习技 ...

  2. 【线上分享】探讨TensorRT加速AI模型的简易方案:以图像超分为例

    AI模型近年来广泛应用于图像.视频处理,在超分.降噪.插帧等应用中展现了良好效果.由于图像AI模型的计算量大,即便部署在GPU上,有时仍达不到理想的运行速度.为此,NVIDIA推出了TensorRT, ...

  3. 技术分享 | CodeReview主要Review什么?

    源宝导读:Code Review, 意即代码审查,是指一种有意识和系统的召集其他程序员来检查彼此的代码是否有错误的地方. 在敏捷团队中推行CodeReview, 可以帮助团队快速成长.本文将分享在&q ...

  4. 【华为云技术分享】“技术-经济范式”视角下的开源软件演进剖析-part 1

    前言 以互联网为代表的信息技术的迅猛发展对整个经济体系产生了巨大的影响.信息技术的发展一方面使知识的积累和传播更加迅速,知识爆炸性的增长:另一方面,使信息的获取变得越来越容易,信息交流的强度逐渐增加, ...

  5. 又拍云黄慧攀QCon 2016技术分享:直播平台架构与实施

    QCon 2016全球软件开发大会日前在北京落下帷幕,作为全球顶级技术盛会,自2007年首次举办以来,已经有超万名高级技术人员参加过QCon大会.本届大会主题为"升级你的软件思维" ...

  6. 【干货】2021年技术趋势:全球企业加速数字化转型-德勤.pdf(附下载链接)

    大家好,我是文文(微信:sscbg2020),今天给大家分享德勤发布的干货报告<2021年技术趋势:全球企业加速数字化转型.pdf>,对技术趋势感兴趣的伙伴别错过啦! 本年度报告的主题是韧 ...

  7. Unity游戏帧同步技术分享篇【01】帧同步解决方案概述

    前言: 1.0 帧同步原理与简介 A.什么是帧同步? 帧同步是一种前后端数据同步的方式,一般应用于对实时性要求很高的网络游戏. 其基本实现流程及思路可以概括为: 1.所有客户端每帧上传操作指令集到服务 ...

  8. 腾讯技术分享:微信小程序音视频技术背后的故事

    1.引言 微信小程序自2017年1月9日正式对外公布以来,越来越受到关注和重视,小程序上的各种技术体验也越来越丰富.而音视频作为高速移动网络时代下增长最快的应用形式之一,在微信小程序中也当然不能错过. ...

  9. 《从PPTV网络视频,到PPIO区块链分布式存储》 -- 同济创业谷PPIO CodeTalks区块链技术分享会

    摘要:2019年11月26日,同济创业谷与 PPIO CodeTalks 联合举办了<创新X - 区块链与创新创业>区块链技术分享会,本期我们为读者带来主题分享 -- <从PPTV网 ...

最新文章

  1. WCF服务重构实录(上)
  2. max(min)-device-width和max(min)-width的区别
  3. 中科院计算所关于“木兰”语言问题处理情况说明
  4. weblogic cluster error-----Could not= open connection with host: 127.0.0.1
  5. Kali-Linux虚拟机安装提示
  6. 深度学习资料汇总(满满的干货)
  7. 前端学习(3286):Aop
  8. 彻底解决zend studio 下 assignment in condition警告
  9. Facebook推出高速光网络技术将共享
  10. 锐起无盘服务器客户机不同步,使用批处理判断锐起无盘客户机是否为超级用户状态...
  11. php中hexdec,PHP hexdec()函数
  12. python中nx_python在nx在Python3中使用asyncio库进行快速数据抓取的教程
  13. HyperV使用主机摄像头
  14. Webmin未经身份验证的远程代码执行-墨者学院
  15. axios发送x-www-form-urlencoded格式数据
  16. 思维精进01:罗辑思维2019跨年演讲--小趋势
  17. 2020-SIGIR- Lightgcn: Simplifying and powering graph convolution network for recommendation
  18. linux硬盘盘符更改,linux更改emc磁盘盘符
  19. 【解决方案】电力巡检进入智能化时代,无人机+EasyDSS开启智能巡检新模式
  20. a1 抛光等级spi_SPI美国标准(抛光等级)

热门文章

  1. swoole等多进程下的 mysql has gone away 解决方案
  2. 我的复习计划(很有借鉴意义)
  3. springboot jpa汉服销售系统源码+论文+ppt+开题报告+任务书(论文的内容不一样,修改下即可
  4. LOL钓鱼网站实战渗透
  5. 项目管理系列---任务管理工具深度分析
  6. 细粒度图像分类论文研读-2019
  7. 富士康高管三年受贿逾千万 回扣牵出管理难题
  8. 我的助理辞职了—刘苏
  9. 怎么把文字生成图片?三款ai绘画生成器分享
  10. 一步一个脚印,其实真的不慌