最新的
原文:https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

JAX快速入门


首先解答一个问题:JAX是什么?

简单的说就是GPU加速、支持自动微分(autodiff)的numpy。众所周知,numpy是Python下的基础数值运算库,得到广泛应用。用Python搞科学计算或机器学习,没人离得开它。但是numpy不支持GPU或其他硬件加速器,也没有对backpropagation的内置支持,再加上Python本身的速度限制,所以很少有人会在生产环境下直接用numpy训练或部署深度学习模型。这也是为什么会出现Theano, TensorFlow, Caffe等深度学习框架的原因。但是numpy有其独特的优势:底层、灵活、调试方便、API稳定且为大家所熟悉(与MATLAB一脉相承),深受研究者的青睐。JAX的主要出发点就是将numpy的以上优势与硬件加速结合。现在已经开源的JAX ( https://github.com/google/jax) 就是通过GPU (CUDA)来实现硬件加速。出自:https://www.zhihu.com/question/306496943/answer/557876584

小宋说:JAX 其实就是一个支持加速器(GPU 和 TPU)的科学计算库(numpy, scipy)和神经网络库(提供relu,sigmoid, conv 等),相较于PyTorch与TensorFlow更加灵活,通用性更佳。这也是笔者推荐学习和做这个翻译工作的原因,带着大家一起去学习掌握这个框架。

由于笔者非英语专业,有些内荣难免翻译有误,欢迎大家批评指正。对于有些笔者不确定的翻译,采用下划线加括号引用原词的方式来补充,例如:自动微分differentiation

官方定义:JAX是CPU,GPU和TPU上的NumPy,具有出色的自动差分differentiation),可用于高性能机器学习研究。

作为更新版本的Autograd,JAX可以自动微分本机Python和NumPy代码。它可以通过Python的大部分功能(包括循环,if,递归和闭包)进行微分,甚至可以采用派生类的派生类。它支持反向模式和正向模式微分,并且两者可以任意顺序组成。

新功能是JAX使用 XLA 在诸如GPU和TPU的加速器上编译和运行您的NumPy代码。默认情况下,编译是在后台进行的,而库调用将得到及时的编译和执行。但是,JAX甚至允许您使用单功能API即时将自己的Python函数编译为XLA优化的内核。编译和自动微分可以任意组合,因此您无需离开Python即可表达复杂的算法并获得最佳性能。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

乘法矩阵

在以下示例中,我们将生成随机数据。NumPy和JAX之间的一大区别是生成随机数的方式。有关更多详细信息,请参见JAX中的Common Gotchas。

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427-0.6713536  -0.59086424  0.73168874  0.56730247]

乘以两个大矩阵。

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

489 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

我们补充说,block_until_ready因为默认情况下JAX使用异步执行(请参见异步调度)。

JAX NumPy函数可在常规NumPy数组上使用。

import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

488 ms ± 942 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

这样比较慢,因为它每次都必须将数据传输到GPU。您可以使用来确保NDArray由设备内存支持device_put()

from jax import device_putx = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

487 ms ± 9.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

的输出device_put()仍然像NDArray一样,但是它仅在需要打印,绘图,保存(printing, plotting, saving)到磁盘,分支等需要它们的值时才将值复制回CPU。的行为device_put()等效于函数,但是速度更快。jit(lambda x: x)

如果您有GPU(或TPU!),这些调用将在加速器上运行,并且可能比在CPU上快得多。

x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)

235 ms ± 546 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX不仅仅是一个由GPU支持的NumPy。它还带有一些程序转换,这些转换在编写数字代码时很有用。目前,主要有三个:

  • jit(),以加快您的代码

  • grad(),用于求梯度(derivatives)

  • vmap(),用于自动矢量化或批处理。

让我们一一介绍。我们还将最终以有趣的方式编写这些内容。

利用jit()加快功能

JAX在GPU上透明运行(如果没有,则在CPU上运行,而TPU即将推出!)。但是,在上面的示例中,JAX一次将内核分配给GPU一次操作。如果我们有一系列操作,则可以使用@jit装饰器使用XLA一起编译多个操作。让我们尝试一下。

def selu(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

4.4 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我们可以使用加快速度@jit,它将在第一次selu调用jit-compile并将其之后缓存。

selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

860 µs ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

通过 grad()计算梯度

除了评估数值函数外,我们还希望对其进行转换。一种转变是自动微分。在JAX中,就像在Autograd中一样,您可以使用grad()函数来计算梯度。

def sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]

让我们以极限微分(finite differences)验证我们的结果是正确的。

def first_finite_differences(f, x):eps = 1e-3return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)for v in jnp.eye(len(x))])print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1964569  0.10502338]

求解梯度可以通过简单调用grad()grad()jit()可以任意混合。在上面的示例中,我们先抖动sum_logistic然后取其派生词。我们继续深入学习实验:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.035325594

对于更高级的autodiff,可以将其jax.vjp()用于反向模式矢量雅各比积和jax.jvp()正向模式雅可比矢量积。两者可以彼此任意组合,也可以与其他JAX转换任意组合。这是组合它们以构成有效计算完整的Hessian矩阵的函数的一种方法:

from jax import jacfwd, jacrev
def hessian(fun):return jit(jacfwd(jacrev(fun)))

自动向量化 vmap()

JAX在其API中还有另一种转换,您可能会发现它有用:vmap()向量化映射。它具有沿数组轴映射函数的熟悉语义( familiar semantics),但不是将循环保留在外部,而是将循环推入函数的原始操作中以提高性能。当与组合时jit(),它的速度可以与手动添加批处理尺寸一样快。

我们将使用一个简单的示例,并使用将矩阵向量乘积提升为矩阵矩阵乘积vmap()。尽管在这种特定情况下很容易手动完成此操作,但是相同的技术可以应用于更复杂的功能。

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))def apply_matrix(v):return jnp.dot(mat, v)

给定诸如之类的功能apply_matrix,我们可以在Python中循环执行批处理维度,但是这样做的性能通常很差。

def naively_batched_apply_matrix(v_batched):return jnp.stack([apply_matrix(v) for v in v_batched])print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched

4.43 ms ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我们知道如何手动批处理此操作。在这种情况下,jnp.dot透明地处理额外的批次尺寸。

@jit
def batched_apply_matrix(v_batched):return jnp.dot(v_batched, mat.T)print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched

51.9 µs ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

但是,假设没有批处理支持,我们的功能更加复杂。我们可以用来vmap()自动添加批处理支持。

@jit
def vmap_batched_apply_matrix(v_batched):return vmap(apply_matrix)(v_batched)print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap

79.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

当然,vmap()可以与任意组成jit()grad()和任何其它JAX变换。

这只是JAX可以做的事情。我们很高兴看到您的操作!

『JAX中文文档』JAX快速入门相关推荐

  1. 以太坊智能合约开发,Web3.js API 中文文档 ethereum web3.js入门说明

    以太坊智能合约开发,Web3.js API 中文文档 ethereum web3.js入门说明 为了让你的Ðapp运行上以太坊,一种选择是使用web3.js library提供的web3.对象.底层实 ...

  2. keras中文文档学习笔记—快速上手keras

    keras的核心数据结构是"model",其中最主要的是Sequential模型: Sequential模型调用 from keras.model import Sequentia ...

  3. php sequelize,Sequelize 中文文档 v4 - Getting started - 入门

    Getting started - 入门 此系列文章的应用示例已发布于 GitHub: sequelize-docs-Zh-CN. 可以 Fork 帮助改进或 Star 关注更新. 欢迎 Star. ...

  4. 最新 | Python 官方中文文档正式发布!

    点击上方"AI有道",选择"置顶"公众号 重磅干货,第一时间送达 千呼万唤始出来!Python 官方文档终于发布中文版了!受英语困扰的小伙伴终于可以更轻松地阅读 ...

  5. web3.js 中文文档 入门

    web3.js 中文文档 v1.3.4 入门(Getting Started) web3.js是包含以太坊生态系统功能的模块集合. web3-eth用于以太坊区块链和智能合约. web3-shh是针对 ...

  6. Hugo中文文档 快速开始

    Hugo中文文档 快速开始 安装Hugo 1. 二进制安装(推荐:简单.快速) 到 Hugo Releases 下载对应的操作系统版本的Hugo二进制文件(hugo或者hugo.exe) Mac下直接 ...

  7. Bootstrap 一篇就够 快速入门使用(中文文档)

    目录 一.Bootstrap 简介 什么是 Bootstrap? 历史 为什么使用 Bootstrap? Bootstrap 包的内容 在线实例 Bootstrap 实例 更多实例 Bootstrap ...

  8. Babel 是什么?· Babel 中文文档

    Babel 是一个 JavaScript 编译器 Babel 是一个工具链,主要用于将 ECMAScript 2015+ 版本的代码转换为向后兼容的 JavaScript 语法,以便能够运行在当前和旧 ...

  9. springboot中文文档_登顶 Github 的 Spring Boot 仓库!艿艿写的最肝系列

    源码精品专栏 中文详细注释的开源项目 RPC 框架 Dubbo 源码解析 网络应用框架 Netty 源码解析 消息中间件 RocketMQ 源码解析 数据库中间件 Sharding-JDBC 和 My ...

最新文章

  1. from torchvision import _C解决办法
  2. 推荐系统笔记:无任何限制的矩阵分解
  3. 【OpenCV3】棋盘格角点检测与绘制——cv::findChessboardCorners()与cv::drawChessboardCorners()详解
  4. Windows Print Spooler服务最新漏洞CVE-2021-34527详细分析
  5. [转载] 武汉天河机场大巴时刻及路线
  6. React中的CSS——styled-components
  7. hdu2609 How many
  8. Asp.net开发环境的设置所遇到的问题
  9. 创建组件“ovalshape”失败_Django的forms组件检验字段\渲染模板
  10. 怎么解log方程_微观动力学解合成氨催化反应TOF
  11. SQLi LABS Less-8
  12. 【牛客2021暑假多校10】Train Wreck(出栈顺序,建树,优先队列维护)
  13. 5.2 各种类型的Attention: 原理、计算流程
  14. Unity资源热更-Addressables总结(一)
  15. excel两列数据对比找不同_数据相差太大在Excel图表对比柱形图,那是你不会次坐标设置!...
  16. 唯样商城:常见电阻种类
  17. w7系统您的计算机无法启动,Windows7旗舰版启动不了怎么办?电脑无法正常启动Windows7解决方法...
  18. 根据广播星历计算GNSS卫星在瞬时地球坐标系中的坐标
  19. docker安装wechat微信、wxwork企业微信脚本整理
  20. 重新开始噼里啪啦写小文字啦~

热门文章

  1. 乐动体育里斯本网路峰会!「吹哨者」史诺登警告:科技巨擘权力太大
  2. 算法题(模板)——N个球放入M个盒子中
  3. linux分区不格式化能挂栽吗,linux硬盘分区、格式化与挂载
  4. adb怎么连接Genymotion虚拟机
  5. Mac隐藏终端的用户名和主机名
  6. mysql insert into on_MySQL之INSERT INTO ON DUPLICATE KEY UPDATE用法详解 | 夕辞
  7. 继推出科创板,证监会将统筹推进新三板创业板改革...
  8. wamp启动后显示黄色的图标解决方法
  9. 深度学习 - 26.TF TF2.x tf.feature_column 详解
  10. 四年巨亏49亿,第四范式四闯IPO