我们在用tensorflow训练模型时,可能需要训练很长很长一段时间,为了方便下次使用,应该将模型保存起来。在sklearn中,我们可以使用pickle模块进行模型保存;而在tensorflow中,我们可以使用它自带的Saver()类进行模型的保存。

(一)Saver类

  Saver类是用于保存和恢复变量的。它有将变量保存到checkpoint和从checkpoint中恢复变量的操作。

  Checkpoints是一个二进制文件,它的属性值和tensor变量值一一对应。最好的检查checkpoints内容的方法就是用一个Saver去加载它。

  Saver可以自动的为chackpoint文件进行计数。这可以让你在训练模型时,保存多个checkpoint(通过计数来区分)。例如你可以通过训练的epoch来标识你的checkpoint文件。为了防止过分使用内存,你可以为saver设置最多保存的checkpoint文件数量。

  你可以通过为save()函数传入global_step参数值来标识checkpoint文件例如:

saver.save(sess, 'my-model', global_step=0)           ==>filename: 'my-model-0'
saver.save(sess, 'my-model', global_step=1000)        ==>filename: 'my-model-1000'

属性:

last_checkpoints当前所有保存的checkpoint文件的名字的list集合。你可以将这个返回的文件名list的任意一个元素作为restore()函数的参数,用于恢复指定的checkpoint。returns:返回checkpoints文件名列表,从最旧到最新的排序。

主要方法:

1.__init__(var_list=None,reshape=False,sharded=False,max_to_keep=5,keep_checkpoint_every_n_hours=10000.0,name=None,restore_sequentially=False,saver_def=None,builder=None,defer_build=False,allow_empty=False,write_version=tf.train.SaverDef.V2,pad_step_number=False,save_relative_paths=False,filename=None)Saver类的构造函数。Args:var_list:需要保存的参数列表,可以dict或list形式。如果这个参数为None,则默认保存所有可以保存的对象。一般使用缺省值即可。max_to_keep:保存的checkpoint文件的最大数量。默认为只保存最后5个。keep_checkpoint_every_n_hours:多久保存一次checkpoint文件,默认10000小时每次。其他参数不常用。2.save(sess,save_path,global_step=None,latest_filename=None,meta_graph_suffix='meta',write_meta_graph=True,write_state=True,strip_default_attrs=False)保存变量。这个方法用来保存变量,它需要一个session参数来指明哪个图。保存的参数必须已经被初始化过了。args:sess:保存变量需要的sessionsave_path:checkpoint文件保存的路径。global_step:如果指定了,则会将这个数字添加到save_path后面,用于唯一标识checkpoint文件。latest_filename:和save_path在同一个文件夹中,用于最后一个checkpoint文件的命名。默认为checkpoint。其他不常用。3.restore(sess,save_path)从save_path中恢复模型的参数。它需要一个session,需要恢复的参数不需要初始化,因为恢复本身就是一种初始化变量的方法。而参数save_path就是save()函数产生的文件的路径名。args:sess:一个sessionsave_path:保存的路径

(二)使用举例

1.保存模型:
  使用Saver保存模型的参数时,一定要将saver = tf.train.Saver定义在你保存的的参数定义之后,即定义在需要的tf.Variable之后。定义在saver之后的参数无法被保存,切记切记!!!

  在下面的例子中,我们生成了y = (x - 1) ^ 2 - 2的样本,然后加上了一些噪音,试着用tensorflow训练出拟合该曲线的参数。

__author__ = 'liuwei'import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plth = 1
v = -2#prepare data
x_train = np.linspace(-2, 4, 201)                        #x样本
noise = np.random.randn(*x_train.shape) * 0.4            #噪音
y_train = (x_train - h) ** 2 + v + noise                 #y样本n = x_train.shape[0]x_train = np.reshape(x_train, (n, 1))                    #重塑
y_train = np.reshape(y_train, (n, 1))#画出产生的数据的形状
'''
plt.rcParams['figure.figsize'] = (10, 6)
plt.scatter(x_train, y_train)
plt.xlabel('x_train')
plt.ylabel('y_train')
plt.show()
'''
#create variable
X = tf.placeholder(tf.float32, [1])                      #两个占位符,x和y
Y = tf.placeholder(tf.float32, [1])h_est = tf.Variable(tf.random_uniform([1], -1, 1))       #定义需要训练的参数,在saver之前定义
v_est = tf.Variable(tf.random_uniform([1], -1, 1))saver = tf.train.Saver()                                 #保存模型参数的savervalue = (X - h_est) ** 2 + v_est                         #拟合的曲线loss = tf.reduce_mean(tf.square(value - Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)for epoch in range(100):                             #100个epochfor (x, y) in zip(x_train, y_train):sess.run(optimizer, feed_dict={X: x, Y: y})#保存checkpointsaver.save(sess, './model_iter', global_step=epoch)#saver the final modelsaver.save(sess, './final_model')                    #最后一个epoch对应的checkpointh_ = sess.run(h_est)v_ = sess.run(v_est)print(h_, v_)

运行结果如下:

2.恢复参数
  恢复参数时,我们只需要定义保存的Variable,不需要初始化,因为恢复过程其实就是一种初始化。恢复参数的代码如下:

__author__ = 'liuwei'import tensorflow as tf
import numpy as np h_est = tf.Variable(tf.random_uniform([1], -1, 1))     #只定义,没有初始化
v_est = tf.Variable(tf.random_uniform([1], -1, 1))saver = tf.train.Saver()                      #saver类path = './final_model'                        #要恢复的checkpoint路径with tf.Session() as sess:saver.restore(sess, path)                 #恢复参数print(sess.run(h_est), sess.run(v_est))

运行结果为:

完整代码:https://github.com/liuwei1206/tensorflow-study/tree/master/2.saver/test

tensorflow中用saver保存模型相关推荐

  1. Tensorflow使用Saver保存模型PermissionDeniedError (see above for traceback)

    使用tensorflow训练模型时需要保存训练后得到的模型,并在测试时加载模型. Tensorflow中可以使用Saver类进行参数保存. 保存参数是可以选择的,如果不传参就是保存所有参数. weig ...

  2. Tensorflow详解保存模型(基础版)

    我们都知道tensorflow最后生成的模型文件含: checkpoint xxxxx.meta xxxxx.ckpt.data-xxx xxxxx.index 学习和使用tensorflow的小伙伴 ...

  3. TensorFlow的Saver保存类

    一.Saver的介绍 有时可能只需要保存或者加载部分变量. 比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最 ...

  4. Tensorflow详解保存模型(进阶版一):如何有选择的保存变量

    当然掌握了基础版还不够,我们来看一下进阶版一:如何有选择的保存变量: 这里还要另外涉及两个函数: tf.variable_scope("xxx") 和 tf.get_variabl ...

  5. Tensorflow:如何保存/恢复模型?

    在Tensorflow中训练模型后: 您如何保存经过训练的模型? 您以后如何还原此保存的模型? #1楼 对于TensorFlow版本<0.11.0RC1: 保存的检查点包含模型中Variable ...

  6. tensorflow保存模型和加载模型的方法(Python和Android)

    tensorflow保存模型和加载模型的方法(Python和Android) 一.tensorflow保存模型的几种方法: (1) tf.train.saver()保存模型 使用 tf.train.s ...

  7. tensorflow 1.x Saver(保存与加载模型) 预测

    20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...

  8. Tensorflow |(5)模型保存与恢复、自定义命令行参数

    Tensorflow |(1)初识Tensorflow Tensorflow |(2)张量的阶和数据类型及张量操作 Tensorflow |(3)变量的的创建.初始化.保存和加载 Tensorflow ...

  9. Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning

    使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获 ...

最新文章

  1. 关于学习Python的一点学习总结(3->标识符->if->模块->字符)
  2. 深度学习也利用进化论!李飞飞谈创建具身智能体,学动物进化法则
  3. [博客之路]如何增加一个博客的PR值(一)
  4. 计算机缓存Cache机制理解
  5. scrum与第一次teamwork
  6. 对比excel,用python绘制柱状图时添加table数据表
  7. nodejs 遍历json数据_PostgreSQL 务实应用(四/5)JSON
  8. (1)PCIE接口应用领域(学无止境)
  9. java c 性能分析工具_Jprofiler使用介绍--java性能分析工具中文帮助
  10. 支付宝花呗接口接入php,支付宝小程序开通花呗接口,这是正式向微信小程序正式宣战?...
  11. 机器人“病患”会流血会休克,魔鬼训练从斯坦福医院开始 |准医生的噩梦
  12. node.js 创建服务器_Node.js HTTP软件包–创建HTTP服务器
  13. Flink Table和SQL的基本API
  14. python笔记 字典赋值
  15. iphone手机屏幕开发尺寸
  16. eyoucms目录结构
  17. 计算机网络设置无法保存,win10系统启用网络发现无法保存的解决步骤
  18. 校友会2019中国大学计算机,校友会2019中国大学一流专业排名800强发布,北大清华复旦前三...
  19. 星际争霸2Beta版单机使用方法
  20. 见过世面的程序员,到底有多厉害

热门文章

  1. 编程之美2014 热身赛 题目3 : 树上的三角形
  2. Java环境变量CLASSPATH详解(转载)
  3. 团购“罗生门”渐退行业舞台 千团大战仍需企业自律
  4. Serverless 实战 —— Serverless 的运行原理与组件架构
  5. JavaScript玩转机器学习:张量(Tensors) 和 操作(operations)
  6. Nginx 过滤模块
  7. Kubernetes 入门:运行不同类型的 Job
  8. SpringBoot之项目运行常见报错
  9. python读取html文件中的表格数据_使用解析html表pd.read_html文件其中单元格本身包含完整表...
  10. 【Python】random库