Tensorflow SavedModel 模型的保存和加载
前言
最近参加了天池上的Apache Flink极客挑战赛——垃圾图片分类比赛,里面涉及到了Java调用tensorflow的SavedModel格式的模型进行预测,于是专门对此内容进行了调研。这里记录了SavedModel模型的优势,结构以及保存和加载的方法。
SavedModel的优势
Tensorflow训练的模型可以保存为ckpt格式,但是这种格式的模型文件在跨语言方面不是很灵活。而SaveModel与语言无关,比如可以使用python语言训练模型,然后在Java中非常方便的加载模型。
SavedModel的结构
以SavedModel格式保存模型时,tensorflow将创建一个SavedModel目录,该目录由以下子目录和文件组成:
assets/
assets.extra/
variables/
variables.data-???-of-???
variables.index
saved_model.pb|saved_model.pbtxt
其中,各目录和文件的说明如下:
assets/是包含辅助(外部)文件(如词汇表)的子文件夹。资产被复制到SavedModel位置,并且可以在加载特定的MetaGraphDef时读取。
assets.extra是一个子文件夹,高级库和用户可以将自己的资源添加进去,这些资源将与模型共存但不由图形加载。此子文件夹不由SavedModel库管理。
variables/是包含tf.train.saver输出的子文件夹。
saved_model.pb或saved_model.pbtxt是SavedModel协议缓冲区。它将图形定义作为MetaGraphDef协议缓冲区。
MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示。
assets/和assets.extra目录是可选的。
SavedModel模型的保存
这里以手写识别模型为例。
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.saved_model.signature_def_utils import predict_signature_def
from tensorflow.saved_model import tag_constants
mnist = input_data.read_data_sets(“MNIST_data/”, one_hot=True)
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784], name=“Input”) # 为输入op添加命名"Input"
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 1))
tf.identity(y, name=“Output”) # 为输出op命名为"Output"
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.global_variables_initializer().run()
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
将模型保存到文件
简单方法:
tf.saved_model.simple_save(sess,
“./model_simple”,
inputs={“Input”: x},
outputs={“Output”: y})
复杂方法
builder = tf.saved_model.builder.SavedModelBuilder(“./model_complex”)
signature = predict_signature_def(inputs={‘Input’: x},
outputs={‘Output’: y})
builder.add_meta_graph_and_variables(sess=sess,
tags=[tag_constants.SERVING],
signature_def_map={‘predict’: signature})
builder.save()
代码解析:
x = tf.placeholder(tf.float32, [None, 784], name=“Input”) # 为输入op添加命名"Input" 这里是为输入op进行命名,当然也可以不命名,系统会默认给一个名称"Placeholder",当我们需要引用多个op的时候,给每个op一个命名,确实方便我们后面的使用。
tf.identity(y, name=“Output”) # 为输出op命名为"Output" 使用tf.identity为输出tensor命名。
代码中给出了两种方法进行模型保存。复杂方法较简单方法的最大优势在于——可以自己定义tag,在签名的定义上更加灵活。
tag的作用: 一个模型可以包含不同的MetaGraphDef,比如你想保存graph的CPU版本和GPU版本,或者你想区分训练和发布版本。这个时候tag就可以用来区分不同的MetaGraphDef,加载的时候能够根据tag来加载模型的不同计算图。在simple_save方法中,系统会给一个默认的tag: “serve”,也可以用tag_constants.SERVING这个常量。
SavedModel模型的加载
Python
import numpy as np
mnist = input_data.read_data_sets(“MNIST_data/”, one_hot=True)
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [“serve”], “./model_simple”)
graph = tf.get_default_graph()
input = np.expand_dims(mnist.test.images[0], 0)
x = sess.graph.get_tensor_by_name('Input:0')
y = sess.graph.get_tensor_by_name('Output:0')
batch_xs, batch_ys = mnist.test.next_batch(1)
scores = sess.run(y,feed_dict={x: batch_xs})
print("predict: %d, actual: %d" % (np.argmax(scores, 1), np.argmax(batch_ys, 1)))
tf.saved_model.loader.load的第二个参数是定义的tag值,要和保存时定义保持一致;第三个参数是模型保存的路径。
Java
Java需要添加maven依赖:
<dependency><groupId>org.tensorflow</groupId><artifactId>tensorflow</artifactId><version>1.11.0</version></dependency><dependency><groupId>org.tensorflow</groupId><artifactId>proto</artifactId><version>1.11.0</version></dependency>
package com.garbage;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
/**
- @Author: Jeremy
- @Date: 2019/9/17 11:29
*/
public class SavedModelLoader {
public static void main(String[] args) throws Exception{
ConfigProto configProto = ConfigProto.newBuilder()
.setAllowSoftPlacement(true)
.build();
SavedModelBundle model = SavedModelBundle.loader(“YOUR_MODEL_PATH”)
.withConfigProto(configProto.toByteArray())
.withTags(“serve”)
.load();
SignatureDef modelSig = MetaGraphDef.parseFrom(model.metaGraphDef()).getSignatureDefOrThrow(“serving_default”);
int numInputs = modelSig.getInputsCount();
String inputTensorName = modelSig.getInputsMap().get(“Input”).getName();
String outputTensorName = modelSig.getOutputsMap().get(“Output”).getName();
System.out.println(String.format(“numInputs: %d, inputTensorName: %s, outputTensor: %s”, numInputs, inputTensorName, outputTensorName));
}
}
输出:
numInputs: 1, inputTensorName: Input_9:0, outputTensor: Softmax_10:0
使用Java预测的部分不是本文的重点,以后有机会再写。
总结
本篇博文总结了Tensorflow SaveModel模型,介绍了SaveModel模型的优势,目录结构,使用python进行模型保存,使用python和java加载模型。
Tensorflow SavedModel 模型的保存和加载相关推荐
- numpy将所有数据变为0和1_PyTorch 学习笔记(二):张量、变量、数据集的读取、模组、优化、模型的保存和加载...
一. 张量 PyTorch里面最基本的操作对象就是Tensor,Tensor是张量的英文,表示的是一个多维的矩阵,比如零维就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维的数组,这和 ...
- PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard
文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...
- 线性回归之模型的保存和加载
线性回归之模型的保存和加载 1 sklearn模型的保存和加载API from sklearn.externals import joblib [目前这行代码报错,直接写import joblib ...
- PyTorch | 模型的保存和加载
PyTorch | 模型的保存和加载 一.模型参数的保存和加载 二.完整模型的保存和加载 一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用mo ...
- pytorch模型的保存和加载、checkpoint
pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...
- paddlepaddle模型的保存和加载
导读 深度学习中模型的计算图可以被分为两种,静态图和动态图,这两种模型的计算图各有优劣. 静态图需要我们先定义好网络的结构,然后再进行计算,所以静态图的计算速度快,但是debug比较的困难,因为只有当 ...
- tensorflow 模型的保存和加载
为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型. 1. 保存模型 tensorflow提供了一个API可以方便的 ...
- PyTorch基础-模型的保存和加载-09
模型的保存 import numpy as np import torch from torch import nn,optim from torch.autograd import Variable ...
- 调gensim库,word2vec模型的保存和加载
一.模型的保存 模型保存可以有很多种格式,根据格式的不同可以分为2种,一种是保存为.model的文件,一种是非.model文件的保存.我常用的保存格式是.model和.vector直接上代码和结果: ...
- 机器学习算法------2.11 模型的保存和加载(joblib.dump()、joblib.load())
# 模型保存 joblib.dump(estimator, "./data/test.pkl") # 模型加载 estimator = joblib.load("./d ...
最新文章
- Ubuntu9.10使用windows的字体的方法!
- C#动态属性(.NET Framework4.5支持)
- python list 取重复次数
- 如何同时GET√5斤网易味央猪肉和正确的APP IM开发姿势?
- servlet3.0新特性_查看Servlet 3.0的新增功能
- Linux内核 eBPF:Hacking Linux USDT with Ftrace
- 自学python 编程基础科学计算及数据分析 pdf_自学Python:编程基础、科学计算及数据分析...
- batch批处理(转载)
- 寻找绝对隐蔽的后门的办法 分享
- Android 循环缓冲区
- (转)GridView固定表头
- 华为网关服务器型号,02311CWM CN21ITGC SP212 I350-T4 华为服务器四口千兆网卡
- namp安装及官方使用手册翻译及注释5
- solidworks迈迪插件_迈迪工具集V55特别PJ版_打包下载
- linux配置文件前面有分号,linux中的分号 ||
- SAP查询销售订单库存
- Android TV H5 电视应用
- jena4.1.0安装及使用
- Frequent values RMQ
- Javascript基础知识之四(常用数组方法)
热门文章
- c语言 输出1到n之间的全部素数,输出1到n中所有的素数
- This Product is covered by one or more of the folloWing patents
- 快速原型VS敏捷、迭代
- Java开发社招面试经验:2021最新Java面试笔试
- 电子技术基础(三)__第5章 之逻辑函数的卡诺图化简方法
- 高质量论文配图配色,让你的图更加亮眼
- ios计算机错误,用iTunes更新IOS14失败,显示发生未知错误(4000)的简单解决办法!...
- fiddler+mitmproxy+夜神模拟器安装
- 阿里云dataV大屏可视化的使用攻略——vue项目
- python数据科学包第三天(股票数据分析、时间事件日志)