为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型。

1. 保存模型

tensorflow提供了一个API可以方便的保存和还原神经网络的模型。这个API就是tf.train.saver类。

import tensorflow as tf# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2init_op = tf.global_variables_initializer()
# 声明tf.train.Saver()类用于保存模型
saver = tf.train.Saver()
# 加载保存了两个变量和的模型
with tf.Session() as sess:print("save here...")sess.run(init_op)# 保存模型到下面路径下saver.save(sess,"/Users/lilong/Desktop/tt/model.ckpt")print(sess.run(result))

运行结果:

save here...
[-1.6226364]

这里的代码实现了一个简单的加法功能,通过saver.save函数把模型保存到了相应的路径下,这里一定要注意第一次保存一定是saver.save,而不是saver.restore
虽然上面的模型保存路径只提供了一个,但是这个目录下一般会出现三个文件,这是因为tensorflow会将计算图的结构和图上的参数值分开保存。

  • model.ckpt.meta:保存了计算图的网路结构
  • model.ckpt.data.:保存了变量的取值
  • checkpoint:保存了一个目录下的所有的模型文件列表

2. 加载保存的模型

import tensorflow as tf# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2# 加载模型的代码和保存模型的代码的区别是:没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来
#init_op = tf.global_variables_initializer()saver = tf.train.Saver()
# 加载保存了两个变量和的模型
with tf.Session() as sess:print("Reading checkpoints...")# 加载已经保存的模型saver.restore(sess,"/Users/lilong/Desktop/tt/model.ckpt")print(sess.run(result))

这里要注意的是加载模型和保存模型的区别是:加载模型的代码没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。

上面是单独加载模型,当然也可以如下面这样保存好模型后直接加载:

import tensorflow as tf# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2init_op = tf.global_variables_initializer()
# 声明tf.train.Saver()类用于保存模型
saver = tf.train.Saver()
# 加载保存了两个变量和的模型
with tf.Session() as sess:print("save here...")sess.run(init_op)# 保存模型到下面路径下saver.save(sess,"/Users/lilong/Desktop/tt/model.ckpt")print(sess.run(result))# 加载保存了两个变量和的模型
with tf.Session() as sess:print("Reading checkpoints...")# 加载已经保存的模型saver.restore(sess,"/Users/lilong/Desktop/tt/model.ckpt")print(sess.run(result))

运行结果:

save here...
[-1.6226364]
Reading checkpoints...
INFO:tensorflow:Restoring parameters from /Users/lilong/Desktop/tt/model.ckpt
[-1.6226364]

还可以这样加载已经持久化的模型:

import tensorflow as tf
#  直接加载持久化的图。
saver = tf.train.import_meta_graph("/Users/lilong/Desktop/tt/model.ckpt.meta")with tf.Session() as sess:print('get here...')saver.restore(sess, "/Users/lilong/Desktop/tt/model.ckpt")print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

输出:

get here...
INFO:tensorflow:Restoring parameters from /Users/lilong/Desktop/tt/model.ckpt
[-1.6226364]

这里得到的是指定的张量的值。

4. 加载模型时给变量重命名

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 16 16:17:17 2018@author: lilong
"""import tensorflow as tf# 保存计算两个变量和的模型v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result1 = v1 + v2init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
# 加载保存了两个变量和的模型
with tf.Session() as sess:print("save here...")sess.run(init_op)# 保存模型到下面路径下saver.save(sess,"/Users/lilong/Desktop/qq/model.ckpt")print(sess.run(result1))for variables in tf.global_variables(): print ('variables_1:',variables.name)# 这里声明的变量和已经保存的模型中的变量名称不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2")
result2 = v1 + v2saver1 = tf.train.Saver({"v1": v1, "v2": v2})
#saver1 = tf.train.Saver()# 加载保存了两个变量和的模型
with tf.Session() as sess:print("Reading checkpoints...")# 加载已经保存的模型saver1.restore(sess,"/Users/lilong/Desktop/qq/model.ckpt")print(sess.run(result2))for variables in tf.global_variables(): print ('variables_2:',variables.name)

运行:

save here...
[3.]
variables_1: v1:0
variables_1: v2:0
Reading checkpoints...
INFO:tensorflow:Restoring parameters from /Users/lilong/Desktop/qq/model.ckpt
[3.]
variables_2: v1:0
variables_2: v2:0
variables_2: other-v1:0
variables_2: other-v2:0

这里对变量v1,v2的名称进行了修改,所以如果直接使用tf.train.Saver()来保存默认的模型,那么程序就会报找不到变量的错误,因为模型保存时和加载时的名称不一致,这个时候可以使用字典把模型保存时的变量名和需要加载的变量联系起来。
这样的好处之一是方便使用变量的滑动平均值,在tensorflow中的每一个变量的滑动平均值是通过影子变量维护的,所以要获取变量的滑动平均值实际就是获取影子变量的值,如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时就不需要调用函数来获取滑动平均值了。

4. 保存滑动平均模型

import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")
# 在没有申请滑动平均值时只有一个变量
for variables in tf.global_variables(): print('Before MovingAverage:',variables.name)ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
# 在申请滑动平均模型之后,tensorflow会自动生成一个影子变量:v/ExponentialMovingAverage:0
for variables in tf.global_variables(): print ('After MovingAverage:',variables.name)# 保存滑动平均模型
saver = tf.train.Saver()
with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)sess.run(tf.assign(v, 10))sess.run(maintain_averages_op)# 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。saver.save(sess, "model/model2.ckpt")print ('last:',sess.run([v, ema.average(v)])) # 输出:[10.0, 0.099999905]# 通过变量重命名直接读取变量的滑动平均值,通过这个方法就可以用完全一样的代码来计算滑动平均模型的前向传播的结果
v_1 = tf.Variable(0, dtype=tf.float32, name="v")
# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v_1})
with tf.Session() as sess:saver.restore(sess, "model/model2.ckpt")print('here:',sess.run(v_1)) # 输出0.099999905,这个值就是原来模型中变量v的滑动平均值

运行结果:

Before MovingAverage: v:0
After MovingAverage: v:0
After MovingAverage: v/ExponentialMovingAverage:0
last: [10.0, 0.099999905]
INFO:tensorflow:Restoring parameters from model/model2.ckpt
here: 0.099999905

可以看到通过变量重命名直接读取变量的滑动平均值。

为了方便加载时重命名滑动平均变量,tensorflow提供了variables_to_restore()函数,来生成tf.train.Saver类需要的变量重命名字典:

import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
# variables_to_restore()函数可以直接生成字典
print('here:',ema.variables_to_restore())#saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:print("Reading checkpoints...")saver.restore(sess, "model/model2.ckpt")print ('run:',sess.run(v))

运行结果:

here: {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
Reading checkpoints...
INFO:tensorflow:Restoring parameters from model/model2.ckpt
run: 0.099999905

使用tf.train.Saver会保存运行tensorflow中程序所需要的全部信息,而某些情况下并不需要全部的信息,比如测试或离线预测时,只需知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要其他的一些信息,有时将变量取值和计算图分成不同的文件存储也不方便,于是有了convert_variables_to_constants函数,该函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个tensorflow图可以统一保存在一个文件中。
示例:

import tensorflow as tf
from tensorflow.python.framework import graph_utilv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2init_op = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init_op)# 导出当前计算图的graphdef部分,只需要这一部分就可以完成从输入层到输出层的计算过程graph_def = tf.get_default_graph().as_graph_def()# 将图中的变量及其取值转化为常量,同时将图中不必要的节点。这里我们只关心程序中的某些计算节点,# 和这些无关的计算节点就没有计算并保存了。'add'是计算机节点名字output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])# 将导入的模型存入文件with tf.gfile.GFile("model/combined_model.pb", "wb") as f:f.write(output_graph_def.SerializeToString())# 通过下面的程序就可以直接计算定义的加法运算的结果,该方法可以用于迁移学习
from tensorflow.python.platform import gfile
with tf.Session() as sess:model_filename = "model/combined_model.pb"# 读取保存的模型文件,并将文件解析成对应的graph protocol bufferwith gfile.FastGFile(model_filename, 'rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())# graph_def中保存的图加载到当前的图中。return_elements给出返回的张量的名称。# 这里的add不是计算机节点的名称,而是张量的名称,所以会是add:0result = tf.import_graph_def(graph_def, return_elements=["add:0"])print('run:',sess.run(result))

输出:

INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
run: [array([3.], dtype=float32)]

参考:《Tensorflow实战Google深度学习框架》

tensorflow 模型的保存和加载相关推荐

  1. numpy将所有数据变为0和1_PyTorch 学习笔记(二):张量、变量、数据集的读取、模组、优化、模型的保存和加载...

    一. 张量 PyTorch里面最基本的操作对象就是Tensor,Tensor是张量的英文,表示的是一个多维的矩阵,比如零维就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维的数组,这和 ...

  2. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  3. 线性回归之模型的保存和加载

    线性回归之模型的保存和加载 1 sklearn模型的保存和加载API from sklearn.externals import joblib   [目前这行代码报错,直接写import joblib ...

  4. PyTorch | 模型的保存和加载

    PyTorch | 模型的保存和加载 一.模型参数的保存和加载 二.完整模型的保存和加载 一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用mo ...

  5. pytorch模型的保存和加载、checkpoint

    pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...

  6. paddlepaddle模型的保存和加载

    导读 深度学习中模型的计算图可以被分为两种,静态图和动态图,这两种模型的计算图各有优劣. 静态图需要我们先定义好网络的结构,然后再进行计算,所以静态图的计算速度快,但是debug比较的困难,因为只有当 ...

  7. PyTorch基础-模型的保存和加载-09

    模型的保存 import numpy as np import torch from torch import nn,optim from torch.autograd import Variable ...

  8. 调gensim库,word2vec模型的保存和加载

    一.模型的保存 模型保存可以有很多种格式,根据格式的不同可以分为2种,一种是保存为.model的文件,一种是非.model文件的保存.我常用的保存格式是.model和.vector直接上代码和结果: ...

  9. 机器学习算法------2.11 模型的保存和加载(joblib.dump()、joblib.load())

    #  模型保存 joblib.dump(estimator, "./data/test.pkl") # 模型加载 estimator = joblib.load("./d ...

最新文章

  1. Ubuntu 14.04 64bit上安装Intel官方集显更新驱动程序
  2. 学习笔记Hive(八)—— 查询优化
  3. 一直记不住window下面的盘符切换
  4. Tomcat6 ,servlet配置(可用)
  5. csdn-markdown 编辑器
  6. ubuntu 配置 静态ip
  7. Linux Shell脚本_设置时区并同步时间
  8. 95-38-150-Buffer-CompositeByteBuf
  9. R2B fpga flow script
  10. 中断python快捷键_python的快捷键
  11. C++头文件、源文件的编译链接
  12. 管道通信的基本流程和代码
  13. 人生若只如初见,当时只道是寻常
  14. 跨境网上收款 找PayPal没错(php如何实现paypal支付)
  15. 动态规划(DP)小结
  16. 招行股东会通过收购永隆银行议案
  17. MYSQL使用OR关键字查询,MySQL带OR关键字的多条件查询
  18. 谜题51:那个点是什么?
  19. mac电脑如何打包dmg安装包文件
  20. 当谈到携程机票产品经理的数据意识,我们在谈什么?

热门文章

  1. OpenCV文件输入输出的序列化功能的实例(附完整代码)
  2. OpenCV基本的SIMD的实例(附完整代码)
  3. OpenGL创建多维数据集的多个实例
  4. c++扔鸡蛋问题egg dropping puzzle(附完整源码)
  5. QT的QAction类的使用
  6. c++内存,堆和栈的区别
  7. pandas.DataFrame.iloc的使用
  8. 21.等值线图(Counter Plot)、Contour Demo、Creating a “meshgrid”、Calculation of the Values、等
  9. Flink流计算编程--在WindowedStream中体会EventTime与ProcessingTime
  10. 史上最简单的SpringCloud教程 | 第七篇: 高可用的分布式配置中心(Spring Cloud Config)