一、TensorFlow模型保存和提取方法

1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt"),实际在这个文件目录下会生成4个人文件:

checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

2. 加载这个已保存的TensorFlow模型的方法是saver.restore(sess,"./Model/model.ckpt"),加载模型的代码中也要定义TensorFlow计算图上的所有运算并声明一个tf.train.Saver类,不同的是加载模型时不需要进行变量的初始化,而是将变量的取值通过保存的模型加载进来,注意加载路径的写法。若不希望重复定义计算图上的运算,可直接加载已经持久化的图,saver =tf.train.import_meta_graph("Model/model.ckpt.meta")。

3.tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名},saver = tf.train.Saver({"v1":u1, "v2": u2})即原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中。

4. 上一条做的目的之一就是方便使用变量的滑动平均值。如果在加载模型时直接将影子变量映射到变量自身,则在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。载入时,声明Saver类对象时通过一个字典将滑动平均值直接加载到新的变量中,saver = tf.train.Saver({"v/ExponentialMovingAverage": v}),另通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典。

此外,通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中。

二、TensorFlow程序实现

# 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行

# 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确

# Part1: 通过tf.train.Saver类实现保存和载入神经网络模型

# 执行本段程序时注意当前的工作路径

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")

v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")

result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

saver.save(sess, "Model/model.ckpt")

# Part2: 加载TensorFlow模型的方法

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")

v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")

result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:

saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./"

print(sess.run(result)) # [ 3.]

# Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图

import tensorflow as tf

saver = tf.train.import_meta_graph("Model/model.ckpt.meta")

with tf.Session() as sess:

saver.restore(sess, "./Model/model.ckpt") # 注意路径写法

print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.]

# Part4: tf.train.Saver类也支持在保存和加载时给变量重命名

import tensorflow as tf

# 声明的变量名称name与已保存的模型中的变量名称name不一致

u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")

u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")

result = u1 + u2

# 若直接生命Saver类对象,会报错变量找不到

# 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}

# 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中

saver = tf.train.Saver({"v1": u1, "v2": u2})

with tf.Session() as sess:

saver.restore(sess, "./Model/model.ckpt")

print(sess.run(result)) # [ 3.]

# Part5: 保存滑动平均模型

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")

for variables in tf.global_variables():

print(variables.name) # v:0

ema = tf.train.ExponentialMovingAverage(0.99)

maintain_averages_op = ema.apply(tf.global_variables())

for variables in tf.global_variables():

print(variables.name) # v:0

# v/ExponentialMovingAverage:0

saver = tf.train.Saver()

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

sess.run(tf.assign(v, 10))

sess.run(maintain_averages_op)

saver.save(sess, "Model/model_ema.ckpt")

print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905]

# Part6: 通过变量重命名直接读取变量的滑动平均值

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")

saver = tf.train.Saver({"v/ExponentialMovingAverage": v})

with tf.Session() as sess:

saver.restore(sess, "./Model/model_ema.ckpt")

print(sess.run(v)) # 0.0999999

# Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")

# 注意此处的变量名称name一定要与已保存的变量名称一致

ema = tf.train.ExponentialMovingAverage(0.99)

print(ema.variables_to_restore())

# {'v/ExponentialMovingAverage': }

# 此处的v取自上面变量v的名称name="v"

saver = tf.train.Saver(ema.variables_to_restore())

with tf.Session() as sess:

saver.restore(sess, "./Model/model_ema.ckpt")

print(sess.run(v)) # 0.0999999

# Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中

import tensorflow as tf

from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")

v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")

result = v1 + v2

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

# 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分

graph_def = tf.get_default_graph().as_graph_def()

output_graph_def = graph_util.convert_variables_to_constants(sess,

graph_def, ['add'])

with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f:

f.write(output_graph_def.SerializeToString())

# Part9: 载入包含变量及其取值的模型

import tensorflow as tf

from tensorflow.python.platform import gfile

with tf.Session() as sess:

model_filename = "Model/combined_model.pb"

with gfile.FastGFile(model_filename, 'rb') as f:

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

result = tf.import_graph_def(graph_def, return_elements=["add:0"])

print(sess.run(result)) # [array([ 3.], dtype=float32)]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

本文标题: TensorFlow模型保存和提取的方法

本文地址: http://www.cppcns.com/jiaoben/python/222000.html

python保存模型方法_TensorFlow模型保存和提取的方法相关推荐

  1. tensorflow保存模型和加载模型的方法(Python和Android)

    tensorflow保存模型和加载模型的方法(Python和Android) 一.tensorflow保存模型的几种方法: (1) tf.train.saver()保存模型 使用 tf.train.s ...

  2. 加载dict_PyTorch 7.保存和加载pytorch模型的两种方法

    众所周知,python的对象都可以通过torch.save和torch.load函数进行保存和加载(不知道?那你现在知道了(*^_^*)),比如: x1 = {"d":" ...

  3. python如何保存训练好的模型_Python机器学习7:如何保存、加载训练好的机器学习模型...

    本文将介绍如何使用scikit-learn机器学习库保存Python机器学习模型.加载已经训练好的模型.学会了这个,你才能够用已有的模型做预测,而不需要每次都重新训练模型. 本文将使用两种方法来实现模 ...

  4. Python时间序列模型推理预测实战:时序推理数据预处理(特征生成、lstm输入结构组织)、模型加载、模型预测结果保存、条件判断模型循环运行

    Python时间序列模型推理预测实战:时序推理数据预处理(特征生成.lstm输入结构组织).模型加载.模型预测结果保存.条件判断模型循环运行 目录

  5. python模型保存save_浅谈keras保存模型中的save()和save_weights()区别

    今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别. 我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5.同样是h5文件用save ...

  6. python保存模型与参数_Pytorch - 模型和参数的保存与恢复

    模型训练后,需要保存到文件,以供测试和部署:或,继续之前的训练状态. 1. Best Practices 主要有两种模型序列化保存和加载恢复的方法. 1.1 方法 M1 - 推荐 只保存和加载恢复模型 ...

  7. Python Word2vec训练医学短文本字/词向量实例实现,Word2vec训练字向量,Word2vec训练词向量,Word2vec训练保存与加载模型,Word2vec基础知识

    一.Word2vec概念 (1)Word2vec,是一群用来产生词向量的相关模型.这些模型为浅而双层的神经网络,用来训练以重新建构语言学之词文本.网络以词表现,并且需猜测相邻位置的输入词,在word2 ...

  8. tensorflow 1.x Saver(保存与加载模型) 预测

    20201231 tensorflow 1.X 模型保存 https://blog.csdn.net/qq_35290785/article/details/89646248 保存模型 saver=t ...

  9. PyTorch 深度剖析:如何保存和加载PyTorch模型?

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨科技猛兽 编辑丨极市平台 导读 本文详解了PyTorch 模型 ...

  10. 保存和加载pytorch模型

    当保存和加载模型时,需要熟悉三个核心功能: torch.save:将序列化对象保存到磁盘.此函数使用Python的pickle模块进行序列化.使用此函数可以保存如模型.tensor.字典等各种对象. ...

最新文章

  1. 开发管理 CheckLists(4) -风险管理
  2. 微信朋友圈装x代码_NBA总决赛朋友圈装X图鉴:直男之间有真正的友谊吗?
  3. 正则表达式30分钟教程
  4. VTK:PolyData之PointLocator
  5. 微信小程序怎么新建php文件,微信小程序中创建小程序页面的步骤介绍(图文)...
  6. C语言学习输入输出函数,函数的调用
  7. java朗控点异常_Java语言基础(day_04)
  8. Win32汇编——过程控制(环境变量、命令行参数、可执行文件执行)
  9. js基础知识汇总06
  10. centos7永久修改主机名
  11. Android P 怎样屏蔽HOME键和RECENT键
  12. 企业微信和个人微信区别到底有哪些
  13. 批量创建文件夹并命名
  14. 友推app微信分享功能集成攻略
  15. 查看hive的当前参数值
  16. 梦幻单机游戏添加怪物lua
  17. HDOJ 1069 Monkey and Banana 解题报告
  18. 戴尔笔记本——减低风扇声音的一种办法
  19. 如何在Android中使用Realm数据库
  20. Gatsby 学习 - 01 Gatsby 介绍、创建页面

热门文章

  1. thinkPad电脑无人操作时休眠设置
  2. 【解决】Failed to process import candidates for configuration class [cn.itcast.eureka.EurekaApplication]
  3. python ide哪个好用_好用的Python IDE推荐
  4. Android手机游戏大全apk
  5. python爬取上海高级人民法院网开庭公告数据
  6. 使用计算机能播放音乐也能观看视频,我电脑可以放歌有声音。怎么播放视频没声音啊?给我解决方案...
  7. 比较自然语言与计算机语言,计算机语言与自然语言的比较研究.pdf
  8. 用php怎么输出一首诗,如何用一首诗总结你的2018年?
  9. 广东省计算机一级技巧,广东省计算机一级
  10. android 植入谷歌广告,将谷歌广告添加到Android应用程序