Tensorflow2.0(Keras)转换TFlite
Tensorflow 2.0(Keras)转换TFlite
目录
Tensorflow 2.0(Keras)转换TFlite
1. TensorFlow Lite 指南
(1)TensorFlow Lite 转换器
(2)转换量化模型
(3)兼容的算子:Compatible operations
2. 转换脚本(Python版本)
1. TensorFlow Lite 指南
(1)TensorFlow Lite 转换器
https://tensorflow.google.cn/lite/guide/ops_select
为了能够转换包含 TensorFlow 运算符的 TensorFlow Lite 模型,可使用位于 TensorFlow Lite 转换器 中的 target_spec.supported_ops
参数。target_spec.supported_ops
的可选值如下:
TFLITE_BUILTINS
- 使用 TensorFlow Lite 内置运算符转换模型。SELECT_TF_OPS
- 使用 TensorFlow 运算符转换模型。已经支持的 TensorFlow 运算符的完整列表可以在白名单lite/delegates/flex/whitelisted_flex_ops.cc
中查看。
注意:target_spec.supported_ops
是之前 Python API 中的 target_ops
。
我们优先推荐使用 TFLITE_BUILTINS
转换模型,然后是同时使用 TFLITE_BUILTINS,SELECT_TF_OPS
,最后是只使用 SELECT_TF_OPS
。同时使用两个选项(也就是 TFLITE_BUILTINS,SELECT_TF_OPS
)会用 TensorFlow Lite 内置的运算符去转换支持的运算符。有些 TensorFlow 运算符 TensorFlow Lite 只支持部分用法,这时可以使用 SELECT_TF_OPS
选项来避免这种局限性。
(2)转换量化模型
训练后:针对特定 CPU 型号的量化模型
创建小模型的最简单方法是在推理期间将权重量化为 8 位并“在运行中”量化输入/激活。这具有延迟优势,但优先考虑减小尺寸。
在转换期间,将 optimizations 标志设置为针对大小进行优化:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
训练过程中:仅用于整数执行的量化模型
仅用于整数执行的量化模型获得具有更低延迟,更小尺寸和仅针对整数加速器兼容模型的模型。目前,这需要训练具有"假量化"节点的模型 。
转换图表:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
input_arrays = converter.get_input_arrays()
converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
对于全整数模型,输入为 uint8。mean 和 std_dev values 指定在训练模型时这些 UINT8 的值是如何值映射到输入的浮点值。
mean 是 0 到 255 之间的整数值,映射到浮点数 0.0f。std_dev = 255 /(float_max - float_min)
对于大多数用户,我们建议使用训练后量化。我们正在研究用于后期训练和训练量化的新工具,我们希望这将简化生成量化模型。
(3)兼容的算子:Compatible operations
https://tensorflow.google.cn/lite/guide/ops_compatibility
2. 转换脚本(Python版本)
以下脚本实现将Tensorflow2.0(Keras)保存的模型(建议保存为*.h5的格式),转换TFlite模型
建议版本信息:
tensorboard==2.0.2
tensorflow-estimator==2.0.1
tensorflow-gpu==2.0.0
# -*- coding: utf-8 -*-
"""
# --------------------------------------------------------
# @Author : panjq
# @E-mail : pan_jinquan@163.com
# @Date : 2020-02-05 11:01:49
# --------------------------------------------------------
"""
import os
import numpy as np
import glob
import cv2
import argparse
import tensorflow as tfprint("TF version:{}".format(tf.__version__))
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)def converer_keras_to_tflite_v1(keras_path, outputs_layer=None, out_tflite=None):""":param keras_path: keras *.h5 files:param outputs_layer:param out_tflite: output *.tflite file:return:"""model_dir = os.path.dirname(keras_path)model_name = os.path.basename(keras_path)[:-len(".h5")]# 加载keras模型, 结构打印model_keras = tf.keras.models.load_model(keras_path)print(model_keras.summary())# 从keras模型中提取fc1层, 需先保存成新keras模型, 再转换成tflitemodel_embedding = tf.keras.models.Model(inputs=model_keras.input,outputs=model_keras.get_layer(outputs_layer).output)print(model_embedding.summary())keras_file = os.path.join(model_dir, "{}_{}.h5".format(model_name, outputs_layer))tf.keras.models.Model.save(model_embedding, keras_file)# converter = tf.lite.TocoConverter.from_keras_model_file(keras_file)converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) # tf1.3# converter = tf.lite.TFLiteConverter.from_keras_model(model_keras) # tf2.0tflite_model = converter.convert()if not out_tflite:out_tflite = os.path.join(model_dir, "{}_{}.tflite".format(model_name, outputs_layer))open(out_tflite, "wb").write(tflite_model)print("successfully convert to tflite done")print("save model at: {}".format(out_tflite))def converer_keras_to_tflite_v2(keras_path, outputs_layer=None, out_tflite=None, optimize=False, quantization=False):""":param keras_path: keras *.h5 files:param outputs_layer: default last layer:param out_tflite: output *.tflite file:param optimize:return:"""if not os.path.exists(keras_path):raise Exception("Error:{}".format(keras_path))model_dir = os.path.dirname(keras_path)model_name = os.path.basename(keras_path)[:-len(".h5")]# 加载keras模型, 结构打印# model = tf.keras.models.load_model(keras_path)model = tf.keras.models.load_model(model_path, custom_objects={'tf': tf}, compile=False)print(model.summary())if outputs_layer:# 从keras模型中提取层,转换成tflitemodel = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer(outputs_layer).output)# outputs = [model.output["bbox"],model.output["scores"]]# model = tf.keras.models.Model(inputs=model.input, outputs=outputs)print(model.summary())# converter = tf.lite.TocoConverter.from_keras_model_file(keras_file)# converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) # tf1.3converter = tf.lite.TFLiteConverter.from_keras_model(model) # tf2.0prefix = [model_name, outputs_layer]# converter.allow_custom_ops = False# converter.experimental_new_converter = True""""https://tensorflow.google.cn/lite/guide/ops_select我们优先推荐使用 TFLITE_BUILTINS 转换模型,然后是同时使用 TFLITE_BUILTINS,SELECT_TF_OPS ,最后是只使用 SELECT_TF_OPS。同时使用两个选项(也就是 TFLITE_BUILTINS,SELECT_TF_OPS)会用 TensorFlow Lite 内置的运算符去转换支持的运算符。有些 TensorFlow 运算符 TensorFlow Lite 只支持部分用法,这时可以使用 SELECT_TF_OPS 选项来避免这种局限性。"""# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,# tf.lite.OpsSet.SELECT_TF_OPS]# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]if optimize:print("weight quantization")# Enforce full integer quantization for all ops and use int input/outputconverter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]prefix += ["optimize"]else:converter.optimizations = [tf.lite.Optimize.DEFAULT]if quantization == "int8":converter.representative_dataset = representative_dataset_gen# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]# converter.inference_input_type = tf.int8 # or tf.uint8# converter.inference_output_type = tf.int8 # or tf.uint8converter.target_spec.supported_types = [tf.int8]elif quantization == "float16":converter.target_spec.supported_types = [tf.float16]prefix += [quantization]if not out_tflite:prefix = [str(n) for n in prefix if n]prefix = "_".join(prefix)out_tflite = os.path.join(model_dir, "{}.tflite".format(prefix))tflite_model = converter.convert()open(out_tflite, "wb").write(tflite_model)print("successfully convert to tflite done")print("save model at: {}".format(out_tflite))def representative_dataset_gen():"""# 生成代表性数据集:return:"""image_dir = '/home/dm/panjinquan3/FaceDetector/tf-yolov3-detection/data/finger_images/'input_size = [320, 320]imgSet = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]for img_path in imgSet:orig_image = cv2.imread(img_path)rgb_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)image_tensor = cv2.resize(rgb_image, dsize=tuple(input_size))image_tensor = np.asarray(image_tensor / 255.0, dtype=np.float32)image_tensor = image_tensor[np.newaxis, :]yield [image_tensor]def parse_args():# weights_path = "../yolov3-micro.h5"weights_path = "../yolov3-micro_freeze_head.h5"parser = argparse.ArgumentParser()parser.add_argument("-c", "--model_path", help="model_path", default=weights_path, type=str)parser.add_argument("--outputs_layer", help="outputs_layer", default=None, type=str)parser.add_argument("-o", "--out_tflite", help="out tflite model path", default=None, type=str)parser.add_argument("-opt", "--optimize", help="optimize model", default=True, type=bool)parser.add_argument("-q", "--quantization", help="quantization model", default="float16", type=str)return parser.parse_args()def get_model(model_path):if not model_path:model_path = os.path.join(os.getcwd(), "*.h5")model_list = glob.glob(model_path)else:model_list = [model_path]return model_listdef unsupport_tflite_op():"""tf.shape,tf.Sizeerror: 'tf.Size' op is neither a custom op nor a flex operror: 'tf.Softmax' op is neither a custom op nor a flex op===========================================================tf.Softmax-->tf.nn.softmax"""passif __name__ == '__main__':args = parse_args()# outputs_layer = "fc1"outputs_layer = args.outputs_layermodel_list = get_model(args.model_path)out_tflite = args.out_tfliteoptimize = args.optimizequantization = args.quantizationfor model_path in model_list:converer_keras_to_tflite_v2(model_path, outputs_layer, out_tflite=out_tflite, optimize=optimize,quantization=quantization)# converer_keras_to_tflite_v1(keras_model, outputs_layer, out_tflite=None)
3. 常见问题解决方法
'tf.ResizeNearestNeighbor' op is neither a custom op nor a flex op
<unknown>:0: error: failed while converting: 'main': Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): ResizeNearestNeighbor.
解决方法: https://panjinquan.blog.csdn.net/article/details/107360700
Quantization not yet supported for op: LEAKY_RELU
Tensorflow2.0(Keras)转换TFlite相关推荐
- TensorFlow2.0: keras.metrics的使用
keras.metrics中有两个api函数可以简化准确率acc和损失值loss的计算.其分别是metrics.Accuracy( )和metrics.Mean( ). 一.建立测量尺 #建立测量尺 ...
- TensorFlow2.0 Keras多层感知器模型imdb情感分类
# 下载 import urllib.request import os import tarfileurl = 'http://ai.stanford.edu/~amaas/data/sentime ...
- tensorflow2.0 Keras VGG16 VGG19 系列 代码实现
模型介绍参看:博文 VGG16 迁移模型 先看看标准答案 import tensorflow as tf from tensorflow import kerasbase_model = keras. ...
- 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战
基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕 前言 运行python环境 ...
- 基于TensorFlow2.0的摄像头数字识别
import numpy as np import cv2 from skimage import data, segmentation, measure, morphology, color imp ...
- 【深度学习】(6) tensorflow2.0使用keras高层API
各位同学好,今天和大家分享一下TensorFlow2.0深度学习中借助keras的接口减少神经网络代码量.主要内容有: 1. metrics指标:2. compile 模型配置:3. fit 模型训练 ...
- Tensorflow2.0学习(五) — Keras基础应用(IMDb电影集情感分析)
今天这一节内容是关于Keras应用分析的最后一节,在熟悉了Keras的基础知识之后,下面几节我们就可以正式接触Tensorflow2.0.根据博主多处查阅,最终还是发现Tensorflow的官方教程好 ...
- 【TensorFlow2.0】以后我们再也离不开Keras了?
TensorFlow2.0 Alpha版已经发布,在2.0中最重要的API或者说到处都出现的API是谁,那无疑是Keras.因此用过2.0的人都会吐槽全世界都是Keras.今天我们就来说说Keras这 ...
- 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization
<<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...
最新文章
- R语言sunburst图(sunburst plot)可视化实战:使用sunburstR包和ggplot2包进行可视化
- controller与servlet区别
- 抗侧力构件弹性位移如何计算_穿心棒法盖梁施工计算书(工字钢)
- 【知识强化】第二章 进程管理 2.1 进程与线程
- K8S 基于NFS实现文件集群间共享
- aspectj 获取方法入参_深入探索编译插桩技术(二、AspectJ)
- 2020牛客国庆集训派对day2 F题 Java大数处理
- 前端学习(1269):axios的拦截器
- 使用栈将递归函数转化为非递归函数_栈(Stack)及其应用-Python实现
- 关于svn、git生成版本号脚本的改进
- 微软设计套装 Expression Studio 4 (Ultimate+Web Pro+Encoder Pro) 最新版下载
- 控件必须放在具有 runat=server 的窗体标记内错误的解决方法
- CentOS云主机安全之新增ssh登录账户、禁止ROOT登陆
- Java实现自己想要的代码生成器!
- 2021年下半年《信息系统项目管理师》真题
- 零食店投资?市场成本风险分析
- 产业安全专家谈丨数字经济高速发展,数据要素安全该如何保障?
- 网传三星手机大半夜黑屏乱码,原因竟然是闰四月?
- 中国AI觉醒 阿里王坚:云智能将成为大趋势
- WPF学习笔记16 BookDemo 2