一. TensorFlow Lite

TensorFlow Lite介绍.jpeg

TensorFlow Lite特性.jpeg

TensorFlow Lite使用.jpeg

TensorFlow Lite 是用于移动设备和嵌入式设备的轻量级解决方案。TensorFlow Lite 支持 Android、iOS 甚至树莓派等多种平台。

我们知道大多数的 AI 是在云端运算的,但是在移动端使用 AI 具有无网络延迟、响应更加及时、数据隐私等特性。

对于离线的场合,云端的 AI 就无法使用了,而此时可以在移动设备中使用 TensorFlow Lite。

二. tflite 格式

TensorFlow 生成的模型是无法直接给移动端使用的,需要离线转换成.tflite文件格式。

tflite 存储格式是 flatbuffers。

FlatBuffers 是由Google开源的一个免费软件库,用于实现序列化格式。它类似于Protocol Buffers、Thrift、Apache Avro。

因此,如果要给移动端使用的话,必须把 TensorFlow 训练好的 protobuf 模型文件转换成 FlatBuffers 格式。官方提供了 toco 来实现模型格式的转换。

三. 常用的 Java API

TensorFlow Lite 提供了 C ++ 和 Java 两种类型的 API。无论哪种 API 都需要加载模型和运行模型。

而 TensorFlow Lite 的 Java API 使用了 Interpreter 类(解释器)来完成加载模型和运行模型的任务。后面的例子会看到如何使用 Interpreter。

四. TensorFlow Lite + mnist 数据集实现识别手写数字

mnist 是手写数字图片数据集,包含60000张训练样本和10000张测试样本。
测试集也是同样比例的手写数字数据。每张图片有28x28个像素点构成,每个像素点用一个灰度值表示,这里是将28x28的像素展开为一个一维的行向量(每行784个值)。

mnist 数据集获取地址:http://yann.lecun.com/exdb/mnist/

下面的 demo 中已经包含了 mnist.tflite 模型文件。(如果没有的话,需要自己训练保存成pb文件,再转换成tflite 格式)

对于一个识别类,首先需要初始化 TensorFlow Lite 解释器,以及输入、输出。

    // The tensorflow lite fileprivate lateinit var tflite: Interpreter// Input byte bufferprivate lateinit var inputBuffer: ByteBuffer// Output array [batch_size, 10]private lateinit var mnistOutput: Array<FloatArray>init {try {tflite = Interpreter(loadModelFile(activity))inputBuffer = ByteBuffer.allocateDirect(BYTE_SIZE_OF_FLOAT * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE)inputBuffer.order(ByteOrder.nativeOrder())mnistOutput = Array(DIM_BATCH_SIZE) { FloatArray(NUMBER_LENGTH) }Log.d(TAG, "Created a Tensorflow Lite MNIST Classifier.")} catch (e: IOException) {Log.e(TAG, "IOException loading the tflite file failed.")}}

从 asserts 文件中加载 mnist.tflite 模型:

    /*** Load the model file from the assets folder*/@Throws(IOException::class)private fun loadModelFile(activity: Activity): MappedByteBuffer {val fileDescriptor = activity.assets.openFd(MODEL_PATH)val inputStream = FileInputStream(fileDescriptor.fileDescriptor)val fileChannel = inputStream.channelval startOffset = fileDescriptor.startOffsetval declaredLength = fileDescriptor.declaredLengthreturn fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)}

真正识别手写数字是在 classify() 方法:

val digit = mnistClassifier.classify(Bitmap.createScaledBitmap(paintView.bitmap, PIXEL_WIDTH, PIXEL_WIDTH, false))

classify() 方法包含了预处理用于初始化 inputBuffer、运行 mnist 模型、识别出数字。

    /*** Classifies the number with the mnist model.** @param bitmap* @return the identified number*/fun classify(bitmap: Bitmap): Int {if (tflite == null) {Log.e(TAG, "Image classifier has not been initialized; Skipped.")}preProcess(bitmap)runModel()return postProcess()}/*** Converts it into the Byte Buffer to feed into the model** @param bitmap*/private fun preProcess(bitmap: Bitmap?) {if (bitmap == null || inputBuffer == null) {return}// Reset the image datainputBuffer.rewind()val width = bitmap.widthval height = bitmap.height// The bitmap shape should be 28 x 28val pixels = IntArray(width * height)bitmap.getPixels(pixels, 0, width, 0, 0, width, height)for (i in pixels.indices) {// Set 0 for white and 255 for black pixelsval pixel = pixels[i]// The color of the input is black so the blue channel will be 0xFF.val channel = pixel and 0xffinputBuffer.putFloat((0xff - channel).toFloat())}}/*** Run the TFLite model*/private fun runModel() = tflite.run(inputBuffer, mnistOutput)/*** Go through the output and find the number that was identified.** @return the number that was identified (returns -1 if one wasn't found)*/private fun postProcess(): Int {for (i in 0 until mnistOutput[0].size) {val value = mnistOutput[0][i]if (value == 1f) {return i}}return -1}

对于 Android 有一个地方需要注意,必须在 app 模块的 build.gradle 中添加如下的语句,否则无法加载模型。

android {......aaptOptions {noCompress "tflite"}
}

demo 运行效果如下:

识别手写数字5.png

识别手写数字7.png

何程序错误,以及技术疑问或需要解答的,请扫码添加作者VX

五. 总结

本文只是 TF Lite 的初探,很多细节并没有详细阐述。应该会在未来的文章中详细介绍。

本文 demo 的 github 地址:https://github.com/xiaobingchan/TFLite-MnistDemo

当然,也可以跑一下官方的例子:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/examples/android/app

Android TensorFlow Lite 深度学习识别手写数字mnist demo相关推荐

  1. 【Get】用深度学习识别手写数字

    前置参考读物: <机器学习,看完就明白了>传送门 获取数据源 训练数据直接使用开源的手写数据集MNIST. MNIST数据集是一个开源的手写数据库.它提供了大量的数据样本作为训练集和验证集 ...

  2. 深度学习(4)--手写数字mnist实现

    前面两节,讲述了梯度下降和方向传播的原理,这里我通过mnist训练来讲述下python的实现方法 头文件 numpy用于矩阵运算,random用于数据集的shuffle,mnist_loader 用于 ...

  3. 【深度学习】手写数字识别Tensorflow2实验报告

    实验一:手写数字识别 一.实验目的 利用深度学习实现手写数字识别,当输入一张手写图片后,能够准确的识别出该图片中数字是几.输出内容是0.1.2.3.4.5.6.7.8.9的其中一个. 二.实验原理 ( ...

  4. 基于深度学习的手写数字识别算法Python实现

    摘 要 深度学习是传统机器学习下的一个分支,得益于近些年来计算机硬件计算能力质的飞跃,使得深度学习成为了当下热门之一.手写数字识别更是深度学习入门的经典案例,学习和理解其背后的原理对于深度学习的理解有 ...

  5. 基于深度学习的手写数字识别、python实现

    基于深度学习的手写数字识别.python实现 一.what is 深度学习 二.加深层可以减少网络的参数数量 三.深度学习的手写数字识别 一.what is 深度学习 深度学习是加深了层的深度神经网络 ...

  6. 基于深度学习的手写数字识别Matlab实现

    基于深度学习的手写数字识别Matlab实现 1.网络设计 2. 训练方法 3.实验结果 4.实验结果分析 5.结论 1.网络设计 1.1 CNN(特征提取网络+分类网络) 随着深度学习的迅猛发展,其应 ...

  7. 03_深度学习实现手写数字识别(python)

    本次项目采用了多种模型进行测试,并尝试策略来提升模型的泛化能力,最终取得了99.67%的准确率,并采用pyqt5来制作可视化GUI界面进行呈现.具体代码已经开源. 代码详情见附录 1简介 早在1998 ...

  8. Python基于深度学习的手写数字识别

    Python基于深度学习的手写数字识别 1.代码的功能和运行方法 2. 网络设计 3.训练方法 4.实验结果分析 5.结论 1.代码的功能和运行方法 代码可以实现任意数字0-9的识别,只需要将图片载入 ...

  9. 百度飞桨PaddelePaddle-21天零基础实践深度学习-【手写数字任务】2

    百度飞桨PaddelePaddle-21天零基础实践深度学习-[手写数字任务]2 模型设计 网络结构 损失函数 训练配置 优化算法 模型设计 网络结构 全连接神经网络 经典的全连接神经网络来包含四层网 ...

最新文章

  1. 各种小的 dp (精)
  2. CoreNLP请求超时 runtime out
  3. idea提交新项目到远程git创库
  4. java文件下载并添加水印_Java下载文件加文字水印(Excel、PDF、图片)
  5. 互联网日报 | 小米11取消随机附送充电器;苏宁30周年发庆生红包;2021年全国两会召开时间确定...
  6. python--反射机制
  7. 腾讯云TStack与IBM LinuxONE互认证
  8. 云网融合个人浅析(一)
  9. ubuntu PHP Cannot adopt OID in UCD-SNMP-MIB
  10. WeX5 V3.6 正式版核心特性
  11. 算法面试必备-----数据分析常见面试题
  12. 留学Essay写作方法从哪里学习?
  13. 嵌入式Linux中间件,高可用性(HA)和嵌入式管理中间件:Enea Element详解
  14. CRM客户关系管理系统让企业在竞争中脱颖而出
  15. 小白建网站,该如何入手?
  16. 在微信小程序中编写金额摇奖效果
  17. 计算机动画算法与编程基础pdf,清华大学 计算机动画算法与编程基础2-图形绘制课件.ppt...
  18. Data Types in the Kernel [LDD3 11]
  19. 概率和统计是一回事么?
  20. 解决springmvc返回json数据IE出现文件下载和json数据中文乱码问题

热门文章

  1. linux 标准 GPIO 操作
  2. 转: 加快Android编译速度
  3. 将数据库表导入到solr索引
  4. GridView 中 Bind和Eval的区别详解
  5. android 广告栏效果,实现android广告栏效果
  6. Java工具实现无水印批量下载
  7. java中的Iterator和Iterable 区别
  8. 802d简明调试手册_SINUMERIK-828D简明调试手册.pdf
  9. 计算机基础应用的培养活动记录,小学少年宫计算机兴趣小组活动记录表
  10. python自然语言处理书籍_精通Python自然语言处理pdf