PyTorch模型部署:pth转onnx跨框架部署详解+代码
文章目录
- 引言
- 基础概念
- onnx:跨框架的模型表达标准
- onnxruntime:部署模型的推理引擎
- 示例代码
- 0)安装onnx和onnxruntime
- 1)pytorch模型转onnx模型
- 2)onnx模型检验
- 3)调用ONNX Runtime测试输入图片
- 参考教程
引言
目前训练阶段最流行的框架是pytorch和tensorflow,训练完的模型常常需要集成到现有的应用程序中、部署到不同类型的平台上(比如云端、边端等)。
如果要在每个平台上实现所有模型的框架,会极大增加环境的复杂性,优化不同框架和硬件的所有组合非常耗时。
所以需要一种通用的解决方案,来集成、部署、优化不同框架训练出的模型。ONNX就是为了解决这个问题而诞生。
现在常用的跨框架部署方案是:
- 训练模型→中间模型(如onnx)→推理引擎(如onnxruntime、TensorRT)
基础概念
onnx:跨框架的模型表达标准
开放式神经网络交换(ONNX),为深度学习和传统ML的AI模型提供了一种开源格式。它定义了一个可扩展的计算图形模型,以及内置运算符和标准数据类型的定义。
当前,ONNX专注于推理所需的功能(暂不支持训练的相关功能)。
Microsoft 和合作伙伴社区创建了 ONNX 作为表示机器学习模型的开放标准。 许多框架(包括 TensorFlow、PyTorch、SciKit-Learn、Keras、Chainer、MXNet、MATLAB 和 SparkML)中的模型都可以导出或转换为标准 ONNX 格式。 模型采用 ONNX 格式后,可在各种平台和设备上运行。
onnxruntime:部署模型的推理引擎
onnxruntime是一种用于将 ONNX 模型部署到生产环境的高性能推理引擎,与许多流行的ML / DNN框架兼容,包括PyTorch,TensorFlow / Keras,scikit-learn等。github地址
针对云和 Edge 进行了优化,适用于 Linux、Windows 和 Mac。
使用 C++ 编写,还包含 C、Python、C#、Java 和 Javascript (Node.js) API,可在各种环境中使用。
同时支持 DNN 和传统 ML 模型,并与不同硬件上的加速器(例如,NVidia GPU 上的 TensorRT、Intel 处理器上的 OpenVINO、Windows 上的 DirectML 等)集成。
采用ONNX Runtime的主要优点是:
- 提高各种ML模型的推理性能
- 减少训练大型模型的时间和成本
- 使用Python进行训练,但可以部署到C#/ C ++ / Java应用程序中
- 在不同的硬件和操作系统上运行
- 支持在多个不同框架中创建的模型
API文档:https://www.onnxruntime.ai/python/index.html
示例代码
0)安装onnx和onnxruntime
# pytorch一般自带onnx
conda install -c conda-forge onnx
# onnxruntime的其他版本有可能报错ImportError: cannot import name 'get_all_providers'
pip install onnxruntime # CPU build
pip install onnxruntime-gpu # GPU build
如果需要编译安装,可以参照ONNX Runtime 的GitHub仓库地址 上的说明。
需要注意的是:
- 默认安装方式只支持Python3
- windows下编译安装需要Visual C++ 2019 runtime,且只支持win10及以上版本
1)pytorch模型转onnx模型
利用torch.onnx.export导出模型到ONNX格式。
import torch.onnx
import torchvision# Standard ImageNet input - 3 channels, 224x224,
# values don't matter as we care about network structure.
# But they can also be real inputs.
dummy_input = torch.randn(1, 3, 224, 224)
# Obtain your model, it can be also constructed in your script explicitly
model = torchvision.models.alexnet(pretrained=True)
# Invoke export
torch.onnx.export(model, dummy_input, "alexnet.onnx")
2)onnx模型检验
检查onnx模型,并打印模型的结构表示。onnx模型还可以通过可视化工具(如 Netron)查看。
import onnx# Load the ONNX model
model = onnx.load("alexnet.onnx")# Check that the IR is well formed
onnx.checker.check_model(model)# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))
3)调用ONNX Runtime测试输入图片
采用python API,加载一张图片测试我们转化后的onnx模型。
import numpy as np # we're going to use numpy to process input and output data
import onnxruntime # to inference ONNX models, we use the ONNX Runtime
import time
from PIL import Imagedef load_labels(path):with open(path) as f:data = json.load(f)return np.asarray(data)# 图像预处理
def preprocess(input_data):# convert the input data into the float32 inputimg_data = input_data.astype('float32')#normalizemean_vec = np.array([0.485, 0.456, 0.406])stddev_vec = np.array([0.229, 0.224, 0.225])norm_img_data = np.zeros(img_data.shape).astype('float32')for i in range(img_data.shape[0]):norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]#add batch channelnorm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')return norm_img_datadef softmax(x):x = x.reshape(-1)e_x = np.exp(x - np.max(x))return e_x / e_x.sum(axis=0)def postprocess(result):return softmax(np.array(result)).tolist()# Load the raw image
img = Image.open("D://Pic//cat.jpg")
img = img.resize((224, 224), Image.BILINEAR)
print("Image size: ", img.size)image_data = np.array(img).transpose(2, 0, 1)
input_data = preprocess(image_data)# Run the model on the backend
session = onnxruntime.InferenceSession('.//alexnet.onnx', None)# get the name of the first input of the model
input_name = session.get_inputs()[0].name
print('Input Name:', input_name)# Inference
start = time.time()
raw_result = session.run([], {input_name: input_data})
end = time.time()
res = postprocess(raw_result)inference_time = np.round((end - start) * 1000, 2)
idx = np.argmax(res)print('========================================')
print('Final top prediction is: %d'% idx)
print('========================================')print('========================================')
print('Inference time: ' + str(inference_time) + " ms")
print('========================================')
参考教程
[1] https://github.com/onnx/tutorials/blob/master/tutorials/PytorchOnnxExport.ipynb
[2] https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
[3] https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb
PyTorch模型部署:pth转onnx跨框架部署详解+代码相关推荐
- Spring基于注解TestContext 测试框架使用详解
原创整理不易,转载请注明出处:Spring基于注解TestContext 测试框架使用详解 代码下载地址:http://www.zuidaima.com/share/1775574182939648. ...
- 将训练好的pytorch模型的pth文件转换成onnx模型(亲测成功)
将训练好的pytorch模型的pth文件转换成onnx模型(亲测成功) 模型转换 声明:本文原创,未经许可严禁转载,原文地址https://blog.csdn.net/hutao1030813002/ ...
- 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用
首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...
- pytorch保存模型pth_浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save ...
- linux系统pkl,pytorch模型 .pt, .pth, .pkl有什么区别及如何保存
pytorch模型 .pt, .pth, .pkl有什么区别及如何保存 发布时间:2020-07-22 10:47:44 来源:亿速云 阅读:371 作者:小猪 小编这次要给大家分享的是pytorch ...
- 网易考拉海购Dubbok框架优化详解
网易考拉海购Dubbok框架优化详解 摘要:微服务化是当前电商产品演化的必然趋势,网易考拉海购通过微服务化打破了业务爆发增长的架构瓶颈.本文结合网易考拉海购引用的开源Dubbo框架,分享支持考拉微服务 ...
- java定时任务框架elasticjob详解
这篇文章主要介绍了java定时任务框架elasticjob详解,Elastic-Job是ddframe中dd-job的作业模块中分离出来的分布式弹性作业框架.该项目基于成熟的开源产品Quartz和Zo ...
- Android UI 测试框架Espresso详解
Android UI 测试框架Espresso详解 1. Espresso测试框架 2.提供Intents Espresso 2.1.安装 2.2.为Espresso配置Gradle构建文件 2.3. ...
- Android 进阶——Framework 核心之Android Storage Access Framework(SAF)存储访问框架机制详解(二)
文章大纲 引言 一.DirectFragment 1.当选中DirectoryFragment中RecyclerView的Item时 2.选中DirectoryFragment中RecyclerVie ...
最新文章
- 深挖之后吓一跳,谷歌AI专利何止一个dropout,至少30项今日生效
- Spy++的使用方法及下载
- data URI scheme及其应用
- bootstrap datetimepicker、bootstrap datepicker日期组件对范围的简单封装
- 这 5 条 IntelliJ IDEA 调试技巧太强了!
- 泛微数字化督查督办平台:不见面也能高效落实工作、管理到位
- CodeSmith注册
- 22条创业军规,让你5分钟读完《创业维艰》
- 阿里P7级别面试经验总结,面试心得体会
- 计算机管理员绩效指标,网络管理员绩效kpi考核标准..doc
- C# asp.net图片拼接方法
- 观察者模式(java)浅析
- qq看点模块测试用例
- vm怎么装vim_安装nginx报-bash: vm: command not found..错误提示vim文本编辑器命令没有安装...
- wind上怎么连接mysql_windows上连接mysql数据库怎么连接
- python递归算法(斐波那契数列,汉诺塔、二分法查找)
- 第4天:python的数据类型、用户交互以及基本运算符
- python程序填空_python练习题-基础巩固-第一周
- 今天这教程难度有点高,反爬虫之跳过淘宝滑块验证!爬虫必会教程
- 制作网页游戏的页面。(开始网页,登录账号网页和进入网页)