将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'

...

saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一个mnist实例:

# -*- coding: utf-8 -*-

"""

Created on Sun Jun 4 10:29:48 2017

@author: Administrator

"""

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])

y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x,

units=1024,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

dense2= tf.layers.dense(inputs=dense1,

units=512,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

logits= tf.layers.dense(inputs=dense2,

units=10,

activation=None,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)

train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)

acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

sess.close()

代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver(max_to_keep=1)

max_acc=0

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)

max_acc=0

f=open('ckpt/acc.txt','w')

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

f.close()

sess.close()

模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')

saver.restore(sess,model_file)

则程序后半段代码我们可以改为:

sess=tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

is_train=False

saver=tf.train.Saver(max_to_keep=3)

#训练阶段

if is_train:

max_acc=0

f=open('ckpt/acc.txt','w')

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

f.close()

#验证阶段

else:

model_file=tf.train.latest_checkpoint('ckpt/')

saver.restore(sess,model_file)

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))

sess.close()

标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。

整个源程序:

# -*- coding: utf-8 -*-

"""

Created on Sun Jun 4 10:29:48 2017

@author: Administrator

"""

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])

y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x,

units=1024,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

dense2= tf.layers.dense(inputs=dense1,

units=512,

activation=tf.nn.relu,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

logits= tf.layers.dense(inputs=dense2,

units=10,

activation=None,

kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)

train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)

acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

is_train=True

saver=tf.train.Saver(max_to_keep=3)

#训练阶段

if is_train:

max_acc=0

f=open('ckpt/acc.txt','w')

for i in range(100):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')

if val_acc>max_acc:

max_acc=val_acc

saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

f.close()

#验证阶段

else:

model_file=tf.train.latest_checkpoint('ckpt/')

saver.restore(sess,model_file)

val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})

print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

python模型保存与恢复_tensorflow1.0学习之模型的保存与恢复(Saver)相关推荐

  1. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  2. c调用python keras模型_使用TensorFlow 2.0创建Keras模型的三种方法

    TensorFlow 2.0和tf.keras提供了三种方式来实现神经网络模型:Sequential API Functional API Model subclassing 下面我将分别使用这三种方 ...

  3. java加载pytorch模型,使用Dev Pytorch 1.0将Pytorch模型加载到C中

    Pytorch 1.0具有将模型转换为火炬脚本程序(以某种方式序列化)的功能,以使其能够在没有Python依赖性的情况下在C中执行 . 这是如何做到的: import torch import tor ...

  4. python程序基本语法实验_Python3.0学习基本语法·1

    1.input() 程序运行到这里,需要输入东西,按下enter才能继续程序 示例: name=input('请输入你的名字:') name+='嘛,大家都懂得' print(name) 请输入你的名 ...

  5. Three.js学习六——模型动画

    目录 Three.js动画系统(Animation system) 实现流程 基本流程 工程文件 场景搭建 添加模型 模型动画 动画实现的基本流程 相关对象方法和代码 完整代码和实现效果 Three. ...

  6. 伯克利、OpenAI等提出基于模型的元策略优化强化学习

    基于模型的强化学习方法数据效率高,前景可观.本文提出了一种基于模型的元策略强化学习方法,实践证明,该方法比以前基于模型的方法更能够应对模型缺陷,还能取得与无模型方法相近的性能. 引言 强化学习领域近期 ...

  7. Tensorflow深度学习实战之(五)--保存与恢复模型

    文章目录 一.保存模型 二.恢复模型 三.使用模型预测 一.保存模型 在训练完Tensorflow模型为了方便对新的数据进行预测需要保存该模型,Tensorflow提供 tf.train.Saver( ...

  8. 【Python深度学习】基于Tensorflow2.0构建CNN模型尝试分类音乐类型(二)

    前情提要 基于上文所说 基于Tensorflow2.0构建CNN模型尝试分类音乐类型(一) 我用tf2.0和Python3.7复现了一个基于CNN做音乐分类器.用余弦相似度评估距离的一个音乐推荐模型. ...

  9. TensorFlow2.0学习笔记-3.模型训练

    3.模型训练 3.1.Keras版本模型训练 • 构建模型(顺序模型.函数式模型.子类模型) • 模型训练: model.fit() • 模型验证: model.evaluate() • 模型预测:  ...

  10. 基于python语言,使用爬虫和深度学习来预测双色球(二、模型训练与测试)

    在上一篇博文中(基于python语言,使用爬虫和深度学习来预测双色球(一.数据的准备)),我们完成了数据的准备,通过爬虫的方式将2003年至今的每一期的中奖数据保存到了txt文件中,那么我们现在就开始 ...

最新文章

  1. 单机编程c语言,完美的8051单机C语言编程模板.doc
  2. 你应该关注的几个网站
  3. java学习(171): 枚举类
  4. Hi,我们是-MobileNet-家族
  5. Java基础-hashMap原理剖析
  6. 常见的计算机端口及服务
  7. 头条小程序服务器设置,今日头条小程序怎么开发?如何注册申请
  8. 创业维艰--书摘+乱七八糟
  9. 组合数计算(从1000到1e9的组合数各类求法)
  10. 使用fiddler代理,手机无法上网
  11. 人工智能革命:人类永生还是灭亡(下)
  12. ctf工具整理-持续更新
  13. 介绍java 8 的 Period 和 Duration 类
  14. 2022内蒙古最新建筑施工塔式起重机(建筑特种作业)模拟考试题库及答案
  15. Bash Shellshock(Bash远程代码执行)漏洞批量利用脚本
  16. 【转】走火大神说:去年这时候又辞退了一个老油条,不知道他现在是否在开公司了,可以对比一下混工资的水平...
  17. chromedriver与chrome浏览器各版本对应下载
  18. Android软件开发之获取通讯录联系人信息
  19. [HTML]入门小知识,列表?框架?表格?来吧。纯手工制作,满满都是智慧
  20. 合天网安 第四周 | Check your source code

热门文章

  1. Java 8 VM GC Tuning Guide Charter3-4
  2. mysql5.7导出数据提示–secure-file-priv选项问题的解决方法
  3. idea2020导入maven工程(解决项目文件没有蓝色方块问题)
  4. python画图系列整理
  5. ​最强全集,数据科学领域,那些你不能不知道的大咖们!
  6. WPF 通过Image控件实现多张图片的播放
  7. 429. N 叉树的层序遍历(中等 树 广度优先搜索)
  8. itext生成pdf加页码和总页码
  9. 教程向|3D建模最难之面部雕刻,详细教程带给大家(下)
  10. 国外广告联盟骗局汇总(持续更新中)