创建Tensorflow的模型

在Android平台受到设备的限制,本身并不能训练模型,因此需要使用已有的模型。
在本文中将介绍如何将Tensorflow的模型转换成tflite模型,为Android设备可以使用。

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
import matplotlib.pyplot as plt# 读取训练用的输入特征和标签
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()x_train =x_train.reshape(x_train.shape[0],28,28,1).astype("float32")
x_test  = x_test.reshape(x_test.shape[0],28,28,1).astype("float32")# 输入特征归一化,减小计算量,将图片默认为0~255之间的数字,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0
class_name = ["T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子"]class FashionModel_CNN(tf.keras.Model):"""定义CNN网络结构"""def __init__(self):super().__init__()self.conv1 = tf.keras.layers.Conv2D(filters=32, kernel_size=[5,5], padding='valid',input_shape=(28,28,1),activation=tf.nn.relu)self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2,2],strides = 2)self.conv2 = tf.keras.layers.Conv2D(filters=64,kernel_size=[5,5],padding="same",activation = tf.nn.relu)self.pool2 = tf.keras.layers.MaxPool2D(pool_size = [2,2],strides = 2)self.flatten = Flatten() #tf.keras.layers.Reshape(target_shape=(28*28*64,))self.dense1 = tf.keras.layers.Dense(units = 128,activation = tf.nn.relu)self.dense2 = tf.keras.layers.Dense(units=10,activation = "softmax")def call(self,inputs):x = self.conv1(inputs)x = self.pool1(x)x = self.conv2(x)x = self.pool2(x)x = self.flatten(x)x = self.dense1(x)x = self.dense2(x)output = tf.nn.softmax(x)return outputdef generate_nn(x_train,y_train,x_test,y_test):# 声明神经网络对象model = FashionModel_CNN()# 配置训练方法(优化器,损失函数,评测指标)model.compile(optimizer=tf.keras.optimizers.Adam(),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=[tf.keras.metrics.sparse_categorical_accuracy])# 执行训练过程model.fit(x_train, y_train,batch_size=32, epochs=10,validation_data=(x_test, y_test),validation_freq=1)# 打印网络结构和参数model.summary()return modelmodel = generate_nn(x_train,y_train,x_test,y_test)
tf.saved_model.save(model, "./forAndroid")

在Python应用tensorflow的模型并测试

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Modelfrom tensorflow.python.keras.preprocessing.image import load_img
import matplotlib.pyplot as plt
#加载从指定的目录中加载模型
model = tf.saved_model.load("./forAndroid")
# 读取训练用的输入特征和标签
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()x_train = x_train.reshape(x_train.shape[0],28,28,1).astype("float32")
x_test  = x_test.reshape(x_test.shape[0],28,28,1).astype("float32")# 输入特征归一化,减小计算量,将图片默认为0~255之间的数字,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0
class_name = ["T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子"]
#测试数据集的路径
path = "./test_data/exam_fashion/exam_fashion/"
images = ["%d.jpeg"%i for i in range(0,10)]
matrix = np.full(784,255.0).reshape(28,28)
images_data=[]for imgfile in images:img_name = "%s%s"%(path,imgfile)print(img_name)img = tf.keras.preprocessing.image.load_img(img_name,color_mode="grayscale",target_size=(28,28))img_data = tf.keras.preprocessing.image.img_to_array(img)img_data = matrix-img_data.reshape(28,28) #将numpy数组的数据从float64转换成float32images_data.append(img_data.astype('float32'))
#测试样本
images_data = np.array(images_data).reshape(10,28,28,1)
print(x_test.dtype,images_data.dtype)
y_pred = model(images_data)
index_list = np.argmax(y_pred,axis=1)
img_id = 0
index = 1for i in index_list:img_id +=1print(index,i,class_name[i])index = index+1

利用tflite_convert将tensorflow模型导出tflite

在控制台中执行:

tflite_convert __saved_model_dir=forAndroid __output_file=model.tflite

执行出现错误,错误如下:

2021-11-25 09:19:56.975579: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
2021-11-25 09:19:56.979582: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: DESKTOP-4CLUK38
2021-11-25 09:19:56.979752: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: DESKTOP-4CLUK38
2021-11-25 09:19:56.991993: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Traceback (most recent call last):File "d:\anaconda3\envs\tensorflow\lib\runpy.py", line 194, in _run_module_as_mainreturn _run_code(code, main_globals, None,File "d:\anaconda3\envs\tensorflow\lib\runpy.py", line 87, in _run_codeexec(code, run_globals)File "D:\anaconda3\envs\tensorflow\Scripts\tflite_convert.exe\__main__.py", line 7, in <module>File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\lite\python\tflite_convert.py", line 697, in mainapp.run(main=run_main, argv=sys.argv[:1])File "d:\anaconda3\envs\tensorflow\lib\site-packages\absl\app.py", line 303, in run_run_main(main, args)File "d:\anaconda3\envs\tensorflow\lib\site-packages\absl\app.py", line 251, in _run_mainsys.exit(main(argv))File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\lite\python\tflite_convert.py", line 680, in run_main_convert_tf2_model(tflite_flags)File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\lite\python\tflite_convert.py", line 281, in _convert_tf2_modelconverter = lite.TFLiteConverterV2.from_saved_model(File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\lite\python\lite.py", line 1348, in from_saved_modelsaved_model = _load(saved_model_dir, tags)File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 864, in loadresult = load_internal(export_dir, tags, options)["root"]File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 902, in load_internalloader = loader_cls(object_graph_proto, saved_model_proto, export_dir,File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 162, in __init__self._load_all()File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 259, in _load_allself._load_nodes()File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 448, in _load_nodesslot_variable = optimizer_object.add_slot(
AttributeError: '_UserObject' object has no attribute 'add_slot'

根据错误提示要求rebuild Tensorflow
在官方提供的解决方法是使用bazel对tflite_convert编译
形如:
bazel run tflite_convert saved_model_dir=目录 __output_file=目标.tflite
因为在windows10安装bazel的代价太大
因此调整解决思路

解决的方法

于是转而使用运行代码来实现转换,希望得到更多的信息
编辑如下代码:

import tensorflow as tf
#指定保存模型的目录
saved_model_dir = "forAndroid"
# 转换模型
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
#转换模型
tflite_model = converter.convert()
# 将转换的模型保持到指定的文件model.tflite
with open('model.tflite', 'wb') as f:f.write(tflite_model)

出现的问题

运行转换程序出现错误,错误的内容如下:
提示信息出现错误的日志

2021-11-25 09:28:46.532142: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
2021-11-25 09:28:46.535042: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: DESKTOP-4CLUK38
2021-11-25 09:28:46.535137: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: DESKTOP-4CLUK38
2021-11-25 09:28:46.535300: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
......

日志的错误信息量很大,增加如下内容

import tensorflow as tf
import os
#指定日志的级别,设置为错误和警告信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
#指定保存模型的目录
saved_model_dir = "forAndroid"
# 转换模型
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
#转换模型
tflite_model = converter.convert()
# 将转换的模型保持到指定的文件model.tflite
with open('model.tflite', 'wb') as f:f.write(tflite_model)

忽视掉基本提示的警告,发现已经成功生成tflite文件。

关于将Tesorflow的SavedModel模型转换成tflite模型相关推荐

  1. darknet cpp weights模型转换成ONNX模型

    整理不易,如果觉得有用,记得点赞收藏和分享哦 1. 下载转换需要的代码文件 在下面地址下载代码文件 https://gitee.com/liangjiaxi2019/pytorch-YOLOv4 2. ...

  2. 【地平线开发板 模型转换】将pytorch生成的onnx模型转换成.bin模型

    文章目录 1 获取onnx模型 2 启动docker容器 3 onnx模型检查 3.1 为什么要检查? 3.2 如何操作 4 图像数据预处理 4.1 一些问题的思考 4.2 图片挑选与放置 4.2 使 ...

  3. 将Hugging Face模型转换成LibTorch模型

    Hugging Face的模型 以waifu-diffusion模型为例,给出的实现一般是基于diffuser库,示例代码如下: import torch from torch import auto ...

  4. assimp批量转模型_IGS模型批量转换成STL模型

    背景:做配载仪时需要批量的将IGS模型转换成STL模型 用到的方法:方法1:使用3D-tool或者其他工具进行转换 方法2:通过python使用FreeCAD的接口批量转换(当时师弟写的程序) Fre ...

  5. er图转换成关系模型的例题_有关数据库系统的练习题 E-R图的关系画图转换,,急需 谢谢了...

    展开全部 你看62616964757a686964616fe58685e5aeb931333332643239下下边的例子,你的问题就可以解决了. 设某商业集团数据库中有三个实体集.一是"商 ...

  6. Unity将内部模型转换成stl格式模型,用于3D打印机进行打印

    本章我们一起来看下怎样将unity中的fbx模型转成stl模型并且保存到本地. 原理:stl模型都是由三角面组成的,只要我们了解stl文件的格式,就能够轻松的将fbx模型转换成stl. 1.先获取到f ...

  7. 现代控制理论的matlab上机实验 将状态空间模型转换成传递函数模型(便于求各种响应)

    现代控制理论 用matlab将状态空间模型转换成传递函数模型(便于求各种响应) 例:matlab程序如下 A=[-21,19,-20;19,-21,20;40,-40,-40]; B=[0,1,2]' ...

  8. sketchup 图片转模型_不用CAD描图迅速将图片转换成su模型

    原标题:不用CAD描图迅速将图片转换成su模型 首先上图,说明一下大概的流程: 1.[ps阶段]首先选取一副木雕图片,背景为单色最佳.如果不是单色,则需要ps一下,将背景转换成白色,这个过程使用魔术棒 ...

  9. maya多边形建模怎样做曲面_maya将曲面模型转换成多边形模型

    在maya中,曲面建模跟多边形建模各有各的优势,但现在建模一般是以多边形建模为主.很多模型用曲面做很简单,用多边形做就会显得有点复杂,那么如果能够将曲面模型转,换成多边形模型就好了. 1.双击打开ma ...

  10. 将传统的照片在Autodesk® 123D™ Catch中转换成3D模型

    将传统的照片在Autodesk® 123D™ Catch中转换成3D模型 http://www.douban.com/group/Catch123D/ Autodesk 123D Catch 如何进行 ...

最新文章

  1. iOS实现动态区域裁剪图片
  2. 初识react-native
  3. Three.js中使用材质覆盖属性
  4. hive数据库numeric_hive中常用的函数
  5. 全景摄像技术大有可为
  6. 织梦dede 5.7系统基本参数无法修改保存,提示Token mismatch!
  7. jenkins配置用户权限
  8. MX250和MX350哪个好一点,区别和差距在哪里?
  9. 邓俊辉数据结构学习-7-BST
  10. 数学建模优化模型简单例题_数学建模案例分析--最优化方法建模7习题六
  11. 计算机四级软考数据库系统工程师教材
  12. pyspark中where条件使用,单一匹配及多条件匹配
  13. 【Proteus仿真】6位数码管秒计数器(0-999999S)
  14. 2015最好用的PHP开源建站系统
  15. 《涨知识啦34》-LED器件的I-V特性曲线
  16. mac系统通过ADB与scrcpy实现手机投屏
  17. 谷歌浏览器密码导入导出
  18. 学编程有什么用?零基础小白可以学吗?
  19. wms仓库管理系统的订单处理及流程
  20. 桂林银行携手华为,做“好山水”里的“好银行”

热门文章

  1. IGS提供的数据(转)
  2. SCAU高级语言程序设计--实验9 函数的应用(1)
  3. python断言语句_Python断言assert的用法代码解析
  4. android 谷歌室内定位,打造室内导航 谷歌发布WifiRttScan App测试室内定位
  5. attiny85(digispark)零延迟启动探究
  6. opencv 特征提取 -SIFT
  7. 【5G】5G中的CU和DU是什么?
  8. 编程入门——计算机硬件介绍
  9. Windows照片查看器无法打开此图片怎么办
  10. VS编译间接引用的DLL不一定输出