这篇文章主要介绍了tensorflow1.0学习之模型的保存与恢复(Saver) ,现在分享给大家,也给大家做个参考。一起过来看看吧

将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'

...

saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一个mnist实例:

# -*- coding: utf-8 -*-

"""

Created on Sun Jun 4 10:29:48 2017

@author: Administrator

"""

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])

y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x,

units=1024,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

dense2= tf.layers.dense(inputs=dense1,

units=512,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

logits= tf.layers.dense(inputs=dense2,

units=10,

activation=None,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)

train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)

acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

sess.close()

代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver(max_to_keep=1)

max_acc=0

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)

max_acc=0

f=open('ckpt/acc.txt','w')

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

f.close()

sess.close()

模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')

saver.restore(sess,model_file)

则程序后半段代码我们可以改为:

sess=tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

is_train=False

saver=tf.train.Saver(max_to_keep=3)

#训练阶段

if is_train:

max_acc=0

f=open('ckpt/acc.txt','w')

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

f.close()

#验证阶段

else:

model_file=tf.train.latest_checkpoint('ckpt/')

saver.restore(sess,model_file)

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))

sess.close()

标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。

整个源程序:

# -*- coding: utf-8 -*-

"""

Created on Sun Jun 4 10:29:48 2017

@author: Administrator

"""

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])

y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x,

units=1024,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

dense2= tf.layers.dense(inputs=dense1,

units=512,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

logits= tf.layers.dense(inputs=dense2,

units=10,

activation=None,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)

train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)

acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

is_train=True

saver=tf.train.Saver(max_to_keep=3)

#训练阶段

if is_train:

max_acc=0

f=open('ckpt/acc.txt','w')

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

f.close()

#验证阶段

else:

model_file=tf.train.latest_checkpoint('ckpt/')

saver.restore(sess,model_file)

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))

sess.close()

相关推荐:

将TensorFlow的模型网络导出为单个文件的方法

python模型保存与恢复_tensorflow1.0学习之模型的保存与恢复(Saver)_python相关推荐

  1. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  2. c调用python keras模型_使用TensorFlow 2.0创建Keras模型的三种方法

    TensorFlow 2.0和tf.keras提供了三种方式来实现神经网络模型:Sequential API Functional API Model subclassing 下面我将分别使用这三种方 ...

  3. java加载pytorch模型,使用Dev Pytorch 1.0将Pytorch模型加载到C中

    Pytorch 1.0具有将模型转换为火炬脚本程序(以某种方式序列化)的功能,以使其能够在没有Python依赖性的情况下在C中执行 . 这是如何做到的: import torch import tor ...

  4. python程序基本语法实验_Python3.0学习基本语法·1

    1.input() 程序运行到这里,需要输入东西,按下enter才能继续程序 示例: name=input('请输入你的名字:') name+='嘛,大家都懂得' print(name) 请输入你的名 ...

  5. Three.js学习六——模型动画

    目录 Three.js动画系统(Animation system) 实现流程 基本流程 工程文件 场景搭建 添加模型 模型动画 动画实现的基本流程 相关对象方法和代码 完整代码和实现效果 Three. ...

  6. 伯克利、OpenAI等提出基于模型的元策略优化强化学习

    基于模型的强化学习方法数据效率高,前景可观.本文提出了一种基于模型的元策略强化学习方法,实践证明,该方法比以前基于模型的方法更能够应对模型缺陷,还能取得与无模型方法相近的性能. 引言 强化学习领域近期 ...

  7. Tensorflow深度学习实战之(五)--保存与恢复模型

    文章目录 一.保存模型 二.恢复模型 三.使用模型预测 一.保存模型 在训练完Tensorflow模型为了方便对新的数据进行预测需要保存该模型,Tensorflow提供 tf.train.Saver( ...

  8. 【Python深度学习】基于Tensorflow2.0构建CNN模型尝试分类音乐类型(二)

    前情提要 基于上文所说 基于Tensorflow2.0构建CNN模型尝试分类音乐类型(一) 我用tf2.0和Python3.7复现了一个基于CNN做音乐分类器.用余弦相似度评估距离的一个音乐推荐模型. ...

  9. TensorFlow2.0学习笔记-3.模型训练

    3.模型训练 3.1.Keras版本模型训练 • 构建模型(顺序模型.函数式模型.子类模型) • 模型训练: model.fit() • 模型验证: model.evaluate() • 模型预测:  ...

  10. 基于python语言,使用爬虫和深度学习来预测双色球(二、模型训练与测试)

    在上一篇博文中(基于python语言,使用爬虫和深度学习来预测双色球(一.数据的准备)),我们完成了数据的准备,通过爬虫的方式将2003年至今的每一期的中奖数据保存到了txt文件中,那么我们现在就开始 ...

最新文章

  1. 清华90后女学霸范楚楚将加入MIT任助理教授,面试宝典分享!
  2. JDK内置工具--jconsole
  3. 关键字—修饰方法、类、属性和变量的关键字(共9个)
  4. html使用element ui_Kendo UI for jQuery使用教程:自定义小部件(二)
  5. linux系统编程之进程(七):system()函数使用【转】
  6. html5 密码框明文,elementUI的密码框的密文和明文
  7. php node.js django,Vue.js和Django搭建前后端分离项目示例详解
  8. IT众包Web网站服务案例:江苏保税店
  9. mvc做网站怎么在mvc中直接访问.html网页 [问题点数:20分]
  10. PacMan开发-碰撞检测实现
  11. 读大道至简——编程的精义感想
  12. 药品数据查询系统工具(非付费官网50个)
  13. 粒子滤波算法matlab代码,粒子滤波算法原理及Matlab程序(专题).ppt
  14. 【 以项目实战讲解CAD的二次开发(二)】
  15. Debian11更换阿里源
  16. 数据结构 hbb(汉堡包)
  17. 选择排序(Selection sort)是一种简单直观的排序算法
  18. 广电为什么禁止投屏_广电的机顶盒怎么投屏
  19. vivo智能手机产能
  20. HDU-4069___Squiggly Sudoku —— 锯齿数独 + BFS

热门文章

  1. vue同一项目搭建PC端和移动端
  2. 【FPGA的小娱乐】tft显示屏生成信号辅助测试阵列
  3. JS实现方块颜色的渐变
  4. ns注册改服务器,NameSilo域名更改NS服务器简单过程介绍
  5. TCP 握手没成功怎么办?
  6. 美国大学计算机专业排名2014,2014年美国大学计算机专业排名
  7. 2023跨境出海指南:马来西亚网红营销白皮书
  8. 虾皮马来西亚站如何选品?附快速出单秘诀
  9. cryengine3 C++添加结点,制作插件
  10. 10.2 校内集训 解题报告