训练完一个模型后,为了以后重复使用,通常我们需要对模型的结果进行保存。如果用Tensorflow去实现神经网络,所要保存的就是神经网络中的各项权重值。建议可以使用Saver类保存和加载模型的结果。

1、使用tf.train.Saver.save()方法保存模型

tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)

  • sess: 用于保存变量操作的会话。
  • save_path: String类型,用于指定训练结果的保存路径。
  • global_step: 如果提供的话,这个数字会添加到save_path后面,用于构建checkpoint文件。这个参数有助于我们区分不同训练阶段的结果。

2、使用tf.train.Saver.restore方法价值模型

tf.train.Saver.restore(sess, save_path)

  • sess: 用于加载变量操作的会话。
  • save_path: 同保存模型是用到的的save_path参数。

下面通过一个代码演示这两个函数的使用方法

import tensorflow as tf
import numpy as npx = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + bloss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))with tf.Session() as sess:sess.run(tf.initialize_all_variables())if isTrain:for i in xrange(train_steps):sess.run(train, feed_dict={x: x_data})if (i + 1) % checkpoint_steps == 0:saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)else:ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)else:passprint(sess.run(w))print(sess.run(b))

转载于:https://www.cnblogs.com/txq157/p/7242385.html

转载:tensorflow保存训练后的模型相关推荐

  1. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直 ...

  2. 如何在TensorFlow中训练Boosted Trees模型

    在使用结构化数据时,诸如梯度提升决策树和随机森林之类的树集合方法是最流行和最有效的机器学习工具之一. 树集合方法训练速度快,无需大量调整即可正常工作,并且不需要大型数据集进行训练. 在TensorFl ...

  3. tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)

    最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...

  4. Pytorch——保存训练好的模型参数

    文章目录 1.前言 2.torch.save(保存模型) 3.torch.load整个网络 4.torch.load网络参数(只提取参数) 5.调用三个函数 1.前言 训练好了一个模型, 我们当然想要 ...

  5. python如何保存训练好的模型_Python机器学习7:如何保存、加载训练好的机器学习模型...

    本文将介绍如何使用scikit-learn机器学习库保存Python机器学习模型.加载已经训练好的模型.学会了这个,你才能够用已有的模型做预测,而不需要每次都重新训练模型. 本文将使用两种方法来实现模 ...

  6. 保存训练好的模型并调用

    当我们训练好一个model后,下次如果还想用这个model,我们就需要把这个model保存下来,下次直接导入就好了,不然每次都跑一遍,训练时间短还好,要是一次跑好几天的那怕是要天荒地老了. sciki ...

  7. libsvm 训练后,模型参数详解

    本节主要就是讲解利用libsvm-mat工具箱建立分类(回归模型)后,得到的模型model里面参数的意义,以及如果通过model得到相应模型的表达式,这里主要以分类问题为例子. 测试数据使用的是lib ...

  8. tensorflow保存模型和加载模型的方法(Python和Android)

    tensorflow保存模型和加载模型的方法(Python和Android) 一.tensorflow保存模型的几种方法: (1) tf.train.saver()保存模型 使用 tf.train.s ...

  9. 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

    将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...

最新文章

  1. seaborn系列 (13) | 点图pointplot()
  2. SEH(结构化异常处理)
  3. android ringtone获取uri,android – 如何通过文件路径从MediaStore获取Uri?
  4. 【易网库】周年庆幸运大抽奖, 有机会获3年免费主机空间
  5. 程序员的自我救赎---4.2:消息中心讲解与应用
  6. 微信小程序提醒并延迟跳转
  7. 团队第一阶段冲刺——第七天
  8. 查看一个网站使用的开发技术
  9. 结构体structure
  10. 解决Navicat连接MySQL总是报错1251的方法
  11. Sqlyog的安装与使用
  12. 熔断机制什么意思_熔断机制是什么意思?
  13. O2O商业模式的现状以及发展趋势是什么?
  14. 活动倒计时的一些想法
  15. [linux虚拟机] 使用yum命令时,解析不了yum源,Cannot find a valid baseurl for repo: base/7/x86_6
  16. Use // eslint-disable-next-line to ignore the next line. Use /* eslint-disable */ to ignore all warn
  17. mysql宕机日志查询_Mysql容器异常宕机
  18. 泛微软件服务器是多少,泛微oa云服务器要求
  19. Matplotlib绘制圆环图
  20. 编程猫python讲师面试_【编程猫工资|编程猫待遇怎么样】-看准网

热门文章

  1. Web工程师修行笔记_必备单词(第三部)
  2. Duilib嵌入CEF以及JavaScript与C++交互
  3. UVAlive 6131 dp+斜率优化
  4. Eigen(7)Map类
  5. brainfu*k语言执行
  6. 中国计算机手机全部被黑,云南一高校电脑全被黑
  7. mysql递归查询所有上下节点_非递归打印二叉树的所有路径,保存父节点和孩子节点到底有啥差别...
  8. python数据分析神器_太香了!墙裂推荐6个Python数据分析神器!!
  9. linux c文件属性,【linux c learn 之stat】获取文件的属性
  10. python能不能用c打开文件_C/C++/Python等 使用二进制模式打开文件与不使用二进制模式的区别...