ONNX是开放式神经网络(Open Neural Network Exchange)的简称,主要由微软和合作伙伴社区创建和维护。很多深度学习训练框架(如Tensorflow, PyTorch, Scikit-learn, MXNet等)的模型都可以导出或转换为标准的ONNX格式,采用ONNX格式作为统一的界面,各种嵌入式平台就可以只需要解析ONNX格式的模型而不用支持多种多样的训练框架,本文主要介绍如何通过代码或JSON文件的形式来构造一个ONNX单算子模型或者整个graph,以及使用ONNX Runtime进行推理得到算子或模型的计算结果。

一. ONNX文件格式

ONNX文件是基于Protobuf进行序列化。了解Protobuf协议的同学应该知道,Protobuf都会有一个*.proto的文件定义协议,ONNX的该协议定义在https://github.com/onnx/onnx/blob/master/onnx/onnx.proto3 文件中。

从onnx.proto3协议中我们需要重点知道的数据结构如下:

  • ModelProto:模型的定义,包含版本信息,生产者和GraphProto。
  • GraphProto: 包含很多重复的NodeProto, initializer, ValueInfoProto等,这些元素共同构成一个计算图,在GraphProto中,这些元素都是以列表的方式存储,连接关系是通过Node之间的输入输出进行表达的。
  • NodeProto: onnx的计算图是一个有向无环图(DAG),NodeProto定义算子类型,节点的输入输出,还包含属性。
  • ValueInforProto: 定义输入输出这类变量的类型。
  • TensorProto: 序列化的权重数据,包含数据的数据类型,shape等。
  • AttributeProto: 具有名字的属性,可以存储基本的数据类型(int, float, string, vector等)也可以存储onnx定义的数据结构(TENSOR, GRAPH等)。

二. Python API

2.1 搭建ONNX模型

ONNX是用DAG来描述网络结构的,也就是一个网络(Graph)由节点(Node)和边(Tensor)组成,ONNX提供的helper类中有很多API可以用来构建一个ONNX网络模型,比如make_node, make_graph, make_tensor等,下面是一个单个Conv2d的网络构造示例:

import onnx
from onnx import helper
from onnx import TensorProto
import numpy as npweight = np.random.randn(36)
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2, 4, 4])
W = helper.make_tensor('W', TensorProto.FLOAT, [2, 2, 3, 3], weight)
B = helper.make_tensor('B', TensorProto.FLOAT, [2], [1.0, 2.0])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 2, 2, 2])node_def = helper.make_node('Conv', # node name['X', 'W', 'B'],['Y'], # outputs# attributesstrides=[2,2],)graph_def = helper.make_graph([node_def],'test_conv_mode',[X], # graph inputs[Y], # graph outputsinitializer=[W, B],
)mode_def = helper.make_model(graph_def, producer_name='onnx-example')
onnx.checker.check_model(mode_def)
onnx.save(mode_def, "./Conv.onnx")

搭建的这个Conv算子模型使用netron可视化如下图所示:

这个示例演示了如何使用helper的make_tensor_value_info, make_mode, make_graph, make_model等方法来搭建一个onnx模型。

相比于PyTorch或其它框架,这些API看起来仍然显得比较繁琐,一般我们也不会用ONNX来搭建一个大型的网络模型,而是通过其它框架转换得到一个ONNX模型。

2.2 Shape Inference

很多时候我们从pytorch, tensorflow或其他框架转换过来的onnx模型中间节点并没有shape信息,如下图所示:

我们经常希望能直接看到网络中某些node的shape信息,shape_inference模块可以推导出所有node的shape信息,这样可视化模型时将会更友好:

import onnx
from onnx import shape_inferenceonnx_model = onnx.load("./test_data/mobilenetv2-1.0.onnx")
onnx_model = shape_inference.infer_shapes(onnx_model)
onnx.save(onnx_model, "./test_data/mobilenetv2-1.0_shaped.onnx")

可视化经过shape_inference之后的模型如下图:

2.3 ONNX Optimizer

ONNX的optimizer模块提供部分图优化的功能,例如最常用的:fuse_bn_into_conv,fuse_pad_into_conv等等。

查看onnx支持的优化方法:

from onnx import optimizer
all_passes = optimizer.get_available_passes()
print("Available optimization passes:")
for p in all_passes:print(p)
print()

应用图优化到onnx模型上进行变换:

passes = ['fuse_bn_into_conv']
# Apply the optimization on the original model
optimized_model = optimizer.optimize(onnx_model, passes)

将mobile net v2应用fuse_bn_into_conv之后,BatchNormalization的参数合并到了Conv的weight和bias参数中,如下图所示:

三. ONNX Runtime计算ONNX模型

onnx本身只是一个协议,定义算子与模型结构等,不涉及具体的计算。onnx runtime是类似JVM一样将ONNX格式的模型运行起来的解释器,包括对模型的解析、图优化、后端运行等。

安装onnx runtime:

python3 -m pip install onnxruntime

推理:

import onnx
import onnxruntime as ort
import numpy as np
import cv2def preprocess(img_data):mean_vec = np.array([0.485, 0.456, 0.406])stddev_vec = np.array([0.229, 0.224, 0.225])norm_img_data = np.zeros(img_data.shape).astype('float32')for i in range(img_data.shape[0]):# for each pixel in each channel, divide the value by 255 to get value between [0, 1] and then normalizenorm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]return norm_img_dataimg = cv2.imread("test_data/dog.jpeg")
img = cv2.resize(img, (224,224), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_data = np.transpose(img, (2, 0, 1))
input_data = preprocess(input_data)
input_data = input_data.reshape([1, 3, 224, 224])
sess = ort.InferenceSession("test_data/mobilenetv2-1.0.onnx")
input_name = sess.get_inputs()[0].name
result = sess.run([], {input_name: input_data})
result = np.reshape(result, [1, -1])
index = np.argmax(result)
print("max index:", index)

ONNX构建并运行模型相关推荐

  1. ONNX系列三 --- 使用ONNX使PyTorch AI模型可移植

    目录 PyTorch简介 导入转换器 快速浏览模型 将PyTorch模型转换为ONNX 摘要和后续步骤 参考文献 下载源547.1 KB 系列文章列表如下: ONNX系列一 --- 带有ONNX的便携 ...

  2. TensorFlow:实战Google深度学习框架(一)计算、数据、运行模型

    第3章 TensorFlow入门 3.1 TensorFlow计算模型--计算图 3.1.1 计算图的概念 3.1.2 计算图的使用 3.2 TensorFlow数据模型--张量 3.2.1 张量的概 ...

  3. 通过Dapr实现一个简单的基于.net的微服务电商系统(十二)——istio+dapr构建多运行时服务网格...

    多运行时是一个非常新的概念.在 2020 年,Bilgin Ibryam 提出了 Multi-Runtime(多运行时)的理念,对基于 Sidecar 模式的各种产品形态进行了实践总结和理论升华.那到 ...

  4. java源文件编译成jar_从源文件和JAR文件构建Java代码模型

    java源文件编译成jar 最近,我花了一些时间来研究有效java ,该方法正在GitHub上达到300星(可以免费帮助实现目标:D). Effectivejava是在您的Java代码上运行查询的工具 ...

  5. 从源文件和JAR文件构建Java代码模型

    最近,我花了一些时间来研究有效java ,该方法正在GitHub上达到300星(随时帮助实现目标:D). Effectivejava是在您的Java代码上运行查询的工具. 它基于我参与的另一个项目ja ...

  6. TensorFlow2.0(二)--Keras构建神经网络分类模型

    Keras构建分类模型 1. tf.keras简介 2. 利用tf.keras构建神经网络分类模型 2.1 导入相应的库 2.2 数据读取与展示 2.3 数据归一化 2.4 构建模型 2.5 模型的编 ...

  7. onnx实现对pytorch模型推理加速

    向AI转型的程序员都关注了这个号???????????? 人工智能大数据与深度学习  公众号:datayx 微软宣布将多平台通用ONNX机器学习引擎开源,此举将让机器学习框架,向着机器学习框架的标准化 ...

  8. 使用SpaCy构建自定义 NER 模型

    什么是NER? 命名实体识别(NER)是一种自然语言处理技术,用于在给定的文本内容中提取适当的实体,并将提取的实体分类到预定义的类别下. 简单来说,NER 是一种用于从给定文本中提取诸如人名.地名.公 ...

  9. Pytorch版本MobileNetV3转ONNX然后转om模型使用Pyacl离线推理

    Pytorch版本MobileNetV3转ONNX然后转om模型使用Pyacl离线推理 概述:本文主要讲述把MobileNet转成华为Altas服务器离线推理om模型的过程,本人在转换过程中也遇到过比 ...

最新文章

  1. 流程控制if、while、for
  2. Dijkstra 贪心算法 动态规划
  3. 走差异化发展路线思想
  4. anaconda 怎么安装xlrd_Anaconda 安装 tensorflow 和 keras
  5. 【ZOJ - 3780】Paint the Grid Again(拓扑排序,图论,证明性质)
  6. Drupal常用开发工具(一)——Devel模块
  7. MapBalanceReduce介绍
  8. 用servlet进行用户名和密码校验
  9. springboot定时删除log4j_SpringBoot整合log4j2进行日志配置及防坑指南
  10. 计算机的发展经历了选择题,计算机发展历程的相关选择题.doc
  11. select * 与 count(*)数量不一致_技术分享 | MySQL:count(*)、count(字段) 实现上区别
  12. python tkinter listbox控件 简书_python tkinter模块的控件操作(1)
  13. 【CAD】DWF文件格式详细说明,清晰易懂
  14. 速领,阿里巴巴Java开发手册终极版
  15. 个人微信api接口调用,微信好友收发消息
  16. opencv的下载与安装
  17. python定量城市研究_借助Python来实现的定量城市研究
  18. 使用react-cropper-pro实现图片裁切压缩上传
  19. 二维码签到的几大优势,你了解几个?
  20. RDIFramework.NET ━ .NET快速信息化系统开发框架 V2.8 版本━新增岗位管理-WinForm部分

热门文章

  1. python中如何遍历26个英文字母?三种办法
  2. python随机数写入excel
  3. 文件扩展名,你知道这些吗?(续)
  4. 2018-2019-2 20165334『网络对抗技术』Exp5:MSF基础应用
  5. Java Quene
  6. 自然语言处理之情感分析
  7. getapp.php,getApp.php
  8. Java中常见的集合框架及常用的方法
  9. EfficientNet Backbone结构解析 -- 以EfficientNet-B0为例说明
  10. 用python 求矩形最大面积_LeetCode 84. 柱状图中最大的矩形 | Python