tensorflow 1.0 学习:模型的保存与恢复(Saver)
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。
模型保存,先要创建一个Saver对象:如
saver=tf.train.Saver()
在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:
saver=tf.train.Saver(max_to_keep=0)
但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。
当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即
saver=tf.train.Saver(max_to_keep=1)
创建完saver对象后,就可以保存训练好的模型了,如:
saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
看一个mnist实例:
# -*- coding: utf-8 -*- """ Created on Sun Jun 4 10:29:48 2017@author: Administrator """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)x = tf.placeholder(tf.float32, [None, 784]) y_=tf.placeholder(tf.int32,[None,])dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss) logits= tf.layers.dense(inputs=dense2, units=10, activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits) train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer())saver=tf.train.Saver(max_to_keep=1) for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) sess.close()
代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).
在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。
saver=tf.train.Saver(max_to_keep=1) max_acc=0 for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) sess.close()
如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。
saver=tf.train.Saver(max_to_keep=3) max_acc=0 f=open('ckpt/acc.txt','w') for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1) f.close() sess.close()
模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:
model_file=tf.train.latest_checkpoint('ckpt/') saver.restore(sess,model_file)
则程序后半段代码我们可以改为:
sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer())is_train=False saver=tf.train.Saver(max_to_keep=3)#训练阶段 if is_train:max_acc=0f=open('ckpt/acc.txt','w')for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()#验证阶段 else:model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('val_loss:%f, val_acc:%f'%(val_loss,val_acc)) sess.close()
标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。
整个源程序:
# -*- coding: utf-8 -*- """ Created on Sun Jun 4 10:29:48 2017@author: Administrator """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)x = tf.placeholder(tf.float32, [None, 784]) y_=tf.placeholder(tf.int32,[None,])dense1 = tf.layers.dense(inputs=x, units=1024, activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss) logits= tf.layers.dense(inputs=dense2, units=10, activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits) train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer())is_train=True saver=tf.train.Saver(max_to_keep=3)#训练阶段 if is_train:max_acc=0f=open('ckpt/acc.txt','w')for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()#验证阶段 else:model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('val_loss:%f, val_acc:%f'%(val_loss,val_acc)) sess.close()
View Code
参考文章:http://blog.csdn.net/u011500062/article/details/51728830
tensorflow 1.0 学习:模型的保存与恢复(Saver)相关推荐
- TensorFlow:模型的保存与恢复(Saver)
目录 前言 1 实例化对象 2 保存训练过程中或者训练好的, 模型图及权重参数 2.1保存训练模型 2.2 查看保存 3. 重载模型的图及权重参数(模型恢复) 前言 我们经常在训练完一个模型之 ...
- Tensorflow【实战Google深度学习框架】TensorFlow模型的保存与恢复加载
我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存 ...
- nvidia-docker2完成tensorflow/serving深度学习模型在线部署
深度学习技术已经广泛应用在各个行业领域.实际应用,通过大量数据往往可以训练一个泛化能力好的模型,但如何将模型进行快捷.方便的远程部署,已成为好多企业考虑的问题.现阶段,常用的深度学习模型远程部署工具有 ...
- TensorFlow 模型的保存与恢复
TensorFlow目前保存的模型文件主要有两种,ckpt与pb,二者之间的异同请见 https://zhuanlan.zhihu.com/p/32887066 下面,我以mnist手写数据集用sof ...
- tensorflow 1.0 学习:参数和特征的提取
tensorflow 1.0 学习:参数和特征的提取 在tf中,参与训练的参数可用 tf.trainable_variables()提取出来,如: #取出所有参与训练的参数 params=tf.tra ...
- 基于Python的模型的保存、恢复、继续训练
资源下载地址:https://download.csdn.net/download/sheziqiong/86774566 资源下载地址:https://download.csdn.net/downl ...
- 简单完整地讲解tensorflow模型的保存和恢复
http://blog.csdn.net/liangyihuai/article/details/78515913 在本教程主要讲到: 1. 什么是Tensorflow模型? 2. 如何保存Tenso ...
- Tensorflow模型的保存与恢复的细节
翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...
- TensorFlow 2.0 - 自定义模型、训练过程
文章目录 1. 自定义模型 2. 学习流程 学习于:简单粗暴 TensorFlow 2 1. 自定义模型 重载 call() 方法,pytorch 是重载 forward() 方法 import te ...
最新文章
- linux shell数据重定向(输入重定向与输出重定向)详细分析
- 大数据时代 安全与应用的博弈
- App上架应用市场,如何攻破安全过检难题
- Kubernetes的四种用户部署场景
- 使用命名空间解决名字冲突
- RabbitMQ 交换器、持久化
- 【PAT甲级 BigInteger】1019 General Palindromic Number (20 分) Java版 7/7通过
- URAL1815 Farm in San Andreas(费马点,圆圆相交)
- 如何编写杂项设备驱动
- 真机开包!国产至强5600服务器35张图赏
- Leecode-2 Add Two Numbers
- iphone连不上微软服务器,windows没法连接到iphone是什么意思
- STL中迭代器的介绍及分类
- java验证码(采用struts2实现)
- 机械动力学瑞利法matlab程序,工业机器人的动力学建模与仿真+MATLAB程序
- 计算机桌面的图标怎么删除,桌面图标删不掉怎么办 如何彻底删除桌面图标
- C语言要点系统复习三:scanf读取缓冲区的那些事
- deadline集群渲染_Maya笔记
- 中国Linux内核开发者大会
- Word粘贴时出现“文件未找到:MathPage.WLL”的解决方案
热门文章
- MFC中CStatic控件动态改变
- 京东对话中国农民丰收节交易会 供应链下沉打通产销全链
- 东平县农民丰收节-农业大健康·万祥军:功能性谋定生态品牌
- JS-WEB-API(存储)
- spring 配置文件无法加载,junit找不到xml配置文件java.lang.IllegalStateException: Failed to load ApplicationContext...
- 第七章 控制PL/SQL错误
- IBatis.Net学习笔记二--下载、编译、运行NPetShop
- 【干货】Facebook产品经理:高效对接and流程解读
- pmcaff系列活动《走进今日头条》
- 轴对称 Navier-Stokes 方程组的点态正则性准则 I