前言

Tensorflow2之后,训练保存的模型也有所变化,基于Keras接口搭建的网络模型默认保存的模型是h5格式的,而之前的模型格式是pb。Tensorflow2的h5格式的模型转换成tflite格式模型非常方便。本教程就是介绍如何使用Tensorflow2的Keras接口训练分类模型并使用Tensorflow Lite部署到Android设备上。

本教程源码:https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TFLiteClassification

训练和转换模型

以下是使用Tensorflow2的keras搭建的一个MobileNetV2模型并训练自定义数据集,本教程主要是介绍如何在Android设备上使用Tensorflow Lite部署分类模型,所以关于训练模型只是简单介绍,代码并不完整。通过下面的训练模型,我们最终会得到一个mobilenet_v2.h5模型。

import os
import tensorflow as tf
import reader
import config as cfg# 获取模型
input_shape = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE, cfg.IMAGE_CHANNEL)
model = tf.keras.Sequential([tf.keras.applications.MobileNetV2(input_shape=input_shape, include_top=False, pooling='max'),tf.keras.layers.Dense(units=cfg.CLASS_DIM, activation='softmax')])
model.summary()# 获取训练数据
train_data = reader.train_reader(data_list_path=cfg.TRAIN_LIST_PATH, batch_size=cfg.BATCH_SIZE)# 定义训练参数
model.compile(optimizer=tf.keras.optimizers.RMSprop(),loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])# 开始训练
model.fit(train_data, epochs=cfg.EPOCH_SUM, workers=4)# 保存h5模型
if not os.path.exists(os.path.dirname(cfg.H5_MODEL_PATH)):os.makedirs(os.path.dirname(cfg.H5_MODEL_PATH))
model.save(filepath=cfg.H5_MODEL_PATH)
print('saved h5 model!')

通过上面得到的mobilenet_v2.h5模型,我们需要转换为tflite格式的模型,在Tensorflow2之后,这个转换就变动很简单了,通过下面的几行代码即可完成转换,最终我们会得到一个mobilenet_v2.tflite模型。

import tensorflow as tf
import config as cfg# 加载模型
model = tf.keras.models.load_model(cfg.H5_MODEL_PATH)# 生成非量化的tflite模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open(cfg.TFLITE_MODEL_FILE, 'wb').write(tflite_model)
print('saved tflite model!')

如果保存的模型格式不是h5,而是tf格式的,如下代码,保存的模型是tf格式的。

import tensorflow as tfmodel = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3))model.save(filepath='mobilenet_v2', save_format='tf')

如果是tf格式的模型,那需要使用以下转换模型的方式。

import tensorflow as tfconverter = tf.lite.TFLiteConverter.from_saved_model('mobilenet_v2')
tflite_model = converter.convert()
open("mobilenet_v2.tflite", "wb").write(tflite_model)

在部署到Android中可能需要到输入输出层的名称,通过下面代码可以获取到输入输出层的名称和shape。

import tensorflow as tfmodel_path = 'models/mobilenet_v2.tflite'interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()# 获取输入和输出张量。
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()print(input_details)
print(output_details)

部署到Android设备

首先要在build.gradle导入这三个库,如果不使用GPU可以只导入两个库。

implementation 'org.tensorflow:tensorflow-lite:2.3.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.3.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0-rc1'

在以前还需要在android下添加以下代码,避免在打包apk的是对模型有压缩操作,损坏模型。现在好像不加也没有关系,但是为了安全起见,还是添加上去。

    aaptOptions {noCompress "tflite"}

复制转换的预测模型到app/src/main/assets目录下,还有类别的标签,每一行对应一个标签名称。

Tensorflow Lite工具

编写一个TFLiteClassificationUtil工具类,关于Tensorflow Lite的操作都在这里完成,如加载模型、预测。在构造方法中,通过参数传递的模型路径加载模型,在加载模型的时候配置预测信息,例如是否使用Android底层神经网络APINnApiDelegate或者是否使用GPUGpuDelegate,同时获取网络的输入输出层。有了tensorflow-lite-support库,数据预处理就变得非常简单,通过ImageProcessor创建一个数据预处理的工具,之后在预测之前使用这个工具对图像进行预处理,处理速度还是挺快的,要注意的是图像的均值IMAGE_MEAN和标准差IMAGE_STD,因为在训练的时候图像预处理可能不一样的,有些读者出现在电脑上准确率很高,但在手机上准确率很低,多数情况下就是这个图像预处理做得不对。

private static final float[] IMAGE_MEAN = new float[]{128.0f, 128.0f, 128.0f};
private static final float[] IMAGE_STD = new float[]{128.0f, 128.0f, 128.0f};public TFLiteClassificationUtil(String modelPath) throws Exception {File file = new File(modelPath);if (!file.exists()) {throw new Exception("model file is not exists!");}try {Interpreter.Options options = new Interpreter.Options();// 使用多线程预测options.setNumThreads(NUM_THREADS);// 使用Android自带的API或者GPU加速NnApiDelegate delegate = new NnApiDelegate();
//            GpuDelegate delegate = new GpuDelegate();options.addDelegate(delegate);tflite = new Interpreter(file, options);// 获取输入,shape为{1, height, width, 3}int[] imageShape = tflite.getInputTensor(tflite.getInputIndex("input_1")).shape();DataType imageDataType = tflite.getInputTensor(tflite.getInputIndex("input_1")).dataType();inputImageBuffer = new TensorImage(imageDataType);// 获取输入,shape为{1, NUM_CLASSES}int[] probabilityShape = tflite.getOutputTensor(tflite.getOutputIndex("Identity")).shape();DataType probabilityDataType = tflite.getOutputTensor(tflite.getOutputIndex("Identity")).dataType();outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);// 添加图像预处理方式imageProcessor = new ImageProcessor.Builder().add(new ResizeOp(imageShape[1], imageShape[2], ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)).add(new NormalizeOp(IMAGE_MEAN, IMAGE_STD)).build();} catch (Exception e) {e.printStackTrace();throw new Exception("load model fail!");}
}

为了兼容图片路径和Bitmap格式的图片预测,这里创建了两个重载方法,它们都是通过调用predict()

public int predictImage(String image_path) throws Exception {if (!new File(image_path).exists()) {throw new Exception("image file is not exists!");}FileInputStream fis = new FileInputStream(image_path);Bitmap bitmap = BitmapFactory.decodeStream(fis);int result = predictImage(bitmap);if (bitmap.isRecycled()) {bitmap.recycle();}return result;
}public int predictImage(Bitmap bitmap) throws Exception {return predict(bitmap);
}

这里创建一个获取最大概率值,并把下标返回的方法,其实就是获取概率最大的预测标签。

public static int getMaxResult(float[] result) {float probability = 0;int r = 0;for (int i = 0; i < result.length; i++) {if (probability < result[i]) {probability = result[i];r = i;}}return r;
}

这个方法就是Tensorflow Lite执行预测的最后一步,通过执行tflite.run()对输入的数据进行预测并得到预测结果,通过解析获取到最大的概率的预测标签,并返回。到这里Tensorflow Lite的工具就完成了。

private int predict(Bitmap bmp) throws Exception {inputImageBuffer = loadImage(bmp);try {tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());} catch (Exception e) {throw new Exception("predict image fail! log:" + e);}float[] results = outputProbabilityBuffer.getFloatArray();Log.d(TAG, Arrays.toString(results));return getMaxResult(results);
}

选择图片预测

本教程会有两个页面,一个是选择图片进行预测的页面,另一个是使用相机实时预测并显示预测结果。以下为activity_main.xml的代码,通过按钮选择图片,并在该页面显示图片和预测结果。

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"xmlns:app="http://schemas.android.com/apk/res-auto"xmlns:tools="http://schemas.android.com/tools"android:layout_width="match_parent"android:layout_height="match_parent"android:orientation="vertical"tools:context=".MainActivity"><ImageViewandroid:id="@+id/image_view"android:layout_width="match_parent"android:layout_height="400dp" /><TextViewandroid:id="@+id/result_text"android:layout_width="match_parent"android:layout_height="wrap_content"android:layout_below="@id/image_view"android:text="识别结果"android:textSize="16sp" /><LinearLayoutandroid:layout_width="match_parent"android:layout_height="wrap_content"android:layout_alignParentBottom="true"android:orientation="horizontal"><Buttonandroid:id="@+id/select_img_btn"android:layout_width="0dp"android:layout_height="wrap_content"android:layout_weight="1"android:text="选择照片" /><Buttonandroid:id="@+id/open_camera"android:layout_width="0dp"android:layout_height="wrap_content"android:layout_weight="1"android:text="实时预测" /></LinearLayout></RelativeLayout>

MainActivity.java中,进入到页面我们就要先加载模型,我们是把模型放在Android项目的assets目录的,但是Tensorflow Lite并不建议直接在assets读取模型,所以我们需要把模型复制到一个缓存目录,然后再从缓存目录加载模型,同时还有读取标签名,标签名称按照训练的label顺序存放在assets的label_list.txt,以下为实现代码。

classNames = Utils.ReadListFromFile(getAssets(), "label_list.txt");
String classificationModelPath = getCacheDir().getAbsolutePath() + File.separator + "mobilenet_v2.tflite";
Utils.copyFileFromAsset(MainActivity.this, "mobilenet_v2.tflite", classificationModelPath);
try {tfLiteClassificationUtil = new TFLiteClassificationUtil(classificationModelPath);Toast.makeText(MainActivity.this, "模型加载成功!", Toast.LENGTH_SHORT).show();
} catch (Exception e) {Toast.makeText(MainActivity.this, "模型加载失败!", Toast.LENGTH_SHORT).show();e.printStackTrace();finish();
}

添加两个按钮点击事件,可以选择打开相册读取图片进行预测,或者打开另一个Activity进行调用摄像头实时识别。

Button selectImgBtn = findViewById(R.id.select_img_btn);
Button openCamera = findViewById(R.id.open_camera);
imageView = findViewById(R.id.image_view);
textView = findViewById(R.id.result_text);
selectImgBtn.setOnClickListener(new View.OnClickListener() {@Overridepublic void onClick(View v) {// 打开相册Intent intent = new Intent(Intent.ACTION_PICK);intent.setType("image/*");startActivityForResult(intent, 1);}
});
openCamera.setOnClickListener(new View.OnClickListener() {@Overridepublic void onClick(View v) {// 打开实时拍摄识别页面Intent intent = new Intent(MainActivity.this, CameraActivity.class);startActivity(intent);}
});

当打开相册选择照片之后,回到原来的页面,在下面这个回调方法中获取选择图片的Uri,通过Uri可以获取到图片的绝对路径。如果Android8以上的设备获取不到图片,需要在AndroidManifest.xml配置文件中的application添加android:requestLegacyExternalStorage="true"。拿到图片路径之后,调用TFLiteClassificationUtil类中的predictImage()方法预测并获取预测值,在页面上显示预测的标签、对应标签的名称、概率值和预测时间。

@Override
protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {super.onActivityResult(requestCode, resultCode, data);String image_path;if (resultCode == Activity.RESULT_OK) {if (requestCode == 1) {if (data == null) {Log.w("onActivityResult", "user photo data is null");return;}Uri image_uri = data.getData();image_path = getPathFromURI(MainActivity.this, image_uri);try {// 预测图像FileInputStream fis = new FileInputStream(image_path);imageView.setImageBitmap(BitmapFactory.decodeStream(fis));long start = System.currentTimeMillis();float[] result = tfLiteClassificationUtil.predictImage(image_path);long end = System.currentTimeMillis();String show_text = "预测结果标签:" + (int) result[0] +"\n名称:" +  classNames[(int) result[0]] +"\n概率:" + result[1] +"\n时间:" + (end - start) + "ms";textView.setText(show_text);} catch (Exception e) {e.printStackTrace();}}}
}

上面获取的Uri可以通过下面这个方法把Url转换成绝对路径。

// get photo from Uri
public static String getPathFromURI(Context context, Uri uri) {String result;Cursor cursor = context.getContentResolver().query(uri, null, null, null, null);if (cursor == null) {result = uri.getPath();} else {cursor.moveToFirst();int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA);result = cursor.getString(idx);cursor.close();}return result;
}

摄像头实时预测

在调用相机实时预测我就不再介绍了,原理都差不多,具体可以查看https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TFLiteClassification中的源代码。核心代码如下,创建一个子线程,子线程中不断从摄像头预览的AutoFitTextureView上获取图像,并执行预测,并在页面上显示预测的标签、对应标签的名称、概率值和预测时间。每一次预测完成之后都立即获取图片继续预测,只要预测速度够快,就可以看成实时预测。

private Runnable periodicClassify =new Runnable() {@Overridepublic void run() {synchronized (lock) {if (runClassifier) {// 开始预测前要判断相机是否已经准备好if (getApplicationContext() != null && mCameraDevice != null && tfLiteClassificationUtil != null) {predict();}}}if (mInferThread != null && mInferHandler != null && mCaptureHandler != null && mCaptureThread != null) {mInferHandler.post(periodicClassify);}}};// 预测相机捕获的图像
private void predict() {// 获取相机捕获的图像Bitmap bitmap = mTextureView.getBitmap();try {// 预测图像long start = System.currentTimeMillis();float[] result = tfLiteClassificationUtil.predictImage(bitmap);long end = System.currentTimeMillis();String show_text = "预测结果标签:" + (int) result[0] +"\n名称:" +  classNames[(int) result[0]] +"\n概率:" + result[1] +"\n时间:" + (end - start) + "ms";textView.setText(show_text);} catch (Exception e) {e.printStackTrace();}
}

本项目中使用的了读取图片的权限和打开相机的权限,所以不要忘记在AndroidManifest.xml添加以下权限申请。

<uses-permission android:name="android.permission.CAMERA"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>

如果是Android 6 以上的设备还要动态申请权限。

    // check had permissionprivate boolean hasPermission() {if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {return checkSelfPermission(Manifest.permission.CAMERA) == PackageManager.PERMISSION_GRANTED &&checkSelfPermission(Manifest.permission.READ_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED &&checkSelfPermission(Manifest.permission.WRITE_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED;} else {return true;}}// request permissionprivate void requestPermission() {if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {requestPermissions(new String[]{Manifest.permission.CAMERA,Manifest.permission.READ_EXTERNAL_STORAGE,Manifest.permission.WRITE_EXTERNAL_STORAGE}, 1);}}

选择图片识别效果图:

相机实时识别效果图:

基于Tensorflow2 Lite在Android手机上实现图像分类相关推荐

  1. 基于Tensorflow2 Lite在Android手机上实现时间序列温度预测(二)

    前言 Tensorflow2之后,训练保存的模型也有所变化,基于Keras接口搭建的网络模型默认保存的模型是h5格式的,而之前的模型格式是pb.Tensorflow2的h5格式的模型转换成tflite ...

  2. 基于MNN在Android手机上实现图像分类

    原文博客:Doi技术团队 链接地址:https://blog.doiduoyi.com/authors/1584446358138 初心:记录优秀的Doi技术团队学习经历 本文链接:基于MNN在And ...

  3. 如何在Android手机上进行Google Map的开发。

    1.题记 提起谷歌Map相信大家都不会陌生,那进入我们今天的话题,如何在Android手机上进行Google Map的开发. 2.Map应用程序的开发 2.1 准备工作 2.1.1 申请Android ...

  4. 您可以在Windows PC或Android手机上使用iMessage吗?

    Denys Prykhodov/Shutterstock.comDenys Prykhodov / Shutterstock.com Want iMessage for Android or Wind ...

  5. Android 手机上安装并运行 Ubuntu 12.04

    Android 手机上安装并运行 Ubuntu 12.04 2012 年 8 月 16 日  by  DawnDIY in  Android |  7 Comments Android 是基于Linu ...

  6. 如何在 Android 手机上实现抓包?

    如何在 Android 手机上实现抓包? http://www.zhihu.com/question/20467503 我想知道某个应用究竟在数据提交到哪里,提交了什么. 网上的教程太复杂,不想麻烦. ...

  7. Android 浏览器内核浅谈,基于WebKit内核的Android手机浏览器的性能研究与加载优化...

    南京邮电大学 硕士 2017 基于WebKit内核的Android手机浏览器的性能研究与加载优化 Performance Research and Load Optimization of Andro ...

  8. android手机进行android开发,如何在Android手机上进行自动化测试(上)

    版权声明:允许转载,但转载必须保留原链接:请勿用作商业或者非法用途 前言 通过阅读本节教程,你将了解到以下内容: 如何在脚本代码中.运行脚本时指定手机 如何填写--device Android:/// ...

  9. unity3d shader编程中GrabPass 在某些android手机上失效的解决方案

    unity3d shader编程中GrabPass 在某些android手机上失效的解决方案 参考文章: (1)unity3d shader编程中GrabPass 在某些android手机上失效的解决 ...

最新文章

  1. ZooKeeper概述
  2. Yarn已过时!Kubeflow实现机器学习调度平台才是未来
  3. 如何使用Azure API管理服务?
  4. 排序算法(插入、快速、归并)java实现
  5. HDU 1848 Fibonacci again and again
  6. 设计模式---读书笔记
  7. 已跳过全部重新生成_2020年最新跳对公技术1+5,1+10,5+50(必读)
  8. 50 CO配置-控制-获利能力分析-维护经营关注点
  9. 7-102 单词首字母大写 (15 分)
  10. 重学前端学习笔记(十三)--浏览器工作解析(三)
  11. 自定义控件2.第一个自定义view
  12. linux启动python项目_java项目部署Linux服务器几种启动方式总结经验
  13. 鸿蒙系统下载地址_华为鸿蒙系统下载
  14. 二维码软件如何扫描二维码打开网页
  15. Spring学习资料
  16. kali安装后详细配置
  17. Android 最常用的设计模式二 安卓Rxjava源码分析—观察者模式Observer(有实例)
  18. 【爬虫1】爬虫和反爬虫介绍
  19. 图片验证码的逻辑实现
  20. 使用 OpenWhisk 自建 Serverless 服务

热门文章

  1. 怎样知道android的手机号码,查自己手机号码怎么查 教你五种方法【图文教程】...
  2. win10安装ubuntu16.04双系统+详细步骤实现
  3. C语言——leetcode69——X的平方根
  4. html设计动画小黄人,CSS3画出小黄人并实现动画效果!
  5. 一个简单的循环往复的动画效果
  6. 参考文档:《基于多目标算法的冷热电联供型综合能源系统运行优化》
  7. 华为禁止系统更新的方法
  8. YOLOv5报错AssertionError:Label class 1 exceeds nc=1 in yolo/dataset.ymal Possible class labels are 0-0
  9. 如何理解:ListString list=new ArrayListString();为甚麼要声明为List 而不是ArrayListString?
  10. 茶学领域如何用的上计算机,计算机视觉图像理技术在茶学领域应用方法的研究.pdf...