将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一个mnist实例:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun  4 10:29:48 2017@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, units=10, activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())saver=tf.train.Saver(max_to_keep=1)
for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

则程序后半段代码我们可以改为:

sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())is_train=False
saver=tf.train.Saver(max_to_keep=3)#训练阶段
if is_train:max_acc=0f=open('ckpt/acc.txt','w')for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()#验证阶段
else:model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。

整个源程序:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun  4 10:29:48 2017@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, units=10, activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())is_train=True
saver=tf.train.Saver(max_to_keep=3)#训练阶段
if is_train:max_acc=0f=open('ckpt/acc.txt','w')for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()#验证阶段
else:model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

View Code

参考文章:http://blog.csdn.net/u011500062/article/details/51728830

tensorflow 1.0 学习:模型的保存与恢复(Saver)相关推荐

  1. TensorFlow:模型的保存与恢复(Saver)

    目录 前言 1 实例化对象 2 保存训练过程中或者训练好的, 模型图及权重参数 2.1保存训练模型 2.2 查看保存 3. 重载模型的图及权重参数(模型恢复)     前言 我们经常在训练完一个模型之 ...

  2. Tensorflow【实战Google深度学习框架】TensorFlow模型的保存与恢复加载

    我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存 ...

  3. nvidia-docker2完成tensorflow/serving深度学习模型在线部署

    深度学习技术已经广泛应用在各个行业领域.实际应用,通过大量数据往往可以训练一个泛化能力好的模型,但如何将模型进行快捷.方便的远程部署,已成为好多企业考虑的问题.现阶段,常用的深度学习模型远程部署工具有 ...

  4. TensorFlow 模型的保存与恢复

    TensorFlow目前保存的模型文件主要有两种,ckpt与pb,二者之间的异同请见 https://zhuanlan.zhihu.com/p/32887066 下面,我以mnist手写数据集用sof ...

  5. tensorflow 1.0 学习:参数和特征的提取

    tensorflow 1.0 学习:参数和特征的提取 在tf中,参与训练的参数可用 tf.trainable_variables()提取出来,如: #取出所有参与训练的参数 params=tf.tra ...

  6. 基于Python的模型的保存、恢复、继续训练

    资源下载地址:https://download.csdn.net/download/sheziqiong/86774566 资源下载地址:https://download.csdn.net/downl ...

  7. 简单完整地讲解tensorflow模型的保存和恢复

    http://blog.csdn.net/liangyihuai/article/details/78515913 在本教程主要讲到: 1. 什么是Tensorflow模型? 2. 如何保存Tenso ...

  8. Tensorflow模型的保存与恢复的细节

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  9. TensorFlow 2.0 - 自定义模型、训练过程

    文章目录 1. 自定义模型 2. 学习流程 学习于:简单粗暴 TensorFlow 2 1. 自定义模型 重载 call() 方法,pytorch 是重载 forward() 方法 import te ...

最新文章

  1. linux shell数据重定向(输入重定向与输出重定向)详细分析
  2. 大数据时代 安全与应用的博弈
  3. App上架应用市场,如何攻破安全过检难题
  4. Kubernetes的四种用户部署场景
  5. 使用命名空间解决名字冲突
  6. RabbitMQ 交换器、持久化
  7. 【PAT甲级 BigInteger】1019 General Palindromic Number (20 分) Java版 7/7通过
  8. URAL1815 Farm in San Andreas(费马点,圆圆相交)
  9. 如何编写杂项设备驱动
  10. 真机开包!国产至强5600服务器35张图赏
  11. Leecode-2 Add Two Numbers
  12. iphone连不上微软服务器,windows没法连接到iphone是什么意思
  13. STL中迭代器的介绍及分类
  14. java验证码(采用struts2实现)
  15. 机械动力学瑞利法matlab程序,工业机器人的动力学建模与仿真+MATLAB程序
  16. 计算机桌面的图标怎么删除,桌面图标删不掉怎么办 如何彻底删除桌面图标
  17. C语言要点系统复习三:scanf读取缓冲区的那些事
  18. deadline集群渲染_Maya笔记
  19. 中国Linux内核开发者大会
  20. Word粘贴时出现“文件未找到:MathPage.WLL”的解决方案

热门文章

  1. MFC中CStatic控件动态改变
  2. 京东对话中国农民丰收节交易会 供应链下沉打通产销全链
  3. 东平县农民丰收节-农业大健康·万祥军:功能性谋定生态品牌
  4. JS-WEB-API(存储)
  5. spring 配置文件无法加载,junit找不到xml配置文件java.lang.IllegalStateException: Failed to load ApplicationContext...
  6. 第七章 控制PL/SQL错误
  7. IBatis.Net学习笔记二--下载、编译、运行NPetShop
  8. 【干货】Facebook产品经理:高效对接and流程解读
  9. pmcaff系列活动《走进今日头条》
  10. 轴对称 Navier-Stokes 方程组的点态正则性准则 I