TensorFlow Lite 开发手册(5)——TensorFlow Lite模型使用实例(分类模型)

  • (一)新建CLion工程
  • (二)编写Cmakelist
  • (三)编写main.cpp
  • (四)下载预训练模型
  • (五)修改模型配置
  • (六)运行实例

(一)新建CLion工程

到(https://download.csdn.net/download/weixin_42499236/11892106)下载该工程,解压后如下图所示:

(二)编写Cmakelist

cmake_minimum_required(VERSION 3.15)
project(testlite)set(CMAKE_CXX_STANDARD 14)include_directories(/home/ai/CLionProjects/tensorflow-master/)
include_directories(/home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/downloads/flatbuffers/include)
include_directories(/home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/downloads/absl)add_executable(testlite main.cpp bitmap_helpers.cc utils.cc)target_link_libraries(testlite /home/ai/CLionProjects/tensorflow-master/tensorflow/lite/tools/make/gen/linux_x86_64/lib/libtensorflow-lite.a -lpthread -ldl -lrt)

(三)编写main.cpp

  • 导入头文件
#include <fcntl.h>      // NOLINT(build/include_order)
#include <getopt.h>     // NOLINT(build/include_order)
#include <sys/time.h>   // NOLINT(build/include_order)
#include <sys/types.h>  // NOLINT(build/include_order)
#include <sys/uio.h>    // NOLINT(build/include_order)
#include <unistd.h>     // NOLINT(build/include_order)#include <iostream>
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>#include "bitmap_helpers.h"
#include "get_top_n.h"#include "tensorflow/lite/model.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/optional_debug_tools.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "absl/memory/memory.h"
#include "utils.h"using namespace std;
  • 调用GPU、NNAPI加速(若无GPU,则默认使用CPU)
#define LOG(x) std::cerrdouble get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
using TfLiteDelegatePtrMap = std::map<std::string, TfLiteDelegatePtr>;// 调用GPU
TfLiteDelegatePtr CreateGPUDelegate(tflite::label_image::Settings* s) {#if defined(__ANDROID__)TfLiteGpuDelegateOptionsV2 gpu_opts = TfLiteGpuDelegateOptionsV2Default();gpu_opts.inference_preference =TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED;gpu_opts.is_precision_loss_allowed = s->allow_fp16 ? 1 : 0;return evaluation::CreateGPUDelegate(s->model, &gpu_opts);
#elsereturn tflite::evaluation::CreateGPUDelegate(s->model);
#endif
}TfLiteDelegatePtrMap GetDelegates(tflite::label_image::Settings* s) {TfLiteDelegatePtrMap delegates;if (s->gl_backend) {auto delegate = CreateGPUDelegate(s);if (!delegate) {LOG(INFO) << "GPU acceleration is unsupported on this platform.";} else {delegates.emplace("GPU", std::move(delegate));}}if (s->accel) {auto delegate = tflite::evaluation::CreateNNAPIDelegate();if (!delegate) {LOG(INFO) << "NNAPI acceleration is unsupported on this platform.";} else {delegates.emplace("NNAPI", tflite::evaluation::CreateNNAPIDelegate());}}return delegates;
}
  • 读取标签文件
TfLiteStatus ReadLabelsFile(const string& file_name,std::vector<string>* result,size_t* found_label_count) {std::ifstream file(file_name);if (!file) {LOG(FATAL) << "Labels file " << file_name << " not found\n";return kTfLiteError;}result->clear();string line;while (std::getline(file, line)) {result->push_back(line);}*found_label_count = result->size();const int padding = 16;while (result->size() % padding) {result->emplace_back();}return kTfLiteOk;
}
  • 打印模型节点信息
void PrintProfilingInfo(const tflite::profiling::ProfileEvent* e,uint32_t subgraph_index, uint32_t op_index,TfLiteRegistration registration) {// output something like// time (ms) , Node xxx, OpCode xxx, symblic name//      5.352, Node   5, OpCode   4, DEPTHWISE_CONV_2DLOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3)<< (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0<< ", Subgraph " << std::setw(3) << std::setprecision(3)<< subgraph_index << ", Node " << std::setw(3)<< std::setprecision(3) << op_index << ", OpCode " << std::setw(3)<< std::setprecision(3) << registration.builtin_code << ", "<< EnumNameBuiltinOperator(static_cast<tflite::BuiltinOperator>(registration.builtin_code))<< "\n";
}
  • 定义模型推理函数
void RunInference(tflite::label_image::Settings* s){if (!s->model_name.c_str()) {LOG(ERROR) << "no model file name\n";exit(-1);}// 读取.tflite模型std::unique_ptr<tflite::FlatBufferModel> model;std::unique_ptr<tflite::Interpreter> interpreter;model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());if (!model) {LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";exit(-1);}s->model = model.get();LOG(INFO) << "Loaded model " << s->model_name << "\n";model->error_reporter();LOG(INFO) << "resolved reporter\n";
// 生成解释器tflite::ops::builtin::BuiltinOpResolver resolver;tflite::InterpreterBuilder(*model, resolver)(&interpreter);if (!interpreter) {LOG(FATAL) << "Failed to construct interpreter\n";exit(-1);}interpreter->UseNNAPI(s->old_accel);interpreter->SetAllowFp16PrecisionForFp32(s->allow_fp16);
// 打印解释器参数,包括张量大小、输入节点名称等if (s->verbose) {LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n";int t_size = interpreter->tensors_size();for (int i = 0; i < t_size; i++) {if (interpreter->tensor(i)->name)LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "<< interpreter->tensor(i)->bytes << ", "<< interpreter->tensor(i)->type << ", "<< interpreter->tensor(i)->params.scale << ", "<< interpreter->tensor(i)->params.zero_point << "\n";}}if (s->number_of_threads != -1) {interpreter->SetNumThreads(s->number_of_threads);}// 定义输入图像参数int image_width = 224;int image_height = 224;int image_channels = 3;
// 读取bmp图像std::vector<uint8_t> in = tflite::label_image::read_bmp(s->input_bmp_name, &image_width,&image_height, &image_channels, s);int input = interpreter->inputs()[0];if (s->verbose) LOG(INFO) << "input: " << input << "\n";const std::vector<int> inputs = interpreter->inputs();const std::vector<int> outputs = interpreter->outputs();if (s->verbose) {LOG(INFO) << "number of inputs: " << inputs.size() << "\n";LOG(INFO) << "number of outputs: " << outputs.size() << "\n";}// 创建图auto delegates_ = GetDelegates(s);for (const auto& delegate : delegates_) {if (interpreter->ModifyGraphWithDelegate(delegate.second.get()) !=kTfLiteOk) {LOG(FATAL) << "Failed to apply " << delegate.first << " delegate.";} else {LOG(INFO) << "Applied " << delegate.first << " delegate.";}}if (interpreter->AllocateTensors() != kTfLiteOk) {LOG(FATAL) << "Failed to allocate tensors!";}if (s->verbose) PrintInterpreterState(interpreter.get());// 获取输入张量元数据的维度等信息TfLiteIntArray* dims = interpreter->tensor(input)->dims;int wanted_height = dims->data[1];int wanted_width = dims->data[2];int wanted_channels = dims->data[3];// 对图像进行resizeswitch (interpreter->tensor(input)->type) {case kTfLiteFloat32:s->input_floating = true;tflite::label_image::resize<float>(interpreter->typed_tensor<float>(input), in.data(),image_height, image_width, image_channels, wanted_height,wanted_width, wanted_channels, s);break;case kTfLiteUInt8:tflite::label_image::resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(),image_height, image_width, image_channels, wanted_height,wanted_width, wanted_channels, s);break;default:LOG(FATAL) << "cannot handle input type "<< interpreter->tensor(input)->type << " yet";exit(-1);}// 调用解释器auto profiler =absl::make_unique<tflite::profiling::Profiler>(s->max_profiling_buffer_entries);interpreter->SetProfiler(profiler.get());if (s->profiling) profiler->StartProfiling();if (s->loop_count > 1)for (int i = 0; i < s->number_of_warmup_runs; i++) {if (interpreter->Invoke() != kTfLiteOk) {LOG(FATAL) << "Failed to invoke tflite!\n";}}
// 进行模型推理并计算运行时间struct timeval start_time, stop_time;gettimeofday(&start_time, nullptr);for (int i = 0; i < s->loop_count; i++) {if (interpreter->Invoke() != kTfLiteOk) {LOG(FATAL) << "Failed to invoke tflite!\n";}}gettimeofday(&stop_time, nullptr);LOG(INFO) << "invoked \n";LOG(INFO) << "average time: "<< (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)<< " ms \n";
// 打印运行事件if (s->profiling) {profiler->StopProfiling();auto profile_events = profiler->GetProfileEvents();for (int i = 0; i < profile_events.size(); i++) {auto subgraph_index = profile_events[i]->event_subgraph_index;auto op_index = profile_events[i]->event_metadata;const auto subgraph = interpreter->subgraph(subgraph_index);const auto node_and_registration =subgraph->node_and_registration(op_index);const TfLiteRegistration registration = node_and_registration->second;PrintProfilingInfo(profile_events[i], subgraph_index, op_index,registration);}}const float threshold = 0.001f;std::vector<std::pair<float, int>> top_results;// 获取Top-N结果int output = interpreter->outputs()[0];TfLiteIntArray* output_dims = interpreter->tensor(output)->dims;// assume output dims to be something like (1, 1, ... ,size)auto output_size = output_dims->data[output_dims->size - 1];switch (interpreter->tensor(output)->type) {case kTfLiteFloat32:tflite::label_image::get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,s->number_of_results, threshold, &top_results, true);break;case kTfLiteUInt8:tflite::label_image::get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),output_size, s->number_of_results, threshold,&top_results, false);break;default:LOG(FATAL) << "cannot handle output type "<< interpreter->tensor(input)->type << " yet";exit(-1);}std::vector<string> labels;size_t label_count;if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)exit(-1);
// 打印Top-N结果for (const auto& result : top_results) {const float confidence = result.first;const int index = result.second;LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n";}
}int main() {tflite::label_image::Settings s;RunInference(&s);
}

(四)下载预训练模型

# Get model
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz | tar xzv -C /tmp# Get labels
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz  | tar xzv -C /tmp  mobilenet_v1_1.0_224/labels.txtmv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/

(五)修改模型配置

在label_image.h中修改Settings:

struct Settings {bool verbose = false;bool accel = false;bool old_accel = false;bool input_floating = false;bool profiling = false;bool allow_fp16 = false;bool gl_backend = false;int loop_count = 1;float input_mean = 127.5f;float input_std = 127.5f;string model_name = "/home/ai/CLionProjects/tflite/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.tflite";tflite::FlatBufferModel* model;string input_bmp_name = "/home/ai/CLionProjects/tflite/grace_hopper.bmp";string labels_file_name = "/home/ai/CLionProjects/tflite/mobilenet_v1_1.0_224/labels.txt";string input_layer_type = "uint8_t";int number_of_threads = 4;int number_of_results = 5;int max_profiling_buffer_entries = 1024;int number_of_warmup_runs = 2;
};

(六)运行实例

Top5分类结果输出如下:

Loaded model /tmp/mobilenet_v1_1.0_224.tflite
resolved reporter
invoked
average time: 68.12 ms
0.860174: 653 653:military uniform
0.0481017: 907 907:Windsor tie
0.00786704: 466 466:bulletproof vest
0.00644932: 514 514:cornet, horn, trumpet, trump
0.00608029: 543 543:drumstick

结果显示该图像被正确分类,平均耗时68.12ms,速度非常快!

TensorFlow Lite 开发手册(5)——TensorFlow Lite模型使用实例(分类模型)相关推荐

  1. 机器学习数据预处理之缺失值:预测填充(回归模型填充、分类模型填充)

    机器学习数据预处理之缺失值:预测填充(回归模型填充.分类模型填充) garbage in, garbage out. 没有高质量的数据,就没有高质量的数据挖掘结果,数据值缺失是数据分析中经常遇到的问题 ...

  2. 【图像分割模型】实例分割模型—DeepMask

    这是专栏<图像分割模型>的第11篇文章.在这里,我们将共同探索解决分割问题的主流网络结构和设计思想. 本文介绍了用于实例分割任务的模型结构--DeepMask. 作者 | 孙叔桥 编辑 | ...

  3. 线性插值改变图像尺寸_【图像分割模型】实例分割模型—DeepMask

    这是专栏<图像分割模型>的第11篇文章.在这里,我们将共同探索解决分割问题的主流网络结构和设计思想. 本文介绍了用于实例分割任务的模型结构--DeepMask. 作者 | 孙叔桥 编辑 | ...

  4. R语言VaR市场风险计算方法与回测、用LOGIT逻辑回归、PROBIT模型信用风险与分类模型...

    全文链接:http://tecdat.cn/?p=27530  市场风险指的是由金融市场中资产的价格下跌或价格波动增加所导致的可能损失. 相关视频 市场风险包含两种类型:相对风险和绝对风险.绝对风险关 ...

  5. 人口预测和阻尼-增长模型_使用分类模型预测利率-第1部分

    人口预测和阻尼-增长模型 A couple of years ago, I started working for a quant company called M2X Investments, an ...

  6. 人口预测和阻尼-增长模型_使用分类模型预测利率-第3部分

    人口预测和阻尼-增长模型 This is the final article of the series " Predicting Interest Rate with Classifica ...

  7. 人口预测和阻尼-增长模型_使用分类模型预测利率-第2部分

    人口预测和阻尼-增长模型 We are back! This post is a continuation of the series "Predicting Interest Rate w ...

  8. R语言VaR市场风险计算方法与回测、用Logit逻辑回归、Probit模型信用风险与分类模型

    最近我们被客户要求撰写关于信用风险与分类的研究报告,包括一些图形和统计输出. 市场风险指的是由金融市场中资产的价格下跌或价格波动增加所导致的可能损失. 市场风险包含两种类型:相对风险和绝对风险.绝对风 ...

  9. 文本分类模型_文本分类模型之TextCNN

    六年的大学生涯结束了,目前在搜索推荐岗位上继续进阶,近期正好在做类目预测多标签分类的项目,因此把相关的模型记录总结一下,便于后续查阅总结. 一.理论篇: 在我们的场景中,文本数据量比较大,因此直接采用 ...

最新文章

  1. 计算机虚拟网络毕业论文,计算机毕业论文——基于WEB的虚拟计算机网络实验平台.doc...
  2. 配置php7.2.4支持swoole2.1.1扩展
  3. QT信号和槽函数学习笔记
  4. 吴恩达作业7:梯度下降优化算法
  5. C++对象模型3--无重写的单继承
  6. Django中的form模块的高级处理
  7. Sqoop是一款开源的工具,主要用于在HADOOP(Hive)与传统的数据库(mysql、oracle...)间进行数据的传递...
  8. 高级与低级编程语言的解释,哪一种更容易上手?
  9. excel 打开文件后自动卡死的解决方法
  10. php codeigniter 语言,CodeIgniter多语言实现方法详解
  11. 面试经历---YY欢聚时代
  12. 面试官没想到我对redis数据结构这么了解,直接给offer
  13. 【历史上的今天】7 月 29 日:Win10 七周年;微软和雅虎的搜索协议;微软发行 NT 4.0
  14. 宝塔Linux面板 软件商店中安装不了任何php版本的解决方法
  15. 冒险岛2无限服务器断开,冒险岛2无限龙无限命版
  16. win10与内置ubuntu之间复制粘贴操作
  17. 完整的项目管理流程包括什么?
  18. 【面试】北京航天无人机系统工程研究所
  19. 机械手定位(带角度)的思路及3点计算旋转中心
  20. 洛谷P1562 还是N皇后(DFS+状态压缩+位运算)

热门文章

  1. Django 数据库常用字段类型
  2. 【北大青鸟天府校区的Java专业怎么样?】
  3. [BZOJ3238] [AHOI2013] 差异 - 后缀自动机
  4. 子类继承父类,父类实现接口,子类中调用父类和接口的同名成员变量会出现歧义
  5. 直接法-穷举、递推和迭代
  6. CF785C (1600)
  7. SMS发送WapPush
  8. 解决Error inflating class com.google.android.material.appbar.CollapsingToolbarLayout
  9. CS党必须了解的P/NP常识
  10. 分布式存储系统——HBase