Tensorflow 2.0(Keras)转换TFlite

目录

Tensorflow 2.0(Keras)转换TFlite

1. TensorFlow Lite 指南

(1)TensorFlow Lite 转换器

(2)转换量化模型

(3)兼容的算子:Compatible operations

2. 转换脚本(Python版本)


1. TensorFlow Lite 指南

(1)TensorFlow Lite 转换器

https://tensorflow.google.cn/lite/guide/ops_select

为了能够转换包含 TensorFlow 运算符的 TensorFlow Lite 模型,可使用位于 TensorFlow Lite 转换器 中的 target_spec.supported_ops 参数。target_spec.supported_ops 的可选值如下:

  • TFLITE_BUILTINS - 使用 TensorFlow Lite 内置运算符转换模型。
  • SELECT_TF_OPS - 使用 TensorFlow 运算符转换模型。已经支持的 TensorFlow 运算符的完整列表可以在白名单lite/delegates/flex/whitelisted_flex_ops.cc 中查看。

注意:target_spec.supported_ops 是之前 Python API 中的 target_ops

我们优先推荐使用 TFLITE_BUILTINS 转换模型,然后是同时使用 TFLITE_BUILTINS,SELECT_TF_OPS ,最后是只使用 SELECT_TF_OPS。同时使用两个选项(也就是 TFLITE_BUILTINS,SELECT_TF_OPS)会用 TensorFlow Lite 内置的运算符去转换支持的运算符。有些 TensorFlow 运算符 TensorFlow Lite 只支持部分用法,这时可以使用 SELECT_TF_OPS 选项来避免这种局限性。

(2)转换量化模型

训练后:针对特定 CPU 型号的量化模型

创建小模型的最简单方法是在推理期间将权重量化为 8 位并“在运行中”量化输入/激活。这具有延迟优势,但优先考虑减小尺寸。

在转换期间,将 optimizations 标志设置为针对大小进行优化:

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()

训练过程中:仅用于整数执行的量化模型

仅用于整数执行的量化模型获得具有更低延迟,更小尺寸和仅针对整数加速器兼容模型的模型。目前,这需要训练具有"假量化"节点的模型 。

转换图表:

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
input_arrays = converter.get_input_arrays()
converter.quantized_input_stats = {input_arrays[0] : (0., 1.)}  # mean, std_dev
tflite_model = converter.convert()

对于全整数模型,输入为 uint8。mean 和 std_dev values 指定在训练模型时这些 UINT8 的值是如何值映射到输入的浮点值。

mean 是 0 到 255 之间的整数值,映射到浮点数 0.0f。std_dev = 255 /(float_max - float_min)

对于大多数用户,我们建议使用训练后量化。我们正在研究用于后期训练和训练量化的新工具,我们希望这将简化生成量化模型。

(3)兼容的算子:Compatible operations

https://tensorflow.google.cn/lite/guide/ops_compatibility


2. 转换脚本(Python版本)

以下脚本实现将Tensorflow2.0(Keras)保存的模型(建议保存为*.h5的格式),转换TFlite模型

建议版本信息:

tensorboard==2.0.2                
tensorflow-estimator==2.0.1                   
tensorflow-gpu==2.0.0

# -*- coding: utf-8 -*-
"""
# --------------------------------------------------------
# @Author : panjq
# @E-mail : pan_jinquan@163.com
# @Date   : 2020-02-05 11:01:49
# --------------------------------------------------------
"""
import os
import numpy as np
import glob
import cv2
import argparse
import tensorflow as tfprint("TF version:{}".format(tf.__version__))
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)def converer_keras_to_tflite_v1(keras_path, outputs_layer=None, out_tflite=None):""":param keras_path: keras *.h5 files:param outputs_layer:param out_tflite: output *.tflite file:return:"""model_dir = os.path.dirname(keras_path)model_name = os.path.basename(keras_path)[:-len(".h5")]# 加载keras模型, 结构打印model_keras = tf.keras.models.load_model(keras_path)print(model_keras.summary())# 从keras模型中提取fc1层, 需先保存成新keras模型, 再转换成tflitemodel_embedding = tf.keras.models.Model(inputs=model_keras.input,outputs=model_keras.get_layer(outputs_layer).output)print(model_embedding.summary())keras_file = os.path.join(model_dir, "{}_{}.h5".format(model_name, outputs_layer))tf.keras.models.Model.save(model_embedding, keras_file)# converter = tf.lite.TocoConverter.from_keras_model_file(keras_file)converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)  # tf1.3# converter = tf.lite.TFLiteConverter.from_keras_model(model_keras)  # tf2.0tflite_model = converter.convert()if not out_tflite:out_tflite = os.path.join(model_dir, "{}_{}.tflite".format(model_name, outputs_layer))open(out_tflite, "wb").write(tflite_model)print("successfully convert to tflite done")print("save model at: {}".format(out_tflite))def converer_keras_to_tflite_v2(keras_path, outputs_layer=None, out_tflite=None, optimize=False, quantization=False):""":param keras_path: keras *.h5 files:param outputs_layer: default last layer:param out_tflite: output *.tflite file:param optimize:return:"""if not os.path.exists(keras_path):raise Exception("Error:{}".format(keras_path))model_dir = os.path.dirname(keras_path)model_name = os.path.basename(keras_path)[:-len(".h5")]# 加载keras模型, 结构打印# model = tf.keras.models.load_model(keras_path)model = tf.keras.models.load_model(model_path, custom_objects={'tf': tf}, compile=False)print(model.summary())if outputs_layer:# 从keras模型中提取层,转换成tflitemodel = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer(outputs_layer).output)# outputs = [model.output["bbox"],model.output["scores"]]# model = tf.keras.models.Model(inputs=model.input, outputs=outputs)print(model.summary())# converter = tf.lite.TocoConverter.from_keras_model_file(keras_file)# converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)  # tf1.3converter = tf.lite.TFLiteConverter.from_keras_model(model)  # tf2.0prefix = [model_name, outputs_layer]# converter.allow_custom_ops = False# converter.experimental_new_converter = True""""https://tensorflow.google.cn/lite/guide/ops_select我们优先推荐使用 TFLITE_BUILTINS 转换模型,然后是同时使用 TFLITE_BUILTINS,SELECT_TF_OPS ,最后是只使用 SELECT_TF_OPS。同时使用两个选项(也就是 TFLITE_BUILTINS,SELECT_TF_OPS)会用 TensorFlow Lite 内置的运算符去转换支持的运算符。有些 TensorFlow 运算符 TensorFlow Lite 只支持部分用法,这时可以使用 SELECT_TF_OPS 选项来避免这种局限性。"""# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,#                                        tf.lite.OpsSet.SELECT_TF_OPS]# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]if optimize:print("weight quantization")# Enforce full integer quantization for all ops and use int input/outputconverter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]prefix += ["optimize"]else:converter.optimizations = [tf.lite.Optimize.DEFAULT]if quantization == "int8":converter.representative_dataset = representative_dataset_gen# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]# converter.inference_input_type = tf.int8  # or tf.uint8# converter.inference_output_type = tf.int8  # or tf.uint8converter.target_spec.supported_types = [tf.int8]elif quantization == "float16":converter.target_spec.supported_types = [tf.float16]prefix += [quantization]if not out_tflite:prefix = [str(n) for n in prefix if n]prefix = "_".join(prefix)out_tflite = os.path.join(model_dir, "{}.tflite".format(prefix))tflite_model = converter.convert()open(out_tflite, "wb").write(tflite_model)print("successfully convert to tflite done")print("save model at: {}".format(out_tflite))def representative_dataset_gen():"""# 生成代表性数据集:return:"""image_dir = '/home/dm/panjinquan3/FaceDetector/tf-yolov3-detection/data/finger_images/'input_size = [320, 320]imgSet = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]for img_path in imgSet:orig_image = cv2.imread(img_path)rgb_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)image_tensor = cv2.resize(rgb_image, dsize=tuple(input_size))image_tensor = np.asarray(image_tensor / 255.0, dtype=np.float32)image_tensor = image_tensor[np.newaxis, :]yield [image_tensor]def parse_args():# weights_path = "../yolov3-micro.h5"weights_path = "../yolov3-micro_freeze_head.h5"parser = argparse.ArgumentParser()parser.add_argument("-c", "--model_path", help="model_path", default=weights_path, type=str)parser.add_argument("--outputs_layer", help="outputs_layer", default=None, type=str)parser.add_argument("-o", "--out_tflite", help="out tflite model path", default=None, type=str)parser.add_argument("-opt", "--optimize", help="optimize model", default=True, type=bool)parser.add_argument("-q", "--quantization", help="quantization model", default="float16", type=str)return parser.parse_args()def get_model(model_path):if not model_path:model_path = os.path.join(os.getcwd(), "*.h5")model_list = glob.glob(model_path)else:model_list = [model_path]return model_listdef unsupport_tflite_op():"""tf.shape,tf.Sizeerror: 'tf.Size' op is neither a custom op nor a flex operror: 'tf.Softmax' op is neither a custom op nor a flex op===========================================================tf.Softmax-->tf.nn.softmax"""passif __name__ == '__main__':args = parse_args()# outputs_layer = "fc1"outputs_layer = args.outputs_layermodel_list = get_model(args.model_path)out_tflite = args.out_tfliteoptimize = args.optimizequantization = args.quantizationfor model_path in model_list:converer_keras_to_tflite_v2(model_path, outputs_layer, out_tflite=out_tflite, optimize=optimize,quantization=quantization)# converer_keras_to_tflite_v1(keras_model, outputs_layer, out_tflite=None)

3. 常见问题解决方法


'tf.ResizeNearestNeighbor' op is neither a custom op nor a flex op
<unknown>:0: error: failed while converting: 'main': Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): ResizeNearestNeighbor.

解决方法: https://panjinquan.blog.csdn.net/article/details/107360700


Quantization not yet supported for op: LEAKY_RELU

Tensorflow2.0(Keras)转换TFlite相关推荐

  1. TensorFlow2.0: keras.metrics的使用

    keras.metrics中有两个api函数可以简化准确率acc和损失值loss的计算.其分别是metrics.Accuracy( )和metrics.Mean( ). 一.建立测量尺 #建立测量尺 ...

  2. TensorFlow2.0 Keras多层感知器模型imdb情感分类

    # 下载 import urllib.request import os import tarfileurl = 'http://ai.stanford.edu/~amaas/data/sentime ...

  3. tensorflow2.0 Keras VGG16 VGG19 系列 代码实现

    模型介绍参看:博文 VGG16 迁移模型 先看看标准答案 import tensorflow as tf from tensorflow import kerasbase_model = keras. ...

  4. 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战

    基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕 前言 运行python环境 ...

  5. 基于TensorFlow2.0的摄像头数字识别

    import numpy as np import cv2 from skimage import data, segmentation, measure, morphology, color imp ...

  6. 【深度学习】(6) tensorflow2.0使用keras高层API

    各位同学好,今天和大家分享一下TensorFlow2.0深度学习中借助keras的接口减少神经网络代码量.主要内容有: 1. metrics指标:2. compile 模型配置:3. fit 模型训练 ...

  7. Tensorflow2.0学习(五) — Keras基础应用(IMDb电影集情感分析)

    今天这一节内容是关于Keras应用分析的最后一节,在熟悉了Keras的基础知识之后,下面几节我们就可以正式接触Tensorflow2.0.根据博主多处查阅,最终还是发现Tensorflow的官方教程好 ...

  8. 【TensorFlow2.0】以后我们再也离不开Keras了?

    TensorFlow2.0 Alpha版已经发布,在2.0中最重要的API或者说到处都出现的API是谁,那无疑是Keras.因此用过2.0的人都会吐槽全世界都是Keras.今天我们就来说说Keras这 ...

  9. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization

    <<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...

最新文章

  1. R语言sunburst图(sunburst plot)可视化实战:使用sunburstR包和ggplot2包进行可视化
  2. controller与servlet区别
  3. 抗侧力构件弹性位移如何计算_穿心棒法盖梁施工计算书(工字钢)
  4. 【知识强化】第二章 进程管理 2.1 进程与线程
  5. K8S 基于NFS实现文件集群间共享
  6. aspectj 获取方法入参_深入探索编译插桩技术(二、AspectJ)
  7. 2020牛客国庆集训派对day2 F题 Java大数处理
  8. 前端学习(1269):axios的拦截器
  9. 使用栈将递归函数转化为非递归函数_栈(Stack)及其应用-Python实现
  10. 关于svn、git生成版本号脚本的改进
  11. 微软设计套装 Expression Studio 4 (Ultimate+Web Pro+Encoder Pro) 最新版下载
  12. 控件必须放在具有 runat=server 的窗体标记内错误的解决方法
  13. CentOS云主机安全之新增ssh登录账户、禁止ROOT登陆
  14. Java实现自己想要的代码生成器!
  15. 2021年下半年《信息系统项目管理师》真题
  16. 零食店投资?市场成本风险分析
  17. 产业安全专家谈丨数字经济高速发展,数据要素安全该如何保障?
  18. 网传三星手机大半夜黑屏乱码,原因竟然是闰四月?
  19. 中国AI觉醒 阿里王坚:云智能将成为大趋势
  20. WPF学习笔记16 BookDemo 2

热门文章

  1. NDVI等植被相关指数
  2. python购物车模块
  3. linux网络编程之二-----多播(组播)编程
  4. 提示用户输入一个正整数n,如果n=5,就输出下列图形,其他n值以此类推
  5. 2012百度之星冬季赛第二场第二题 消去游戏I
  6. 在C#里调用C++的dll时需要注意的一些问题转
  7. 如何恢复ORACLE数据(冷备份)
  8. 安卓工业平板电脑的蓝牙开发教程
  9. 解决git bash闪退问题
  10. DataGridView 用户输入时,单元格输入值的设定