本文是《TensorFlow从浅入深》系列之第13篇

TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法)

TensorFlow从浅入深系列之二 -- 教你通过思维导图深度理解深层神经网络

TensorFlow从浅入深系列之三 -- 教你如何对MNIST手写识别

TensorFlow从浅入深系列之四 -- 教你深入理解过拟合问题(正则化)

TensorFlow从浅入深系列之五 -- 教你详解滑动平均模型

TensorFlow从浅入深系列之六 -- 教你深入理解经典损失函数(交叉熵、均方误差)

TensorFlow从浅入深系列之七 -- 教你使用验证数据集判断模型效果

TensorFlow从浅入深系列之八 -- 教你学会变量管理

TensorFlow从浅入深系列之九 -- 教你认识图像识别中经典数据集

TensorFlow从浅入深系列之十 -- 教你认识卷积神经网络的基本网路结构及其与全连接神经网络的差异

TensorFlow从浅入深系列之十一 -- 教你深入理解卷积神经网络中的卷积层

TensorFlow从浅入深系列之十二 -- 教你深入理解卷积神经网络中的池化层

目录

1、持久化代码实现

2、加载保存的TensorFlow模型

3、加载部分变量

4、加载变量时重命名


1、持久化代码实现

TensorFlow提供了一个非常简单的API来保存和还原一个神经网络模型。这个API就是tf.train.Saver类。一下代码给出了保存TensorFlow计算图的方法。

#!/usr/bin/env python
# -*- coding:utf-8 -*-import tensorflow as tf# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2init_op = tf.global_variables_initializer()
saver = tf.train.Saver()with tf.Session() as sess:sess.run(init_op)# saver.sabe函数保存到“Saved_model/model.ckpt”saver.save(sess, "Saved_model/model.ckpt")

解析:

  • 在这段代码中,通过saver.save 函数将TensorFlow模型保存到了“Saved_model/model.ckpt”文件中。TensorFlow模型一般会存在后缀为.ckpt的文件中 。
  • 虽然以上程序只指定了 一个文件路径,但是在这个文件目录下会出现三个文件:
  1. 第一个文件为model.ckpt.meta,它保存了 TensorFlow计算图的结构
  2. 第二个文件为model.ckpt,这个文件中保存了TensorFlow 程序中每一个变量的取值。
  3. 第三个文件为checkpoint文件,这个文件中保存了一个目录下所有的模型文件列表

2、加载保存的TensorFlow模型

以下代码中给出了加载这个已经保存的TensorFlow模型的方法

#!/usr/bin/env python
# -*- coding:utf-8 -*-import tensorflow as tf# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2saver = tf.train.Saver()# 加载保存的模型,加载全部模型
with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print(sess.run(result))

上述代码输出为:

解析:

这段加载模型的代码基本上和保存模型的代码是一样的。在加载模型的程序中也是先定义了TensorFlow计算图上的所有运算,并声明了 一个tf.train.Saver类。两段代码唯一不同的是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过己经保存的模型加载进来。


如果不希望重复定义图上的运算,也可以直接加载已经持久化的图。一下代码给出一个样例:

import tensorflow as tf# 加载持久化的图
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")with tf.Session() as sess:saver.restore(sess,"Saved_model/model.ckpt")# 通过张量的名称来获取张量print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

这段代码与上述代码达到的效果相同。是两种方式加载模型。


3、加载部分变量

为了保存或者加载部分变量,在声明 tf.train.Saver 类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train. Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来 。如果运行修改后只加载了v1的代码会得到变革未初始化的错误:

tensorflow.python.framework.errors.FailedPreconditionError:Attempting touse uninitialized value v2

4、加载变量时重命名

tf.train.Saver类也支持在保存或者加载时给变量重命名。下面给出了一个简单的样例程序说明变量重命名是如何被使用的。

#!/usr/bin/env python
# -*- coding:utf-8 -*-import tensorflow as tf# tf.reset_default_graph()# 声明变量
V1 = tf.Variable(tf.constant(1.0, shape=[1]), name="a1")
V2 = tf.Variable(tf.constant(2.0, shape=[1]), name="a2")
# result = V1 + V2# 这里要注意,checkpoint中的变量名的问题,不然就会出现问题
saver = tf.train.Saver({"Variable": V1, "Variable_1": V2})# 加载保存的模型,加载全部模型
with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print(sess.run(V1+V2))

上述关于查看checkpoint文件中的变量名的问题,请参考博文TensorFlow中查看checkpoint文件中的变量名和对应值

【TensorFlow】TensorFlow从浅入深系列之十三 -- 教你深入理解模型持久化(模型保存、模型加载)相关推荐

  1. 【TensorFlow】TensorFlow从浅入深系列之十一 -- 教你深入理解卷积神经网络中的卷积层

    本文是<TensorFlow从浅入深>系列之第11篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维 ...

  2. 【TensorFlow】TensorFlow从浅入深系列之十 -- 教你认识卷积神经网络的基本网路结构及其与全连接神经网络的差异

    本文是<TensorFlow从浅入深>系列之第10篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维 ...

  3. 【TensorFlow】TensorFlow从浅入深系列之九 -- 教你认识图像识别中经典数据集

    本文是<TensorFlow从浅入深>系列之第9篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维导 ...

  4. 【TensorFlow】TensorFlow从浅入深系列之二 -- 教你通过思维导图深度理解深层神经网络

    本文是<TensorFlow从浅入深>系列之第2篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维导 ...

  5. 【TensorFlow】TensorFlow从浅入深系列之十二 -- 教你深入理解卷积神经网络中的池化层

    本文是<TensorFlow从浅入深>系列之第12篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维 ...

  6. 【TensorFlow】TensorFlow从浅入深系列之八 -- 教你学会变量管理

    本文是<TensorFlow从浅入深>系列之第8篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维导 ...

  7. 【TensorFlow】TensorFlow从浅入深系列之七 -- 教你使用验证数据集判断模型效果

    本文是<TensorFlow从浅入深>系列之第7篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维导 ...

  8. 【TensorFlow】TensorFlow从浅入深系列之六 -- 教你深入理解经典损失函数(交叉熵、均方误差)

    本文是<TensorFlow从浅入深>系列之第6篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维导 ...

  9. 【TensorFlow】TensorFlow从浅入深系列之五 -- 教你详解滑动平均模型

    本文是<TensorFlow从浅入深>系列之第5篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维导 ...

最新文章

  1. SVN基本的理解和使用
  2. CentOS6.5 搭建 LNMP (linux + nginx + mysql + php)
  3. StringEscapeUtils类的使用
  4. PHP文件上传 (以上传txt文件为例)
  5. 还在为运维烦恼?体验云上运维服务,提意见赢好礼!【华为云分享】
  6. 利用matlab绘制函数图像
  7. 苹果iPhone XI新爆料:用了被小米当噱头的TOF技术
  8. Linux多任务编程(二)---fork()函数及其基础实验
  9. matlab模式识别大作业_史上最萌最认真的机器学习/深度学习/模式识别入门指导手册(二)...
  10. 对JSP内置对象的部分总结
  11. 转载:子网掩码以及子网划分
  12. 第十四届教育技术与计算机国际会议新增SSCI, ESCI期刊
  13. CF679A.Bear and Prime 100 (交互题)
  14. 50件关于学霸与学渣的小事
  15. Nvidia Xavier平台CAN收发控制器调试记录
  16. IEEE754浮点数格式详解
  17. 通过js获取本地IP地址
  18. 护士执业证注册照片底色怎么更换?照片换背景底色的方法
  19. 项目沟通与干系人管理:沟通渠道选择、干系人权力/利益方格
  20. 用Java实现矩阵乘法

热门文章

  1. python获取系统参数_python 常用系统参数
  2. pythondev更新到3_python版本升级到3.7
  3. python中的zip是什么意思_python中zip是什么函数
  4. mysql limti_mysql优化
  5. pb90代码如何连接sql2008r2_RabbitMQ各种交换机机制,代码实践篇
  6. python数据清洗csv_Pandas 数据处理,数据清洗详解
  7. 服务器向客户机发信息,服务器如何主动给客户端发消息
  8. 安杰文高等计算机与生产技术学校,法国留学院校推荐:安杰文高等计算机与生产技术学校...
  9. 计算机二级安装64位的还是,电脑操作系统安装,该选择32位还是64位?
  10. php 遍历所有网站网址,使用selenium获取网址所加载所有资源url列表信息