5.4 TensorFlow模型持久化

5.4.1. ckpt文件保存方法

在对模型进行加载时候,需要定义出与原来的计算图结构完全相同的计算图,然后才能进行加载,并且不需要对定义出来的计算图进行初始化操作。 
这样保存下来的模型,会在其文件夹下生成三个文件,分别是: 
* .ckpt.meta文件,保存tensorflow模型的计算图结构。 
* .ckpt文件,保存计算图下所有变量的取值。 
* checkpoint文件,保存目录下所有模型文件列表。

import tensorflow as tf
#保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2init_op = tf.global_variables_initializer()
saver = tf.train.Saver()with tf.Session() as sess:sess.run(init_op)saver.save(sess, "Saved_model/model.ckpt")
#加载保存了两个变量和的模型
with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print sess.run(result)INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[-1.6226364]
#直接加载持久化的图。因为之前没有导出v3,所以这里会报错
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print sess.run(v1) print sess.run(v2) print sess.run(v3)
INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[-0.81131822]
[-0.81131822]# 变量重命名,这样可以通过字典将模型保存时的变量名和需要加载的变量联系起来
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2")
saver = tf.train.Saver({"v1": v1, "v2": v2})

View Code

5.4.2.1 滑动平均类的保存

import tensorflow as tf
#使用滑动平均
v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables(): print variables.nameema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables(): print variables.name
v:0
v:0
v/ExponentialMovingAverage:0#保存滑动平均模型
saver = tf.train.Saver()
with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)sess.run(tf.assign(v, 10))sess.run(maintain_averages_op)# 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。saver.save(sess, "Saved_model/model2.ckpt")print sess.run([v, ema.average(v)])
10.0, 0.099999905]#加载滑动平均模型
v = tf.Variable(0, dtype=tf.float32, name="v")# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:saver.restore(sess, "Saved_model/model2.ckpt")print sess.run(v)
INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.0999999

View Code

5.4.2.2 variables_to_restore函数的使用样例

import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
print ema.variables_to_restore()#等同于saver = tf.train.Saver(ema.variables_to_restore())
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:saver.restore(sess, "Saved_model/model2.ckpt")print sess.run(v)
{u'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}

5.4.3. pb文件保存方法

#pb文件的保存方法
import tensorflow as tf
from tensorflow.python.framework import graph_utilv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2init_op = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init_op)graph_def = tf.get_default_graph().as_graph_def()output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:f.write(output_graph_def.SerializeToString())INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
------------------------------------------------------------------------
#加载pb文件
from tensorflow.python.platform import gfile
with tf.Session() as sess:model_filename = "Saved_model/combined_model.pb"with gfile.FastGFile(model_filename, 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())result = tf.import_graph_def(graph_def, return_elements=["add:0"])print sess.run(result)[array([ 3.], dtype=float32)]

张量的名称后面有:0,表示是某个计算节点的第一个输出,而计算节点本身的名称后是没有:0的。

5.5 TensorFlow最佳实践样例程序

为了使程序的可扩展性更好,减少编写冗余代码,提高编程效率,我们可以将不同功能模块分开,这一节还会将前向传播过程抽象成一个单独库函数。

mnist_inference

import tensorflow as tf
#1. 定义神经网络结构相关的参数
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
#2. 通过tf.get_variable函数来获取变量
def get_weight_variable(shape, regularizer):weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))if regularizer != None: tf.add_to_collection('losses', regularizer(weights))return weights
#3. 定义神经网络的前向传播过程
def inference(input_tensor, regularizer):with tf.variable_scope('layer1'):weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)with tf.variable_scope('layer2'):weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))layer2 = tf.matmul(layer1, weights) + biasesreturn layer2

mnist_train

mnist_train
#1. 定义神经网络结构相关的参数
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "MNIST_model/"
MODEL_NAME = "mnist_model"
#2. 定义训练过程
def train(mnist):# 定义输入输出placeholder。x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)y = mnist_inference.inference(x, regularizer)global_step = tf.Variable(0, trainable=False)# 定义损失函数、学习率、滑动平均操作以及训练过程。variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)variables_averages_op = variable_averages.apply(tf.trainable_variables())cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))cross_entropy_mean = tf.reduce_mean(cross_entropy)loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,staircase=True)train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)with tf.control_dependencies([train_step, variables_averages_op]):train_op = tf.no_op(name='train')# 初始化TensorFlow持久化类。saver = tf.train.Saver()with tf.Session() as sess:tf.global_variables_initializer().run()for i in range(TRAINING_STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE)_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})if i % 1000 == 0:print("After %d training step(s), loss on training batch is %g." % (step, loss_value))saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
#3. 主程序入口
def main(argv=None):mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)train(mnist)if __name__ == '__main__':main()
-------------------------------------------------------------------------
Extracting ../../../datasets/MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../../datasets/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
After 1 training step(s), loss on training batch is 3.05851.
After 1001 training step(s), loss on training batch is 0.207949.
After 2001 training step(s), loss on training batch is 0.214515.
After 3001 training step(s), loss on training batch is 0.237391.
After 4001 training step(s), loss on training batch is 0.115064.
After 5001 training step(s), loss on training batch is 0.103093.
After 6001 training step(s), loss on training batch is 0.133556.
....

mnist_eval

import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
#1. 每10秒加载一次最新的模型
# 加载的时间间隔。
EVAL_INTERVAL_SECS = 10def evaluate(mnist):with tf.Graph().as_default() as g:x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}y = mnist_inference.inference(x, None)correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)while True:with tf.Session() as sess:ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]accuracy_score = sess.run(accuracy, feed_dict=validate_feed)print("After %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))else:print('No checkpoint file found')returntime.sleep(EVAL_INTERVAL_SECS)
#2主程序
def main(argv=None):mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)evaluate(mnist)if __name__ == '__main__':main()
------------------------------------------------------------------------
Extracting ../../../datasets/MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../../datasets/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-4001
After 4001 training step(s), validation accuracy = 0.9826
INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-5001
After 5001 training step(s), validation accuracy = 0.983
INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-6001
After 6001 training step(s), validation accuracy = 0.9832
INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-7001
After 7001 training step(s), validation accuracy = 0.9834...

转载于:https://www.cnblogs.com/exciting/p/8542859.html

第五章 MNIST数字识别问题(二)相关推荐

  1. TensorFlow解决MNIST数字识别问题

    TensorFlow解决MNIST数字识别问题 废话 这个MNIST数字识别问题是我实现的第一个神经网络,虽然过程基本上都是对着书上的代码敲,但还是对神经网络的训练过程有了一定的了解,同时也复习了前面 ...

  2. 写给初学者的深度学习教程之 MNIST 数字识别

    一般而言,MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程,几乎是所有的教程都会把它放在最开始的地方.这是因为,这个简单的工程包含了大致的机器学习流程,通过练习这个工程 ...

  3. matlab——红绿灯颜色及数字识别(二)

    实验总结:红绿灯颜色以及数字识别(二):提取信号灯区域 一.知识背景 这里放一些链接来供查阅参考: 1.连通量函数:   Matlab中bwlabel函数的使用 2.膨胀.腐蚀:   形态学操作:膨胀 ...

  4. 深度学习算法优化系列十八 | TensorRT Mnist数字识别使用示例

    1. 前言 上一节对TensorRT做了介绍,然后科普了TensorRT优化方式以及讲解在Windows下如何安装TensorRT6.0,最后还介绍了如何编译一个官方给出的手写数字识别例子获得一个正确 ...

  5. 《深度学习之TensorFlow》reading notes(3)—— MNIST手写数字识别之二

    文章目录 模型保存 模型读取 测试模型 搭建测试模型 使用模型 模型可视化 本文是在上一篇文章 <深度学习之TensorFlow>reading notes(2)-- MNIST手写数字识 ...

  6. 【TensorFlow】笔记3:MNIST数字识别问题

    文章目录 一.MNIST数据处理 1.数据集概述 2.数据获取 二.神经网络模型训练及不同模型结果对比 1.TF训练神经网络 2.使用验证数据判断模型效果 3.不同模型效果比较 三.变量管理 1.tf ...

  7. 学习笔记----周志华《机器学习》第五章(神经网络)(二)

    周志华<机器学习>第五章(神经网络)的学习笔记上篇连接在这里:<上篇>.上篇讲到了神经网络.常用的激活函数.感知机和多层前馈神经网络.局部极小和全局最小,今天继续补上昨天落下得 ...

  8. 第十五章 动态规划(最优二叉搜索树)

    第15章动态规划(最优二叉搜索树) 15.5 最优二叉搜索树 15.5 练习 15.5-1 15.5-2 15.5-3 15.5-4 说在前面的话: 为什么单独拿出来发? 1.由于排版篇幅问题,放一起 ...

  9. mnist手写数字识别python_基于tensorflow的MNIST手写数字识别(二)--入门篇

    一.本文的意义 因为谷歌官方其实已经写了MNIST入门和深入两篇教程了,那我写这些文章又是为什么呢,只是抄袭?那倒并不是,更准确的说应该是笔记吧,然后用更通俗的语言来解释,并且补充更多,官方文章中没有 ...

  10. MNIST数据集手写数字识别(二)

    上一篇对MNIST数据集有了一些了解,数据集包含着60000张训练图片与标签值和10000张测试图片与标签值的数据集,数据集有了,现在我们来构造神经网络,预测下对这测试的10000张图片的正确识别率, ...

最新文章

  1. nginx的gzip压缩功能
  2. 百度谷歌等联合推出机器学习基准 加速AI软硬件发展
  3. python3 image模块_python3之成像库pillow
  4. java 枚举 被继承_enum不能被继承
  5. FreeBSD 安装过程
  6. The working copy is locked due to a previous error.
  7. oracle游标遍历的三种方式
  8. 从零开始学keras之使用预训练的卷积神经网络
  9. UI实用素材|扁平化UI设计模板,UI设计师都要会!
  10. 电脑配置知识_电脑小知识:装机不求人!10 分钟电脑配置挑选速成攻略|硬盘|电脑|cpu|装机|固态硬盘|机械硬盘...
  11. oracle10g debian,Debian5下oracle10g安装时DISPLAY的设置
  12. 关于寒假作业存在问题的强调
  13. atitit.提升性能AppCache
  14. SSH和SSM对比(学完后的总结)
  15. SQL Server如何建表
  16. oracle 查询优化
  17. 计算机二进制技巧,计算机中十进制转二进制的相关技巧
  18. iPhoneXSM屏幕适配、 各机型的逻辑分辨率
  19. 白盒测试---讲解(1)
  20. SQL判断是否为null如果为null则返回0

热门文章

  1. wamp环境下php命令运行时出现错误:无法启动此程序,因为计算机中丢失OCI.dll。尝试重新安装该程序以解决此问题...
  2. Google的语音识别API,支持各种语言
  3. 为什么要关闭数据库连接,可以不关闭吗?
  4. 情感分析技术在美团的探索与应用
  5. 【GPT-3】地表最强语言模型GPT-3的局限与出路
  6. 【Fudan DISC】一种无监督下利用多模态文档结构信息帮助图片-句子匹配的采样方法...
  7. 【深度语义匹配模型】实践篇:语义匹配在贝壳找房智能客服中的应用
  8. 不如跳舞!伯克利的舞蹈动作迁移效果逆天
  9. 机器学习面试- Scikit-learn
  10. 深度学习2.0-23.Keras高层接口之CIFAR10自定义网络实战