什么是Tensorflow的模型

模型部分主要参考了这篇文章和这篇博客;另外,官方文档也给出了很多指导。
Tensorflow的模型主要包括神经网络的架构设计(或者称为计算图的设计)和已经训练好的网络参数。因此,Tensorflow模型包括的主要文件:

  1. “.meta”:包含了计算图的结构
  2. “.data”:包含了变量的值
  3. “.index”:确认checkpoint
  4. “checkpiont”:一个protocol buffer,包含了最近的一些checkpoints

存储一个Tensorflow的模型

当我们训练的神经网络模型的损失函数或者精度收敛时,我们需要把参数或者网络结构存储起来。如果我们想要存储整个网络结构和该网络的所有参数,我们需要创建一个tf.train.Saver()的实例。Tensorflow变量的作用域仅在Session内部。因此,我们必须在一个Session的内部存储有关的数据。

saver.save(sess,'my_test_model')

sess是我们创建的一个Session实例,my_test_model是我们给模型的命名。
具体的实例:

import tensorflow as tfw1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './my_test_model')
sess.close()

执行上述语句,我们会同级目录下看到新增的文件:

my_test_model.data-00000-of-00001
my_test_model.index
my_test_model.meta

如果网络架构更改了,Tensorflow会重写上述的文件。

如果我们想要每1000步保存一次,那么需要更改语句:

saver.save(sess, 'my_test_model', global_step=1000)

那么当训练时,我们会每1000次迭代存储一次模型。.meta会在第一次到达1000次迭代时创建,之后的每千步,就不需要在重新创建.meta文件了。只要图的架构 不更改,就不需要重新创建.meta文件。 如果不写步数,默认每次迭代保存一次。

如果我们要仅仅保留最近4次创建的模型,并且每两个小时存储一次模型,可以进行下面的操作:

# saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

如果我们在tf.train.Saver()中不指定任何参数,那么Tensorflow会默认保存所有的变量。假设我们只想保留部分变量或者collection,那么需要显式地表明需要保留的对象。当创建tf.train.Saver()对象时,使用一个包含有关变量的list或者字典声明。比如:

import tensorflow as tfw1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1, w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './my_test_model')
sess.close()

导入一个训练好的模型

如果我们要导入一个训练好的模型,需要做以下两步:

创建一个网络

使用函数:

saver = tf.train.import_meta_graph('my_test_model-1000.meta')

把存储在my_test_model-1000.meta加载到saver当中。这个操作知识会把在.meta文件中定义的网络追加到当前网络的后面,我们仍然需要加载原来网络的参数数值。

加载参数

操作如下:

with tf.Session() as sess:new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')new_saver.restore(sess, tf.train.lasters_checkpoint('./'))

在这之后,w1w2的数据就会被重新加载进来。

对导入的模型进行的操作

现在,学着加载模型,把模型用于预测、训练甚至更改模型的架构。现在构造一个简单的网络模型,保存并重新导入。注意一点:tf.placeholder的数据不会被保存 !!!!
先定义训练文件:

import tensorflow as tf# 定义用于恢复变量的例子
w1 = tf.placeholder(dtype=tf.float32, name="w1")
w2 = tf.placeholder(dtype=tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}# 定义用于恢复操作的例子   w4=w3*b1,w3=(w1+w2)*b1
w3 = tf.add(w1, w2, name="part_op")
w4 = tf.multiply(w3, b1, name="op_to_restore")sess = tf.Session()
sess.run(tf.global_variables_initializer())  # 时刻记着,要初始化saver = tf.train.Saver()print(sess.run(w4, feed_dict))  # 24.0saver.save(sess, './my_test_model', global_step=1000)sess.close()

定义加载文件:

import tensorflow as tfsess = tf.Session()saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}# w4=w3*b1,w3=(w1+w2)*b1
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")  # 60.0print(sess.run(op_to_restore, feed_dict))sess.close()

当导入模型的时候,不但需要恢复计算图和相关的参数,而且需要重新对tf.placeholder喂数据。通过graph.get_tensor_by_name获取保存的操作和占位符。如果我们想要使用网络计算,仅需要给不同的占位符添加不同的数据即可。

如果我们想要对原来的网络添加更多的层数并接着训练它,可以按照下面的步骤处理:

import tensorflow as tfsess = tf.Session()
# 恢复计算图
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
# 获取占位符
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}
# 恢复操作
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
# 增加新的操作
add_on_op = tf.multiply(op_to_restore, 2.0)
# 别忘了喂数据
print(sess.run(add_on_op, feed_dict))sess.close()

由此可以看出,只需要把原来的操作加载完毕后,当成一个输出数据接入新的网络即可。

也可以把原来网络的一部分加载 到新的网络中,比如下面的操作:
先更改之前的一行代码

w3 = tf.add(w1, w2, name="part_op")

加载操作:

import tensorflow as tfsess = tf.Session()saver = tf.train.import_meta_graph("my_test_model-1000.meta")
saver.restore(sess, tf.train.latest_checkpoint('./'))graph = tf.get_default_graph()w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 14.0}w3 = graph.get_tensor_by_name("part_op:0")op = tf.multiply(w3, 4)
print(sess.run(op, feed_dict))  # 108.0
sess.close()

使用SavedModel的格式

SavedMode类把Saver类进行了一个更高层的封装,开发效率可能会更高,但是暂时没有前一种方法常用。Saver类更看重对变量的封装, 而SavedModel更看重压缩封装保存所有有用的信息。

保存操作:

import tensorflow as tftf.reset_default_graph()w1 = tf.Variable(1.0, name="w1")
w2 = tf.Variable(2.0, name="w2")
w3 = tf.multiply(w1, w2, name="w3")builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/')with tf.Session() as sess:sess.run(tf.global_variables_initializer())sess.run(w3)builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING],signature_def_map=None,assets_collection=None)
builder.save()

读取操作:

import tensorflow as tfwith tf.Session() as sess:tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING],'./SavedModel')w1 = sess.run('w1:0')w2 = sess.run('w2:0')w3 = sess.run('w3:0')print(w1, w2, w3)

TensorFlow保存或加载训练的模型相关推荐

  1. python如何保存训练好的模型_Python机器学习7:如何保存、加载训练好的机器学习模型...

    本文将介绍如何使用scikit-learn机器学习库保存Python机器学习模型.加载已经训练好的模型.学会了这个,你才能够用已有的模型做预测,而不需要每次都重新训练模型. 本文将使用两种方法来实现模 ...

  2. PyTorch模型训练完毕后静态量化、保存、加载int8量化模型

    1. PyTorch模型量化方法 Pytorch模型量化方法介绍有很多可以参考的,这里推荐两篇文章写的很详细可以给大家一个大致的参考Pytorch的量化,官方量化文档 Pytorch的量化大致分为三种 ...

  3. django项目启动加载训练的模型报错OSError: Unable to open file (unable to open file: name = ‘model/model_weigh完美解决

    1.原因分析 此错误原因多样通过网上整理有一下几种 ①h5py版本过高 ,重装h5py ② 相对路径改成绝对路径 ③文件无权限访问,点击文件属性,点击高级.赋予权限 ④这个是我报错的解决办法 因为他单 ...

  4. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

  5. tensorflow 1.x Saver(保存与加载模型) 预测

    20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...

  6. tensorflow加载训练好的模型实例

    1. 首先了解下tensorflow的一些基础语法知识 这里不再详细说明其细节,只举例学习. 1.1 tensorflow的tf.transpose()简单使用: tf.reshape(tensor, ...

  7. mnist手写数字模型训练、保存、加载及图片预测

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 主要过程 导入 加载数据 创建模型和训练 模型应用 总结 前言 非专业程序员,主业PLC单片机,2019年想扩充知识体 ...

  8. [tensorflow] 模型保存、加载与转换详解

    TensorFlow模型加载与转换详解 本次讲解主要涉及到TensorFlow框架训练时候模型文件的管理以及转换. 首先我们需要明确TensorFlow模型文件的存储格式以及文件个数: model_f ...

  9. tensorflow1.0模型的保存、加载、在训练

    1.checkpoint文件总览 tensorflow保存的模型文件如下所示: .meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量.op.集合等. c ...

最新文章

  1. 只要能坚持下来就是好博客
  2. Java并发编程开发笔记——2线程安全性
  3. Spring学习笔记八--Bean生命周期和后置处理器
  4. OpenGL ARB 看来终于想通了,OpenGL SDK终于要出来了。
  5. 腾讯2014年实习生招聘笔试面试经历
  6. 如何判断自己是否到了该辞职的时候
  7. python数据格式简介_Python中数据类型时间的介绍(附代码)
  8. 下面的android入门开发视频教程还不错
  9. 海康威视复赛题 --- 算法说明书
  10. 三维扫描仪行业调研报告 - 市场现状分析与发展前景预测
  11. php 字符串 strpos,PHP字符串处理函数:strpos() -- 内置函数
  12. 2020哔哩哔哩校招后端开发笔试编程题总结
  13. 计算机三级 计算机三级网络技术 如何两天内通过考试
  14. fluke dtx-1800测试精度有必要进行原厂校准吗?
  15. Python的内置函数(四十八)、setattr()函数
  16. CVPR 2022 | 数据堂亮相计算机视觉领域盛会
  17. 论文提纲格式怎么写?
  18. 目前有哪些好用的用例管理工具?
  19. win7此更新不适用计算机,安装Win7补丁遭遇“此更新不适用于你的计算机”
  20. UG10.0怎么导出模型的尺寸图,不用CAD就可以实现!

热门文章

  1. AcWing285. 没有上司的舞会(树形DP)题解
  2. Java 1.2.3 文件输入与输出
  3. ‘python3\r’: No such file or directory
  4. 运算符重载——算术运算符重载
  5. 【oracle】获取近30天日期、近5年、近6个月
  6. 【机房收费系统】---结账
  7. 静态方法、静态内部类和抽象方法的注意问题
  8. Python:assert基本用法
  9. Oracle 查看索引语句
  10. 敏捷开发智慧敏捷系列之四:每日立会开多久?