『JAX中文文档』JAX快速入门
原文:https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
JAX快速入门
首先解答一个问题:JAX是什么?
简单的说就是GPU加速、支持自动微分(autodiff)的numpy。众所周知,numpy是Python下的基础数值运算库,得到广泛应用。用Python搞科学计算或机器学习,没人离得开它。但是numpy不支持GPU或其他硬件加速器,也没有对backpropagation的内置支持,再加上Python本身的速度限制,所以很少有人会在生产环境下直接用numpy训练或部署深度学习模型。这也是为什么会出现Theano, TensorFlow, Caffe等深度学习框架的原因。但是numpy有其独特的优势:底层、灵活、调试方便、API稳定且为大家所熟悉(与MATLAB一脉相承),深受研究者的青睐。JAX的主要出发点就是将numpy的以上优势与硬件加速结合。现在已经开源的JAX ( https://github.com/google/jax) 就是通过GPU (CUDA)来实现硬件加速。出自:https://www.zhihu.com/question/306496943/answer/557876584
小宋说:JAX 其实就是一个支持加速器(GPU 和 TPU)的科学计算库(numpy, scipy)和神经网络库(提供relu,sigmoid, conv 等),相较于PyTorch与TensorFlow更加灵活,通用性更佳。这也是笔者推荐学习和做这个翻译工作的原因,带着大家一起去学习掌握这个框架。
由于笔者非英语专业,有些内荣难免翻译有误,欢迎大家批评指正。对于有些笔者不确定的翻译,采用下划线加括号引用原词的方式来补充,例如:自动微分(differentiation)
官方定义:JAX是CPU,GPU和TPU上的NumPy,具有出色的自动差分(differentiation),可用于高性能机器学习研究。
作为更新版本的Autograd,JAX可以自动微分本机Python和NumPy代码。它可以通过Python的大部分功能(包括循环,if,递归和闭包)进行微分,甚至可以采用派生类的派生类。它支持反向模式和正向模式微分,并且两者可以任意顺序组成。
新功能是JAX使用 XLA 在诸如GPU和TPU的加速器上编译和运行您的NumPy代码。默认情况下,编译是在后台进行的,而库调用将得到及时的编译和执行。但是,JAX甚至允许您使用单功能API即时将自己的Python函数编译为XLA优化的内核。编译和自动微分可以任意组合,因此您无需离开Python即可表达复杂的算法并获得最佳性能。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
乘法矩阵
在以下示例中,我们将生成随机数据。NumPy和JAX之间的一大区别是生成随机数的方式。有关更多详细信息,请参见JAX中的Common Gotchas。
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
[-0.372111 0.2642311 -0.18252774 -0.7368198 -0.44030386 -0.15214427-0.6713536 -0.59086424 0.73168874 0.56730247]
乘以两个大矩阵。
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU
489 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
我们补充说,block_until_ready
因为默认情况下JAX使用异步执行(请参见异步调度)。
JAX NumPy函数可在常规NumPy数组上使用。
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
488 ms ± 942 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
这样比较慢,因为它每次都必须将数据传输到GPU。您可以使用来确保NDArray由设备内存支持device_put()
。
from jax import device_putx = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
487 ms ± 9.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
的输出device_put()
仍然像NDArray一样,但是它仅在需要打印,绘图,保存(printing, plotting, saving)到磁盘,分支等需要它们的值时才将值复制回CPU。的行为device_put()
等效于函数,但是速度更快。jit(lambda x: x)
如果您有GPU(或TPU!),这些调用将在加速器上运行,并且可能比在CPU上快得多。
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
235 ms ± 546 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX不仅仅是一个由GPU支持的NumPy。它还带有一些程序转换,这些转换在编写数字代码时很有用。目前,主要有三个:
jit()
,以加快您的代码grad()
,用于求梯度(derivatives)vmap()
,用于自动矢量化或批处理。
让我们一一介绍。我们还将最终以有趣的方式编写这些内容。
利用jit()
加快功能
JAX在GPU上透明运行(如果没有,则在CPU上运行,而TPU即将推出!)。但是,在上面的示例中,JAX一次将内核分配给GPU一次操作。如果我们有一系列操作,则可以使用@jit
装饰器使用XLA一起编译多个操作。让我们尝试一下。
def selu(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
4.4 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
我们可以使用加快速度@jit
,它将在第一次selu
调用jit-compile并将其之后缓存。
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
860 µs ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
通过 grad()计算梯度
除了评估数值函数外,我们还希望对其进行转换。一种转变是自动微分。在JAX中,就像在Autograd中一样,您可以使用grad()
函数来计算梯度。
def sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25 0.19661197 0.10499357]
让我们以极限微分(finite differences)验证我们的结果是正确的。
def first_finite_differences(f, x):eps = 1e-3return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)for v in jnp.eye(len(x))])print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1964569 0.10502338]
求解梯度可以通过简单调用grad()
。grad()
并jit()
可以任意混合。在上面的示例中,我们先抖动sum_logistic
然后取其派生词。我们继续深入学习实验:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.035325594
对于更高级的autodiff,可以将其jax.vjp()
用于反向模式矢量雅各比积和jax.jvp()
正向模式雅可比矢量积。两者可以彼此任意组合,也可以与其他JAX转换任意组合。这是组合它们以构成有效计算完整的Hessian矩阵的函数的一种方法:
from jax import jacfwd, jacrev
def hessian(fun):return jit(jacfwd(jacrev(fun)))
自动向量化 vmap()
JAX在其API中还有另一种转换,您可能会发现它有用:vmap()
向量化映射。它具有沿数组轴映射函数的熟悉语义( familiar semantics),但不是将循环保留在外部,而是将循环推入函数的原始操作中以提高性能。当与组合时jit()
,它的速度可以与手动添加批处理尺寸一样快。
我们将使用一个简单的示例,并使用将矩阵向量乘积提升为矩阵矩阵乘积vmap()
。尽管在这种特定情况下很容易手动完成此操作,但是相同的技术可以应用于更复杂的功能。
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))def apply_matrix(v):return jnp.dot(mat, v)
给定诸如之类的功能apply_matrix
,我们可以在Python中循环执行批处理维度,但是这样做的性能通常很差。
def naively_batched_apply_matrix(v_batched):return jnp.stack([apply_matrix(v) for v in v_batched])print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
4.43 ms ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
我们知道如何手动批处理此操作。在这种情况下,jnp.dot
透明地处理额外的批次尺寸。
@jit
def batched_apply_matrix(v_batched):return jnp.dot(v_batched, mat.T)print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
51.9 µs ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
但是,假设没有批处理支持,我们的功能更加复杂。我们可以用来vmap()
自动添加批处理支持。
@jit
def vmap_batched_apply_matrix(v_batched):return vmap(apply_matrix)(v_batched)print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
79.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
当然,vmap()
可以与任意组成jit()
,grad()
和任何其它JAX变换。
这只是JAX可以做的事情。我们很高兴看到您的操作!
『JAX中文文档』JAX快速入门相关推荐
- 以太坊智能合约开发,Web3.js API 中文文档 ethereum web3.js入门说明
以太坊智能合约开发,Web3.js API 中文文档 ethereum web3.js入门说明 为了让你的Ðapp运行上以太坊,一种选择是使用web3.js library提供的web3.对象.底层实 ...
- keras中文文档学习笔记—快速上手keras
keras的核心数据结构是"model",其中最主要的是Sequential模型: Sequential模型调用 from keras.model import Sequentia ...
- php sequelize,Sequelize 中文文档 v4 - Getting started - 入门
Getting started - 入门 此系列文章的应用示例已发布于 GitHub: sequelize-docs-Zh-CN. 可以 Fork 帮助改进或 Star 关注更新. 欢迎 Star. ...
- 最新 | Python 官方中文文档正式发布!
点击上方"AI有道",选择"置顶"公众号 重磅干货,第一时间送达 千呼万唤始出来!Python 官方文档终于发布中文版了!受英语困扰的小伙伴终于可以更轻松地阅读 ...
- web3.js 中文文档 入门
web3.js 中文文档 v1.3.4 入门(Getting Started) web3.js是包含以太坊生态系统功能的模块集合. web3-eth用于以太坊区块链和智能合约. web3-shh是针对 ...
- Hugo中文文档 快速开始
Hugo中文文档 快速开始 安装Hugo 1. 二进制安装(推荐:简单.快速) 到 Hugo Releases 下载对应的操作系统版本的Hugo二进制文件(hugo或者hugo.exe) Mac下直接 ...
- Bootstrap 一篇就够 快速入门使用(中文文档)
目录 一.Bootstrap 简介 什么是 Bootstrap? 历史 为什么使用 Bootstrap? Bootstrap 包的内容 在线实例 Bootstrap 实例 更多实例 Bootstrap ...
- Babel 是什么?· Babel 中文文档
Babel 是一个 JavaScript 编译器 Babel 是一个工具链,主要用于将 ECMAScript 2015+ 版本的代码转换为向后兼容的 JavaScript 语法,以便能够运行在当前和旧 ...
- springboot中文文档_登顶 Github 的 Spring Boot 仓库!艿艿写的最肝系列
源码精品专栏 中文详细注释的开源项目 RPC 框架 Dubbo 源码解析 网络应用框架 Netty 源码解析 消息中间件 RocketMQ 源码解析 数据库中间件 Sharding-JDBC 和 My ...
最新文章
- from torchvision import _C解决办法
- 推荐系统笔记:无任何限制的矩阵分解
- 【OpenCV3】棋盘格角点检测与绘制——cv::findChessboardCorners()与cv::drawChessboardCorners()详解
- Windows Print Spooler服务最新漏洞CVE-2021-34527详细分析
- [转载] 武汉天河机场大巴时刻及路线
- React中的CSS——styled-components
- hdu2609 How many
- Asp.net开发环境的设置所遇到的问题
- 创建组件“ovalshape”失败_Django的forms组件检验字段\渲染模板
- 怎么解log方程_微观动力学解合成氨催化反应TOF
- SQLi LABS Less-8
- 【牛客2021暑假多校10】Train Wreck(出栈顺序,建树,优先队列维护)
- 5.2 各种类型的Attention: 原理、计算流程
- Unity资源热更-Addressables总结(一)
- excel两列数据对比找不同_数据相差太大在Excel图表对比柱形图,那是你不会次坐标设置!...
- 唯样商城:常见电阻种类
- w7系统您的计算机无法启动,Windows7旗舰版启动不了怎么办?电脑无法正常启动Windows7解决方法...
- 根据广播星历计算GNSS卫星在瞬时地球坐标系中的坐标
- docker安装wechat微信、wxwork企业微信脚本整理
- 重新开始噼里啪啦写小文字啦~
热门文章
- 乐动体育里斯本网路峰会!「吹哨者」史诺登警告:科技巨擘权力太大
- 算法题(模板)——N个球放入M个盒子中
- linux分区不格式化能挂栽吗,linux硬盘分区、格式化与挂载
- adb怎么连接Genymotion虚拟机
- Mac隐藏终端的用户名和主机名
- mysql insert into on_MySQL之INSERT INTO ON DUPLICATE KEY UPDATE用法详解 | 夕辞
- 继推出科创板,证监会将统筹推进新三板创业板改革...
- wamp启动后显示黄色的图标解决方法
- 深度学习 - 26.TF TF2.x tf.feature_column 详解
- 四年巨亏49亿,第四范式四闯IPO