# coding:utf-8# @author: liu
# @file: sparse_tensor_pb.py
# @time: 2020/4/3 11:00
# @desc:# coding:utf-8import tensorflow as tf
import random
import numpy as np
import os
import shutilprint(tf.__version__)def create_sparse(batch_size, dtype=np.int32):'''创建稀疏张量,ctc_loss中labels要求是稀疏张量,随机生成序列长度在150~180之间的labels'''indices = []values = []for i in range(batch_size):length = random.randint(150, 180)for j in range(length):indices.append((i, j))value = random.randint(0, 779)values.append(value)indices = np.asarray(indices, dtype=np.int64)values = np.asarray(values, dtype=dtype)shape = np.asarray([batch_size, np.asarray(indices).max(0)[1] + 1], dtype=np.int64)  # [64,180]return [indices, values, shape]# 保存成pb模型def saved_model(sess: tf.Session, input: tf.sparse_placeholder , ss, cc, model_path):if os.path.exists(model_path):shutil.rmtree(model_path)builder = tf.saved_model.builder.SavedModelBuilder(model_path)# input_x = tf.saved_model.build_tensor_info(input)indices = tf.saved_model.build_tensor_info(input.indices)values = tf.saved_model.build_tensor_info(input.values)dense_shape = tf.saved_model.build_tensor_info(input.dense_shape)output_a = tf.saved_model.build_tensor_info(ss)output_b = tf.saved_model.build_tensor_info(cc)prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(inputs={"indices": indices, "values": values, "dense_shapes": dense_shape},outputs={"ss": output_a, "cc": output_b},method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={"predict": prediction_signature})builder.save()def load_model(model_path=None):sess = tf.Session()meta_graph = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)signature = meta_graph.signature_def# input_x = signature["predict"].inputs["input"].nameindices = signature["predict"].inputs["indices"].namevalues = signature["predict"].inputs["values"].namedense_shape = signature["predict"].inputs["dense_shapes"].nameoutput_a = signature["predict"].outputs["ss"].nameoutput_b = signature["predict"].outputs["cc"].name# return sess, input_x, output_a, output_breturn sess, indices, values, dense_shape, output_a, output_bdef train():a = tf.sparse_placeholder(tf.float32, name="in")values = a.valuesc = tf.sparse_to_dense(a.indices, a.dense_shape, a.values)s = tf.sparse.reduce_sum(a)indices_list = [[0, 1], [0, 4], [2, 3]]# print(indices_list)values_list = [1, 4, 5]dense_shape_list = [4, 5]with tf.Session() as sess:sess.run(tf.global_variables_initializer())ss, cc, va = sess.run([s, c, values], feed_dict={a: tf.SparseTensorValue(indices=indices_list, values=values_list, dense_shape=dense_shape_list)})print(ss)print("-----------")print(cc)# print("va", va)saved_model(sess, a, s, c, "model/1")def predict():indices_list = [[0, 1], [0, 4], [2, 3]]# print(indices_list)values_list = [1, 4, 5]dense_shape_list = [4, 5]# sess, input_x, output_a, output_b = load_model("model/1")sess, indices, values, dense_shape, output_a, output_b = load_model("model/1")with sess:# a, b = sess.run([output_a, output_b], {input_x: tf.SparseTensorValue(indices=indices_list, values=values_list, dense_shape=dense_shape_list)})a, b = sess.run([output_a, output_b], feed_dict={indices:indices_list, values:values_list, dense_shape:dense_shape_list})print("预测结果")print(a)print(b)train()
predict()

tensorflow中sparse_placeholder在saved_model中保存pb模型的使用方法相关推荐

  1. Tensorflow训练神经网络保存*.pb模型及载入*.pb模型

    1 神经网络结构 1.0 保存*.pb模型 import tensorflow as tf from tensorflow.python.framework import graph_util fro ...

  2. Word中截取部分内容并保存为jpg图片的方法

    private void button1_Click(object sender, EventArgs e) { var appWord = new Microsoft.Office.Interop. ...

  3. TensorFlow模型保存pb或ckpt

    Tensorflow的保存分为三种:1. checkpoint模式:2. pb模式:3. saved_model模式. https://www.zhihu.com/collection/6445044 ...

  4. h5模型转化为pb模型,代码及排坑

    我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数. 附上h5_to_pb.py(pyt ...

  5. tensorflow打印模型图_从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)...

    最近看到一个巨牛的人工智能教程,分享一下给大家.教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家.平时碎片时间可以当小说看,[点这里可以去膜拜一下大神的" ...

  6. TensorFlow中查看checkpoint文件中的变量名和对应值

    在加载模型时, 需要知道checkpoint中变量名称,一下代码可以查看TensorFlow中checkpoint文件中的变量名: #!/usr/bin/env python # -*- coding ...

  7. Tensorflow 迁移学习 识别中国军网、中国军视网Logo水印

    Tensorflow 目标检测项目 图片logo水印识别.识别中国军网.中国军视网Logo水印. image image Step 0 下载项目 git clone https://github.co ...

  8. linux tensorflow demo_独家 | 在浏览器中使用TensorFlow.js和Python构建机器学习模型(附代码)...

    作者:MOHD SANAD ZAKI RIZVI 翻译:吴金笛 校对:丁楠雅 本文约5500字,建议阅读15分钟. 本文首先介绍了TensorFlow.js的重要性及其组件,并介绍使用其在浏览器中构建 ...

  9. python tensorflow web_如何在浏览器中使用TensorFlow?

    原标题:如何在浏览器中使用TensorFlow? 虽然您可以借助TensorFlow用数量较少的训练数据来训练简单的神经网络,但对于拥有庞大训练数据集的深度神经网络而言,确实需要使用具有CUDA功能的 ...

最新文章

  1. class viewController has no initializers解决办法
  2. Kattis - bela
  3. 类路径是什么意思_多播是什么意思 多播介绍【详解】
  4. 美国杜克大学计算机专业世界排名,美国杜克大学世界排名情况怎么样?
  5. 精通spring——深入java ee开发核心技术 pdf_2019精通Spring Boot 42讲 高清pdf完整版
  6. 【POJ 1679 The Unique MST】最小生成树
  7. python实现大批量pdf格式论文的重命名与目录制作功能
  8. php苹果app微信支付 无法返回,微信支付,php_微信支付APP返回-1怎么解决,微信支付,php,移动app - phpStudy...
  9. MapReduce 之shuffle过程
  10. san mysql,高性能MySQL:SAN和NAS
  11. ubuntu 12.04 配置内核崩溃自动重启及转存
  12. 【人民币识别】基于matlab GUI人民币序列号识别【含Matlab源码 908期】
  13. 线性代数科学出版社课后练习题答案
  14. U盘拔出时总是提示有程序正在使用?
  15. 跨专业北邮计算机考研,北京邮电大学跨专业考研心得
  16. jquery-重要的方法和注意事项
  17. LTE的核心网之:MME,SGW,PGW
  18. 常用元器件的识别(转载)
  19. 记一次钉钉群聊机器人的开发
  20. 深度学习不得不知的英文名称

热门文章

  1. Flutter如何实现下拉刷新和上拉加载更多
  2. 马化腾:通向互联网未来的七个路标
  3. Docker 安装最新版禅道16.5版本 原创
  4. 有一个棋盘,有64个方格,在第一个方格里面放1粒芝麻重量是0.00001kg,第二个里面放2粒,第三个里面放4,棋盘上放的所有芝麻的重量。
  5. 常微分方程-差分方程
  6. unity有限状态机和模糊状态机(怪物AI、自动寻路)
  7. 【竞品分析】Android音乐播放器的竞品分析
  8. ArcGIS GeoEvent 使用教程(二)
  9. Android ArcGIS基础使用教程(10.2.8)
  10. 如何将win10系统安装到U盘?