# 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行
# 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确# Part1: 通过tf.train.Saver类实现保存和载入神经网络模型# 执行本段程序时注意当前的工作路径
import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2saver = tf.train.Saver()with tf.Session() as sess:sess.run(tf.global_variables_initializer())saver.save(sess, "Model/model.ckpt")# Part2: 加载TensorFlow模型的方法import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2saver = tf.train.Saver()with tf.Session() as sess:saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./"print(sess.run(result)) # [ 3.]# Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图import tensorflow as tfsaver = tf.train.import_meta_graph("Model/model.ckpt.meta")with tf.Session() as sess:saver.restore(sess, "./Model/model.ckpt") # 注意路径写法print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.]# Part4: tf.train.Saver类也支持在保存和加载时给变量重命名import tensorflow as tf# 声明的变量名称name与已保存的模型中的变量名称name不一致
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = u1 + u2# 若直接生命Saver类对象,会报错变量找不到
# 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}
# 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中
saver = tf.train.Saver({"v1": u1, "v2": u2})with tf.Session() as sess:saver.restore(sess, "./Model/model.ckpt")print(sess.run(result)) # [ 3.]# Part5: 保存滑动平均模型import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables():print(variables.name) # v:0ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():print(variables.name) # v:0# v/ExponentialMovingAverage:0saver = tf.train.Saver()with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run(tf.assign(v, 10))sess.run(maintain_averages_op)saver.save(sess, "Model/model_ema.ckpt")print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905]# Part6: 通过变量重命名直接读取变量的滑动平均值import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})with tf.Session() as sess:saver.restore(sess, "./Model/model_ema.ckpt")print(sess.run(v)) # 0.0999999# Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")
# 注意此处的变量名称name一定要与已保存的变量名称一致
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
# {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
# 此处的v取自上面变量v的名称name="v"saver = tf.train.Saver(ema.variables_to_restore())with tf.Session() as sess:saver.restore(sess, "./Model/model_ema.ckpt")print(sess.run(v)) # 0.0999999# Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中import tensorflow as tf
from tensorflow.python.framework import graph_utilv1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2with tf.Session() as sess:sess.run(tf.global_variables_initializer())# 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分graph_def = tf.get_default_graph().as_graph_def()output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def, ['add'])with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f:f.write(output_graph_def.SerializeToString())# Part9: 载入包含变量及其取值的模型import tensorflow as tf
from tensorflow.python.platform import gfilewith tf.Session() as sess:model_filename = "Model/combined_model.pb"with gfile.FastGFile(model_filename, 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())result = tf.import_graph_def(graph_def, return_elements=["add:0"])print(sess.run(result)) # [array([ 3.], dtype=float32)]

model的存储与读取相关推荐

  1. TF2.0 subclass存储及读取模型

    Tensorflow Subclass存储问题 问题描述:项目中通过tf.keras.layer.Layers及f.keras.layer.Model进行构建模型,在存储的过程中能够存储自己的模型,但 ...

  2. Mybatis解决数据库Blob类型存储与读取问题

    1.Blob介绍 首先,先简单介绍下数据库Blob字段,Blob(Binary Large Object)是指二进制大对象字段,顺带介绍下Clob类型,Clob(Character Large Obj ...

  3. paip.odbc DSN的存储与读取

    paip.odbc DSN的存储与读取 作者Attilax ,  EMAIL:1466519819@qq.com  来源:attilax的专栏 地址:http://blog.csdn.net/atti ...

  4. Python Json存储与读取

    前言 Python 中的文件数据存储和读取可以说是非常方便了,这里记录一下 JSon 数据的存储和读取,需要用到的模块就是 json,该模块能够将简单的 Python 数据结构转储到文件中,并在程序再 ...

  5. .Net下二进制形式的文件(图片)的存储与读取 [ZT]

    .Net下图片的常见存储与读取凡是有以下几种: 存储图片:以二进制的形式存储图片时,要把数据库中的字段设置为Image数据类型(SQL Server),存储的数据是Byte[]. 1.参数是图片路径: ...

  6. 使用SharedPreferences存储和读取数据

    转:http://www.worlduc.com/blog2012.aspx?bid=19403392 1.任务目标 (1)掌握Android中SharedPreferences的使用方法. 2.任务 ...

  7. C++ STL容器vector篇(三) vector容器大小和数组大小, 插入和删除元素, 存储和读取元素

    vector容器的大小(capacity)和存放数据的大小(size) #include <iostream> #include <vector>using namespace ...

  8. .net 数据存储 mysql_asp.net实现存储和读取数据库图片

    本文实例为大家分享了asp.net存储和读取数据库图片的具体代码,供大家参考,具体内容如下 1. 创建asp.net web窗体项目 代码如下: 上传图片 展示图片 效果图如下: 2. 创建数据库 数 ...

  9. php如何从mongo获取视频文件,使用mongodb对文件(图片、音频、视频)的存储、读取操作...

    使用mongodb对文件(图片.音频.视频)的存储.读取操作 实现代码示例: package mongo.util; import java.io.File; import java.io.IOExc ...

最新文章

  1. 机器翻译难敌人类灵活多变的语言
  2. Storm源码阅读之SpoutOutputCollector
  3. mvc中循环遍历分配的代码
  4. samba 2.2.7a 编译
  5. Javascript构造函数的继承
  6. php form 后台函数,Discuz!开发之后台表单生成函数介绍
  7. ArcGIS API for JavaScript心得体验
  8. C语言 — 编程规范、标识符命名规范
  9. 巨人肩膀_如何站在巨人的肩膀上
  10. 新生报到系统_中大深圳校区欢迎你!5个院系1271名本科新生报到
  11. c++运行时报Floating point exception错误
  12. 计算机网络运动会入场词,运动会入场词
  13. 两个PDF比较标出差异_[连玉君专栏]如何检验分组回归后的组间系数差异?
  14. AT24C02数据存储
  15. 数学建模_统计回归模型的梳理与总结:逐步回归,残差检验,自相关
  16. meta20 无法安装 google play_【黑科技】安卓手机安装Google Play
  17. 大厂程序员必备的一套浏览器书签,我帮你整理好了。[下载导入浏览器]
  18. python写手机应用宝下载_APK 批量爬取脚本(应用宝和360市场)
  19. dsf5.0没登录显示登录弹框
  20. Shell 遍历数组的方法

热门文章

  1. 商家入驻商城 多商户商城 宝塔安装搭建教程 说明 小程序、h5、pc端
  2. 太阳能收集充电器设计
  3. 遗传算法(四)MATLAB GA工具箱使用 附解TSP问题
  4. 代码随想录算法训练营第十五天 | 102. 二叉树的层序遍历 | 226.翻转二叉树 | 101. 对称二叉树
  5. 手把手教你:如何让围棋人工智能Leela Zero陪你“人机大战”
  6. 作为一个菜鸟程序员跳槽可行吗?
  7. curl malformed
  8. 格科微电子技术支持(应届)面试
  9. 勾股OA安装配置教程
  10. 阿里后台四年,想要跳槽字节,艰难4面,已收开发岗offer