Tensorflow: 保存和复原模型(save and restore)
报错:
is not valid checkpoint
解决:
module_file = tf.train.latest_checkpoint(diag_obj.save_path) saver.restore(sess, module_file)
Tensorflow: 保存和复原模型(save and restore)
目前我主要看到了两种方法来保存和复原tensorflow model,先总结一下:
MetaGraph
这种就是我们经常看到的 tf.train.Saver
对应的东西。使用这种方法保存模型,会产生两种文件。
- meta: 里面存储的是整个graph的定义
- checkpoint: 这里保存的是
variable
的状态。
这里通过如下的方式保存一个模型
checkpoint_dir = "mysaver"# first creat a simple graph
graph = tf.Graph()#define a simple graph
with graph.as_default():x = tf.placeholder(tf.float32,shape=[],name='input')y = tf.Variable(initial_value=0,dtype=tf.float32,name="y_variable")update_y = y.assign(x)saver = tf.train.Saver(max_to_keep=3)init_op = tf.global_variables_initializer()# train the model and save the model every 4000 iterations.
sess = tf.Session(graph=graph)
sess.run(init_op)
for i in range(1,10000):y_result = sess.run(update_y,feed_dict={x:i})if i %4000 == 0:saver.save(sess,checkpoint_dir,global_step=i)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
这些是产生的文件
checkpoint
mysaver-4000.data-00000-of-00001
mysaver-4000.index
mysaver-4000.meta
mysaver-8000.data-00000-of-00001
mysaver-8000.index
mysaver-8000.meta
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
稍后我们可以复原model
tf.reset_default_graph()
restore_graph = tf.Graph()
with tf.Session(graph=restore_graph) as restore_sess:restore_saver = tf.train.import_meta_graph('mysaver-8000.meta')restore_saver.restore(restore_sess,tf.train.latest_checkpoint('./'))print(restore_sess.run("y_variable:0"))
- 1
- 2
- 3
- 4
- 5
- 6
上面这段python代码的输出如下:
INFO:tensorflow:Restoring parameters from ./mysaver-8000
8000.0
- 1
- 2
因为最新的checkpoint文件是在 8000th iterations保存的,所以当model复原后 y_variable的值是 80000
SavedModel
还有一种保存模型的方法就是 SavedModel
。
这种方法我是在看tensorflow servicing的时候看到的,个人的感觉,这是一种更适合部署的方法。暂时没有去研究tensorflow servicing。但是我看很多代码都使用到了通过这种方式保存的文件。比如imagenet example。所以这里着重介绍怎么使用从别的地方拿到的SavedModel文件。
建立 SavedModel
主要分为三部
* 建立一个 tf.saved_model.builder.SavedModelBuilder
.
* 使用刚刚建立的 builder把当前的graph和variable添加进去:SavedModelBuilder.add_meta_graph_and_variables(...)
* 可以使用 SavedModelBuilder.add_meta_graph
添加多个meta graph
复原 SavedModel
这个需要通过这个 model 来完成的:tf.saved_model.loader
通过命令来查看和执行SavedModel
上面的通过编程的方式来建立和复原SavedModel
, 我现在基本上不需要发布模型给别人用,但是经常想使用一下别人已经训练好的模型。当拿到别人的模型的时候,需要知道怎么使用。官方提供了一个工具:saved_model_cli
,这个工具包含了 show 和 run 两类命令
感兴趣的同学可以查看官方文档 或者这篇博客对应的 jupyter notebook
可视化 SavedModel
我们知道google提供 TensorBoard给我们可视化的调试tensorflow, tensorboard一个最基本的功能就是把graph展示出来。但是有时候我们拿到别人 SavedModel
, 我们需要把这个model跑一遍,产生summary文件才能在tensorboard里面看。google deepdream 参考代码里面提供了一个很方便的代码可以让我们快速的把graph展示出来。代码如下, 这个代码是我也放到我的github了,大家也可以直接去看google deepdram 参考代码
# these function is copied from google deepdream example code
import numpy as np
from IPython.display import clear_output, Image, display, HTML
def strip_consts(graph_def, max_const_size=32):"""Strip large constant values from graph_def."""strip_def = tf.GraphDef()for n0 in graph_def.node:n = strip_def.node.add() n.MergeFrom(n0)if n.op == 'Const':tensor = n.attr['value'].tensorsize = len(tensor.tensor_content)if size > max_const_size:tensor.tensor_content = tf.compat.as_bytes("<stripped %d bytes>"%size)return strip_defdef rename_nodes(graph_def, rename_func):res_def = tf.GraphDef()for n0 in graph_def.node:n = res_def.node.add() n.MergeFrom(n0)n.name = rename_func(n.name)for i, s in enumerate(n.input):n.input[i] = rename_func(s) if s[0]!='^' else '^'+rename_func(s[1:])return res_def
def show_graph(graph_def, max_const_size=32):"""Visualize TensorFlow graph."""if hasattr(graph_def, 'as_graph_def'):graph_def = graph_def.as_graph_def()strip_def = strip_consts(graph_def, max_const_size=max_const_size)code = """<script>function load() {{document.getElementById("{id}").pbtxt = {data};}}</script><link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()><div style="height:600px"><tf-graph-basic id="{id}"></tf-graph-basic></div>""".format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))iframe = """<iframe seamless style="width:800px;height:620px;border:0" srcdoc="{}"></iframe>""".format(code.replace('"', '"'))display(HTML(iframe))
Tensorflow: 保存和复原模型(save and restore)相关推荐
- TensorFlow保存和恢复模型的方法总结
使用TensorFlow训练模型的过程中,需要适时对模型进行保存,以及对保存的模型进行restore,以方便后续对模型进行处理.比如进行测试,或者部署:比如拿别的模型进行fine-tune,等等.当然 ...
- TensorFlow 保存和加载模型
参考: 保存和恢复模型官方教程 tensorflow2保存和加载模型 TensorFlow2.0教程-keras模型保存和序列化
- tensorflow保存和恢复模型saver.restore
1.本文只对一些细节点做补充,大体的步骤就不详述了 2.保存模型 ① 首先我使用的是tensorflow-gpu 1.4.0 ② 这个版本生成的ckpt文件是这样的: 其中.meta存放的是网络模型和 ...
- QPainter保存与恢复:save与restore函数浅析
在Qt中进行图像绘制,需要用到QPainter对象,这个对象可以帮助我们完成一些简单功能的绘制,比如说绘制线条,绘制折线等简单的绘制功能. QPainter对象,有两个很有意思的函数,这两个函数相互之 ...
- Tensorflow |(5)模型保存与恢复、自定义命令行参数
Tensorflow |(1)初识Tensorflow Tensorflow |(2)张量的阶和数据类型及张量操作 Tensorflow |(3)变量的的创建.初始化.保存和加载 Tensorflow ...
- tensorflow model save and restore
TensorFlow 模型保存/载入 我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来.tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个skl ...
- 转载:tensorflow保存训练后的模型
训练完一个模型后,为了以后重复使用,通常我们需要对模型的结果进行保存.如果用Tensorflow去实现神经网络,所要保存的就是神经网络中的各项权重值.建议可以使用Saver类保存和加载模型的结果. 1 ...
- Tensorflow保存神经网络参数有妙招:Saver和Restore
摘要:这篇文章将讲解TensorFlow如何保存变量和神经网络参数,通过Saver保存神经网络,再通过Restore调用训练好的神经网络. 本文分享自华为云社区<[Python人工智能] 十一. ...
- Tensorflow保存模型和加载预训练模型
训练好的模型需要保存下来或者加载已经训练完成的模型,就用到了ckpt文件. 目录 1.了解tensorflow保存的文件 (1)checkpoint (2)MyModel.meta (3)MyMode ...
最新文章
- 30 段极简 Python 代码:这些小技巧你都 Get 了么?
- 中文发音关系频谱的猜想
- css权重计算方法浅谈
- java returnaddress_Java虚拟机规范】Java SE 7虚拟机结构
- 通用客户端表单验证函数修正版
- java如何画出表格_Java利用iText7画个性化表格
- Jmeter---jason提取器处理上下游传参(四)
- mysql自定义变量
- 【最优解法】1087 有多少不同的值 (20分)_17行代码AC
- Angular本地数据存储LocalStorage
- mysql数据库属性_mysql - 数据库操作和数据属性
- rtsp摘要认证协议(Response计算方法)
- python中列表相加规则_在Python字典列表中使用公共键/值求和值
- 追踪盗窃12亿用户登录数据的网络犯罪团伙
- python实现雪花动态图_如何通过雪花算法用Python实现一个简单的发号器
- 字节流Stream(Output 、Input)、字符流(Reader、Writer)
- 计算机grand,The Grand
- 计算机文化与计算机技术有什么区别,什么是计算机文化?
- 愿守内心宁静,砥砺此生修行
- 决策树、装袋、提升和随机森林的对比理解
热门文章
- linux spec 脚本,关于linux:shell脚本的单元测试
- java安装_使用Java 9模块化来发布零依赖本机应用程序
- 卸载后的mysql和navicat怎么清除干净_清除电脑“牛皮癣 ”,带你回归清爽体验~...
- 修改时间服务器失败,电脑系统同步时间失败怎么办 修改时间服务器的方法。...
- 沈阳师范大学计算机题库,沈阳师范大学软件学院计算机学科专业基础综合历年考研真题汇编-20210607153358.docx-原创力文档...
- html一半文字一半图片,一个div的子div宽是200高是350 里面怎么让图片显示一半 另外一半文字居中!?...
- 每日一皮:是金子无论到哪里、哪怕变个形状都会发光..
- 每日一皮:当我看到Bug背后的一切...我退缩了...
- Spring Boot自定义 Servlet Filter 的两种方式
- Spring Boot 2.x基础教程:使用Swagger2构建强大的API文档