参考:https://github.com/tflearn/tflearn/issues/964

解决方法:

"""
Tensorflow graph freezer
Converts Tensorflow trained models in .pbCode adapted from:
https://gist.github.com/morgangiraud/249505f540a5e53a48b0c1a869d370bf#file-medium-tffreeze-1-py
"""import os, argparse
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.python.framework import graph_utildef freeze_graph(model_folder,output_graph="frozen_model.pb"):# We retrieve our checkpoint fullpathtry:checkpoint = tf.train.get_checkpoint_state(model_folder)input_checkpoint = checkpoint.model_checkpoint_pathprint("[INFO] input_checkpoint:", input_checkpoint)except:input_checkpoint = model_folderprint("[INFO] Model folder", model_folder)# Before exporting our graph, we need to precise what is our output node# This is how TF decides what part of the Graph he has to keep and what part it can dumpoutput_node_names = "FullyConnected/Softmax" # NOTE: Change here# We clear devices to allow TensorFlow to control on which device it will load operationsclear_devices = True# We import the meta graph and retrieve a Saversaver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)# We retrieve the protobuf graph definitiongraph = tf.get_default_graph()input_graph_def = graph.as_graph_def()# We start a session and restore the graph weights
    with tf.Session() as sess:saver.restore(sess, input_checkpoint)# We use a built-in TF helper to export variables to constantsoutput_graph_def = graph_util.convert_variables_to_constants(sess,                        # The session is used to retrieve the weightsinput_graph_def,             # The graph_def is used to retrieve the nodes output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) # Finally we serialize and dump the output graph to the filesystemwith tf.gfile.GFile(output_graph, "wb") as f:f.write(output_graph_def.SerializeToString())print("%d ops in the final graph." % len(output_graph_def.node))print("[INFO] output_graph:",output_graph)print("[INFO] all done")if __name__ == '__main__':parser = argparse.ArgumentParser(description="Tensorflow graph freezer\nConverts trained models to .pb file",prefix_chars='-')parser.add_argument("--mfolder", type=str, help="model folder to export")parser.add_argument("--ograph", type=str, help="output graph name", default="frozen_model.pb")args = parser.parse_args()print(args,"\n")freeze_graph(args.mfolder,args.ograph)# However, before doing model.save(...) on TFLearn i have to do
# ************************************************************
# del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
# ************************************************************"""
Then I call this command
python tf_freeze.py --mfolder=<path_to_tflearn_model>NoteThe <path_to_tflearn_model> must not have the ".data-00000-of-00001".The output_node_names variable may change depending on your architecture. The thing is that you must reference the layer that has the softmax activation function.
"""

注意:

1、需要在 tflearn的model.save 前:

del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]

作用:去除模型里训练OP。

参考:https://github.com/tflearn/tflearn/issues/605#issuecomment-298478314

2、如果是有batch normalzition,或者残差网络层,会出现:

Error when loading the frozen graph with tensorflow.contrib.layers.python.layers.batch_norm
ValueError: graph_def is invalid at node u'BatchNorm/cond/AssignMovingAvg/Switch': Input tensor 'BatchNorm/moving_mean:0' Cannot convert a tensor of type float32 to an input of type float32_ref
freeze_graph.py doesn't seem to store moving_mean and moving_variance properly

An ugly way to get it working:
manually replace the wrong node definitions in the frozen graph
RefSwitch --> Switch + add '/read' to the input names
AssignSub --> Sub + remove use_locking attributes

则需要在restore模型后加入:

# fix batch norm nodes
for node in gd.node:if node.op == 'RefSwitch':node.op = 'Switch'for index in xrange(len(node.input)):if 'moving_' in node.input[index]:node.input[index] = node.input[index] + '/read'elif node.op == 'AssignSub':node.op = 'Sub'if 'use_locking' in node.attr: del node.attr['use_locking']

参考:https://github.com/tensorflow/tensorflow/issues/3628

I met the same issue when I was trying to export graph and variables by saved_model module. And finally I found a walk around to fix this issue:

Remove the TRAIN_OPS collections from graph collection. e.g.:

with dnn.graph.as_default():del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]

The dumped graph may not be available for training again (by tflearn), but should be able to perform prediction and evaluation. This is useful when serving model by another module or language (e.g. tensorflow serving or tensorflow go binding). I'll do more further tests about this.

If you wanna re-train the model, please use the builtin "save" method and re-construction the graph and load the saved data when re-training.

2、可能需要在代码修改这行,

output_node_names = "FullyConnected/Softmax" # NOTE: Change here

参考:https://gist.github.com/morgangiraud/249505f540a5e53a48b0c1a869d370bf#file-medium-tffreeze-1-py

@vparikh10 @ratfury @rakashi I faced the same situation just like you.
From what I understood, you may have to change this line according to your network definition.
In my case, instead of having output_node_names = "Accuracy/prediction", I have output_node_names = "FullyConnected_2/Softmax".

I made this change after reading this suggestion

对我自己而言,写成softmax或者Softmax都是不行的!然后我将所有的node names打印出来:打印方法:
    with tf.Session() as sess:model = get_cnn_model(max_len, volcab_size)model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True, batch_size=1000, n_epoch=1)init_op = tf.initialize_all_variables()sess.run(init_op)for v in sess.graph.get_operations():print(v.name)

然后确保output_node_names在里面。


附:gist里的代码,将output node names转换为参数
import os, argparseimport tensorflow as tf# The original freeze_graph function
# from tensorflow.python.tools.freeze_graph import freeze_graph 

dir = os.path.dirname(os.path.realpath(__file__))def freeze_graph(model_dir, output_node_names):"""Extract the sub graph defined by the output nodes and convert all its variables into constant Args:model_dir: the root folder containing the checkpoint state fileoutput_node_names: a string, containing all the output node's names, comma separated"""if not tf.gfile.Exists(model_dir):raise AssertionError("Export directory doesn't exists. Please specify an export ""directory: %s" % model_dir)if not output_node_names:print("You need to supply the name of a node to --output_node_names.")return -1# We retrieve our checkpoint fullpathcheckpoint = tf.train.get_checkpoint_state(model_dir)input_checkpoint = checkpoint.model_checkpoint_path# We precise the file fullname of our freezed graphabsolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])output_graph = absolute_model_dir + "/frozen_model.pb"# We clear devices to allow TensorFlow to control on which device it will load operationsclear_devices = True# We start a session using a temporary fresh Graphwith tf.Session(graph=tf.Graph()) as sess:# We import the meta graph in the current default Graphsaver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)# We restore the weights
        saver.restore(sess, input_checkpoint)# We use a built-in TF helper to export variables to constantsoutput_graph_def = tf.graph_util.convert_variables_to_constants(sess, # The session is used to retrieve the weightstf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) # Finally we serialize and dump the output graph to the filesystemwith tf.gfile.GFile(output_graph, "wb") as f:f.write(output_graph_def.SerializeToString())print("%d ops in the final graph." % len(output_graph_def.node))return output_graph_defif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument("--model_dir", type=str, default="", help="Model folder to export")parser.add_argument("--output_node_names", type=str, default="", help="The name of the output nodes, comma separated.")args = parser.parse_args()freeze_graph(args.model_dir, args.output_node_names)

转载于:https://www.cnblogs.com/bonelee/p/8445261.html

将tflearn的模型保存为pb,给TensorFlow使用相关推荐

  1. tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛...

    tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用. 数据目录在data,data下放了汉字识别图片: data$ ls 0  1  10  11  12  13  14  1 ...

  2. 模型保存的序列化文件pb 什么是PB文件 pb是protocol(协议) buffer(缓冲)的缩写

    pb是protocol(协议) buffer(缓冲)的缩写 TensorFlow 模型保存为pb文件的解释,怎么使用pb文件/模型的Save and Restore_u014264373的博客-CSD ...

  3. TensorFlow模型保存和提取方法

    2019独角兽企业重金招聘Python工程师标准>>> 一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存 ...

  4. TensorFlow模型保存和提取方法(含滑动平均模型)

    一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将Tens ...

  5. Tensorflow2.3用SaveModel保存训练模型.pb等文件+opencvino转IR文件

    #在Tensorflow2.3框架下训练的深度网络模型,目标用于工业部署,所以需要将训练好的模型保存,转成用于C++工业部署的文件格式,期间不断遇到问题,最终成功转换,记录下过程,避免后人踩坑. 1. ...

  6. keras中的模型保存和加载

    tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...

  7. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tfw1 = tf.Variable(tf.constant(2.0, shape=[1]), name=& ...

  8. 深度学习模型保存_Web服务部署深度学习模型

    本文的目的是介绍如何使用Web服务快速部署深度学习模型,虽然TF有TFserving可以进行模型部署,但是对于Pytorch无能为力(如果要使用的话需要把torch模型进行转换,有些麻烦):因此,本文 ...

  9. keras保存模型_TF2 8.模型保存与加载

    举个例子:先训练出一个模型 import 接下来第一种方法:只保留模型的参数:这个有2种方法: model.save_weights("adasd.h5")model.load_w ...

最新文章

  1. 关于 java.util.concurrent 您不知道的 5 件事--转
  2. java 调用SAP RFC函数错误信息集锦
  3. 解决Git中fatal: refusing to merge unrelated histories(亲测)
  4. jcmd:JDK14中的调试神器
  5. java statement 返回类型,6.3 返回类型和返回语句 | Return type Return statement
  6. science量子计算机,第一快讯|《Science》量子计算机被证明超越了经典计算机
  7. 机器学习真的可以起作用吗?(1)
  8. centOS 安装及部署 SVN
  9. PAT 乙级 1045 快速排序
  10. Redis与Redisson的分布式锁
  11. EasyAR4.1平面识别
  12. python 全国省市区列表查询
  13. 软件测试工程师的简历怎么写?
  14. Window10屏幕亮度无法调节尝试解决方法
  15. 互动大屏,unity透明视频的实现方法:
  16. 微信公众号开发(1)微信公众号简介
  17. c语言中if函数应用举例,if函数(if函数的应用举例)
  18. Python 分析《三国演义》看司马懿三父子如何用计谋干掉了曹操后代
  19. 摩羯座|摩羯座性格分析
  20. 用Ultra-Light-Fast-Generic-Face-Detector-1MB寻找人眼

热门文章

  1. handler回调主线程_Android使用Handler实现子线程与子线程、子线程与主线程之间通信...
  2. java handler使用方法_Android中Handler的使用方法及实例(基础回顾)
  3. 南开15计算机基础,南开大学计算机基础06-07_B卷
  4. linux+7+logger,linux日志logger命令详解
  5. mysql-connector-net-6.7.4.msi,在ActiveReports中使用MySQL数据库
  6. 【以太坊】web3.js的1.0版本和0.2.0版本的安装及区别
  7. git第一次提交代码到码云,git pull 报错:fatal: refusing to merge unrelated histories
  8. 【响应式Web前端设计】!important的用法及作用
  9. multiprocessing python_Python多线程/进程(threading、multiprocessing)知识覆盖详解
  10. 怎么升级浏览器_下载的chrome无法访问此网站怎么解决