将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)

1. 写在前面

  最近在做一个数字手势识别的APP(关于这个项目,我会再写一篇博客仔细介绍,博客地址:一步步做一个数字手势识别APP,源代码已经开源在github上,地址:Chinese-number-gestures-recognition),要把在PC端训练好的模型放到Android APP上,调研了下,谷歌发布了TensorFlow Lite可以把TensorFlow训练好的模型迁移到Android APP上,百度也发布了移动端深度学习框架mobile-deep-learning(MDL),这个框架应该是paddlepaddle的手机版,具体的细节没有了解过。因为对TensorFlow稍微熟悉些,因此就决定用TensorFlow来做。
  关于在PC端如何处理数据及训练模型,请参见博客:一步步做一个数字手势识别APP,代码已经开源在github上,上面有代码的说明和APP演示。这篇博客只介绍如何把TensorFlow训练好的模型迁移到Android Studio上进行APP的开发。

2. 模型训练注意事项

  第一步,首先在pc端训练模型的时候要模型保存为.pb模型,在保存的时候有一点非常非常重要,就是你待会再Android studio是使用这个模型用到哪个参数,那么你在保存pb模型的时候就把给哪个参数一个名字,再保存。否则,你在Android studio中很难拿出这个参数,因为TensorFlow Lite的fetch()函数是根据保存在pb模型中的名字去寻找这个参数的。(如果你已经训练好了模型,并且没有给参数名字,且你不想再训练模型了,那么你可以尝试下面的方法去找到你需要使用的变量的默认名字,见下面的代码):

#输出保存的模型中参数名字及对应的值
with tf.gfile.GFile('model_50_200_c3//./digital_gesture.pb', "rb") as f:  #读取模型数据graph_def = tf.GraphDef()graph_def.ParseFromString(f.read()) #得到模型中的计算图和数据
with tf.Graph().as_default() as graph:  # 这里的Graph()要有括号,不然会报TypeErrortf.import_graph_def(graph_def, name="")  #导入模型中的图到现在这个新的计算图中,不指定名字的话默认是 importfor op in graph.get_operations():  # 打印出图中的节点信息print(op.name, op.values())

这段代码打出的变量的名字以及对应的值。

言归正传,通常情况该你应该保存参数的时候都给参数一个指定的名字,如下面这样(通过name参数给变量指定名字),关于训练CNN的完整代码请参见下一篇博客或者github:

X = tf.placeholder(tf.float32, [None, 64, 64, 3], name="input_x")
y = tf.placeholder(tf.float32, [None, 11], name="input_y")
kp = tf.placeholder_with_default(1.0, shape=(), name="keep_prob")
lam = tf.placeholder(tf.float32, name="lamda")
#中间略过若干代码
z_fc2 = tf.add(tf.matmul(z_fc1_drop, W_fc2),b_fc2, name="outlayer")
prob = tf.nn.softmax(z_fc2, name="probability")
pred = tf.argmax(prob, 1, output_type="int32", name="predict")
3. 在Android Studio中配置

  第二步,开始把pb模型移植到Android Studio上,网上绝大部分资料都是说用bazel重新编译模型生成依赖,这种方法难度太大。其实没必须这样做,TensorFlow Lite官方的例子中已经给我们展示了,我们其实只需要两个文件:libandroid_tensorflow_inference_java.jar 和 libtensorflow_inference.so。这两个文件我已经放到github上了,大家可以自行下载使用,下载地址:libandroid_tensorflow_inference_java.jar、libtensorflow_inference.so。

注:检神说,直接用aar依赖也可以,这个我没试过。。有兴趣的可以试一下。

准备工作已经完毕,下面正式开始Android Studio中的配置。

  首先把训练好的pb模型放到Android项目中app/src/main/assets下,若不存在assets目录,则自己新建一个。如图所示:

  其次,把刚刚下载的 libandroid_tensorflow_inference_java.jar 文件放到 app/libs 目下,把libtensorflow_inference.so 放到 app/libs/armeabi-v7a 目录下,如下图所示:

然后在app/build.gradle里进行如下配置:
  在defaultConfig里添加

multiDexEnabled truendk {abiFilters "armeabi-v7a"}

  在android里添加

 sourceSets {main {jni.srcDirs = []jniLibs.srcDirs = ['libs']}}

如图所示:

  在dependencies中添加libandroid_tensorflow_inference_java.jar,即:

implementation files('libs/libandroid_tensorflow_inference_java.jar')

如图所示:

至此,所有配置已经完成,下面是模型调用。

4. 在Android Studio中调用模型

在要用到模型的地方,首先要加载libtensorflow_inference.so库和初始化TensorFlowInferenceInterface对象,代码为:

TensorFlowInferenceInterface inferenceInterface;static {//加载libtensorflow_inference.so库文件System.loadLibrary("tensorflow_inference");Log.e("tensorflow","libtensorflow_inference.so库加载成功");}Classifier(AssetManager assetManager, String modePath) {//初始化TensorFlowInferenceInterface对象inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);Log.e("tf","TensoFlow模型文件加载成功");}

如图所示:

下面来多看一点东西,看看TensorFlow Lite里提供了哪几个接口,官网地址:Here’s what a typical Inference Library sequence looks like on Android.

// Load the model from disk.
TensorFlowInferenceInterface inferenceInterface =
new TensorFlowInferenceInterface(assetManager, modelFilename);// Copy the input data into TensorFlow.
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);// Run the inference call.
inferenceInterface.run(outputNames, logStats);// Copy the output Tensor back into the output array.
inferenceInterface.fetch(outputName, outputs);

下面就可以愉快地使用模型了。放一段我调用模型的代码,以供大家参考:

public ArrayList predict(Bitmap bitmap){ArrayList<String> list = new ArrayList<>();float[] inputdata = getPixels(bitmap);for(int i = 0; i <30; ++i){Log.d("matrix",inputdata[i] + "");}inferenceInterface.feed(inputName, inputdata, 1, IMAGE_SIZE, IMAGE_SIZE, 3);//运行模型,run的参数必须是String[]类型String[] outputNames = new String[]{outputName,probabilityName,outlayerName};inferenceInterface.run(outputNames);//获取结果int[] labels = new int[1];inferenceInterface.fetch(outputName,labels);int label = labels[0];float[] prob = new float[11];inferenceInterface.fetch(probabilityName, prob);
//        float[] outlayer = new float[11];
//        inferenceInterface.fetch(outlayerName, outlayer);//        for(int i = 0; i <11; ++i)
//        {//            Log.d("matrix",outlayer[i] + "");
//        }for(int i = 0; i <11; ++i){Log.d("matrix",prob[i] + "");}DecimalFormat df = new DecimalFormat("0.000000");float label_prob = prob[label];//返回值list.add(Integer.toString(label));list.add(df.format(label_prob));return list;}

最后放一张做的数字手势识别APP的效果,全部代码,将会开源在github上,欢迎star。

再放一张碰运气的识别结果:

将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)相关推荐

  1. 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

    将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...

  2. 飞桨上线万能转换小工具,教你玩转TensorFlow、Caffe等模型迁移

    百度推出飞桨(PaddlePaddle)后,不少开发者开始转向国内的深度学习框架.但是从代码的转移谈何容易,之前的工作重写一遍不太现实,成千上万行代码的手工转换等于是在做一次二次开发. 现在,有个好消 ...

  3. 基于TensorFlow训练花朵识别模型的源码和Demo

    基于TensorFlow训练花朵识别模型的源码和Demo 转发来源: https://blog.csdn.net/Anymake_ren/article/details/80550684 下面就通过对 ...

  4. 使用PaddleFluid和TensorFlow训练序列标注模型

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

  5. java加载tensorflow训练的PB模型记录

    java加载tensorflow训练的PB模型记录 python训练 1. 模型的输入输出定义 2. 训练时保存模型的方法 java加载模型 1.maven依赖 2. Java代码实例 tensor注 ...

  6. 如何用java语言调用tensorflow训练好的模型

    1.TensorFlow的训练模型在Android和Java的应用及调用 2.tensorflow的python离线训练java在线预测方案 3.tensorflow训练的模型在java中的使用 4. ...

  7. 用TensorFlow训练第一个模型

    简述 下面有非常详细的代码注释 学习自莫凡大神给的demo https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/2-2 ...

  8. java调用tensorflow训练好的模型

    1. python的处理 整个模型的源码在此:https://github.com/shelleyHLX/tensorflow_java 多谢star 首先训练一个模型,代码如下 import ten ...

  9. Tensorflow object detection API 搭建自己的目标检测模型并迁移到Android上

    参考链接:https://blog.csdn.net/dy_guox/article/details/79111949 之前参考上述一系列博客在Windows10下面成功运行了TensorFlow A ...

最新文章

  1. Java学习笔记(十)--控制台输入输出
  2. mysql 使用不同引擎_mysql 不同引擎的比较
  3. 怎样用java写一个简单的文件复制程序
  4. 什么是 SAP Spartacus 里的 module augmentation
  5. java提高篇之详解内部类
  6. 现代软件工程课件 需求分析 如何提出靠谱的项目建议 NABCD
  7. 使用Spire.Barcode程序库生成二维码
  8. 码表的理解(ASCII,GBK,Unicode,UTF-8等)。
  9. 强制更新LYNC客户端的地址簿
  10. Camshift原理
  11. 【论文笔记】激光里程计网络 LO-Net:Deep Real-time Lidar Odometry2019
  12. PDF怎么编辑修改,如何编辑PDF文字内容
  13. HTML页面跳转的5种方法。
  14. PixiJS学习(5)几何图形
  15. CodeForces 19E 仙女fairy
  16. 女孩做妻子前应知道的10件事
  17. Tomcat 乱码问题解决方法
  18. 根轨迹超前校正matlab,[自动化] 基于根轨迹法的超前校正
  19. 基础的http协议构成
  20. 数学诺贝尔奖2008阿贝尔奖揭晓

热门文章

  1. iOS开发 3D-touch使用
  2. 阿里天池 Python 训练营1
  3. 09考研!艰难与希望?
  4. Docker 企业级实战青铜段位-崔健敏-专题视频课程
  5. 简述UITableView的属性和用法
  6. python调用excel的宏_在 Excel 中使用 Python 开发宏脚本
  7. RabbitMQ教程大全看这一篇就够了-java版本
  8. 计算机科学 教育部评估,教育部全国第四轮学科评估结果(A+、A类学校)汇总...
  9. Java生成树状结构返回结果
  10. 蛙泳、自由泳、仰泳、蝶泳,图解动画,教你游泳,不会游的看了包你学会!!!