前言
最近参加了天池上的Apache Flink极客挑战赛——垃圾图片分类比赛,里面涉及到了Java调用tensorflow的SavedModel格式的模型进行预测,于是专门对此内容进行了调研。这里记录了SavedModel模型的优势,结构以及保存和加载的方法。

SavedModel的优势
Tensorflow训练的模型可以保存为ckpt格式,但是这种格式的模型文件在跨语言方面不是很灵活。而SaveModel与语言无关,比如可以使用python语言训练模型,然后在Java中非常方便的加载模型。

SavedModel的结构
以SavedModel格式保存模型时,tensorflow将创建一个SavedModel目录,该目录由以下子目录和文件组成:

assets/
assets.extra/
variables/
variables.data-???-of-???
variables.index
saved_model.pb|saved_model.pbtxt
其中,各目录和文件的说明如下:

assets/是包含辅助(外部)文件(如词汇表)的子文件夹。资产被复制到SavedModel位置,并且可以在加载特定的MetaGraphDef时读取。
assets.extra是一个子文件夹,高级库和用户可以将自己的资源添加进去,这些资源将与模型共存但不由图形加载。此子文件夹不由SavedModel库管理。
variables/是包含tf.train.saver输出的子文件夹。
saved_model.pb或saved_model.pbtxt是SavedModel协议缓冲区。它将图形定义作为MetaGraphDef协议缓冲区。
MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示。
assets/和assets.extra目录是可选的。

SavedModel模型的保存
这里以手写识别模型为例。

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.saved_model.signature_def_utils import predict_signature_def
from tensorflow.saved_model import tag_constants

mnist = input_data.read_data_sets(“MNIST_data/”, one_hot=True)

sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784], name=“Input”) # 为输入op添加命名"Input"
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 1))
tf.identity(y, name=“Output”) # 为输出op命名为"Output"

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.global_variables_initializer().run()

for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

将模型保存到文件

简单方法:

tf.saved_model.simple_save(sess,
“./model_simple”,
inputs={“Input”: x},
outputs={“Output”: y})

复杂方法

builder = tf.saved_model.builder.SavedModelBuilder(“./model_complex”)
signature = predict_signature_def(inputs={‘Input’: x},
outputs={‘Output’: y})
builder.add_meta_graph_and_variables(sess=sess,
tags=[tag_constants.SERVING],
signature_def_map={‘predict’: signature})
builder.save()
代码解析:

x = tf.placeholder(tf.float32, [None, 784], name=“Input”) # 为输入op添加命名"Input" 这里是为输入op进行命名,当然也可以不命名,系统会默认给一个名称"Placeholder",当我们需要引用多个op的时候,给每个op一个命名,确实方便我们后面的使用。
tf.identity(y, name=“Output”) # 为输出op命名为"Output" 使用tf.identity为输出tensor命名。
代码中给出了两种方法进行模型保存。复杂方法较简单方法的最大优势在于——可以自己定义tag,在签名的定义上更加灵活。
tag的作用: 一个模型可以包含不同的MetaGraphDef,比如你想保存graph的CPU版本和GPU版本,或者你想区分训练和发布版本。这个时候tag就可以用来区分不同的MetaGraphDef,加载的时候能够根据tag来加载模型的不同计算图。在simple_save方法中,系统会给一个默认的tag: “serve”,也可以用tag_constants.SERVING这个常量。

SavedModel模型的加载
Python
import numpy as np
mnist = input_data.read_data_sets(“MNIST_data/”, one_hot=True)

with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [“serve”], “./model_simple”)
graph = tf.get_default_graph()

input = np.expand_dims(mnist.test.images[0], 0)
x = sess.graph.get_tensor_by_name('Input:0')
y = sess.graph.get_tensor_by_name('Output:0')
batch_xs, batch_ys = mnist.test.next_batch(1)
scores = sess.run(y,feed_dict={x: batch_xs})
print("predict: %d, actual: %d" % (np.argmax(scores, 1), np.argmax(batch_ys, 1)))

tf.saved_model.loader.load的第二个参数是定义的tag值,要和保存时定义保持一致;第三个参数是模型保存的路径。

Java
Java需要添加maven依赖:

    <dependency><groupId>org.tensorflow</groupId><artifactId>tensorflow</artifactId><version>1.11.0</version></dependency><dependency><groupId>org.tensorflow</groupId><artifactId>proto</artifactId><version>1.11.0</version></dependency>

package com.garbage;

import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;

/**

  • @Author: Jeremy
  • @Date: 2019/9/17 11:29
    */
    public class SavedModelLoader {
    public static void main(String[] args) throws Exception{
    ConfigProto configProto = ConfigProto.newBuilder()
    .setAllowSoftPlacement(true)
    .build();
    SavedModelBundle model = SavedModelBundle.loader(“YOUR_MODEL_PATH”)
    .withConfigProto(configProto.toByteArray())
    .withTags(“serve”)
    .load();
    SignatureDef modelSig = MetaGraphDef.parseFrom(model.metaGraphDef()).getSignatureDefOrThrow(“serving_default”);
    int numInputs = modelSig.getInputsCount();
    String inputTensorName = modelSig.getInputsMap().get(“Input”).getName();
    String outputTensorName = modelSig.getOutputsMap().get(“Output”).getName();
    System.out.println(String.format(“numInputs: %d, inputTensorName: %s, outputTensor: %s”, numInputs, inputTensorName, outputTensorName));
    }
    }
    输出:

numInputs: 1, inputTensorName: Input_9:0, outputTensor: Softmax_10:0
使用Java预测的部分不是本文的重点,以后有机会再写。

总结
本篇博文总结了Tensorflow SaveModel模型,介绍了SaveModel模型的优势,目录结构,使用python进行模型保存,使用python和java加载模型。

Tensorflow SavedModel 模型的保存和加载相关推荐

  1. numpy将所有数据变为0和1_PyTorch 学习笔记(二):张量、变量、数据集的读取、模组、优化、模型的保存和加载...

    一. 张量 PyTorch里面最基本的操作对象就是Tensor,Tensor是张量的英文,表示的是一个多维的矩阵,比如零维就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维的数组,这和 ...

  2. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  3. 线性回归之模型的保存和加载

    线性回归之模型的保存和加载 1 sklearn模型的保存和加载API from sklearn.externals import joblib   [目前这行代码报错,直接写import joblib ...

  4. PyTorch | 模型的保存和加载

    PyTorch | 模型的保存和加载 一.模型参数的保存和加载 二.完整模型的保存和加载 一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用mo ...

  5. pytorch模型的保存和加载、checkpoint

    pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...

  6. paddlepaddle模型的保存和加载

    导读 深度学习中模型的计算图可以被分为两种,静态图和动态图,这两种模型的计算图各有优劣. 静态图需要我们先定义好网络的结构,然后再进行计算,所以静态图的计算速度快,但是debug比较的困难,因为只有当 ...

  7. tensorflow 模型的保存和加载

    为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型. 1. 保存模型 tensorflow提供了一个API可以方便的 ...

  8. PyTorch基础-模型的保存和加载-09

    模型的保存 import numpy as np import torch from torch import nn,optim from torch.autograd import Variable ...

  9. 调gensim库,word2vec模型的保存和加载

    一.模型的保存 模型保存可以有很多种格式,根据格式的不同可以分为2种,一种是保存为.model的文件,一种是非.model文件的保存.我常用的保存格式是.model和.vector直接上代码和结果: ...

  10. 机器学习算法------2.11 模型的保存和加载(joblib.dump()、joblib.load())

    #  模型保存 joblib.dump(estimator, "./data/test.pkl") # 模型加载 estimator = joblib.load("./d ...

最新文章

  1. Ubuntu9.10使用windows的字体的方法!
  2. C#动态属性(.NET Framework4.5支持)
  3. python list 取重复次数
  4. 如何同时GET√5斤网易味央猪肉和正确的APP IM开发姿势?
  5. servlet3.0新特性_查看Servlet 3.0的新增功能
  6. Linux内核 eBPF:Hacking Linux USDT with Ftrace
  7. 自学python 编程基础科学计算及数据分析 pdf_自学Python:编程基础、科学计算及数据分析...
  8. batch批处理(转载)
  9. 寻找绝对隐蔽的后门的办法 分享
  10. Android 循环缓冲区
  11. (转)GridView固定表头
  12. 华为网关服务器型号,02311CWM CN21ITGC SP212 I350-T4 华为服务器四口千兆网卡
  13. namp安装及官方使用手册翻译及注释5
  14. solidworks迈迪插件_迈迪工具集V55特别PJ版_打包下载
  15. linux配置文件前面有分号,linux中的分号 ||
  16. SAP查询销售订单库存
  17. Android TV H5 电视应用
  18. jena4.1.0安装及使用
  19. Frequent values RMQ
  20. Javascript基础知识之四(常用数组方法)

热门文章

  1. c语言 输出1到n之间的全部素数,输出1到n中所有的素数
  2. This Product is covered by one or more of the folloWing patents
  3. 快速原型VS敏捷、迭代
  4. Java开发社招面试经验:2021最新Java面试笔试
  5. 电子技术基础(三)__第5章 之逻辑函数的卡诺图化简方法
  6. 高质量论文配图配色,让你的图更加亮眼
  7. ios计算机错误,用iTunes更新IOS14失败,显示发生未知错误(4000)的简单解决办法!...
  8. fiddler+mitmproxy+夜神模拟器安装
  9. 阿里云dataV大屏可视化的使用攻略——vue项目
  10. python数据科学包第三天(股票数据分析、时间事件日志)