参考:

TensorFlow:保存和提取模型
最全Tensorflow模型保存和提取的方法——附实例

  1. 模型的保存会覆盖,后一次保存的模型会覆盖上一次保存的模型。最多保存近5次结果。
  2. 应当保存效果最优时候的模型,而不是训练最后一次的模型。所以应该在每次进行模型性能评估后与保存的目前最后效果比较,如果性能更好则进行模型的保存。
  3. 模型的复用,当你想用别的性能评估指标的时候,不需要再次训练模型来获得指标值,可以提取最优模型直接计算新指标的值。
sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())is_train=False
saver=tf.train.Saver(max_to_keep=3)#训练阶段
if is_train:max_acc=0f=open('ckpt/acc.txt','w')for i in range(100):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()#验证阶段
else:model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

实操:

说明:

  1. Social Attentional Memory Network 是一个推荐系统的模型,代码中没有模型保存和提取操作,数据量也算是小的,可以下载下来练习一下如何实际操作。
  2. SAMN 是我用这个模型进行的练习,可以参考,代码后面标注 lly 的是我写的或者修改的内容。

步骤:

  1. 先在原代码的主目录的下面建一个文件夹 model 。
  2. 第一次进行训练,进入目录执行 python SAMN.py ,其中参数 is_train = True
    训练完后发现model文件夹下面多了五个模型,最后一次保存的模型为最后模型,出现在第171次迭代的时候,即epoch=170

    然后在控制台可以看到,epoch=170时候的评估结果:
    迭代第 166 次的损失为:26.586210:迭代第 167 次的损失为:26.567725:迭代第 168 次的损失为:26.586499:迭代第 169 次的损失为:26.571110:迭代第 170 次的损失为:26.668282:recall--------------------------------------------------------------------------------0.16846666666666665 0.19796666666666665 0.22703333333333334 0.24936666666666668 0.2713666666666667ndcg----------------------------------------------------------------------------------0.103169807535364 0.11131981364691529 0.11824016391770284 0.12317271387061263 0.12777428228959994save epoch  170
  1. 第二次使用保存好的模型,先将 SAMN.py 文件的参数 is_train 改为 False,再执行文件。
    执行完后可以看到控制台输出的评估结果和之前训练的时候的结果一样,证明操作成功。(最优结果我只保留了k=[10, 20, 50]的情况)

tensorflow--模型的保存和提取相关推荐

  1. Tensorflow【实战Google深度学习框架】TensorFlow模型的保存与恢复加载

    我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存 ...

  2. [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

    [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...

  3. 简单完整地讲解tensorflow模型的保存和恢复

    http://blog.csdn.net/liangyihuai/article/details/78515913 在本教程主要讲到: 1. 什么是Tensorflow模型? 2. 如何保存Tenso ...

  4. Tensorflow模型的保存与恢复的细节

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  5. tensorflow 模型的保存和加载

    为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型. 1. 保存模型 tensorflow提供了一个API可以方便的 ...

  6. TensorFlow 模型的保存与恢复

    TensorFlow目前保存的模型文件主要有两种,ckpt与pb,二者之间的异同请见 https://zhuanlan.zhihu.com/p/32887066 下面,我以mnist手写数据集用sof ...

  7. tensorflow——模型的保存和恢复tf.trian.saver()

    保存 1创建saver对象,确定save哪些:saver=tf.trian.Saver(),不填写参数的话默认全部 2指定在哪个session中保存,以及保存路径:saver.save(sess, ' ...

  8. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tfw1 = tf.Variable(tf.constant(2.0, shape=[1]), name=& ...

  9. 5.2 TensorFlow:模型的加载,存储,实例

    背景 之前已经写过TensorFlow图与模型的加载与存储了,写的很详细,但是或闻有人没看懂,所以在附上一个关于模型加载与存储的例子,CODE是我偶然看到了,就记下来了.其中模型很巧妙,比之前nump ...

  10. TensorFlow模型持久化

    模型持久化的目的在于可以使模型训练后的结果重复使用,节省重复训练模型的时间. 模型保存 train.Saver类是TensorFlow提供的用于保存和还原模型的API,使用非常简单. import t ...

最新文章

  1. swagger2 集成无效_Springboot2 集成Swagger2,解决配置完成后不显示的坑
  2. Azure Arc:微软是怎么玩多云游戏的?
  3. payara 创建 集群_在Payara Server和GlassFish中配置密码
  4. ActionBar(3):搜索条
  5. 使用sourcetree 的git flow
  6. 缺陷管理工具JIRA和禅道对比
  7. oracle lsnrctl status unknown,理解 oracle 的 lsnrctl status
  8. Hive/MaxCompute SQL性能优化(三):数据倾斜优化实战
  9. 多线程下载视频,并运用Fmmpeg合成
  10. 消失的网秦:创始人遭绑架 414 天,睡觉都戴手铐
  11. C# 简单判断枚举值是否被定义
  12. MySQL解压版安装及配置(本地windows环境)
  13. 计算机无法使用网络连接到服务器,电脑无法连接网络并诊断提示DNS服务器未响应的解决方法...
  14. linux EHCI DRIVER之中断处理函数ehci_irq()分析(一)
  15. 我不会是亚瑟王,但我想成为梅林
  16. 【90天英语通】零基础自学新概念英语
  17. 虚拟机(Vmware)磁盘扩容
  18. java编写日期年月日的代码_求Java高手写道题设int year,month,day分别表示一个日期中的年月日,试编程求a) 对于任意三个整数,判...
  19. HDU 6203 贪心 + LCA + dfs序 + BIT
  20. 【庖丁解牛】成功解决yum安装mysql时报错libmysqlclient.so.18

热门文章

  1. Linux 设备驱动模型中的class(类)
  2. access 增加字段 工具_Java效率工具之Lombok
  3. C++ 解析Json
  4. 管理系统制作的python代码_python学生管理系统代码实现
  5. string contains不区分大小写_String基础复习
  6. 数据结构之树:树的介绍——9
  7. OSError: [Errno 22] Invalid argument:**
  8. 贪吃蛇python小白_面向 python 小白的贪吃蛇游戏
  9. Python自动化办公——xlrd、xlwt读写Excel
  10. LeetCode 1481. 不同整数的最少数目(计数+排序+贪心)