使用TensorFlow训练模型的过程中,需要适时对模型进行保存,以及对保存的模型进行restore,以方便后续对模型进行处理。比如进行测试,或者部署;比如拿别的模型进行fine-tune,等等。当然,直接的保存和restore比较简单,无需多言,但是保存和restore中还牵涉到其他问题,以及针对各种需求的各种参数等,可能不便一下都记好。因此,有必要对此进行一个总结。本文就是对使用TensorFlow保存和restore模型的相关内容进行一下总结,以便备忘。

保存模型

保存模型是整个内容的第一步,当然也十分简单。无非是创建一个saver,并在一个Session里完成保存。比如:

saver = tf.train.Saver()
with tf.Session() as sess:saver.save(sess, model_name)

以上代码在0.11以下版本的TensorFlow里会保存与下面类似的3个文件:

checkpoint

model.ckpt-1000.meta

model.ckpt-1000.ckpt

在0.11及以上版本的TensorFlow里则会保存与下类似的4个文件:

checkpoint

model.ckpt-1000.index

model.ckpt-1000.data-00000-of-00001

model.ckpt-1000.meta

其中checkpoint列出保存的所有模型以及最近的模型;meta文件是模型定义的内容;ckpt(或data和index)文件是保存的模型数据;内里细节无需过多关注,如果想了解,stackOverflow上有一个解释的回答。

当然,除了上面最简单的保存方式,也可以指定保存的步数,多长时间保存一次,磁盘上最多保有几个模型(将前面的删除以保持固定个数),如下:

创建saver时指定参数:

saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)

其中savable_variables指定待保存的变量,比如指定为tf.global_variables()保存所有global变量;指定为[v1, v2]保存v1和v2两个变量;如果省略,则保存所有;

max_to_keep指定磁盘上最多保有几个模型;keep_checkpoint_every_n_hours指定多少小时保存一次。

保存模型时指定参数:

saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)

如上,其中可以指定模型文件名,步数,write_meta_graph则用来指定是否保存meta文件记录graph等等。

Restore模型

具体来说,Restore模型的过程可以分为两个部分,首先是创建模型,可以手动创建,也可以从meta文件里加载graph进行创建。

创建模型与训练模型时创建模型的代码相同,可以直接复制过来使用。

从meta文件里进行加载,可以直接在Session里进行如下操作:

with tf.Session() as sess:saver = tf.train.import_meta_graph('model.ckpt-1000.meta')

后面的参数直接使用meta文件的路径即可。如此,即将模型定义的graph加载进来了。

当然,还有一点需要注意,并非所有的TensorFlow模型都能将graph输出到meta文件中或者从meta文件中加载进来,如果模型有部分不能序列化的部分,则此种方法可能会无效。

然后就是为模型加载数据,可以使用下面两种方法:

with tf.Session() as sess:saver = tf.train.import_meta_graph('model.ckpt-1000.meta')saver.restore(sess, tf.train.latest_checkpoint('./'))

此方法加载指定文件夹下最近保存的一个模型的数据;或者

with tf.Session() as sess:saver = tf.train.import_meta_graph('model.ckpt-1000.meta')saver.restore(sess, os.path.join(path, 'model.ckpt-1000'))

此方法可以指定具体某个数据,需要注意的是,指定的文件不要包含后缀。

使用Restore的模型

将模型数据加载进来之后,下一步就是利用加载的模型进行下一步的操作了。这可以根据不同需要以如下几种方式进行操作。

1.查看模型参数

可以直接查看Restore进来的模型的参数,如下:

with tf.Session() as sess:saver = tf.train.import_meta_graph('model.ckpt-1000.meta')saver.restore(sess, tf.train.latest_checkpoint('./'))tvs = [v for v in tf.trainable_variables()]for v in tvs:print(v.name)print(sess.run(v))

如名所言,以上是查看模型中的trainable variables;或者我们也可以查看模型中的所有tensor或者operations,如下:

with tf.Session() as sess:saver = tf.train.import_meta_graph('model.ckpt-1000.meta')saver.restore(sess, tf.train.latest_checkpoint('./'))gv = [v for v in tf.global_variables()]for v in gv:print(v.name)

上面通过global_variables()获得的与前trainable_variables类似,只是多了一些非trainable的变量,比如定义时指定为trainable=False的变量,或Optimizer相关的变量。

下面则可以获得几乎所有的operations相关的tensor:

with tf.Session() as sess:saver = tf.train.import_meta_graph('model.ckpt-1000.meta')saver.restore(sess, tf.train.latest_checkpoint('./'))ops = [o for o in sess.graph.get_operations()]for o in ops:print(o.name)

首先,上面的sess.graph.get_operations()可以换为tf.get_default_graph().get_operations(),二者区别无非是graph明确的时候可以直接使用前者,否则需要使用后者。

此种方法获得的tensor比较齐全,可以从中一窥模型全貌。不过,最方便的方法还是推荐使用tensorboard来查看,当然这需要你提前将sess.graph输出。

2.直接使用原始模型进行训练或测试(前传)

这种操作比较简单,无非是找到原始模型的输入、输出即可。

只要搞清楚输入输出的tensor名字,即可直接使用TensorFlow中graph的get_tensor_by_name函数,建立输入输出的tensor:

with tf.get_default_graph() as graph:data = graph.get_tensor_by_name('data:0')output = graph.get_tensor_by_name('output:0')

如上,需要特别注意,get_tensor_by_name后面传入的参数,如果没有重复,需要在后面加上“:0”。

从模型中找到了输入输出之后,即可直接使用其继续train整个模型,或者将输入数据feed到模型里,并前传得到test输出了。

需要说明的是,有时候从一个graph里找到输入和输出tensor的名字并不容易,所以,在定义graph时,最好能给相应的tensor取上一个明显的名字,比如:

data = tf.placeholder(tf.float32, shape=shape, name='input_data')preds = tf.nn.softmax(logits, name='output')

诸如此类。这样,就可以直接使用tf.get_tensor_by_name(‘input_data:0’)之类的来找到输入输出了。

3.扩展原始模型

除了直接使用原始模型,还可以在原始模型上进行扩展,比如对1中的output继续进行处理,添加新的操作,可以完成对原始模型的扩展,如:

with tf.get_default_graph() as graph:data = graph.get_tensor_by_name('data:0')output = graph.get_tensor_by_name('output:0')logits = tf.nn.softmax(output)

4.使用原始模型的某部分

有时候,我们有对某模型的一部分进行fine-tune的需求,比如使用一个VGG的前面提取特征的部分,而微调其全连层,或者将其全连层更换为使用convolution来完成,等等。TensorFlow也提供了这种支持,可以使用TensorFlow的stop_gradient函数,将模型的一部分进行冻结。

with tf.get_default_graph() as graph:graph.get_tensor_by_name('fc1:0')fc1 = tf.stop_gradient(fc1)# add new procedure on fc1

转载于:https://my.oschina.net/u/2272631/blog/1556094

TensorFlow保存和恢复模型的方法总结相关推荐

  1. tensorflow保存和恢复模型saver.restore

    1.本文只对一些细节点做补充,大体的步骤就不详述了 2.保存模型 ① 首先我使用的是tensorflow-gpu 1.4.0 ② 这个版本生成的ckpt文件是这样的: 其中.meta存放的是网络模型和 ...

  2. TensorFlow - 保存和恢复

    TensorFlow - 保存和恢复 https://tensorflow.google.cn/guide/saved_model TensorFlow 指南 - TensorFlow 工作原理 ht ...

  3. TensorFlow2 -官方教程 :保存和恢复模型

    文章目录 准备工作:安装,导入,获取数据集,定义model 在训练期间保存模型(以 checkpoints 形式保存) Checkpoint 回调用法 checkpoint 回调选项 这些文件是什么? ...

  4. Tensorflow: 保存和复原模型(save and restore)

    报错: is not valid checkpoint 解决: module_file = tf.train.latest_checkpoint(diag_obj.save_path) saver.r ...

  5. TensorFlow 保存和加载模型

    参考: 保存和恢复模型官方教程 tensorflow2保存和加载模型 TensorFlow2.0教程-keras模型保存和序列化

  6. Tensorflow深度学习实战之(五)--保存与恢复模型

    文章目录 一.保存模型 二.恢复模型 三.使用模型预测 一.保存模型 在训练完Tensorflow模型为了方便对新的数据进行预测需要保存该模型,Tensorflow提供 tf.train.Saver( ...

  7. 保存和加载模型的方法

    目录 保存模型权重 保存整个模型 保存模型权重 1. 使用回调函数保存 2. 手动保存 这种是在model.fit时传入保存checkpoint的回调函数.使用的回调函数是tf.keras.callb ...

  8. Tensorflow |(5)模型保存与恢复、自定义命令行参数

    Tensorflow |(1)初识Tensorflow Tensorflow |(2)张量的阶和数据类型及张量操作 Tensorflow |(3)变量的的创建.初始化.保存和加载 Tensorflow ...

  9. Tensorflow:如何保存/恢复模型?

    在Tensorflow中训练模型后: 您如何保存经过训练的模型? 您以后如何还原此保存的模型? #1楼 对于TensorFlow版本<0.11.0RC1: 保存的检查点包含模型中Variable ...

最新文章

  1. 刚入行的小菜鸡,怎样做好功能测试?
  2. golang 字符串拼接方式
  3. This Style does not belong to the supplied Workbook. Are you trying to assign a style from one workb
  4. Web前端岗位面试题汇总(含答案)
  5. mvvm 自动绑定_ZK的实际应用:MVVM –表单绑定
  6. 数据结构—链表-单链表基本操作实现
  7. concurrentbag 删除指定元素_Python 列表,for循环,元组的使用(修改、添加、删除、排序、切片)
  8. php的数组操作,PHP的数组操作
  9. Liunx教程超详细(完整)
  10. pip 使用国内阿里云软件源
  11. 总结中间方攻击和CA认证中心
  12. 微信多开软件苹果版_快手充值快币微信充值苹果版;
  13. 将quantopian的动量策略迁移到老虎证券量化api
  14. 2021github仓库操作流程手册指南
  15. 解决记事本写java时出现中文乱码问题
  16. Redis-敲黑板划重点
  17. 【Django】Django配置文件和设计模式详解
  18. GitHub申请账号
  19. 把自己电脑做成web服务器+内网穿透并发布网页
  20. 携程技术面试官素质。。。。。。哎

热门文章

  1. 在openEuler上做开发?这个大赛拿出30万寻找开源的yyds
  2. 带你读AI论文:基于Transformer的直线段检测
  3. 漫谈SCA(软件成分分析)测试技术:原理、工具与准确性
  4. 探究Python源码,终于弄懂了字符串驻留技术
  5. 补习系列(8)-springboot 单元测试之道
  6. 野生前端的数据结构基础练习(5)——散列
  7. Kotlin学习笔记 第二章 类与对象 第十四 十五节 委托 委托属性
  8. LIF模型及其变种 Training Spiking Deep Networks for Neuromorphic Hardware
  9. 高等组合学笔记(十五):容斥原理,错排问题
  10. R Studio更换外部包镜像的方法