将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)
将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)
1. 写在前面
最近在做一个数字手势识别的APP(关于这个项目,我会再写一篇博客仔细介绍,博客地址:一步步做一个数字手势识别APP,源代码已经开源在github上,地址:Chinese-number-gestures-recognition),要把在PC端训练好的模型放到Android APP上,调研了下,谷歌发布了TensorFlow Lite可以把TensorFlow训练好的模型迁移到Android APP上,百度也发布了移动端深度学习框架mobile-deep-learning(MDL),这个框架应该是paddlepaddle的手机版,具体的细节没有了解过。因为对TensorFlow稍微熟悉些,因此就决定用TensorFlow来做。
关于在PC端如何处理数据及训练模型,请参见博客:一步步做一个数字手势识别APP,代码已经开源在github上,上面有代码的说明和APP演示。这篇博客只介绍如何把TensorFlow训练好的模型迁移到Android Studio上进行APP的开发。
2. 模型训练注意事项
第一步,首先在pc端训练模型的时候要模型保存为.pb模型,在保存的时候有一点非常非常重要,就是你待会再Android studio是使用这个模型用到哪个参数,那么你在保存pb模型的时候就把给哪个参数一个名字,再保存。否则,你在Android studio中很难拿出这个参数,因为TensorFlow Lite的fetch()函数是根据保存在pb模型中的名字去寻找这个参数的。(如果你已经训练好了模型,并且没有给参数名字,且你不想再训练模型了,那么你可以尝试下面的方法去找到你需要使用的变量的默认名字,见下面的代码):
#输出保存的模型中参数名字及对应的值
with tf.gfile.GFile('model_50_200_c3//./digital_gesture.pb', "rb") as f: #读取模型数据graph_def = tf.GraphDef()graph_def.ParseFromString(f.read()) #得到模型中的计算图和数据
with tf.Graph().as_default() as graph: # 这里的Graph()要有括号,不然会报TypeErrortf.import_graph_def(graph_def, name="") #导入模型中的图到现在这个新的计算图中,不指定名字的话默认是 importfor op in graph.get_operations(): # 打印出图中的节点信息print(op.name, op.values())
这段代码打出的变量的名字以及对应的值。
言归正传,通常情况该你应该保存参数的时候都给参数一个指定的名字,如下面这样(通过name参数给变量指定名字),关于训练CNN的完整代码请参见下一篇博客或者github:
X = tf.placeholder(tf.float32, [None, 64, 64, 3], name="input_x")
y = tf.placeholder(tf.float32, [None, 11], name="input_y")
kp = tf.placeholder_with_default(1.0, shape=(), name="keep_prob")
lam = tf.placeholder(tf.float32, name="lamda")
#中间略过若干代码
z_fc2 = tf.add(tf.matmul(z_fc1_drop, W_fc2),b_fc2, name="outlayer")
prob = tf.nn.softmax(z_fc2, name="probability")
pred = tf.argmax(prob, 1, output_type="int32", name="predict")
3. 在Android Studio中配置
第二步,开始把pb模型移植到Android Studio上,网上绝大部分资料都是说用bazel重新编译模型生成依赖,这种方法难度太大。其实没必须这样做,TensorFlow Lite官方的例子中已经给我们展示了,我们其实只需要两个文件:libandroid_tensorflow_inference_java.jar 和 libtensorflow_inference.so。这两个文件我已经放到github上了,大家可以自行下载使用,下载地址:libandroid_tensorflow_inference_java.jar、libtensorflow_inference.so。
注:检神说,直接用aar依赖也可以,这个我没试过。。有兴趣的可以试一下。
准备工作已经完毕,下面正式开始Android Studio中的配置。
首先把训练好的pb模型放到Android项目中app/src/main/assets下,若不存在assets目录,则自己新建一个。如图所示:
其次,把刚刚下载的 libandroid_tensorflow_inference_java.jar 文件放到 app/libs 目下,把libtensorflow_inference.so 放到 app/libs/armeabi-v7a 目录下,如下图所示:
然后在app/build.gradle里进行如下配置:
在defaultConfig里添加
multiDexEnabled truendk {abiFilters "armeabi-v7a"}
在android里添加
sourceSets {main {jni.srcDirs = []jniLibs.srcDirs = ['libs']}}
如图所示:
在dependencies中添加libandroid_tensorflow_inference_java.jar,即:
implementation files('libs/libandroid_tensorflow_inference_java.jar')
如图所示:
至此,所有配置已经完成,下面是模型调用。
4. 在Android Studio中调用模型
在要用到模型的地方,首先要加载libtensorflow_inference.so库和初始化TensorFlowInferenceInterface对象,代码为:
TensorFlowInferenceInterface inferenceInterface;static {//加载libtensorflow_inference.so库文件System.loadLibrary("tensorflow_inference");Log.e("tensorflow","libtensorflow_inference.so库加载成功");}Classifier(AssetManager assetManager, String modePath) {//初始化TensorFlowInferenceInterface对象inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);Log.e("tf","TensoFlow模型文件加载成功");}
如图所示:
下面来多看一点东西,看看TensorFlow Lite里提供了哪几个接口,官网地址:Here’s what a typical Inference Library sequence looks like on Android.
// Load the model from disk.
TensorFlowInferenceInterface inferenceInterface =
new TensorFlowInferenceInterface(assetManager, modelFilename);// Copy the input data into TensorFlow.
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);// Run the inference call.
inferenceInterface.run(outputNames, logStats);// Copy the output Tensor back into the output array.
inferenceInterface.fetch(outputName, outputs);
下面就可以愉快地使用模型了。放一段我调用模型的代码,以供大家参考:
public ArrayList predict(Bitmap bitmap){ArrayList<String> list = new ArrayList<>();float[] inputdata = getPixels(bitmap);for(int i = 0; i <30; ++i){Log.d("matrix",inputdata[i] + "");}inferenceInterface.feed(inputName, inputdata, 1, IMAGE_SIZE, IMAGE_SIZE, 3);//运行模型,run的参数必须是String[]类型String[] outputNames = new String[]{outputName,probabilityName,outlayerName};inferenceInterface.run(outputNames);//获取结果int[] labels = new int[1];inferenceInterface.fetch(outputName,labels);int label = labels[0];float[] prob = new float[11];inferenceInterface.fetch(probabilityName, prob);
// float[] outlayer = new float[11];
// inferenceInterface.fetch(outlayerName, outlayer);// for(int i = 0; i <11; ++i)
// {// Log.d("matrix",outlayer[i] + "");
// }for(int i = 0; i <11; ++i){Log.d("matrix",prob[i] + "");}DecimalFormat df = new DecimalFormat("0.000000");float label_prob = prob[label];//返回值list.add(Integer.toString(label));list.add(df.format(label_prob));return list;}
最后放一张做的数字手势识别APP的效果,全部代码,将会开源在github上,欢迎star。
再放一张碰运气的识别结果:
将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)相关推荐
- 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)
将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...
- 飞桨上线万能转换小工具,教你玩转TensorFlow、Caffe等模型迁移
百度推出飞桨(PaddlePaddle)后,不少开发者开始转向国内的深度学习框架.但是从代码的转移谈何容易,之前的工作重写一遍不太现实,成千上万行代码的手工转换等于是在做一次二次开发. 现在,有个好消 ...
- 基于TensorFlow训练花朵识别模型的源码和Demo
基于TensorFlow训练花朵识别模型的源码和Demo 转发来源: https://blog.csdn.net/Anymake_ren/article/details/80550684 下面就通过对 ...
- 使用PaddleFluid和TensorFlow训练序列标注模型
专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...
- java加载tensorflow训练的PB模型记录
java加载tensorflow训练的PB模型记录 python训练 1. 模型的输入输出定义 2. 训练时保存模型的方法 java加载模型 1.maven依赖 2. Java代码实例 tensor注 ...
- 如何用java语言调用tensorflow训练好的模型
1.TensorFlow的训练模型在Android和Java的应用及调用 2.tensorflow的python离线训练java在线预测方案 3.tensorflow训练的模型在java中的使用 4. ...
- 用TensorFlow训练第一个模型
简述 下面有非常详细的代码注释 学习自莫凡大神给的demo https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/2-2 ...
- java调用tensorflow训练好的模型
1. python的处理 整个模型的源码在此:https://github.com/shelleyHLX/tensorflow_java 多谢star 首先训练一个模型,代码如下 import ten ...
- Tensorflow object detection API 搭建自己的目标检测模型并迁移到Android上
参考链接:https://blog.csdn.net/dy_guox/article/details/79111949 之前参考上述一系列博客在Windows10下面成功运行了TensorFlow A ...
最新文章
- Java学习笔记(十)--控制台输入输出
- mysql 使用不同引擎_mysql 不同引擎的比较
- 怎样用java写一个简单的文件复制程序
- 什么是 SAP Spartacus 里的 module augmentation
- java提高篇之详解内部类
- 现代软件工程课件 需求分析 如何提出靠谱的项目建议 NABCD
- 使用Spire.Barcode程序库生成二维码
- 码表的理解(ASCII,GBK,Unicode,UTF-8等)。
- 强制更新LYNC客户端的地址簿
- Camshift原理
- 【论文笔记】激光里程计网络 LO-Net:Deep Real-time Lidar Odometry2019
- PDF怎么编辑修改,如何编辑PDF文字内容
- HTML页面跳转的5种方法。
- PixiJS学习(5)几何图形
- CodeForces 19E 仙女fairy
- 女孩做妻子前应知道的10件事
- Tomcat 乱码问题解决方法
- 根轨迹超前校正matlab,[自动化] 基于根轨迹法的超前校正
- 基础的http协议构成
- 数学诺贝尔奖2008阿贝尔奖揭晓
热门文章
- iOS开发 3D-touch使用
- 阿里天池 Python 训练营1
- 09考研!艰难与希望?
- Docker 企业级实战青铜段位-崔健敏-专题视频课程
- 简述UITableView的属性和用法
- python调用excel的宏_在 Excel 中使用 Python 开发宏脚本
- RabbitMQ教程大全看这一篇就够了-java版本
- 计算机科学 教育部评估,教育部全国第四轮学科评估结果(A+、A类学校)汇总...
- Java生成树状结构返回结果
- 蛙泳、自由泳、仰泳、蝶泳,图解动画,教你游泳,不会游的看了包你学会!!!