TensorFlow:将ckpt文件固化成pb文件

本文是将yolo3目标检测框架训练出来的ckpt文件固化成pb文件,主要利用了GitHub上的该项目。

为什么要最终生成pb文件呢?简单来说就是直接通过tf.saver保存行程的ckpt文件其变量数据和图是分开的。我们知道TensorFlow是先画图,然后通过placeholde往图里面喂数据。这种解耦形式存在的方法对以后的迁移学习以及对程序进行微小的改动提供了极大的便利性。但是对于训练好,以后不再改变的话这种存在就不再需要。一方面,ckpt文件储存的数据都是变量,既然我们不再改动,就应当让其变成常量,直接‘烧’到图里面。另一方面,对于线上的模型,我们一般是通过C++或者C语言编写的程序进行调用。所以一般模型最终形式都是应该写成pb文件的形式。

由于这次的程序直接从GitHub上下载后改动较小就能够运行,也就是自己写了很少一部分程序。因此进行调试的时候还出现了以前根本没有注意的一些小问题,同时发现自己对TensorFlow还需要更加详细的去研读。

首先对程序进行保存的时候,利用 saver = tf.train.Saver(), saver.save(sess,checkpoint_path,global_step=global_step)对训练的数据进行保存,保存格式为ckpt。但是在恢复的时候一直提示有问题,(其恢复语句为:saver = tf.train.Saver(), saver.restore(sess,ckpt_path),其中,ckpt_path是保存ckpt的文件夹路径)。出现问题的原因我估计是因为我是按照每50个epoch进行保存,而不是让其进行固定次数的batch进行保存,这种固定batch次数的保存系统会自动保存最近5次的ckpt文件(该方法的ckpt_path=tf.train,latest_checkpoint('ckpt/')进行回复)。那么如何将利用epoch的次数进行保存呢(这种保存不是近5次的保存,而是每进行一次保存就会留下当时保存的ckpt,而那种按照batch的会在第n次保存,会将n-5次的删除,n>5)。

我们可以利用:ckpt = tf.train.get_checkpoint_state(ckpt_path),获取最新的ckptpoint文件,然后利用saver.restore(sess,ckpt.checkpoint_path)进行恢复。当然为了安全起见,应该对ckpt和ckpt.checkpoint_path进行判断是否存在后,再进行恢复语句的调用。即:

  1. saver = tf.train.Saver()

  2. ckpt = tf.train.get_checkpoint_state(model_path)

  3. if ckpt and ckpt.model_checkpoint_path:

  4. saver.restore(sess, ckpt.model_checkpoint_path)

对于固化网络,网上有很多的介绍。之所以再介绍,还是由于是用了别人的网络而不是自己的网络遇到的坑。在固化时候我们需要知道输出tensor的名字,而再恢复的时候我们需要知道placeholder的名字。但是,如果网络复杂或者别人的网络命名比较复杂,或者name=,根本就没有自己命名而用的系统自定义的,这样捋起来还是比较费劲的。当时在网上查找的一些方法,像打印整个网络变量的方法(先不管输出的网路名称,甚至随便起一个名字,先固化好pb文件,然后对pb文件进行读取,最后打印变量的名字:

  1. graph = tf.get_default_graph()

  2. input_graph_def = graph.as_graph_def()

  3. output_graph_def = graph_util.convert_variables_to_constants(

  4. sess,

  5. input_graph_def,

  6. ['cls_score/cls_score', 'cls_prob'] # We split on comma for convenience

  7. )

  8. with tf.gfile.GFile(output_graph, "wb") as f:

  9. f.write(output_graph_def.SerializeToString())

  10. print ('开始打印节点名字')

  11. for op in graph.get_operations():

  12. print(op.name)

  13. print("%d ops in the final graph." % len(output_graph_def.node))

这样尽然也能打印出来(尽管输出名字是随便命名的)。但是打印出来的根本对不上,其实不可能对的上,因为打印出来的是变量名(也就是训练的数据),不是输出结果。

那么怎么办?答案简单的让我也很无语。其实,对ckpt进行数据恢复的时候,直接打印输出的tensor名字就可以。比如说在saver以及placeholder定义的时候:output = model.yolo_inference(images, config.num_anchors / 3, config.num_classes, is_training),我们在后面跟一句:print output,从打印出来的信息即可查看。placeholder的查看方法同样如此。

 对网络进行固化:

 代码:

  1. input_image_shape = tf.placeholder(dtype = tf.int32, shape = (2,))

  2. input_image = tf.placeholder(shape = [None, 416, 416, 3], dtype = tf.float32)

  3. predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)

  4. boxes, scores, classes = predictor.predict(input_image, input_image_shape)

  5. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

  6. saver = tf.train.Saver()

  7. ckpt = tf.train.get_checkpoint_state(model_path)

  8. if ckpt and ckpt.model_checkpoint_path:

  9. saver.restore(sess, ckpt.model_checkpoint_path)

  10. # 采用meta 结构加载,不需要知道网络结构

  11. # saver = tf.train.import_meta_graph(model_path, clear_devices=True)

  12. # 这里的model_path是model.ckpt.meta文件的全路径

  13. # ckpt_model_path 是保存模型的文件夹路径

  14. # saver.restore(sess, tf.train.latest_checkpoint(ckpt_model_path))

  15. graph = tf.get_default_graph()

  16. input_graph_def = graph.as_graph_def()

  17. output_graph_def = graph_util.convert_variables_to_constants(

  18. sess,

  19. input_graph_def,

  20. ['concat_11','concat_12','concat_13'] # We split on comma for convenience

  21. )

  22. # # Finally we serialize and dump the output graph to the filesystem

  23. with tf.gfile.GFile(output_graph, "wb") as f:

  24. f.write(output_graph_def.SerializeToString())

由于固化的时候是需要先恢复ckpt网络的,所以还是在restore前写了placeholder和输出tensor的定义(需要注点意的是,我们保存的ckpt文件是训练阶段的graph和变量等,其inference输出和最终predict的输出的Tensor不一样,因此predict与inference的输出相比,还包括了一些后处理,比如说nms等等,只有这些后处理也是TensorFlow框架内的方法写的,才能使最终形成的pb文件能够做到输入一张图片,直接输出最终结果。因此,对于目标检测任务,把后处理任务也交由TensorFlow内的api来实现,可免去夸平台读取pb文件后仍然需要重新进行后处理等相关程序的编写带来的不必要麻烦)。然后结合保存变量的那个文件(ckpt),将变量恢复到inference过程所需的变量数据(predict包括inference和eval两个过程,训练过程只有inference和loss过程参与,而预测过程多了一个后处理eval过程,eval过程无变量。这样在生成pb文件的时候也把后处理eval固化进去。喂给网络数据,即可得到输出tensor。

由于有读者在此问到了还是没有弄明白'concat_11','concat_12','concat_13'是如何得来的,我在这里就在详细说一下:

是这样的,在我们恢复网络的时候肯定需要知道saver这个对象的,在这里介绍两种方法生成这个对象的方法。

一:

saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)

其中meta_graph_location就是保存模型时的.meta文件的路径。保存后有四个文件(checkpoint、.index、.data-00000-of-00001和.meta文件)。.meta文件就是整个TensorFlow的结构图。

二:

saver = tf.train.Saver()

本文采用的是第二种方法(上面已经有详细的代码),由于这种方法得到的saver对象,他不知道具体图是什么样的,因此在恢复前我有用如下代码

predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
boxes, scores, classes = predictor.predict(input_image, input_image_shape)

把整个结构又加载了一遍。如果采用第一种方法,是不需要在重写这两行代码的。

我们要的就是 boxes, scores, classes这三个tensor的结果,并且想知道他们三个tensor的名字。你直接利用print(boxes, scores, classes)打印出来这三个tensor就会出来这三个tensor具体信息(包括名字,和shape,dtype等)。这个只是利用第二种方法得到saver对象,然后恢复ckpt文件,不涉及到固化pb文件问题。固化pb文件是需要知道这三个tensor的名字,所以需要打印看一下。

如果说,我只拿到了保存后的四个文件(checkpoint、.index、.data-00000-of-00001和.meta文件),其相应用代码写成的结构图不清楚,比如说利用这两行代码:

predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path)
boxes, scores, classes = predictor.predict(input_image, input_image_shape)

画出的结构图,我不知道。那么,想要知道具体的placehold和输出tensor的名字,这个我就不太清楚了。

读取pb文件:

代码:

  1. def pb_detect(image_path, pb_model_path):

  2. os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_index

  3. image = Image.open(image_path)

  4. resize_image = letterbox_image(image, (416, 416))

  5. image_data = np.array(resize_image, dtype = np.float32)

  6. image_data /= 255.

  7. image_data = np.expand_dims(image_data, axis = 0)

  8. with tf.Graph().as_default():

  9. output_graph_def = tf.GraphDef()

  10. with open(pb_model_path, "rb") as f:

  11. output_graph_def.ParseFromString(f.read())

  12. tf.import_graph_def(output_graph_def, name="")

  13. with tf.Session() as sess:

  14. sess.run(tf.global_variables_initializer())

  15. input_image_tensor = sess.graph.get_tensor_by_name("Placeholder_1:0")

  16. input_image_tensor_shape = sess.graph.get_tensor_by_name("Placeholder:0")

  17. # 定义输出的张量名称

  18. #output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

  19. boxes = sess.graph.get_tensor_by_name("concat_11:0")

  20. scores = sess.graph.get_tensor_by_name("concat_12:0")

  21. classes = sess.graph.get_tensor_by_name("concat_13:0")

  22. # 读取测试图片

  23. # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字(需要在名字后面加:0),不是操作节点的名字

  24. out_boxes, out_scores, out_classes= sess.run([boxes,scores,classes],

  25. feed_dict={

  26. input_image_tensor: image_data,

  27. input_image_tensor_shape: [image.size[1], image.size[0]]

  28. })

可以看到读取pb文件只需要比恢复ckpt文件容易的多,直接将placeholder的名字获取到,将数据输入恢复的网络,以及读取输出即可。

小记:

有可能是TensorFlow版本更新或者其他原因,在后来工作中加载pb文件是报错了:ValueError: Fetch argument <tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024) dtype=float32> cannot be interpreted as a Tensor. (tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024), dtype=float32) is not an element of this graph.)

将上面读取pb文件的代码with tf.Graph().as_default():改成

  1. global graph

  2. graph = tf.get_default_graph()

TensorFlow:将ckpt文件固化成pb文件相关推荐

  1. 利用Python批量将csv文件转化成xml文件

    文章目录 一.前言 二.Python代码实现 一.前言 将 csv 格式转换成xml格式有许多方法,可以用数据库的方式,也有许多软件可以将 csv 转换成xml.但是比较麻烦,本文利用 Python ...

  2. Pcap文件转化成Pcd文件

    通过RSview将点云文件保存成了Pcap格式,但这种格式不能很好的支持PCL点云库,故不能很好的实现点云获取.滤波.分割.配准.检索.特征提取.识别.追踪.曲面重建.可视化等,因此我们需要将Pcap ...

  3. QT Designer 生成的ui文件转化成py文件以及简单使用

    QT Designer 生成的ui文件转化成py文件以及简单使用 设计 转换 使用 方法一 方法二 设计 首先用QTdesigner 设计你的界面 然后保存成.ui文件 这一步大家应该都可以做到,就不 ...

  4. sarscape 将dem文件转化成stl_STL文件,一种前处理网格划分技术??

    源:吴冠中作品 点击关注CAE仿真空间, 点亮"在看",优质内容不错过对于从事专业仿真技术的工程师而言,我们已经习惯而且能够熟练的利用诸如ANSA.Hypermesh等网格划分前处 ...

  5. 利用pandas实现json文件转化成csv文件

    补充上篇博客提到的json文件数据转换成csv文件.作为数据分析最常用文件类型json与csv,出于一定情况下,我们需要将json中字典类型的数据,转换为csv存储,这又用到python强大的pand ...

  6. Vaa3D_批量将tiff文件转化成v3draw文件

    通过这个代码可以实现输入tif.tiff文件所在的文件夹,在对应的位置生成一个文件夹然后在该文件夹中生成v3draw图像.如果不需要可以注释这部分代码. QString folder=QFileDia ...

  7. 一份basic文件转化成c文件(自己收藏)

    basic的原文件,先通过BCX转化器,转化一下,注意转化过来的东东,错误是N的多哦. 得动手自己修改: 10   REM  ************************************* ...

  8. ipynb文件转化成py文件

    在当前文件夹运行cmd, 输入 jupyter nbconvert --to script xxx.ipynb [注]xxx.ipynb是需要转换成py的文件名称. 搞定!

  9. 使用C#把Tensorflow训练的.pb文件用在生产环境

    训练了很久的Tf模型,终于要到生产环境中去考验一番了.今天花费了一些时间去研究tf的模型如何在生产环境中去使用.大概整理了这些方法. 继续使用分步骤保存了的ckpt文件 这个貌似脱离不了tensorf ...

最新文章

  1. cufflinks基于dataframe数据绘制股票数据:直方图、时序图
  2. 2019全球信息通信业热点回顾
  3. switch( )的经典引用
  4. QT:触摸屏支持手指触摸,增加touch事件touchevent,记录前后touch坐标并处理
  5. MAC上安装iTerm2+oh my zsh+设置Dracula主题
  6. HTTP流媒体播放技术发展以及nginx点播源站
  7. SAP UI5 应用开发教程之六十六 - 基于 OData V4 的 SAP UI5 表格控件如何实现删除功能试读版
  8. 从零认识单片机(9)
  9. ubuntu 启动图形界面命令_Windows 10 远程连接 Ubuntu 18.04 Server图形界面
  10. vue项目编写html,从头搭建、编写一个VUE项目
  11. Git学习系列(三)版本回退和管理文件的修改及删除操作
  12. 机器学习代码实战——KMeans(聚类)
  13. MySQL命令行登录数据库
  14. SpringCloud七:配置中心Eureka+Config+Bus+RabbitMQ
  15. 一个Scrapy爬虫实例
  16. 在线LaTeX公式编辑器(备忘)
  17. Adobe将支持HTTP流媒体直播 预示着ipad将可以用flash吗?
  18. css元素可拖动,css3实现可拖动的魔方3d效果
  19. Spark Structured Steaming实战
  20. 智能合约Smart Contract技术详解

热门文章

  1. 如何在Linux使用Eclipse + CDT开发C/C++程序? (OS) (Linux) (C/C++) (gcc) (g++)
  2. AWS-CLI-V2-Install
  3. 使用 shell 脚本对 Linux 系统和进程资源进行监控
  4. crytojs加密 java解密,使用CryptoJS在Javascript中加密并在Java中解密
  5. 数据结构 单链表 C
  6. mysql查询时间between and_Mysql中用between...and...查询日期时注意事项
  7. mongo在哪创建管理员_如何给mongodb管理员权限
  8. python面向对象代码示例
  9. mysql中如何将一个表中的部分记录合并,MySQL数据库将多条记录的单个字段合并成一条记录_MySQL...
  10. fritz_如何使用Fritz.ai将机器学习应用于Android