TensorFlow目前保存的模型文件主要有两种,ckpt与pb,二者之间的异同请见

https://zhuanlan.zhihu.com/p/32887066

下面,我以mnist手写数据集用softmax回归为例,说明如何对训练好的模型进行保存与恢复。

1. 训练模型并保存为模型文件

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as npmnist = input_data.read_data_sets('MNIST_data', one_hot=True)
sess = tf.InteractiveSession()x = tf.placeholder("float", shape=[None, 784], name='input_x')  # 输入图像占位符
y_ = tf.placeholder("float", shape=[None, 10])  # 标签类别占位符# 模型参数一般用Variable来表示
W = tf.Variable(tf.zeros([784, 10]), name='w')  # 权重W是一个784x10的矩阵(因为我们有784个特征和10个输出值)
b = tf.Variable(tf.zeros([10]), name='b')  # 偏置b是一个10维的向量(因为我们有10个分类)sess.run(tf.initialize_all_variables())  # 变量需要通过seesion初始化后,才能在session中使用
# 使用Tensorflow提供的回归模型softmax,y代表输出,把向量化后的图片x和权重矩阵W相乘,加上偏置b,然后计算每个分类的softmax概率值
y = tf.nn.softmax(tf.matmul(x, W) + b, name='predict')cross_entropy = - tf.reduce_sum(y_ * tf.log(y))  # 计算交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  # 梯度下降算法以0.01的学习速率最小化交叉熵# tf.argmax返回某个tensor对象在某一维上的其数据最大值所在的索引值
# 下面这行返回一组布尔值如[True, False, True, True]
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 把布尔值转换成浮点数,然后取平均值,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))for i in range(1000):batch = mnist.train.next_batch(50)  # 每一步迭代加载50个训练样本,然后执行一次train_stepsess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})if i % 100 == 0:print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))  # 模型在测试数据集上面的正确率# 随便从测试集中取一个例子做测试
print(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}))
print(sess.run(tf.argmax(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}), axis=1)))   # 预测结果
print(mnist.test.labels[15])    # 标签值

对上述代码不熟悉的请参考:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_pros.html

a.保存为ckpt格式的模型文件

saver = tf.train.Saver()
saver.save(sess, "save_path/file_name")

生成的模型文件如下:

b.保存为pb格式的模型文件

builder = tf.saved_model.builder.SavedModelBuilder('./model2')
builder.add_meta_graph_and_variables(sess, ["mytag"])
builder.save()

生成的模型文件如下:

运行结果:(第6个元素最大,表示数字5,说明预测正确)

0.2847
0.8778
0.8945
0.8972
0.9031
0.9015
0.9109
0.9007
0.8901
0.9061
[[  2.59213091e-04   1.70691292e-05   1.03438069e-04   1.55748194e-022.95701193e-05   9.70679998e-01   7.14014686e-06   7.19119780e-051.32500082e-02   6.82865766e-06]]
[5]
[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]

2. 模型文件的恢复与使用

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import os
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_filemnist = input_data.read_data_sets('MNIST_data', one_hot=True)# pb模型的恢复
def restore_model_pb():sess = tf.Session()tf.saved_model.loader.load(sess, ['mytag'], os.getcwd() + '\model2')input_x = sess.graph.get_tensor_by_name('input_x:0')op = sess.graph.get_tensor_by_name('predict:0')print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))sess.close()# ckpt模型的恢复
def restore_model_ckpt():sess = tf.Session()# 加载模型结构saver = tf.train.import_meta_graph('./save_path/file_name.meta')# 只需要指定目录就可以恢复所有变量信息saver.restore(sess, tf.train.latest_checkpoint('./save_path'))# 直接获取保存的变量print(sess.run('w:0'))input_x = sess.graph.get_tensor_by_name('input_x:0')# # 获取需要进行计算的operatorop = sess.graph.get_tensor_by_name('predict:0')print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))sess.close()restore_model_pb()
# 打印所有变量的值
# print_tensors_in_checkpoint_file("save_path/file_name", None, True)

运行结果:

[[  2.59213091e-04   1.70691292e-05   1.03438069e-04   1.55748194e-022.95701193e-05   9.70679998e-01   7.14014686e-06   7.19119780e-051.32500082e-02   6.82865766e-06]]

java中调用以上模型文件请参考:java调用tensorflow模型文件

TensorFlow 模型的保存与恢复相关推荐

  1. Tensorflow【实战Google深度学习框架】TensorFlow模型的保存与恢复加载

    我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存 ...

  2. 简单完整地讲解tensorflow模型的保存和恢复

    http://blog.csdn.net/liangyihuai/article/details/78515913 在本教程主要讲到: 1. 什么是Tensorflow模型? 2. 如何保存Tenso ...

  3. Tensorflow模型的保存与恢复的细节

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  4. tensorflow——模型的保存和恢复tf.trian.saver()

    保存 1创建saver对象,确定save哪些:saver=tf.trian.Saver(),不填写参数的话默认全部 2指定在哪个session中保存,以及保存路径:saver.save(sess, ' ...

  5. [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

    [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...

  6. 基于Python的模型的保存、恢复、继续训练

    资源下载地址:https://download.csdn.net/download/sheziqiong/86774566 资源下载地址:https://download.csdn.net/downl ...

  7. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  8. tensorflow 模型的保存和加载

    为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型. 1. 保存模型 tensorflow提供了一个API可以方便的 ...

  9. TensorFlow:模型的保存与恢复(Saver)

    目录 前言 1 实例化对象 2 保存训练过程中或者训练好的, 模型图及权重参数 2.1保存训练模型 2.2 查看保存 3. 重载模型的图及权重参数(模型恢复)     前言 我们经常在训练完一个模型之 ...

最新文章

  1. QIIME 2用户文档. 18使用q2-vsearch聚类OTUs(2018.11)
  2. 汽车需要镀晶吗?镀晶是起什么作用的?
  3. GooglePR说明
  4. 数据竞赛入门-金融风控(贷款违约预测)四、建模与调参
  5. 马斯克光顾北京包子铺被偶遇 本人盖章:好吃!
  6. oc61--block
  7. 赚钱的基本逻辑就是价值交换
  8. bootstrap checkbox_[推荐]icheck-bootstrap(漂亮的ckeckbox/radiobox)
  9. bzoj千题计划315:bzoj3172: [Tjoi2013]单词(AC自动机)
  10. 升余弦滤波器与无码间串扰(一)
  11. Facebook更名Meta,扎克伯格押注元宇宙
  12. 数据产品经理该懂的python技术
  13. 利用 GDB 快速阅读 postgresql 的内核代码
  14. JQuery是什么?怎么使用JQ?
  15. PaddleOCR——训练总结
  16. 2020一战中科大计算机初复试经验贴
  17. 微型计算机也称为个人计算机由,微型计算机概述计算机概述微型电脑组装系统台式电脑...
  18. html编写在线打字通,前端关键字(打字练习)共1347个字符
  19. Excel报表公式值替换后,报错打开提示部分内容有问题, 是否尝试尽量恢复
  20. 关于专利书写以及申报的一点心得体会

热门文章

  1. 周六暴走香山,欢迎参加 ^_^
  2. 新浪网首页新闻资讯爬虫项目
  3. IDEA 开发一个服务端脚手架(archetype)
  4. 机器学习笔记(2021-08-02 第一稿)
  5. Netty中的EventExecutor
  6. 怎么看台式电脑型号 怎么看台式电脑型号和型号
  7. jdbc结合sqlserver的javaWeb工程的分页查询共通操作代码
  8. 华为手机获取状态栏高度是错误的_聊聊获取屏幕高度这件事
  9. binwalk使用整理
  10. 扫清电路设计软件盲点,protel DXP电路设计软件批量修改