saver.save和saver.restore
saver()与restore()只是保存了session中的相关变量对应的值,并不涉及模型的结构。
Saver的作用是将我们训练好的模型的参数保存下来,以便下一次继续用于训练或测试;Restore则是将训练好的参数提取出来。Saver类训练完后,是以checkpoints文件形式保存。提取的时候也是从checkpoints文件中恢复变量。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
一般地,Saver会自动的管理Checkpoints文件。我们可以指定保存最近的N个Checkpoints文件,当然每一步都保存ckpt文件也是可以的,只是没必要,费存储空间。
- saver()可以选择global_step参数来为ckpt文件名添加数字标记:
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
- max_to_keep参数定义saver()将自动保存的最近n个ckpt文件,默认n=5,即保存最近的5个检查点ckpt文件。若n=0或者None,则保存所有的ckpt文件。
- keep_checkpoint_every_n_hours与max_to_keep类似,定义每n小时保存一个ckpt文件。
...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):sess.run(..training_op..)if step % 1000 == 0:# Append the step number to the checkpoint name:saver.save(sess, 'my-model', global_step=step)
Restore
restore(sess, save_path)
# sess: A Session to use to restore the parameters.
# save_path: Path where parameters were previously saved.
- sess: 保存参数的会话。
- save_path: 保存参数的路径。
- 当从文件中恢复变量时,不需要事先对他们进行初始化,因为“恢复”自身就是一种初始化变量的方法。
- 可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:
model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)
在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。
saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()
如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。
saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()
程序后半段代码我们可以改为:
sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())is_train=False
saver=tf.train.Saver(max_to_keep=3)#训练阶段
if is_train:max_acc=0f=open('ckpt/acc.txt','w')for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()#验证阶段
else:model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()
参考:
https://www.cnblogs.com/denny402/p/6940134.html
https://blog.csdn.net/hellocsz/article/details/89097380
saver.save和saver.restore相关推荐
- Android 中Canvas的save(),saveLayer()和restore()解析
1.save()方法 : 用来保存Canvas的状态,save()方法之后的代码,可以调用Canvas的平移.放缩.旋转.裁剪等操作! 2.restore()方法: 用来恢复Canvas之前保存的状态 ...
- canvas save()和canvas restore()状态的保存和恢复使用方法及实例
canvas.save()用来保存先前状态的 canvas.restore()用来恢复之前保存的状态 注:两种方法必须搭配使用,否则没有效果 <!DOCTYPE html> <htm ...
- TensorFlow精进之路(一):Softmax回归模型训练MNIST
1.MNIST数据集简介: MNIST数据集主要由一些手写数字的图片和相应标签组成,图片总共分为10类,分别对应0-9十个数字. 如上图所示,每张图片的大小为28×28像素.而标签则由one-hot向 ...
- [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)
[TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...
- python量化投资培训清华大学深研院_GitHub - CatsJuice/quantitative-investment-learning: 使用Python进行量化投资的学习报告...
quantitative-investment-learning 使用Python进行量化投资的学习报告 Python量化投资学习报告 CatsJuice 编辑于 2019-4-26 上一次更新: 2 ...
- 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)...
1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...
- tensorflow tf.train.Saver.restore() (用于下次训练时恢复模型)
# 保存当前的Session到文件目录tf.train.Saver().save(sess, 'net/my_net.ckpt') # 然后在下次训练时恢复模型: tf.train.Saver().r ...
- Tensorflow保存神经网络参数有妙招:Saver和Restore
摘要:这篇文章将讲解TensorFlow如何保存变量和神经网络参数,通过Saver保存神经网络,再通过Restore调用训练好的神经网络. 本文分享自华为云社区<[Python人工智能] 十一. ...
- tensorflow保存和恢复模型saver.restore
1.本文只对一些细节点做补充,大体的步骤就不详述了 2.保存模型 ① 首先我使用的是tensorflow-gpu 1.4.0 ② 这个版本生成的ckpt文件是这样的: 其中.meta存放的是网络模型和 ...
- python中save 函数_Tensorflow之Saver的用法详解
Saver的用法 1. Saver的背景介绍 我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试.Tensorflow针对这一需求提供了Saver ...
最新文章
- copy模块中的copy与deepcopy的区别
- vscode使用教程python-用VSCode写python的正确姿势
- KEIL-MDK 5 CMSIS的问题
- Wechat的支付逻辑流程
- HTML-CSS背景渐进色
- 业务功能中包含邮件发送,怎么测试?
- linux怎么才能算telnet成功_怎么表白才算成功呢
- 德勤预判:2022技术七大趋势
- JPA + Hibernate + PostgreSQL + Maven基本配置示例
- Java中timer的schedule()和schedualAtFixedRate()函数的区别
- DSS的Windows版本如何安装呢?
- cad快捷栏怎么调出来_cad怎么显示工具栏快捷键 cad快捷键常见问题解决办法
- 卸载WPS后Office文档图标显示异常
- uni-app开发经验分享十三:实现手机扫描二维码并跳转全过程
- 死锁预防之银行家算法
- 外贸人如何在领英linkedin上高效开发客户
- UGUI ContentSizeFitter 嵌套 适配
- 美团外卖退款显示服务器异常,外卖遇到异常订单几种类型及处理技巧
- testtesttesttesttesttesttest
- 绝地潜兵服务器不稳定,爽快的合作射爆游戏《绝地潜兵》评测评分汇总