ONNX Runtime:由微软推出,用于优化和加速机器学习推理和训练,适用于ONNX模型,是一个跨平台推理和训练机器学习加速器(ONNX Runtime is a cross-platform inference and training machine-learning accelerator),源码地址:https://github.com/microsoft/onnxruntime,最新发布版本为v1.11.1,License为MIT:

      1.ONNX Runtime Inferencing:高性能推理引擎

      (1).可在不同的操作系统上运行,包括Windows、Linux、Mac、Android、iOS等;

      (2).可利用硬件增加性能,包括CUDA、TensorRT、DirectML、OpenVINO等;

      (3).支持PyTorch、TensorFlow等深度学习框架的模型,需先调用相应接口转换为ONNX模型;

      (4).在Python中训练,确可部署到C++/Java等应用程序中。

      2.ONNX Runtime Training:于2021年4月发布,可加快PyTorch对模型训练,可通过CUDA加速,目前多用于Linux平台。

      通过conda命令安装执行:

conda install -c conda-forge onnxruntime

      以下为测试代码:通过ResNet-50对图像进行分类

import numpy as np
import onnxruntime
import onnx
from onnx import numpy_helper
import urllib.request
import os
import tarfile
import json
import cv2# reference: https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb
def download_onnx_model():labels_file_name = "imagenet-simple-labels.json"model_tar_name = "resnet50v2.tar.gz"model_directory_name = "resnet50v2"if os.path.exists(model_tar_name) and os.path.exists(labels_file_name):print("files exist, don't need to download")else:print("files don't exist, need to download ...")onnx_model_url = "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.tar.gz"imagenet_labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"# retrieve our model from the ONNX model zoourllib.request.urlretrieve(onnx_model_url, filename=model_tar_name)urllib.request.urlretrieve(imagenet_labels_url, filename=labels_file_name)print("download completed, start decompress ...")file = tarfile.open(model_tar_name)file.extractall("./")file.close()return model_directory_name, labels_file_namedef load_labels(path):with open(path) as f:data = json.load(f)return np.asarray(data)def images_preprocess(images_path, images_name):input_data = []for name in images_name:img = cv2.imread(images_path + name)img = cv2.resize(img, (224, 224))img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)data = np.array(img).transpose(2, 0, 1)#print(f"name: {name}, opencv image shape(h,w,c): {img.shape}, transpose shape(c,h,w): {data.shape}")# convert the input data into the float32 inputdata = data.astype('float32')# normalizemean_vec = np.array([0.485, 0.456, 0.406])stddev_vec = np.array([0.229, 0.224, 0.225])norm_data = np.zeros(data.shape).astype('float32')for i in range(data.shape[0]):norm_data[i,:,:] = (data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]# add batch channelnorm_data = norm_data.reshape(1, 3, 224, 224).astype('float32')input_data.append(norm_data)return input_datadef softmax(x):x = x.reshape(-1)e_x = np.exp(x - np.max(x))return e_x / e_x.sum(axis=0)def postprocess(result):return softmax(np.array(result)).tolist()def inference(onnx_model, labels, input_data, images_name, images_label):session = onnxruntime.InferenceSession(onnx_model, None)# get the name of the first input of the modelinput_name = session.get_inputs()[0].namecount = 0for data in input_data:print(f"{count+1}. image name: {images_name[count]}, actual value: {images_label[count]}")count += 1raw_result = session.run([], {input_name: data})res = postprocess(raw_result)idx = np.argmax(res)print(f"  result: idx: {idx}, label: {labels[idx]}, percentage: {round(res[idx]*100, 4)}%")sort_idx = np.flip(np.squeeze(np.argsort(res)))print("  top 5 labels are:", labels[sort_idx[:5]])def main():model_directory_name, labels_file_name = download_onnx_model()labels = load_labels(labels_file_name)print("the number of categories is:", len(labels)) # 1000images_path = "../../data/image/"images_name = ["5.jpg", "6.jpg", "7.jpg", "8.jpg", "9.jpg", "10.jpg"]images_label = ["goldfish", "hen", "ostrich", "crocodile", "goose", "sheep"]if len(images_name) != len(images_label):print("Error: images count and labes'length don't match")returninput_data = images_preprocess(images_path, images_name)onnx_model = model_directory_name + "/resnet50v2.onnx"inference(onnx_model, labels, input_data, images_name, images_label)print("test finish")if __name__ == "__main__":main()

测试图像如下所示:

执行结果如下所示:

GitHub: https://github.com/fengbingchun/PyTorch_Test

ONNX Runtime介绍相关推荐

  1. ONNX Runtime使用简单介绍

    前面系列博客中有用tensorRT.OpenVINO加速模型推理 TensorRT加速方法介绍(python pytorch模型)_竹叶青lvye的博客-CSDN博客_tensorrt加速 OpenV ...

  2. Optimum + ONNX Runtime: 更容易、更快地训练你的 Hugging Face 模型

    介绍 基于语言.视觉和语音的 Transformer 模型越来越大,以支持终端用户复杂的多模态用例.增加模型大小直接影响训练这些模型所需的资源,并随着模型大小的增加而扩展它们.Hugging Face ...

  3. 微软推出了Cloud Native Application Bundles和开源ONNX Runtime

    微软的Microsoft Connect(); 2018年的开发者大会 对Azure和IoT Edge服务进行了大量更新; Windows Presentation Foundation,Window ...

  4. Spring框架Runtime介绍(导包)

    一.Spring框架Runtime介绍,如图 1.1 Test: Spring提供测试功能 1.2 Core Container:Spring核心容器,Spring启动的基本条件, 1.2.1 Bea ...

  5. ONNX Runtime: ubutnu16.04编译 (编到怀疑人生)

    ONNX Runtime: ubutnu16.04编译 1. 前言 ONNX Runtime是什么? ONNX Runtime是适用于Linux,Windows和Mac上ONNX格式的机器学习模型的高 ...

  6. Pytorch的pth模型转onnx,再用ONNX Runtime调用推理(附python代码)

    我做的是一个简单的二分类任务,用了个vgg11,因为要部署到应用,所以将 PyTorch 中定义的模型转换为 ONNX 格式,然后在 ONNX Runtime 中运行它,那就不用了在机子上配pytor ...

  7. 【YOLOv5】手把手教你使用LabVIEW ONNX Runtime部署 TensorRT加速,实现YOLOv5实时物体识别(含源码)

    文章目录 前言 一.TensorRT简介 二.准备工作 三.YOLOv5模型的获取 1.下载源码 2.安装模块 3.下载预训练模型 4.转换为onnx模型 四.LabVIEW使用TensorRT加速Y ...

  8. iOS开发中runtime介绍

    一.runtime简介 RunTime简称运行时.OC就是运行时机制,也就是在运行时候的一些机制,其中最主要的是消息机制. 对于C语言,函数的调用在编译的时候会决定调用哪个函数. 对于OC的函数,属于 ...

  9. runtime介绍及基本使用

    1. 概念 runtime(运行时系统),是一套基于C语言API,包含在 <objc/runtime.h>和<objc/message.h>中,运行时系统的功能是在运行期间(而 ...

最新文章

  1. ABAP实例之ALV
  2. vue组件一直注册不了_【报Bug】现在究竟支不支持Vue.use内注册组件
  3. spark运行时加载hive,hdfs配置文件
  4. SAP Cloud for Customer里根据External Reference搜索销售订单
  5. C语言 void 指针 - C语言零基础入门教程
  6. 【算法导论】第7章快速排序
  7. protobuf 中的嵌套消息的使用 主要对set_allocated_和mutable_的使用
  8. 录像带转存电脑的方法_《波西亚时光》录像带使用方法介绍
  9. vscode php插件_vscode+phpstudy+xdebug无法断点(踩坑记)
  10. Class create, device create, device create file
  11. 力扣116. 填充每个节点的下一个右侧节点指针(JavaScript)
  12. 马云获福布斯终身成就奖;华为推出首款 4G 芯片 Balong 711;PyPy 7.2 发布 | 极客头条...
  13. 计算机表演赛bug,只会编程序,敲代码,找bug?不,他们保研浙大、去美国进修……...
  14. 20190908每日一句
  15. MySQL · myrocks · 事务锁分析
  16. 工具类 --UUIDUtil ---32位UUID生成器
  17. 政府安全资讯精选 2017年第四期:聚焦美国网络安全新动态
  18. php文件显示文字乱码怎么解决,php遍历到的文件是中文文件名 显示为乱码 该如何解决...
  19. 小程序视频旋转的相关问题
  20. 影响计算机速度的有哪些配件,影响电脑上网速度的重要因素有哪些?

热门文章

  1. 图解 负载均衡算法及分类
  2. 未来教育计算机二级操作题素材,未来教育计算机二级的操作题答案.doc
  3. 速卖通排序规则,优化产品信息,让店铺引流更精准
  4. 基于matlab解决配料问题及其衍生问题和灵敏度分析
  5. The request was rejected because the URL contained a potentially malicious String ;报错解决
  6. 利用python+selenium带上cookies自动登录bilibili
  7. 吃瓜吃瓜:国内首档程序员真人秀来袭!天才程序员鹿死谁手?
  8. Windows 10 Build 14997中Edge浏览器已默认阻止Flash运行
  9. 就大学毕业典礼的演讲所感触的
  10. pta习题2-2 阶梯电价