现在的深度学习框架一般都是基于 Python 来实现,构建、训练、保存和调用模型都可以很容易地在 Python 下完成。但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直接调用 TensorFlow 的 C/C++ 接口来导入 TensorFlow 预训练好的模型。

1.环境配置 点此查看 C/C++ 接口的编译

2. 导入预定义的图和训练好的参数值

    // set up your input pathsconst string pathToGraph = "/home/senius/python/c_python/test/model-10.meta";const string checkpointPath = "/home/senius/python/c_python/test/model-10";auto session = NewSession(SessionOptions()); //&emsp;创建会话if (session == nullptr){throw runtime_error("Could not create Tensorflow session.");}Status status;// Read in the protobuf graph we exportedMetaGraphDef graph_def;status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);&emsp; //&emsp;导入图模型if (!status.ok()){throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());}// Add the graph to the sessionstatus = session->Create(graph_def.graph_def());&emsp; //&emsp;将图模型加入到会话中if (!status.ok()){throw runtime_error("Error creating graph: " + status.ToString());}// Read weights from the saved checkpointTensor checkpointPathTensor(DT_STRING, TensorShape());checkpointPathTensor.scalar<std::string>()() = checkpointPath;&emsp;// 读取预训练好的权重status = session->Run({{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor},}, {},{graph_def.saver_def().restore_op_name()}, nullptr);if (!status.ok()){throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());}
复制代码

3. 准备测试数据

    const string filename = "/home/senius/python/c_python/test/04t30t00.npy";//Read TXT data to arrayfloat Array[1681*41];ifstream is(filename);for (int i = 0; i < 1681*41; i++){is >> Array[i];}is.close();tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 41, 41, 41, 1}));auto input_tensor_mapped = input_tensor.tensor<float, 5>();float *pdata = Array;// copying the data into the corresponding tensorfor (int x = 0; x < 41; ++x)//depth{for (int y = 0; y < 41; ++y) {for (int z = 0; z < 41; ++z) {const float *source_value = pdata + x * 1681 + y * 41 + z;input_tensor_mapped(0, x, y, z, 0) = *source_value;}}}
复制代码
  • 本例中输入数据是一个 [None, 41, 41, 41, 1] 的张量,我们需要先从 TXT 文件中读出测试数据,然后正确地填充到张量中去。

4. 前向传播得到预测值

    std::vector<tensorflow::Tensor> finalOutput;std::string InputName = "X"; // Your input placeholder's namestd::string OutputName = "sigmoid"; // Your output tensor's namevector<std::pair<string, Tensor> > inputs;inputs.push_back(std::make_pair(InputName, input_tensor));// Fill input tensor with your input datasession->Run(inputs, {OutputName}, {}, &finalOutput);auto output_y = finalOutput[0].scalar<float>();std::cout << output_y() << "\n";
复制代码
  • 通过给定输入和输出张量的名字,我们可以将测试数据传入到模型中,然后进行前向传播得到预测值。

5. 一些问题

  • 本模型是在 TensorFlow 1.4 下训练的,然后编译 TensorFlow 1.4 的 C++ 接口可以正常调用模型,但若是想调用更高版本训练好的模型,则会报错,据出错信息猜测可能是高版本的 TensorFlow 中添加了一些低版本没有的函数,所以不能正常运行。
  • 若是编译高版本的 TensorFlow ,比如最新的 TensorFlow 1.11 的 C++ 接口,则无论是调用旧版本训练的模型还是新版本训练的模型都不能正常运行。出错信息如下:Error loading checkpoint from /media/lab/data/yongsen/Tensorflow_test/test/model-40: Invalid argument: Session was not created with a graph before Run()!,网上暂时也查不到解决办法,姑且先放在这里。

6. 完整代码

#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/io_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/parsing_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/array_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/math_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/data_flow_ops.h>#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
#include <fstream>using namespace std;
using namespace tensorflow;
using namespace tensorflow::ops;int main()
{// set up your input pathsconst string pathToGraph = "/home/senius/python/c_python/test/model-10.meta";const string checkpointPath = "/home/senius/python/c_python/test/model-10";auto session = NewSession(SessionOptions());if (session == nullptr){throw runtime_error("Could not create Tensorflow session.");}Status status;// Read in the protobuf graph we exportedMetaGraphDef graph_def;status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);if (!status.ok()){throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());}// Add the graph to the sessionstatus = session->Create(graph_def.graph_def());if (!status.ok()){throw runtime_error("Error creating graph: " + status.ToString());}// Read weights from the saved checkpointTensor checkpointPathTensor(DT_STRING, TensorShape());checkpointPathTensor.scalar<std::string>()() = checkpointPath;status = session->Run({{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor},}, {},{graph_def.saver_def().restore_op_name()}, nullptr);if (!status.ok()){throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());}cout << 1 << endl;const string filename = "/home/senius/python/c_python/test/04t30t00.npy";//Read TXT data to arrayfloat Array[1681*41];ifstream is(filename);for (int i = 0; i < 1681*41; i++){is >> Array[i];}is.close();tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 41, 41, 41, 1}));auto input_tensor_mapped = input_tensor.tensor<float, 5>();float *pdata = Array;// copying the data into the corresponding tensorfor (int x = 0; x < 41; ++x)//depth{for (int y = 0; y < 41; ++y) {for (int z = 0; z < 41; ++z) {const float *source_value = pdata + x * 1681 + y * 41 + z;
//                input_tensor_mapped(0, x, y, z, 0) = *source_value;input_tensor_mapped(0, x, y, z, 0) = 1;}}}std::vector<tensorflow::Tensor> finalOutput;std::string InputName = "X"; // Your input placeholder's namestd::string OutputName = "sigmoid"; // Your output placeholder's namevector<std::pair<string, Tensor> > inputs;inputs.push_back(std::make_pair(InputName, input_tensor));// Fill input tensor with your input datasession->Run(inputs, {OutputName}, {}, &finalOutput);auto output_y = finalOutput[0].scalar<float>();std::cout << output_y() << "\n";return 0;
}
复制代码
  • Cmakelist 文件如下
cmake_minimum_required(VERSION 3.8)
project(Tensorflow_test)set(CMAKE_CXX_STANDARD 11)set(SOURCE_FILES main.cpp)include_directories(/home/senius/tensorflow-r1.4/home/senius/tensorflow-r1.4/tensorflow/bazel-genfiles/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/protobuf/include/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/host_obj/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/proto/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/downloads/nsync/public/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/downloads/eigen/home/senius/tensorflow-r1.4/bazel-out/local_linux-py3-opt/genfiles
)add_executable(Tensorflow_test ${SOURCE_FILES})target_link_libraries(Tensorflow_test/home/senius/tensorflow-r1.4/bazel-bin/tensorflow/libtensorflow_cc.so/home/senius/tensorflow-r1.4/bazel-bin/tensorflow/libtensorflow_framework.so)
复制代码

获取更多精彩,请关注「seniusen」!

在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现相关推荐

  1. TensorFlow 调用预训练好的模型—— Python 实现

    1. 准备预训练好的模型 TensorFlow 预训练好的模型被保存为以下四个文件 data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如 ...

  2. Tensorflow基于pb模型进行预训练(pb模型转CKPT模型)

    Tensorflow基于pb模型进行预训练(pb模型转CKPT模型) 在网上看到很多教程都是tensorflow基于pb模型进行推理,而不是进行预训练.最近在在做项目的过程中发现之前的大哥只有一个pb ...

  3. 应用在生物医学领域中的NLP预训练语言模型(PubMedBERT)

    文章目录 1. 背景 2. 在生物医学和专业领域建立神经语言模型的新范式 3. 创建一个全面的基准和排行榜,以加快生物医学NLP的进度 4. PubMedBert:优于之前所有的语言模型,并获得最新生 ...

  4. 【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径(使用models.__dict__[model_name]()读取)

    说明 使用torchvision.model加载预训练好的模型时,发现默认下载路径在系统盘下面的用户目录下(这个你执行的时候就会发现),即C:\用户名\.cache\torch\.checkpoint ...

  5. Keras 的预训练权值模型用来进行预测、特征提取和微调(fine-tuning)

    转至:Keras中文文档 https://keras.io/zh/applications/ 应用 Applications Keras 的应用模块(keras.applications)提供了带有预 ...

  6. 【深度学习】预训练的卷积模型比Transformer更好?

    引言 这篇文章就是当下很火的用预训练CNN刷爆Transformer的文章,LeCun对这篇文章做出了很有深意的评论:"Hmmm".本文在预训练微调范式下对基于卷积的Seq2Seq ...

  7. UP-DETR:收敛更快!精度更高!华南理工微信开源无监督预训练目标检测模型...

    关注公众号,发现CV技术之美 0 写在前面 基于Transformer编码器-解码器结构的DETR达到了与Faster R-CNN类似的性能.受预训练Transformer在自然语言处理方面取得巨大成 ...

  8. MICCAI 2020 | 基于3D监督预训练的全身病灶检测SOTA(预训练代码和模型已公开)...

    关注公众号,发现CV技术之美 ▊ 研究背景介绍 由于深度学习任务往往依赖于大量的标注数据,医疗图像的样本标注又会涉及到较多的专业知识,标注人员需要对病灶的大小.形状.边缘等信息进行准确的判断,甚至需要 ...

  9. PromptCLUE:大规模多任务Prompt预训练中文开源模型

    简介 PromptCLUE:大规模多任务Prompt预训练中文开源模型. 中文上的三大统一:统一模型框架,统一任务形式,统一应用方式.支持几十个不同类型的任务,具有较好的零样本学习能力和少样本学习能力 ...

最新文章

  1. 基于r-Kernel的LiteOS操作系统
  2. 如何操作提升手机端网站的排名优化?
  3. android开发调用照相机
  4. Android中APK直接通过JNI访问驱动
  5. Boost::Exception提供的各种常用 error_info typedef的预期用途的测试
  6. Educational Codeforces Round 107 (Rated for Div. 2)
  7. mysql oa数据库设计_OA项目1:环境搭建之数据库创建与环境添加
  8. OpenCV中感兴趣区域的选取与检测(一)
  9. win32收不到F10按键消息解决的方法
  10. 安装neptune-client库
  11. jsp购物车加mysql_网上购物车(jsp+servlet+mysql)
  12. 薅羊毛!某东、某宝、某宁一次搞定~
  13. $(...).nicescroll is not a function报错分析
  14. Linux程序设计第二版练习题(第五章)
  15. python 报童模型
  16. Flash Programer 给CC2530下载Hex文件 error解决办法 汇总
  17. 【Python】断言(assert)
  18. 成本要素****没有被分配到成本组件结构01中的成本组件
  19. 2022/7/2 Jenkins详细教程
  20. 疯狂Python讲义学习笔记(含习题)之网络编程

热门文章

  1. 解决Geoserver请求跨域的几种思路
  2. 当try,catch,finally中均有return语句时,会返回哪一个?---finally中的return
  3. 浅析Memcache和Redis
  4. 专访中科创达王璠:怎样做好嵌入式人工智能的算法开发?
  5. ProgressDialog用法
  6. 极度舒适的 Python 入门教程,佩奇也能学会~
  7. 基因组组装程序linux,基因组组装软件SOAPdenovo安装使用
  8. 广外计算机考研专业课,【广外考研论坛】 21广外各专业考研问题全解答!纯干货!...
  9. java树形菜单_Java构建树形菜单
  10. Nginx关于浏览器缓存相关的配置指令