1. Saver的背景介绍

我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。

  1. Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
  2. 只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
  3. 为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

2. tf.train.get_checkpoint_state

tf.train.get_checkpoint_state函数通过checkpoint文件找到模型文件名

tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)

该函数返回的是checkpoint文件CheckpointState proto类型的内容,其中有model_checkpoint_path和all_model_checkpoint_paths两个属性。其中model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型文件的文件名。

1.参数cheackpoint_dir:  模型所在的文件夹名字

3.实例展示

获取一个文件例子:

import tensorflow as tf
import numpy as npx = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + bloss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.6)
train = optimizer.minimize(loss)isTrain = False
train_steps = 10
checkpoint_steps = 5
checkpoint_dir = ''saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))with tf.Session() as sess:sess.run(tf.initialize_all_variables())if isTrain:for i in xrange(train_steps):sess.run(train, feed_dict={x: x_data})if (i + 1) % checkpoint_steps == 0:saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)else:ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)         #ckpt.model_checkpoint_path就是模型的路径名字。表示获取最新的tensorflow模型文件的文件名else:passprint(sess.run(w))print(sess.run(b))#isTrain:用来区分训练阶段和测试阶段,True表示训练,False表示测试
#train_steps:表示训练的次数
#checkpoint_steps:表示训练多少次保存一下checkpoint
#checkpoint_dir:表示checkpoints文件的保存路径,例子中使用当前路径

获取多有模型例子:

with tf.Session() as sess:            ckpt=tf.train.get_checkpoint_state('Model/')print(ckpt)if ckpt and ckpt.all_model_checkpoint_paths:#加载模型#这一部分是有多个模型文件时,对所有模型进行测试验证for path in ckpt.all_model_checkpoint_paths:saver.restore(sess,path)                global_step=path.split('/')[-1].split('-')[-1]accuracy_score=sess.run(accuracy,feed_dict=validate_feed)print("After %s training step(s),valisation accuracy = %g"%(global_step,accuracy_score))else:print('No checkpoint file found')return#time.sleep(eval_interval_secs)return

Tensorflow学习(二)之——保存加载模型、Saver的用法相关推荐

  1. pytorch学习笔记(6):GPU和如何保存加载模型

    参考文档:https://mp.weixin.qq.com/s/kmed_E4MaDwN-oIqDh8-tg 上篇文章我们完成了一个 vgg 网络的实现,那么现在已经掌握了一些基础的网络结构的实现,距 ...

  2. 图片预加载学习(二):有序加载之图片切换

    基本效果同前一篇,业务有所变化:前一篇是先显示进度条待所有的图片加载完成了再显示图片,这一篇是先显示第一张图片然后依次加载其他图片(比较适合于有内容的图片,人在看第一张图片时程序默默的加载后面的图片) ...

  3. 6.2 模型保存 --- 加载和保存模型结构权重

    一.只保存/加载模型的结构 保存模型的结构,而非其权重或训练配置项: json_string = model.to_json() model.save('my_model.h5') my_model_ ...

  4. Pytorch 保存和加载模型

    当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...

  5. 【pytorch】(六)保存和加载模型

    文章目录 保存和加载模型 保存加载模型参数 保存加载模型和参数 保存和加载模型 import torch from torch import nn from torch.utils.data impo ...

  6. tensorflow中保存模型、加载模型做预测(不需要再定义网络结构)

    下面用一个线下回归模型来记载保存模型.加载模型做预测 参考文章: http://blog.csdn.net/thriving_fcl/article/details/71423039 训练一个线下回归 ...

  7. tensorflow 1.x Saver(保存与加载模型) 预测

    20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...

  8. 网页怎么预先加载模型_使用预先训练的模型进行转移学习

    网页怎么预先加载模型 深度学习 (Deep Learning) 什么是转学? (What is Transfer Learning?) Transfer learning is a research ...

  9. OpenGL教程翻译 第二十二课 使用Assimp加载模型

    第二十二课 使用Assimp加载模型 原文地址:http://ogldev.atspace.co.uk/(源码请从原文主页下载) 背景 到现在为止我们都在使用手动生成的模型.正如你所想的,指明每个顶点 ...

最新文章

  1. java程序员遇到的问题_Java 程序员平时最常遇到的故障:系统OOM (一)
  2. Symantec Endpoint - quarantine
  3. OpenVAS漏洞扫描
  4. python模块之email: 电子邮件编码解码 (二、编码邮件)
  5. Rational Rose2007无法启动,提示缺少“suite objects.dll”
  6. (1)段寄存器属性探测
  7. Java程序员从笨鸟到菜鸟之(七十五)细谈struts2(十四)struts2+ajax实现异步验证...
  8. 大数据、数据挖掘、机器学习与模式识别的关系
  9. 使用Elasticsearch,Kafka和Cassandra构建流式数据中心
  10. 同一对象多条数据同时插入数据库
  11. Postgre 中的空值判断
  12. win10使用navicat管理数据库
  13. 阿里云数据传输服务低价不低质,服务再升级 1
  14. Linux基础知识以及常见面试问题
  15. 城域网光纤、拨号光纤与ADSL的区别
  16. 今天第一次来这里开博,大家多多指教
  17. python 91图片站爬虫
  18. python简单爬虫程序分析_Python简单爬虫
  19. FPGA可以转行数字IC验证吗?
  20. deepin/ubuntu 网易云解锁 UnblockNeteaseMusic

热门文章

  1. cocos2dx luabinding C/C++/LUA部分
  2. CMS垃圾收集器详解(转载)
  3. 2019华北最大国际消防展览会
  4. 虚拟机服务器需要多大内存吗,虚拟机服务器需要多大内存
  5. 考计算机证怎么考?在哪里可以报名?
  6. 解决最新版本chrome及edge 页面崩溃的问题
  7. 通过ceph-ansible安装ceph
  8. LDAP概念和原理介绍
  9. 抖音头条小程序担保支付php版demo源码
  10. 【Android开发经验】APP的缓存文件到底应该存在哪?看完这篇文章你应该就自己清楚了