采用 TensorFlow 的时候,有时候我们需要加载的不止是一个模型,那么如何加载多个模型呢?

原文:https://bretahajek.com/2017/04/importing-multiple-tensorflow-models-graphs/


关于 TensorFlow 可以有很多东西可以说。但这次我只介绍如何导入训练好的模型(图),因为我做不到导入第二个模型并将它和第一个模型一起使用。并且,这种导入非常慢,我也不想重复做第二次。另一方面,将一切东西都放到一个模型也不实际。

在这个教程中,我会介绍如何保存和载入模型,更进一步,如何加载多个模型。

加载 TensorFlow 模型

在介绍加载多个模型之前,我们先介绍下如何加载单个模型,官方文档:https://www.tensorflow.org/programmers_guide/meta_graph。

首先,我们需要创建一个模型,训练并保存它。这部分我不想过多介绍细节,只需要关注如何保存模型以及不要忘记给每个操作命名。

创建一个模型,训练并保存的代码如下:

import tensorflow as tf
### Linear Regression 线性回归###
# Input placeholders
x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')
# Model parameters 定义模型的权值参数
W1 = tf.Variable([0.1], tf.float32)
W2 = tf.Variable([0.1], tf.float32)
W3 = tf.Variable([0.1], tf.float32)
b = tf.Variable([0.1], tf.float32)# Output 模型的输出
linear_model = tf.identity(W1 * x + W2 * x**2 + W3 * x**3 + b,name='activation_opt')# Loss 定义损失函数
loss = tf.reduce_sum(tf.square(linear_model - y), name='loss')
# Optimizer and training step 定义优化器运算
optimizer = tf.train.AdamOptimizer(0.001)
train = optimizer.minimize(loss, name='train_step')# Remember output operation for later aplication
# Adding it to a collections for easy acces
# This is not required if you NAME your output operation
# 记得将输出操作添加到一个集合中,但如何你命名了输出操作,这一步可以省略
tf.add_to_collection("activation", linear_model)## Start the session ##
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#  CREATE SAVER
saver = tf.train.Saver()# Training loop 训练
for i in range(10000):sess.run(train, {x: data, y: expected})if i % 1000 == 0:# You can also save checkpoints using global_step variablesaver.save(sess, "models/model_name", global_step=i)# SAVE TensorFlow graph into path models/model_name
# 保存模型到指定路径并命名模型文件名字
saver.save(sess, "models/model_name")

注意,这里是第一个重点–对变量和运算命名。这是为了在加载模型后可以使用指定的一些权值参数,如果不命名的话,这些变量会自动命名为类似“Placeholder_1”的名字。在复杂点的模型中,使用领域(scopes)是一个很好的做法,但这里不做展开。

总之,重点就是为了在加载模型的时候能够调用权值参数或者某些运算操作,你必须给他们命名或者是放到一个集合中。

当保存模型后,在指定保存模型的文件夹中就应该包含这些文件:model_name.indexmodel_name.meta以及其他文件。如果是采用checkpoints后缀命名模型名字,还会有名字包含model_name-1000的文件,其中的数字是对应变量global_step,也就是当前训练迭代次数。

现在我们就可以开始加载模型了。加载模型其实很简单,我们需要的只是两个函数即可:tf.train.import_meta_graphsaver.restore()。此外,就是提供正确的模型保存路径位置。另外,如果我们希望在不同机器使用模型,那么还需要设置参数:clear_device=True

接着,我们就可以通过之前命名的名字或者是保存到的集合名字来调用保存的运算或者是权值参数了。如果使用了领域,那么还需要包含领域的名字才行。而在实际调用这些运算的时候,还必须采用类似{'PlaceholderName:0': data}的输入占位符,否则会出现错误。

加载模型的代码如下:

sess = tf.Session()# Import graph from the path and recover session
# 加载模型并恢复到会话中
saver = tf.train.import_meta_graph('models/model_name.meta', clear_devices=True)
saver.restore(sess, 'models/model_name')# There are TWO options how to access the operation (choose one)
# 两种方法来调用指定的运算操作,选择其中一个都可以# FROM SAVED COLLECTION: 从保存的集合中调用
activation = tf.get_collection('activation')[0]# BY NAME: 采用命名的方式
activation = tf.get_default_graph.get_operation_by_name('activation_opt').outputs[0]# Use imported graph for data
# You have to feed data as {'x:0': data}
# Don't forget on ':0' part!
# 采用加载的模型进行操作,不要忘记输入占位符
data = 50
result = sess.run(activation, {'x:0': data})
print(result)

多个模型

上述介绍了如何加载单个模型的操作,但如何加载多个模型呢?

如果使用加载单个模型的方式去加载多个模型,那么就会出现变量冲突的错误,也无法工作。这个问题的原因是因为一个默认图的缘故。冲突的发生是因为我们将所有变量都加载到当前会话采用的默认图中。当我们采用会话的时候,我们可以通过tf.Session(graph=MyGraph)来指定采用不同的已经创建好的图。因此,如果我们希望加载多个模型,那么我们需要做的就是把他们加载在不同的图,然后在不同会话中使用它们。

这里,自定义一个类来完成加载指定路径的模型到一个局部图的操作。这个类还提供run函数来对输入数据使用加载的模型进行操作。这个类对于我是有用的,因为我总是将模型输出放到一个集合或者对它命名为activation_opt,并且将输入占位符命名为x。你可以根据自己实际应用需求对这个类进行修改和拓展。

代码如下:

import tensorflow as tfclass ImportGraph():"""  Importing and running isolated TF graph """def __init__(self, loc):# Create local graph and use it in the sessionself.graph = tf.Graph()self.sess = tf.Session(graph=self.graph)with self.graph.as_default():# Import saved model from location 'loc' into local graph# 从指定路径加载模型到局部图中saver = tf.train.import_meta_graph(loc + '.meta',clear_devices=True)saver.restore(self.sess, loc)# There are TWO options how to get activation operation:# 两种方式来调用运算或者参数# FROM SAVED COLLECTION:            self.activation = tf.get_collection('activation')[0]# BY NAME:self.activation = self.graph.get_operation_by_name('activation_opt').outputs[0]def run(self, data):""" Running the activation operation previously imported """# The 'x' corresponds to name of input placeholderreturn self.sess.run(self.activation, feed_dict={"x:0": data})### Using the class ###
# 测试样例
data = 50         # random data
model = ImportGraph('models/model_name')
result = model.run(data)
print(result)

总结

如果你理解了 TensorFlow 的机制的话,加载多个模型并不是一件困难的事情。上述的解决方法可能不是完美的,但是它简单且快速。最后给出总结整个过程的样例代码,这是在 Jupyter notebook 上的,代码地址如下:

https://gist.github.com/Breta01/f205a9d27090c18d394fbaab98de7c65#file-importmodulesnotebook-ipynb


最后,给出文章中几个代码例子的 github 地址:

  1. Code for creating, training and saving TensorFlow model.
  2. Importing and using TensorFlow graph (model)
  3. Class for importing multiple TensorFlow graphs.
  4. Example of importing multiple TensorFlow modules

欢迎关注我的微信公众号–机器学习与计算机视觉或者扫描下方的二维码,在后台留言,和我分享你的建议和看法,指正文章中可能存在的错误,大家一起交流,学习和进步!

推荐阅读

1.机器学习入门系列(1)–机器学习概览(上)

2.机器学习入门系列(2)–机器学习概览(下)

3.[GAN学习系列] 初识GAN

4.[GAN学习系列2] GAN的起源

5.谷歌开源的 GAN 库–TFGAN

TensorFlow 加载多个模型的方法相关推荐

  1. TensorFlow 加载多个模型的方法 - 知乎 https://zhuanlan.zhihu.com/p/53642222

    TensorFlow 加载多个模型的方法 - 知乎 什么是Tensorflow模型? 当你训练好一个神经网络后,你会想保存好你的模型便于以后使用并且用于生产.因此,什么是Tensorflow模型?Te ...

  2. Tensorflow加载多个模型

    Tensorflow 同时载入多个模型 原创 2017年12月13日 16:46:25 标签: python / tensorflow / 多个模型 97 有时我们希望在一个python的文件空间同时 ...

  3. cesium模型加载-加载fbx格式模型

    整体思路: fbx格式→dae格式→gltf格式→cesium加载gltf格式模型 具体方法: 1. fbx格式→dae格式 工具:3dsMax, 3dsMax插件:OpenCOLLADA, 下载地址 ...

  4. TensorFlow——加载和使用多个模型解决方案

    解决方案 在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于 ...

  5. pytorch模型加载测试_pytorch模型加载方法汇总

    Pytorch有很多方便易用的包,今天要谈的是torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvisio ...

  6. 懒加载 字典转模型 自定义cell

    1 懒加载: 1>  什么是懒加载? 懒加载又称为延时加载,即在系统调用的时候加载,如果系统不调用则不会加载.所谓的懒加载其实就是重写其 get 方法. 2>  特点:在使用懒加载的时候要 ...

  7. 超图桌面版加载obj 3D模型 - 2

    在 https://blog.csdn.net/bcbobo21cn/article/details/109041525 里,加载obj格式模型没有出来效果: 下面来看一下其他方法:当前用的版本是10 ...

  8. three.js 加载obj+mtl模型

    本文提供了three.js中 实现将obj+mtl模型加载到场景中 的方法. 我们欲实现将桥模型加载到场景中,并对桥设置透明度: 实现过程: 分别导入three.js中的OBJLoader,MTLLo ...

  9. Three 之 three.js (webgl)基础 第二个入门案例之汽车模型加载和简单模型展示

    Three 之 three.js (webgl)基础 第二个入门案例之汽车模型加载和简单模型展示 目录 ​Three 之 three.js (webgl)基础 第二个入门案例之汽车模型加载和简单模型展 ...

最新文章

  1. SAP SD 客户信贷管理解析
  2. 《黑客与画家》读后感:你对技术一无所知(一些金句)
  3. WebLogic Server的Identity Assertion--转载
  4. 配置Eclipse 实现按任意键代码自动补全
  5. 图像主观质量评价 评分_图像质量分析工具哪家强?
  6. 数据分列将数字转换成文本格式
  7. 遍历文件夹下的所有文件
  8. 开源电子海图和webGIS
  9. 区块链架构与扩容方案
  10. com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException: Table doesn't exist
  11. 读《天才在左,疯子在右》01--偷取时间
  12. 人力资源管理专业知识与实务(初级)【6】
  13. CAS号:2417213-21-7以(ZPS-PVPA)为催化剂载体
  14. 【Codeforces Round #185 (Div. 2) D】Cats Transport
  15. 批量生成独一无二的NFT猫猫图,这项目王多鱼会投吗?
  16. 金誉半导体:MOS管耗尽型和增强型是什么意思?
  17. ReSharper 使用感受
  18. oracle冲账语句_ORA-00xx问题 -oracle卸载不成功
  19. 图说卡尔曼滤波(C++实现)
  20. 中冠百年|投资理财,千万不要犯这些错误

热门文章

  1. python关于包的题怎么做_Python自定义包引入
  2. 公办低分二本_这六所公办二本高校的计算机类相关专业值得低分段考生选择
  3. js时间搓化为今天明天_js转时间戳,时间戳转js
  4. Linux ARM交叉编译工具链制作过程
  5. spring boot使用logback实现多环境日志配置
  6. VLC简介及使用说明
  7. Python中文全攻略
  8. 微软面试题:有100万个数字(1到9),其中只有1个数字重复2次,如何快速找出该数字
  9. 接口报Provisional headers are shown原因和解决方法
  10. modprobe: FATAL: Module xxx.ko not found in directory /lib/modules/$(uname -r)