用于 Keras 用户使用的 TensorFlow.js layers API

TensorFlow.js 的Layers API以Keras为模型。考虑到 JavaScript 和 Python 之间的差异,我们努力使Layers API 与Keras 类似。这让具有使用Python开发Keras模型经验的用户可以更轻松地将项目迁移到 JavaScript中的TensorFlow.js Layers。例如,以下 Keras 代码转换为 JavaScript:

# Python:
import keras
import numpy as np# 建立并编译模型.
model = keras.Sequential()
model.add(keras.layers.Dense(units=1, input_shape=[1]))
model.compile(optimizer='sgd', loss='mean_squared_error')# 生成一些用于训练的数据.
xs = np.array([[1], [2], [3], [4]])
ys = np.array([[1], [3], [5], [7]])# 用 fit() 训练模型.
model.fit(xs, ys, epochs=1000)# 用 predict() 推理.
print(model.predict(np.array([[5]])))
// JavaScript:
import * as tf from '@tensorlowjs/tfjs';// 建立并编译模型.
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});// 生成一些用于训练的数据.
const xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);
const ys = tf.tensor2d([[1], [3], [5], [7]], [4, 1]);// 用 fit() 训练模型.
await model.fit(xs, ys, {epochs: 1000});// 用 predict() 推理.
model.predict(tf.tensor2d([[5]], [1, 1])).print();

但是,我们希望在本文档中说明并解释一些差异。一旦理解了这些差异及其背后的基本原理,将您的程序从Python 迁移到JavaScript(或反向迁移)应该会是一种相对平稳的体验。

构造函数将 JavaScript 对象作为配置

比较上面示例中的以下 Python 和 JavaScript 代码:它们都创建了一个全连接层。

# Python:
keras.layers.Dense(units=1, inputShape=[1])
// JavaScript:
tf.layers.dense({units: 1, inputShape: [1]});

JavaScript函数在Python 函数中没有等效的关键字参数。我们希望避免在 JavaScript 中实现构造函数选项作为位置参数,这对于记忆和使用具有大量关键字参数的构造函数(如LSTM尤其麻烦 。这就是我们使用JavaScript 配置对象的原因。这些对象提供与Python关键字参数相同的位置不变性和灵活性。

Model 类的一些方法(例如,Model.compile())也将 JavaScript 配置对象作为输入。但是,请记住 Model.fit()、Model.evaluate() 和 Model.predict() 略有不同。因为这些方法将强制 x(feature 特征)和 y(label 标签或 target 目标)数据作为输入;x 和 y 是与后续配置对象分开的位置参数,属于关键字参数。例如:

Model.fit()是异步的

Model.fit() 是用户在Tensorflow.js中执行模型训练的主要方法。这个方法往往是长时间运行的(持续数秒或数分钟)。因此,我们利用了JavaScript语言的“异步”特性。所以在浏览器中运行时,这样使用此函数就不会阻塞主UI线程。这和JavaScript中其他可能长期运行的函数类似,例如async获取。需要注意async是一个在python中不存在的构造。当fit()方法在keras中返回一个历史对象, 在JavaScript中fit()方法的对应项返回一个包含训练历史的Promise这个应答可以await(等待),也可以与then()方法一起使用。

TensorFlow.js 中没有 NumPy

Python Keras 用户经常使用NumPy来执行基本的数值和数组的操作,例如在上面的示例中生成 2D 张量。

# Python:
xs = np.array([[1], [2], [3], [4]])

在 TensorFlow.js 中,这种基本的数字的操作是使用包本身完成的。例如:

// JavaScript:
const xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);

该 tf.* 命名空间还提供数组和线性代数的operations(操作),如矩阵乘法。有关更多信息,请参阅 TensorFlow.js核心文档。

使用factory(工厂)方法,而不是构造函数

Python 中的这一行(来自上面的例子)是一个构造函数调用:

# Python:
model = keras.Sequential()

如果严格转换为 JavaScript,则等效构造函数调用将如下所示:

// JavaScript:
const model = new tf.Sequential();  // 不! 要! 这! 样! 做!

然而,我们决定不使用“new”构造函数,因为 1)“new”关键字会使代码更加膨胀;2)“new”构造函数被视为 JavaScript 的“bad part”:一个潜在的陷阱,如在JavaScript: the Good Parts.中的争论。要在 TensorFlow.js 中创建模型和 Layer ,可以调用被称为 lowerCamelCase(小驼峰命名)的工厂方法,例如:

// JavaScript:
const model = tf.sequential();const layer = tf.layers.batchNormalization({axis: 1});

选项字符串值为小驼峰命名,而不是 snake_case

在 JavaScript 中,与 Python 相比,更常见的是使用小驼峰作为符号名称(例如,Google JavaScript Style Guide),而 Python 中 snake_case 很常见(例如,在 Keras 中)。因此,我们决定使用小驼峰命名作为选项的字符串值,包括以下内容:

  • DataFormat,例如,channelsFirst 而不是 channels_first
  • Initializer,例如,glorotNormal 而不是 glorot_normal
  • Loss and metrics,例如,meanSquaredError 而不是 mean_squared_error,categoricalCrossentropy 而不是 categorical_crossentropy。

例如,如上例所示:

// JavaScript:
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});

对于模型序列化和反序列化,请放心。请放心。TensorFlow.js 的内部机制确保正确处理 JSON 对象中的 snake_case ,例如,从 Python Keras 加载预训练模型时。

使用 apply() 运行 Layer 对象,而不是将其作为函数调用

在 Keras 中,Layer 对象定义了__call__方法。因此,用户可以通过将对象作为函数调用来调用 Layer 的逻辑,例如:

# Python:
my_input = keras.Input(shape=[2, 4])
flatten = keras.layers.Flatten()print(flatten(my_input).shape)

这个 Python 语法糖在 TensorFlow.js 中以 apply() 方法实现:

// JavaScript:
const myInput = tf.input{shape: [2, 4]});
const flatten = tf.layers.flatten();console.log(flatten.apply(myInput).shape);

Layer.apply() 支持对具体 Tensor(张量)的命令式(Eager)执行

目前,在 Keras 中,__call__方法只能对(Python)TensorFlow 的 tf.Tensor 对象进行操作(假设 TensorFlow 是后端),这些对象是符号化的并且不包含实际的数值。这就是上一节中的示例中所显示的内容。但是,在 TensorFlow.js 中,Layer 的 apply() 方法可以在符号和命令模式下运行。如果用 SymbolicTensor 调用 apply()(类似于 tf.Tensor)调用,则返回值将为 SymbolicTensor。这通常发生在模型构建期间。但是如果用实际的具体 Tensor(张量)值调用 apply(),将返回一个具体的 Tensor(张量)。例如:

// JavaScript:
const flatten = tf.layers.flatten();flatten.apply(tf.ones([2, 3, 4])).print();

这个特性让人联想到(Python)TensorFlow 的Eager Execution。它在模型开发期间提供了更大的交互性和可调试性,并且为组成动态神经网络打开了大门。

Optimizers(优化器)在 train.* 下,而不是 optimizers.*

在 Keras 中,Optimizer(优化器)对象的构造函数位于 keras.optimizers.* 命名空间下。在 TensorFlow.js Layer 中,Optimizer(优化器)的工厂方法位于 tf.train.* 命名空间下。例如:

# Python:
my_sgd = keras.optimizers.sgd(lr=0.2)
// JavaScript:
const mySGD = tf.train.sgd({lr: 0.2});

loadLayersModel() 从 URL 加载,而不是 HDF5 文件

在 Keras 中,模型通常保存为 HDF5(.h5)文件,然后可以使用 keras.models.load_model()方法加载 。该方法采用 .h5 文件的路径。TensorFlow.js 中的 load_model() 对应的是tf.loadLayersModel()。由于 HDF5 文件格式对浏览器并不友好,因此 tf.loadLayersModel() 采用 TensorFlow.js 特定的格式。tf.lloadLayersModel() 将 model.json 文件作为其输入参数。可以使用 tensorflowjs 的 pip 包从 Keras HDF5 文件转换 model.json。

// JavaScript:
const model = await tf.loadLayersModel('https://foo.bar/model.json');

还要注意的是tf.loadLayersModel()返回的是tf.Model的应答`。

通常,tf.Model在 TensorFlow.js中保存和加载分别使用tf.Model.savetf.loadLayersModel方法。我们将这些 API 设计为类似于Kerasthe save and load_model API。但是浏览器环境与 Keras 等主要深度学习框架运行的后端环境完全不同,特别是用于持久化和传输数据的路由数组中。因此,TensorFlow.js 和 Keras 中的 save/load API 之间存在一些有趣的差异。有关更多详细信息,请参阅我们关于 保存和加载tf.Model的教程。

fitDataset()训练模型使用tf.data.Dataset对象

在python版本的tensorflow keras中, 一个模型可以使用Dataset对象进行训练。模型的fit()方法直接接受这样的对象。一个Tensorflow.js方法可以使用相当于Dataset对象的Javascript进行训练,详见TensorFlow.js的tf.data API文档。然而,与python不同, 基于Dataset的训练是通过一个专门的方法来完成的这个方法称之为fitDataset。fit() 只针对基于Tensor(张量)的模型训练。

Layer(层)对象和Model(模型)对象的内存管理

TensorFlow.js在浏览器中的WebGL上运行,其中层和模型对象的权重由WebGL纹理支持。然而WebGL并不支持内置的垃圾收集。在推理和训练的过程中,Layer(层)和Model(模型)对象为用户在内部管理Tensor(张量)内存。但是它们也允许用户清理它们以释放它们占用的WebGL内存。这对于在单页加载过程中创建和释放许多模型实例的情况很有用。想要清理一个Layer(层)和Model(模型)对象,使用dispose() 方法。

用于 Keras 用户使用的 TensorFlow.js layers API相关推荐

  1. 8 适用于 Keras 用户的 TensorFlow.js 层 API

    TensorFlow.js 的 Layers API 以 Keras 为模型.考虑到 JavaScript 与 Python 之间的差异,我们努力使 ​​Layers API​​ 与 Keras 类似 ...

  2. python 加载动图_在浏览器中使用TensorFlow.js和Python构建机器学习模型(附代码)...

    大数据文摘授权转载自数据派THU 作者:MOHD SANAD ZAKI RIZVI 本文主要介绍了: TensorFlow.js (deeplearn.js)使我们能够在浏览器中构建机器学习和深度学习 ...

  3. 独家 | 在浏览器中使用TensorFlow.js和Python构建机器学习模型(附代码)

    作者:MOHD SANAD ZAKI RIZVI 翻译:吴金笛 校对:丁楠雅 本文约5500字,建议阅读15分钟. 本文首先介绍了TensorFlow.js的重要性及其组件,并介绍使用其在浏览器中构建 ...

  4. linux tensorflow demo_独家 | 在浏览器中使用TensorFlow.js和Python构建机器学习模型(附代码)...

    作者:MOHD SANAD ZAKI RIZVI 翻译:吴金笛 校对:丁楠雅 本文约5500字,建议阅读15分钟. 本文首先介绍了TensorFlow.js的重要性及其组件,并介绍使用其在浏览器中构建 ...

  5. 来自前端开发者的灵魂发问:TensorFlow.js 好学吗?

    本文作者 蔡善清(Shanqing Cai),谷歌公司软件工程师,深度参与了 TensorFlow 和 TensorFlow.js 的开发工作.从清华大学毕业后,他前往约翰斯 · 霍普金斯大学和麻省理 ...

  6. 如何使用 TensorFlow.js 自动化 Chrome 恐龙游戏?

    本文为 AI 研习社编译的技术博客,原标题 : Using TensorFlow.js to Automate the Chrome Dinosaur Game (part 1) 作者 | Aayus ...

  7. tensorflow 迁移学习_基于 TensorFlow.js 1.5 的迁移学习图像分类器

    在黑胡桃社区的体验案例中,有一个"人工智能教练",它其实是一个自定义的图像分类器.使用 TensorFlow.js 这个强大而灵活的 Javascript 机器学习库可以很轻松地构 ...

  8. 使用 Colab 在 tf.keras 中训练模型,并使用 TensorFlow.js 在浏览器中运行

    文 / Zaid Alyafeai 我们将创建一个简单的工具来识别图纸并输出当前图纸的名称. 此应用程序将直接在浏览器上运行,无需任何安装.我们会使用 Google Colab 来训练模型,并使用 T ...

  9. 使用Keras,TensorFlow.js,Node.js和Firebase构建,训练和部署Book Recommender系统(第2部分)

    Welcome back to the second part of our recommender engine tutorial series. In the first part, you le ...

最新文章

  1. SAP SD 微观研究之如何得到Customer List?
  2. pip 安装依赖包 报错 No matching distribution found for pandas
  3. python怎么写文件-Python 读写文件
  4. Matlab中varargin函数
  5. 035 函数和代码复用小结
  6. python编程入门第九讲,第九讲作业---函数
  7. 把Scala代码当作脚本运行
  8. Spring框架 注解
  9. php写入word文档内容,如何在PHP中读取和写入WORD文档
  10. 一步一步写一个简单通用的makefile(一)
  11. 基于QT框架的离线词典应用程序
  12. CSS 网页定位与布局
  13. Facebook新模型SEER|图像预训练的内卷
  14. 近期Domino相关产品要闻速览
  15. 微信公众号笔记---本地调试微信接口
  16. 魔兽正式服5区服务器互通信息,魔兽世界怀旧服付费转服能跨区吗
  17. 在公共卫生领域GIS系统的应用范畴
  18. SpaceBuilder 1.0RC源代码提供下载
  19. 十、MYSQL数据库的条件查询
  20. C#找到最小的整数X,同时满足:X是2019的整倍数,X的每一位数字是奇数

热门文章

  1. host切换工具、修改HOST不用重启IE
  2. 用虚拟机VMware安装雪豹提示:当前主机无法支持64位操作系统
  3. 关于CSS浮动(float,clear)的通俗讲解(经验分享)
  4. Ubuntu终端(terminal)及Thunderbird邮件客户端常用的快捷键
  5. UIKeyboard键盘相关知识点
  6. C++实现MD5加密
  7. 一个账号,防止多设备登陆
  8. 【今日CV 视觉论文速览】28 Nov 2018
  9. Java—接口(工厂模式代理模式)
  10. 窗体跳转传值 1130