文章目录

  • 引言
  • 基础概念
    • onnx:跨框架的模型表达标准
    • onnxruntime:部署模型的推理引擎
  • 示例代码
    • 0)安装onnx和onnxruntime
    • 1)pytorch模型转onnx模型
    • 2)onnx模型检验
    • 3)调用ONNX Runtime测试输入图片
  • 参考教程

引言

目前训练阶段最流行的框架是pytorch和tensorflow,训练完的模型常常需要集成到现有的应用程序中、部署到不同类型的平台上(比如云端、边端等)。

如果要在每个平台上实现所有模型的框架,会极大增加环境的复杂性,优化不同框架和硬件的所有组合非常耗时

所以需要一种通用的解决方案,来集成、部署、优化不同框架训练出的模型。ONNX就是为了解决这个问题而诞生。

现在常用的跨框架部署方案是:

  • 训练模型→中间模型(如onnx)→推理引擎(如onnxruntime、TensorRT)

基础概念

onnx:跨框架的模型表达标准

开放式神经网络交换(ONNX),为深度学习和传统ML的AI模型提供了一种开源格式。它定义了一个可扩展的计算图形模型,以及内置运算符和标准数据类型的定义。

当前,ONNX专注于推理所需的功能(暂不支持训练的相关功能)。

Microsoft 和合作伙伴社区创建了 ONNX 作为表示机器学习模型的开放标准。 许多框架(包括 TensorFlow、PyTorch、SciKit-Learn、Keras、Chainer、MXNet、MATLAB 和 SparkML)中的模型都可以导出或转换为标准 ONNX 格式。 模型采用 ONNX 格式后,可在各种平台和设备上运行。

onnxruntime:部署模型的推理引擎

onnxruntime是一种用于将 ONNX 模型部署到生产环境的高性能推理引擎,与许多流行的ML / DNN框架兼容,包括PyTorch,TensorFlow / Keras,scikit-learn等。github地址

  • 针对云和 Edge 进行了优化,适用于 Linux、Windows 和 Mac。

  • 使用 C++ 编写,还包含 C、Python、C#、Java 和 Javascript (Node.js) API,可在各种环境中使用

  • 同时支持 DNN 和传统 ML 模型,并与不同硬件上的加速器(例如,NVidia GPU 上的 TensorRT、Intel 处理器上的 OpenVINO、Windows 上的 DirectML 等)集成。

采用ONNX Runtime的主要优点是:

  • 提高各种ML模型的推理性能
  • 减少训练大型模型的时间和成本
  • 使用Python进行训练,但可以部署到C#/ C ++ / Java应用程序中
  • 在不同的硬件和操作系统上运行
  • 支持在多个不同框架中创建的模型

API文档:https://www.onnxruntime.ai/python/index.html

示例代码

0)安装onnx和onnxruntime

# pytorch一般自带onnx
conda install -c conda-forge onnx
# onnxruntime的其他版本有可能报错ImportError: cannot import name 'get_all_providers'
pip install onnxruntime       # CPU build
pip install onnxruntime-gpu   # GPU build

如果需要编译安装,可以参照ONNX Runtime 的GitHub仓库地址 上的说明。

需要注意的是:

  • 默认安装方式只支持Python3
  • windows下编译安装需要Visual C++ 2019 runtime,且只支持win10及以上版本

1)pytorch模型转onnx模型

利用torch.onnx.export导出模型到ONNX格式。

import torch.onnx
import torchvision# Standard ImageNet input - 3 channels, 224x224,
# values don't matter as we care about network structure.
# But they can also be real inputs.
dummy_input = torch.randn(1, 3, 224, 224)
# Obtain your model, it can be also constructed in your script explicitly
model = torchvision.models.alexnet(pretrained=True)
# Invoke export
torch.onnx.export(model, dummy_input, "alexnet.onnx")

2)onnx模型检验

检查onnx模型,并打印模型的结构表示。onnx模型还可以通过可视化工具(如 Netron)查看。

import onnx# Load the ONNX model
model = onnx.load("alexnet.onnx")# Check that the IR is well formed
onnx.checker.check_model(model)# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

3)调用ONNX Runtime测试输入图片

采用python API,加载一张图片测试我们转化后的onnx模型。


import numpy as np # we're going to use numpy to process input and output data
import onnxruntime # to inference ONNX models, we use the ONNX Runtime
import time
from PIL import Imagedef load_labels(path):with open(path) as f:data = json.load(f)return np.asarray(data)# 图像预处理
def preprocess(input_data):# convert the input data into the float32 inputimg_data = input_data.astype('float32')#normalizemean_vec = np.array([0.485, 0.456, 0.406])stddev_vec = np.array([0.229, 0.224, 0.225])norm_img_data = np.zeros(img_data.shape).astype('float32')for i in range(img_data.shape[0]):norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]#add batch channelnorm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')return norm_img_datadef softmax(x):x = x.reshape(-1)e_x = np.exp(x - np.max(x))return e_x / e_x.sum(axis=0)def postprocess(result):return softmax(np.array(result)).tolist()# Load the raw image
img = Image.open("D://Pic//cat.jpg")
img = img.resize((224, 224), Image.BILINEAR)
print("Image size: ", img.size)image_data = np.array(img).transpose(2, 0, 1)
input_data = preprocess(image_data)# Run the model on the backend
session = onnxruntime.InferenceSession('.//alexnet.onnx', None)# get the name of the first input of the model
input_name = session.get_inputs()[0].name
print('Input Name:', input_name)# Inference
start = time.time()
raw_result = session.run([], {input_name: input_data})
end = time.time()
res = postprocess(raw_result)inference_time = np.round((end - start) * 1000, 2)
idx = np.argmax(res)print('========================================')
print('Final top prediction is: %d'% idx)
print('========================================')print('========================================')
print('Inference time: ' + str(inference_time) + " ms")
print('========================================')

参考教程

[1] https://github.com/onnx/tutorials/blob/master/tutorials/PytorchOnnxExport.ipynb
[2] https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
[3] https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb

PyTorch模型部署:pth转onnx跨框架部署详解+代码相关推荐

  1. Spring基于注解TestContext 测试框架使用详解

    原创整理不易,转载请注明出处:Spring基于注解TestContext 测试框架使用详解 代码下载地址:http://www.zuidaima.com/share/1775574182939648. ...

  2. 将训练好的pytorch模型的pth文件转换成onnx模型(亲测成功)

    将训练好的pytorch模型的pth文件转换成onnx模型(亲测成功) 模型转换 声明:本文原创,未经许可严禁转载,原文地址https://blog.csdn.net/hutao1030813002/ ...

  3. 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用

    首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...

  4. pytorch保存模型pth_浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式

    我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save ...

  5. linux系统pkl,pytorch模型 .pt, .pth, .pkl有什么区别及如何保存

    pytorch模型 .pt, .pth, .pkl有什么区别及如何保存 发布时间:2020-07-22 10:47:44 来源:亿速云 阅读:371 作者:小猪 小编这次要给大家分享的是pytorch ...

  6. 网易考拉海购Dubbok框架优化详解

    网易考拉海购Dubbok框架优化详解 摘要:微服务化是当前电商产品演化的必然趋势,网易考拉海购通过微服务化打破了业务爆发增长的架构瓶颈.本文结合网易考拉海购引用的开源Dubbo框架,分享支持考拉微服务 ...

  7. java定时任务框架elasticjob详解

    这篇文章主要介绍了java定时任务框架elasticjob详解,Elastic-Job是ddframe中dd-job的作业模块中分离出来的分布式弹性作业框架.该项目基于成熟的开源产品Quartz和Zo ...

  8. Android UI 测试框架Espresso详解

    Android UI 测试框架Espresso详解 1. Espresso测试框架 2.提供Intents Espresso 2.1.安装 2.2.为Espresso配置Gradle构建文件 2.3. ...

  9. Android 进阶——Framework 核心之Android Storage Access Framework(SAF)存储访问框架机制详解(二)

    文章大纲 引言 一.DirectFragment 1.当选中DirectoryFragment中RecyclerView的Item时 2.选中DirectoryFragment中RecyclerVie ...

最新文章

  1. 深挖之后吓一跳,谷歌AI专利何止一个dropout,至少30项今日生效
  2. Spy++的使用方法及下载
  3. data URI scheme及其应用
  4. bootstrap datetimepicker、bootstrap datepicker日期组件对范围的简单封装
  5. 这 5 条 IntelliJ IDEA 调试技巧太强了!
  6. 泛微数字化督查督办平台:不见面也能高效落实工作、管理到位
  7. CodeSmith注册
  8. 22条创业军规,让你5分钟读完《创业维艰》
  9. 阿里P7级别面试经验总结,面试心得体会
  10. 计算机管理员绩效指标,网络管理员绩效kpi考核标准..doc
  11. C# asp.net图片拼接方法
  12. 观察者模式(java)浅析
  13. qq看点模块测试用例
  14. vm怎么装vim_安装nginx报-bash: vm: command not found..错误提示vim文本编辑器命令没有安装...
  15. wind上怎么连接mysql_windows上连接mysql数据库怎么连接
  16. python递归算法(斐波那契数列,汉诺塔、二分法查找)
  17. 第4天:python的数据类型、用户交互以及基本运算符
  18. python程序填空_python练习题-基础巩固-第一周
  19. 今天这教程难度有点高,反爬虫之跳过淘宝滑块验证!爬虫必会教程
  20. 制作网页游戏的页面。(开始网页,登录账号网页和进入网页)

热门文章

  1. JS window事件全集解析 (转载)
  2. DreamFactory 第8章 保护您的DreamFactory环境
  3. ZooKeeper配额指南
  4. 通过Gogs部署git仓库
  5. java判断字符串中是否含有某个字符串
  6. CQRS之旅——旅程6(我们系统的版本管理)
  7. Docker Swarm 初步认识 及 集群搭建
  8. 刷新echart控件
  9. 【Html】Html基本标记
  10. C#LeetCode刷题-图