Tensorflow:模型变量保存

觉得有用的话,欢迎一起讨论相互学习~Follow Me

参考文献Tensorflow实战Google深度学习框架
实验平台:
Tensorflow1.4.0
python3.5.0

Tensorflow常用保存模型方法

import tensorflow as tf
saver = tf.train.Saver()  # 创建保存器
with tf.Session() as sess:saver.save(sess,"/path/model.ckpt")  #保存模型到相应ckpt文件saver.restore(sess,"/path/model.ckpt")  #从相应ckpt文件中恢复模型变量

使用tf.train.Saver会保存运行Tensorflow程序所需要的全部信息,然而有时并不需要某些信息。比如在测试或离线预测时,只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要类似的变量初始化,模型保存等辅助节点的信息。Tensorflow提供了convert_varibales_to_constants函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个Tensorflow计算图可以统一存放在一个文件中。

将变量取值保存为pb文件

# pb文件保存方法
import tensorflow as tf
from tensorflow.python.framework import graph_utilv1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2init_op = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init_op)  # 初始化所有变量# 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程graph_def = tf.get_default_graph().as_graph_def()# 将需要保存的add节点名称传入参数中,表示将所需的变量转化为常量保存下来。output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])# 将导出的模型存入文件中with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:f.write(output_graph_def.SerializeToString())# 2. 加载pb文件。from tensorflow.python.platform import gfilewith tf.Session() as sess:model_filename = "Saved_model/combined_model.pb"# 读取保存的模型文件,并将其解析成对应的GraphDef Protocol Bufferwith gfile.FastGFile(model_filename, 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())# 将graph_def中保存的图加载到当前图中,其中保存的时候保存的是计算节点的名称,为add# 但是读取时使用的是张量的名称所以是add:0result = tf.import_graph_def(graph_def, return_elements=["add:0"])print(sess.run(result))
# Converted 2 variables to const ops.
# [array([3.], dtype=float32)]

转载于:https://www.cnblogs.com/cloud-ken/p/9317238.html

Tensorflow模型变量保存相关推荐

  1. Tensorflow【实战Google深度学习框架】TensorFlow模型的保存与恢复加载

    我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存 ...

  2. [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

    [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...

  3. 简单完整地讲解tensorflow模型的保存和恢复

    http://blog.csdn.net/liangyihuai/article/details/78515913 在本教程主要讲到: 1. 什么是Tensorflow模型? 2. 如何保存Tenso ...

  4. Tensorflow模型的保存与恢复的细节

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  5. tensorflow 模型的保存和加载

    为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型. 1. 保存模型 tensorflow提供了一个API可以方便的 ...

  6. TensorFlow 模型的保存与恢复

    TensorFlow目前保存的模型文件主要有两种,ckpt与pb,二者之间的异同请见 https://zhuanlan.zhihu.com/p/32887066 下面,我以mnist手写数据集用sof ...

  7. tensorflow——模型的保存和恢复tf.trian.saver()

    保存 1创建saver对象,确定save哪些:saver=tf.trian.Saver(),不填写参数的话默认全部 2指定在哪个session中保存,以及保存路径:saver.save(sess, ' ...

  8. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tfw1 = tf.Variable(tf.constant(2.0, shape=[1]), name=& ...

  9. 5.2 TensorFlow:模型的加载,存储,实例

    背景 之前已经写过TensorFlow图与模型的加载与存储了,写的很详细,但是或闻有人没看懂,所以在附上一个关于模型加载与存储的例子,CODE是我偶然看到了,就记下来了.其中模型很巧妙,比之前nump ...

  10. TensorFlow模型持久化

    模型持久化的目的在于可以使模型训练后的结果重复使用,节省重复训练模型的时间. 模型保存 train.Saver类是TensorFlow提供的用于保存和还原模型的API,使用非常简单. import t ...

最新文章

  1. C++知识点2——指针、引用基础
  2. Windows下更改mysql data目录
  3. 【C 语言】结构体 ( 结构体变量内存操作 | 通过 “ . “ 操作符操作结构体内存空间 | 通过 “ -> “ 操作符操作结构体内存空间 )
  4. 排版 项目 html,实现HTML自动排版的法则2_html
  5. 微信小程序_组件学习_001
  6. android studio failed to open zip file,Android Studio出现Failed to open zip file问题的解决方法...
  7. 腾讯吃鸡 android,腾讯吃鸡手游《光荣使命》正式上线:安卓/iOS不限号测试
  8. webshell提权教程linux,Linux下WEBSHELL提权
  9. python怎么开发工具_为程序员和新手准备的8大Python开发工具
  10. php与mysql关系大揭秘_【慕课笔记】PHP与MySQL关系大揭秘
  11. 数据结构笔记(三)-- 链式实现顺序表
  12. css:transform,transition,animation总结
  13. python——extend用新序列扩展其他列表
  14. Mybatis关联关系
  15. XP的故障恢复控制台
  16. python print输出指定小数位数
  17. canvas实现动态点线背景,鼠标画点连线。
  18. c语言 编码 乐学,c语言乐学作业
  19. php 合成图片 微信公众号合成海报
  20. SpringBoot+Vue实现第三方QQ登录(一)

热门文章

  1. linux访问vdma的数据,Xilinx VDMA 24位流输出与32位AXI总线的内存流数据关系
  2. JAVA rs 是否要关闭_关闭结果集rs和statement以后,是否还要关闭数据库连接呢?...
  3. 405.数字转换为十六进制数
  4. 32个参数累加_「机械设计教程」滚珠丝杠选型过程中考虑的9个参数
  5. Kruskal算法实现最小生成树MST(java)
  6. 无线通信基础(一):高斯随机变量
  7. Spring Cloud随记----远程配置文件资源库的建立-涉及一些简单的git操作
  8. 企业微信oauth认证_企业微信开发之授权登录
  9. 八个小技巧教你做出舒服的MG动画
  10. 刷题记录 kuangbin带你飞专题一:简单搜索