tensorflow中sparse_placeholder在saved_model中保存pb模型的使用方法
# coding:utf-8# @author: liu # @file: sparse_tensor_pb.py # @time: 2020/4/3 11:00 # @desc:# coding:utf-8import tensorflow as tf import random import numpy as np import os import shutilprint(tf.__version__)def create_sparse(batch_size, dtype=np.int32):'''创建稀疏张量,ctc_loss中labels要求是稀疏张量,随机生成序列长度在150~180之间的labels'''indices = []values = []for i in range(batch_size):length = random.randint(150, 180)for j in range(length):indices.append((i, j))value = random.randint(0, 779)values.append(value)indices = np.asarray(indices, dtype=np.int64)values = np.asarray(values, dtype=dtype)shape = np.asarray([batch_size, np.asarray(indices).max(0)[1] + 1], dtype=np.int64) # [64,180]return [indices, values, shape]# 保存成pb模型def saved_model(sess: tf.Session, input: tf.sparse_placeholder , ss, cc, model_path):if os.path.exists(model_path):shutil.rmtree(model_path)builder = tf.saved_model.builder.SavedModelBuilder(model_path)# input_x = tf.saved_model.build_tensor_info(input)indices = tf.saved_model.build_tensor_info(input.indices)values = tf.saved_model.build_tensor_info(input.values)dense_shape = tf.saved_model.build_tensor_info(input.dense_shape)output_a = tf.saved_model.build_tensor_info(ss)output_b = tf.saved_model.build_tensor_info(cc)prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(inputs={"indices": indices, "values": values, "dense_shapes": dense_shape},outputs={"ss": output_a, "cc": output_b},method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={"predict": prediction_signature})builder.save()def load_model(model_path=None):sess = tf.Session()meta_graph = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)signature = meta_graph.signature_def# input_x = signature["predict"].inputs["input"].nameindices = signature["predict"].inputs["indices"].namevalues = signature["predict"].inputs["values"].namedense_shape = signature["predict"].inputs["dense_shapes"].nameoutput_a = signature["predict"].outputs["ss"].nameoutput_b = signature["predict"].outputs["cc"].name# return sess, input_x, output_a, output_breturn sess, indices, values, dense_shape, output_a, output_bdef train():a = tf.sparse_placeholder(tf.float32, name="in")values = a.valuesc = tf.sparse_to_dense(a.indices, a.dense_shape, a.values)s = tf.sparse.reduce_sum(a)indices_list = [[0, 1], [0, 4], [2, 3]]# print(indices_list)values_list = [1, 4, 5]dense_shape_list = [4, 5]with tf.Session() as sess:sess.run(tf.global_variables_initializer())ss, cc, va = sess.run([s, c, values], feed_dict={a: tf.SparseTensorValue(indices=indices_list, values=values_list, dense_shape=dense_shape_list)})print(ss)print("-----------")print(cc)# print("va", va)saved_model(sess, a, s, c, "model/1")def predict():indices_list = [[0, 1], [0, 4], [2, 3]]# print(indices_list)values_list = [1, 4, 5]dense_shape_list = [4, 5]# sess, input_x, output_a, output_b = load_model("model/1")sess, indices, values, dense_shape, output_a, output_b = load_model("model/1")with sess:# a, b = sess.run([output_a, output_b], {input_x: tf.SparseTensorValue(indices=indices_list, values=values_list, dense_shape=dense_shape_list)})a, b = sess.run([output_a, output_b], feed_dict={indices:indices_list, values:values_list, dense_shape:dense_shape_list})print("预测结果")print(a)print(b)train() predict()
tensorflow中sparse_placeholder在saved_model中保存pb模型的使用方法相关推荐
- Tensorflow训练神经网络保存*.pb模型及载入*.pb模型
1 神经网络结构 1.0 保存*.pb模型 import tensorflow as tf from tensorflow.python.framework import graph_util fro ...
- Word中截取部分内容并保存为jpg图片的方法
private void button1_Click(object sender, EventArgs e) { var appWord = new Microsoft.Office.Interop. ...
- TensorFlow模型保存pb或ckpt
Tensorflow的保存分为三种:1. checkpoint模式:2. pb模式:3. saved_model模式. https://www.zhihu.com/collection/6445044 ...
- h5模型转化为pb模型,代码及排坑
我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数. 附上h5_to_pb.py(pyt ...
- tensorflow打印模型图_从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)...
最近看到一个巨牛的人工智能教程,分享一下给大家.教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家.平时碎片时间可以当小说看,[点这里可以去膜拜一下大神的" ...
- TensorFlow中查看checkpoint文件中的变量名和对应值
在加载模型时, 需要知道checkpoint中变量名称,一下代码可以查看TensorFlow中checkpoint文件中的变量名: #!/usr/bin/env python # -*- coding ...
- Tensorflow 迁移学习 识别中国军网、中国军视网Logo水印
Tensorflow 目标检测项目 图片logo水印识别.识别中国军网.中国军视网Logo水印. image image Step 0 下载项目 git clone https://github.co ...
- linux tensorflow demo_独家 | 在浏览器中使用TensorFlow.js和Python构建机器学习模型(附代码)...
作者:MOHD SANAD ZAKI RIZVI 翻译:吴金笛 校对:丁楠雅 本文约5500字,建议阅读15分钟. 本文首先介绍了TensorFlow.js的重要性及其组件,并介绍使用其在浏览器中构建 ...
- python tensorflow web_如何在浏览器中使用TensorFlow?
原标题:如何在浏览器中使用TensorFlow? 虽然您可以借助TensorFlow用数量较少的训练数据来训练简单的神经网络,但对于拥有庞大训练数据集的深度神经网络而言,确实需要使用具有CUDA功能的 ...
最新文章
- class viewController has no initializers解决办法
- Kattis - bela
- 类路径是什么意思_多播是什么意思 多播介绍【详解】
- 美国杜克大学计算机专业世界排名,美国杜克大学世界排名情况怎么样?
- 精通spring——深入java ee开发核心技术 pdf_2019精通Spring Boot 42讲 高清pdf完整版
- 【POJ 1679 The Unique MST】最小生成树
- python实现大批量pdf格式论文的重命名与目录制作功能
- php苹果app微信支付 无法返回,微信支付,php_微信支付APP返回-1怎么解决,微信支付,php,移动app - phpStudy...
- MapReduce 之shuffle过程
- san mysql,高性能MySQL:SAN和NAS
- ubuntu 12.04 配置内核崩溃自动重启及转存
- 【人民币识别】基于matlab GUI人民币序列号识别【含Matlab源码 908期】
- 线性代数科学出版社课后练习题答案
- U盘拔出时总是提示有程序正在使用?
- 跨专业北邮计算机考研,北京邮电大学跨专业考研心得
- jquery-重要的方法和注意事项
- LTE的核心网之:MME,SGW,PGW
- 常用元器件的识别(转载)
- 记一次钉钉群聊机器人的开发
- 深度学习不得不知的英文名称
热门文章
- Flutter如何实现下拉刷新和上拉加载更多
- 马化腾:通向互联网未来的七个路标
- Docker 安装最新版禅道16.5版本 原创
- 有一个棋盘,有64个方格,在第一个方格里面放1粒芝麻重量是0.00001kg,第二个里面放2粒,第三个里面放4,棋盘上放的所有芝麻的重量。
- 常微分方程-差分方程
- unity有限状态机和模糊状态机(怪物AI、自动寻路)
- 【竞品分析】Android音乐播放器的竞品分析
- ArcGIS GeoEvent 使用教程(二)
- Android ArcGIS基础使用教程(10.2.8)
- 如何将win10系统安装到U盘?