saver()与restore()只是保存了session中的相关变量对应的值,并不涉及模型的结构。

Saver的作用是将我们训练好的模型的参数保存下来,以便下一次继续用于训练或测试;Restore则是将训练好的参数提取出来。Saver类训练完后,是以checkpoints文件形式保存。提取的时候也是从checkpoints文件中恢复变量。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。

一般地,Saver会自动的管理Checkpoints文件。我们可以指定保存最近的N个Checkpoints文件,当然每一步都保存ckpt文件也是可以的,只是没必要,费存储空间。

  • saver()可以选择global_step参数来为ckpt文件名添加数字标记:
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
  • max_to_keep参数定义saver()将自动保存的最近n个ckpt文件,默认n=5,即保存最近的5个检查点ckpt文件。若n=0或者None,则保存所有的ckpt文件。
  • keep_checkpoint_every_n_hours与max_to_keep类似,定义每n小时保存一个ckpt文件。
...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):sess.run(..training_op..)if step % 1000 == 0:# Append the step number to the checkpoint name:saver.save(sess, 'my-model', global_step=step)

Restore

restore(sess, save_path)
# sess: A Session to use to restore the parameters.
# save_path: Path where parameters were previously saved.
  • sess: 保存参数的会话。
  • save_path: 保存参数的路径。
  • 当从文件中恢复变量时,不需要事先对他们进行初始化,因为“恢复”自身就是一种初始化变量的方法。
  • 可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:
model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

程序后半段代码我们可以改为:

sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())is_train=False
saver=tf.train.Saver(max_to_keep=3)#训练阶段
if is_train:max_acc=0f=open('ckpt/acc.txt','w')for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()#验证阶段
else:model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

参考:
https://www.cnblogs.com/denny402/p/6940134.html
https://blog.csdn.net/hellocsz/article/details/89097380

saver.save和saver.restore相关推荐

  1. Android 中Canvas的save(),saveLayer()和restore()解析

    1.save()方法 : 用来保存Canvas的状态,save()方法之后的代码,可以调用Canvas的平移.放缩.旋转.裁剪等操作! 2.restore()方法: 用来恢复Canvas之前保存的状态 ...

  2. canvas save()和canvas restore()状态的保存和恢复使用方法及实例

    canvas.save()用来保存先前状态的 canvas.restore()用来恢复之前保存的状态 注:两种方法必须搭配使用,否则没有效果 <!DOCTYPE html> <htm ...

  3. TensorFlow精进之路(一):Softmax回归模型训练MNIST

    1.MNIST数据集简介: MNIST数据集主要由一些手写数字的图片和相应标签组成,图片总共分为10类,分别对应0-9十个数字. 如上图所示,每张图片的大小为28×28像素.而标签则由one-hot向 ...

  4. [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

    [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...

  5. python量化投资培训清华大学深研院_GitHub - CatsJuice/quantitative-investment-learning: 使用Python进行量化投资的学习报告...

    quantitative-investment-learning 使用Python进行量化投资的学习报告 Python量化投资学习报告 CatsJuice 编辑于 2019-4-26 上一次更新: 2 ...

  6. 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)...

    1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...

  7. tensorflow tf.train.Saver.restore() (用于下次训练时恢复模型)

    # 保存当前的Session到文件目录tf.train.Saver().save(sess, 'net/my_net.ckpt') # 然后在下次训练时恢复模型: tf.train.Saver().r ...

  8. Tensorflow保存神经网络参数有妙招:Saver和Restore

    摘要:这篇文章将讲解TensorFlow如何保存变量和神经网络参数,通过Saver保存神经网络,再通过Restore调用训练好的神经网络. 本文分享自华为云社区<[Python人工智能] 十一. ...

  9. tensorflow保存和恢复模型saver.restore

    1.本文只对一些细节点做补充,大体的步骤就不详述了 2.保存模型 ① 首先我使用的是tensorflow-gpu 1.4.0 ② 这个版本生成的ckpt文件是这样的: 其中.meta存放的是网络模型和 ...

  10. python中save 函数_Tensorflow之Saver的用法详解

    Saver的用法 1. Saver的背景介绍 我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试.Tensorflow针对这一需求提供了Saver ...

最新文章

  1. copy模块中的copy与deepcopy的区别
  2. vscode使用教程python-用VSCode写python的正确姿势
  3. KEIL-MDK 5 CMSIS的问题
  4. Wechat的支付逻辑流程
  5. HTML-CSS背景渐进色
  6. 业务功能中包含邮件发送,怎么测试?
  7. linux怎么才能算telnet成功_怎么表白才算成功呢
  8. 德勤预判:2022技术七大趋势
  9. JPA + Hibernate + PostgreSQL + Maven基本配置示例
  10. Java中timer的schedule()和schedualAtFixedRate()函数的区别
  11. DSS的Windows版本如何安装呢?
  12. cad快捷栏怎么调出来_cad怎么显示工具栏快捷键 cad快捷键常见问题解决办法
  13. 卸载WPS后Office文档图标显示异常
  14. uni-app开发经验分享十三:实现手机扫描二维码并跳转全过程
  15. 死锁预防之银行家算法
  16. 外贸人如何在领英linkedin上高效开发客户
  17. UGUI ContentSizeFitter 嵌套 适配
  18. 美团外卖退款显示服务器异常,外卖遇到异常订单几种类型及处理技巧
  19. testtesttesttesttesttesttest
  20. 绝地潜兵服务器不稳定,爽快的合作射爆游戏《绝地潜兵》评测评分汇总

热门文章

  1. 电脑如何做动态图 GIF表情包怎么制作
  2. json 转 实体对象 报解析错误
  3. 扣扣机器人唱歌_qq小冰唱歌指令生成器下载-qq小冰唱歌关键字命令生成器 _5577安卓网...
  4. Shopee上货软件,新手小白必备神器
  5. 路由器自动获取ip失败
  6. 网站服务器停止响应是什么意思?
  7. PowerBI切换日期维度
  8. crt计算机显示器,crt显示器最高分辨率_crt显示器最高刷新率
  9. 国外广告联盟:玩转国外CPC网站作弊
  10. 激活win10专业版最简单的方法