• 原文地址:How to Use TensorFlow Mobile in Android Apps
  • 原文作者:Ashraff Hathibelagal
  • 译文出自:掘金翻译计划
  • 本文永久链接:github.com/xitu/gold-m…
  • 译者:luochen
  • 校对者:ALVINYEH LeeSniper

TensorFlow 是当今最流行的机器学习框架之一,您利用它可以轻松创建和训练深度模型 —— 通常也称为深度前馈神经网络,这些模型可以解决各种复杂问题,如图像分类、目标检测和自然语言理解。TensorFlow Mobile 是一个旨在帮助您在移动应用中利用这些模型的库。

在本教程中,我将向您展示如何在 Android Studio 项目中使用 TensorFlow Mobile。

前期准备

为了能够跟上教程,您需要做的是:

  • Android Studio 3.0 或更高版本
  • TensorFlow 1.5.0 或更高版本
  • 一台能够运行 API level 21 或更高的安卓设备
  • 以及对 TensorFlow 框架的基本了解

1、创建模型

在我们开始使用 TensorFlow Mobile 之前,我们需要一个已经训练好的 TensorFlow 模型。我们现在创建一个。

我们的模型将非常基础,类似于异或门,接受两个输入,它们可以是零或一,然后有一个输出。如果两个输入相同,则输出为零。此外,因为它将是一个深度模型,它将有两个隐藏层,一个有四个神经元,另一个有三个神经元。您可以自由改变隐藏层的数量以及它们包含的神经元的数量。

为了保持本教程的简洁,我们将使用 TFLearn,这是一个很受欢迎的 TensorFlow 封装框架,它提供更加直接而简洁的 API,而不是直接使用低级别的 TensorFlow API。如果您还没安装它,请使用以下命令将其安装在 TensorFlow 虚拟环境中:

pip install tflearn
复制代码

要开始创建模型,最好在空目录中先新建一个名为 create_model.py 的 Python 脚本,然后使用您最喜欢的文本编辑器打开它。

在文件里,我们需要做的第一件事是导入 TFLearn API。

import tflearn
复制代码

接下来,我们必须创建训练数据。对于我们的简单模型,只有四种可能的输入和输出,类似于异或门真值表的内容。

X = [[0, 0],[0, 1],[1, 0],[1, 1]
]Y = [[0],  # Desired output for inputs 0, 0[1],  # Desired output for inputs 0, 1[1],  # Desired output for inputs 1, 0[0]   # Desired output for inputs 1, 1
]
复制代码

为隐藏层中的所有神经元分配初始权重时,最好的做法通常是使用从均匀分布中产生的随机数。可以使用 uniform() 方法生成这些值。

weights = tflearn.initializations.uniform(minval = -1, maxval = 1)
复制代码

此时,我们可以开始构建神经网络层。要创建输入层,我们必须使用 input_data() 方法,它允许我们指定网络可以接受的输入数量。一旦输入层准备就绪,我们可以多次调用 fully_connected() 方法来向网络添加更多层。

# 输入层
net = tflearn.input_data(shape = [None, 2],name = 'my_input'
)# 隐藏层
net = tflearn.fully_connected(net, 4,activation = 'sigmoid',weights_init = weights
)
net = tflearn.fully_connected(net, 3,activation = 'sigmoid',weights_init = weights
)# 输出层
net = tflearn.fully_connected(net, 1,activation = 'sigmoid', weights_init = weights,name = 'my_output'
)
复制代码

注意,在上面的代码中,我们赋予了输入层和输出层有意义的名称。这么做很重要,因为我们在使用安卓应用中的网络时需要它们。还要注意隐藏层和输出层使用了 sigmoid 激活函数。您可以试试其他激活函数,例如 softmaxtanhrelu

作为我们网络的最后一层,我们必须使用 regression() 函数创建一个回归层,该函数需要一些超参数作为其参数,例如网络的学习率以及它应该使用的优化器和损失函数。以下代码向您展示了如何使用随机梯度下降(简称 SGD)作为优化器函数,均方误差作为损失函数:

net = tflearn.regression(net,learning_rate = 2,optimizer = 'sgd',loss = 'mean_square'
)
复制代码

接下来,为了让 TFLearn 框架知道我们的网络模型实际上是一个深度神经网络模型,我们须要调用 DNN() 函数。

model = tflearn.DNN(net)
复制代码

模型现在已经准备好了。我们现在要做的就是使用我们之前创建的训练数据进行训练。因此,调用模型的 fit() 方法,并指定训练数据与训练周期。由于训练数据非常小,我们的模型将需要数千次迭代才能达到合理的精度。

model.fit(X, Y, 5000)
复制代码

一旦训练完成,我们可以调用模型的 predict() 方法来检查它是否生成期望的输出。以下代码展示了如何检查所有有效输入的输出:

print("1 XOR 0 = %f" % model.predict([[1,0]]).item(0))
print("1 XOR 1 = %f" % model.predict([[1,1]]).item(0))
print("0 XOR 1 = %f" % model.predict([[0,1]]).item(0))
print("0 XOR 0 = %f" % model.predict([[0,0]]).item(0))
复制代码

如果现在运行 Python 脚本,您应该看到如下所示的输出:

请注意,输出不会完全是 0 或 1。而是接近 0 或 1 的浮点数。因此,在使用输出时,可能需要使用 Python 的 round() 函数。

除非我们在训练后明确保存模型,否则只要程序结束,我们就会失去模型。幸运的是,对于 TFLearn,只需调用 save() 方法即可保存模型。但是,为了能够在 TensorFlow Mobile 中使用保存的模型,在保存之前,我们必须确保移除所有训练相关的操作。这些操作都在 tf.GraphKeys.TRAIN_OPS 集合中。以下代码展示了怎么去移除相关操作:

# 移除训练相关的操作
with net.graph.as_default():del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]# 保存模型
model.save('xor.tflearn')
复制代码

如果您再次运行该脚本,您会发现它会生成检查点文件、元数据文件、索引文件和数据文件,所有这些文件一起使用时可以快速重建我们训练好的模型。

2、固化模型

除了保存模型外,我们还必须先固化模型,然后才能将其与 TensorFlow Mobile 配合使用。正如您可能已经猜到的那样,固化模型的过程涉及将其所有变量转换为常量。此外,固化模型必须是符合 Google Protocol Buffers 序列化格式的单个二进制文件。

新建一个名为 freeze_model.py 的 Python 脚本,并使用文本编辑器打开它。我们将在这个文件中编写固化的模型代码来。

由于 TFLearn 没有任何固化模型的功能,我们现在必须直接使用 TensorFlow API。通过将以下行添加到文件来导入它们:

import tensorflow as tf
复制代码

整个脚本里面,我们将使用单个 TensorFlow 会话。我们使用 Session 类的构造函数创建会话。

with tf.Session() as session:# 代码的其他部分在这
复制代码

此时,我们必须通过调用 import_meta_graph() 函数并将模型的元数据文件的名称传递给它来创建 Saver 对象,除了返回 Saver 对象外,import_meta_graph() 函数还会自动将模型的图定义添加到会话的图定义中。

一旦创建了保存器(saver),我们可以通过调用 restore() 方法来初始化图定义中存在的所有变量,该方法需要包含模型最新检查点文件的目录路径。

my_saver = tf.train.import_meta_graph('xor.tflearn.meta')
my_saver.restore(session, tf.train.latest_checkpoint('.'))
复制代码

此时,我们可以调用 convert_variables_to_constants() 函数来创建一个固化的图定义,其中模型的所有变量都替换成常量。作为其输入,函数需要当前会话、当前会话的图定义以及包含模型输出层名称的列表。

frozen_graph = tf.graph_util.convert_variables_to_constants(session,session.graph_def,['my_output/Sigmoid']
)
复制代码

调用固化图定义的 SerializeToString() 方法为我们提供了模型的二进制 protobuf 表示。通过使用 Python 基本的文件 I/O,我建议您把它保存为一个名为 frozen_model.pb 的文件。

with open('frozen_model.pb', 'wb') as f:f.write(frozen_graph.SerializeToString())
复制代码

现在可以运行脚本来生成固化模型。

我们现在拥有开始使用 TensorFlow Mobile 所需的一切。

3、Android Studio 项目设置

TensorFlow Mobile 库可在 JCenter 上使用,所以我们可以直接将它添加为 app 模块 build.gradle 文件中的 implementation 依赖项。

implementation 'org.tensorflow:tensorflow-android:1.7.0'
复制代码

要把固化的模型添加到项目中,请将 frozen_model.pb 文件放置到项目的 assets 文件夹中。

4、初始化 TensorFlow 接口

TensorFlow Mobile 提供了一个简单的接口,我们可以使用它与我们的固化模型进行交互。要创建接口,请使用 TensorFlowInferenceInterface 类的构造函数,该类需要一个 AssetManager 实例和固化模型的文件名。

thread {val tfInterface = TensorFlowInferenceInterface(assets,"frozen_model.pb")// More code here
}
复制代码

在上面的代码中,您可以看到我们正在产生一个新的线程。这是为了确保应用的 UI 保持响应,虽然不必要,但建议这样做。

为了保证 TensorFlow Mobile 能够正确读取我们模型的文件,现在让我们尝试打印模型图中所有操作的名称。为了得到对图的引用,我们可以使用接口的 graph() 方法,并获取所有操作,即图的 operations() 方法。以下代码告诉您该怎么做:

val graph = tfInterface.graph()
graph.operations().forEach {println(it.name())
}
复制代码

如果现在运行该应用,则应该能够看到在 Android Studio 的 Logcat 窗口中打印的十几个操作名称。如果固化模型时没有出错,我们可以在这些名称中找到输入和输出层的名称:my_input/Xmy_output/Sigmoid

5、使用模型

为了用模型进行预测,我们将数据输入到输入层,在输出层得到数据。将数据输入到输入层需要使用接口的 feed() 方法,该方法需要输入层的名称、含有输入数据的数组以及数组的维数。以下代码展示如何将数字 01 输入到输入层:

tfInterface.feed("my_input/X",floatArrayOf(0f, 1f), 1, 2)
复制代码

数据加载到输入层后,我们必须使用 run() 方法进行推断操作,该方法需要输出层的名称。一旦操作完成,输出层将包含模型的预测。为了将预测结果加载到 Kotlin 数组中,我们可以使用 fetch() 方法。以下代码显示了如何执行此操作:

tfInterface.run(arrayOf("my_output/Sigmoid"))val output = floatArrayOf(-1f)
tfInterface.fetch("my_output/Sigmoid", output)
复制代码

您现在可以运行该应用来查看模型的预测是否正确。

可以更改输入到输入层的数字,以确认模型的预测始终正确。

总结

您现在知道如何创建一个简单的 TensorFlow 模型以及在安卓应用上通过 TensorFlow Mobile 去使用该模型。不过不必拘泥于自己的模型,用您今天学到的东西,使用更大的模型对您来说应该没有任何问题。例如 MobileNet 以及 Inception,这些都可以在 TensorFlow 的 模型园 里找到。但是请注意,这些模型会使 APK 更大,从而给使用低端设备的用户造成问题。

要了解有关 TensorFlow Mobile 的更多信息,请参阅 官方文档.


掘金翻译计划 是一个翻译优质互联网技术文章的社区,文章来源为 掘金 上的英文分享文章。内容覆盖 Android、iOS、前端、后端、区块链、产品、设计、人工智能等领域,想要查看更多优质译文请持续关注 掘金翻译计划、官方微博、知乎专栏。

[译] 如何在安卓应用中使用 TensorFlow Mobile相关推荐

  1. 豁然开朗篇:安卓开发中关于虚拟机那些事

    彻底搞懂虚拟机这一块,看这一篇就够了 前言 作为豁然开朗篇的最终篇,本文要讲解的是虚拟机这块,因为在之前讲解内存与线程的时候,一直都会牵涉到虚拟机和指令集这块,所以,为了让大家再豁然开朗多一次,本文会 ...

  2. node sqlite 插入数据_安卓手机中的应用数据都保存在哪些文件中?

    随笔 知识 案例 声音 其他 编者按 手机取证,品牌是一方面,从操作系统入手是另外一个渠道.手机中的重要数据基本上都以轻量数据库的形式保存在本地,也就是经常讲的sqlite db文件中. 从推特上得知 ...

  3. python 加载动图_在浏览器中使用TensorFlow.js和Python构建机器学习模型(附代码)...

    大数据文摘授权转载自数据派THU 作者:MOHD SANAD ZAKI RIZVI 本文主要介绍了: TensorFlow.js (deeplearn.js)使我们能够在浏览器中构建机器学习和深度学习 ...

  4. Android安卓开发中图片缩放讲解

    安卓开发中应用到图片的处理时候,我们通常会怎么缩放操作呢,来看下面的两种做法: 方法1:按固定比例进行缩放 在开发一些软件,如新闻客户端,很多时候要显示图片的缩略图,由于手机屏幕限制,一般情况下,我们 ...

  5. 在Data Collector中使用TensorFlow进行实时机器学习

    导言 只有当业务方面的用户和应用程序能够从一系列来源访问原始和聚合数据,并及时生成数据驱动时,才能实现现代DataOps平台的真正价值.借助机器学习,分析师和数据科学家可以利用TensorFlow等技 ...

  6. anconda安装后命令行中安装tensorflow报错

    现象  anconda安装后命令行中安装tensorflow报错 pip install --upgrade --ignore-installed tensorflow-gpu Building wh ...

  7. 在Win10 Anaconda中安装Tensorflow

    离完成上一篇文章有近1年了.2016年发生了太多的事情,从而没能坚持哪怕是每月一篇这样的频率.终于在2017年的1月份抽出几天搞出了一些东西.一路坑洼,赶紧记录下来. 2016年初就开始看深度学习的东 ...

  8. 独家 | 在浏览器中使用TensorFlow.js和Python构建机器学习模型(附代码)

    作者:MOHD SANAD ZAKI RIZVI 翻译:吴金笛 校对:丁楠雅 本文约5500字,建议阅读15分钟. 本文首先介绍了TensorFlow.js的重要性及其组件,并介绍使用其在浏览器中构建 ...

  9. Android Studio安卓开发中使用json来作为网络数据传输格式

    如果你是在安卓开发中并且使用android studio,要使用json来作为数据传输的格式,那么下面是我的一些经验. 一开始我在android studio中导入那6个包,那6个包找了非常久,因为放 ...

最新文章

  1. 图片视角转换 cv2.warpPerspective
  2. Scala中没有break和continue, 如何退出循环
  3. 4.弹性网络( Elastic Net)
  4. Awk by Example--转载
  5. Serverless 解惑——函数计算如何访问 PostgreSQL 数据库
  6. 单片机最小系统制作记录
  7. 一家公司干了8年的程序员的年终总结
  8. 七夕新浪漫,让AI黑科技带你们提前看看爱情的结晶
  9. aws mysql链接_AWS Lambda和MySQL连接处理
  10. DFS(深度优先遍历)走迷宫算法
  11. pythonsuper多重继承_解决python super()调用多重继承函数的问题
  12. 一种基于频域滤波法消除干扰项与角谱法重构技术的数字全息显微台阶形貌测量实例分析
  13. 三星2610打印机故障INTERNAL ERROR - Incomplete Session by time out
  14. React的调和过程(Reconcilliation)
  15. 开头的单词_学Z字母本义和引申义,初高中Z开头的单词几分钟全部轻松记忆!...
  16. 小程序修改单页面的背景颜色
  17. CAD的图层过滤器有什么用?
  18. matlab整数规划--简单入门
  19. 【C语言刷题】青蛙跳台阶
  20. Bearer Token的相关定义与使用方法

热门文章

  1. layui form模块
  2. vb 关于commondialog的多选
  3. nginx与IIS服务器搭建集群实现负载均衡(一)
  4. 程序员高效技巧系列 -- 完全脱离鼠标的终端
  5. 如何判断web应用是否添加到主屏幕
  6. 传输层协议TCP和UDP
  7. JAVA总裁--Java数组基础知识
  8. C# ToString() 参数大全
  9. 下班前网上搜集的方法哈哈
  10. Nginx+DNS负载均衡实现