Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它

本教程我们将描述如何将PyTorch中定义的模型转换为ONNX格式,然后使用ONNX运行时运行它。

ONNX运行时是一个针对ONNX模型的性能关注引擎,它可以高效地跨多个平台和硬件(Windows、Linux和Mac以及cpu和gpu)进行推理。ONNX运行时已被证明在多个模型上显著提高了性能。

对于本教程,您将需要安装ONNX和ONNX运行时。您可以使用pip install ONNX onnxruntime获得ONNX和ONNX运行时的二进制构建。请注意,ONNX运行时兼容Python 3.6到3.9版本。

注意: 本教程需要PyTorch主分支,它可以按照“https://github.com/pytorch/pytorch#from-source”说明安装。

# Some standard imports
import io
import numpy as npfrom torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

超分辨率是提高图像、视频分辨率的一种方法,广泛应用于图像处理或视频编辑。本教程我们将使用一个小的超分辨率模型。

首先,让我们在PyTorch中创建一个超分辨率模型。该模型使用Shi等人在“使用有效的亚像素卷积神经网络实现实时单图像和视频超分辨率”中描述的有效的亚像素卷积层,以提高图像的分辨率。

该模型将图像YCbCr的Y分量作为输入,并以超分辨率输出放大后的Y分量。

这个模型直接来自PyTorch的例子,没有修改:

# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as initclass SuperResolutionNet(nn.Module):def __init__(self, upscale_factor, inplace=False):super(SuperResolutionNet, self).__init__()self.relu = nn.ReLU(inplace=inplace)self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))self.pixel_shuffle = nn.PixelShuffle(upscale_factor)self._initialize_weights()def forward(self, x):x = self.relu(self.conv1(x))x = self.relu(self.conv2(x))x = self.relu(self.conv3(x))x = self.pixel_shuffle(self.conv4(x))return xdef _initialize_weights(self):init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv4.weight)# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

通常,你会训练这个模型;然而,本教程我们将下载一些预先训练过的力量。请注意,这个模型并不是为了获得良好的准确性而完全训练的,这里仅用于演示目的。

在导出模型之前调用torch_model.eval()或torch_model.train(False)是很重要的,以便将模型转换为推理模式。这是必需的,因为像dropout或batchnorm这样的操作符在推理和训练模式中的行为不同。

# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # just a random number# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))# set the model to inference mode
torch_model.eval()

在PyTorch中导出模型可以通过跟踪或脚本实现。

本教程将使用一个通过跟踪导出的模型作为示例。

要导出模型,我们调用torch.onnx.export()函数。

这将执行模型,记录用于计算输出的操作符的跟踪。

因为export运行这个模型,所以我们需要提供一个输入张量x。只要它的类型和大小正确,其中的值就可以是随机的。

注意,除非指定为动态轴,否则输出的ONNX图中的所有输入尺寸都是固定的。

在本例中,我们使用batch_size 1的输入导出模型,然后在torch.onnx.export()中的dynamic_axes参数中将第一个维度指定为动态的。

因此,导出的模型将接受size [batch_size, 1, 224, 224]的输入,其中batch_size可以是可变的。

要了解关于PyTorch的导出接口的更多细节,请查看torch.onnx文档。

# Input to the model
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)# Export the model
torch.onnx.export(torch_model,               # model being runx,                         # model input (or a tuple for multiple inputs)"super_resolution.onnx",   # where to save the model (can be a file or file-like object)export_params=True,        # store the trained parameter weights inside the model fileopset_version=10,          # the ONNX version to export the model todo_constant_folding=True,  # whether to execute constant folding for optimizationinput_names = ['input'],   # the model's input namesoutput_names = ['output'], # the model's output namesdynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes'output' : {0 : 'batch_size'}})

我们还计算了模型之后的输出torch_out,我们将使用它来验证我们导出的模型在ONNX运行时计算出的值相同。

但是在用ONNX运行时验证模型的输出之前,我们将用ONNX的API检查ONNX模型。

首先,onnx.load("super_resolution.onnx")将加载保存的模型并输出一个onnx.ModelProto结构(用于绑定ML模型的顶层文件/容器格式,更多信息参考onnx.proto documentation文档)。

然后,onnx_checker .check_model(onnx_model)将验证模型的结构,并确认模型有一个有效的模式。

通过检查模型的版本、图的结构、节点及其输入和输出来验证ONNX图的有效性。

import onnxonnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)

现在让我们使用ONNX运行时的Python api来计算输出。

这部分通常可以在单独的进程或另一台机器上完成,但我们将继续在同一进程中进行,以便验证ONNX运行时和PyTorch为网络计算的值是否相同。

为了使用ONNX运行时运行模型,我们需要使用所选的配置参数(这里我们使用默认配置)为模型创建一个推断会话。

创建会话之后,我们使用run() api对模型进行评估。这个调用的输出是一个列表,其中包含由ONNX运行时计算的模型的输出。

import onnxruntimeort_session = onnxruntime.InferenceSession("super_resolution.onnx")def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)print("Exported model has been tested with ONNXRuntime, and the result looks good!")

我们应该看到PyTorch和ONNX运行时的输出在数字上与给定的精度匹配(rtol=1e-03和atol=1e-05)。

作为附注,如果他们不匹配,那么ONNX导出有问题,所以请联系我们。

Running the model on an image using ONNX Runtime

到目前为止,我们已经从PyTorch导出了一个模型,并展示了如何加载它并在ONNX运行时使用一个虚拟张量作为输入来运行它。

在本教程中,我们将使用一张著名的猫的图片,如下图所示:

首先,让我们加载图像,预处理它使用标准的PIL python库。注意,这种预处理是为训练/测试神经网络而处理数据的标准实践。

我们首先调整图像的大小以适应模型输入的大小(224x224)。然后我们将图像分割成Y、Cb和Cr三个分量。

这些分量代表灰度图像(Y),以及色度分量蓝差(Cb)和红差(Cr)。对于人眼来说,Y分量更敏感,我们感兴趣的是我们要转换的这个分量。

在提取Y分量后,我们将它转换成一个张量,这将是我们模型的输入。

from PIL import Image
import torchvision.transforms as transformsimg = Image.open("./_static/img/cat.jpg")resize = transforms.Resize([224, 224])
img = resize(img)img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()to_tensor = transforms.ToTensor()
img_y = to_tensor(img_y)
img_y.unsqueeze_(0)

现在,作为下一步,让我们使用表示灰度调整后的猫图像的张量,并在ONNX运行时中运行超分辨率模型,如前所述。

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]

此时,模型的输出是一个张量。现在,我们将处理模型的输出,从输出张量中构造出最终的输出图像,并保存图像。后处理步骤在这里采用了超分辨率模型的PyTorch实现(https://github.com/pytorch/examples/blob/master/super_resolution/super_resolve.py)。

img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')# get the output image follow post-processing step from PyTorch implementation
final_img = Image.merge("YCbCr", [img_out_y,img_cb.resize(img_out_y.size, Image.BICUBIC),img_cr.resize(img_out_y.size, Image.BICUBIC),]).convert("RGB")# Save the image, we will compare this with the output image from mobile device
final_img.save("./_static/img/cat_superres_with_ort.jpg")

ONNX运行时是一个跨平台引擎,你可以在多个平台上运行它,包括cpu和gpu。

ONNX运行时也可以部署到云上,使用Azure机器学习服务进行模型推理。更多的信息在这里。

这里有更多关于ONNX运行时性能的信息。

关于ONNX运行时的更多信息,请点击这里。

脚本的总运行时间:(0分钟0.000秒)

Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它相关推荐

  1. ONNX系列三 --- 使用ONNX使PyTorch AI模型可移植

    目录 PyTorch简介 导入转换器 快速浏览模型 将PyTorch模型转换为ONNX 摘要和后续步骤 参考文献 下载源547.1 KB 系列文章列表如下: ONNX系列一 --- 带有ONNX的便携 ...

  2. PyTorch导出JIT模型并用C++ API libtorch调用

    PyTorch导出JIT模型并用C++ API libtorch调用 本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型 ...

  3. [LibTorch] C++ 调用 PyTorch 导出的模型

    参考文章 C++部署pytorch模型 利用LibTorch部署PyTorch模型 官方文档 问题 pytorch 的神经网络模型有很多,但 libtorch 就特别少.现在面临的问题是要在 C++ ...

  4. pytorch导出onnx格式模型时,不固定输入输出维度

    Pytorch模型转换为onnx格式模型后,模型的输入.输出维度跟转换模型时,用的dummy_input的维度有关系,属于固定尺寸的输入与输出.可以采用以下代码修改onnx模型的输入输出维度: imp ...

  5. PyTorch模型部署:pth转onnx跨框架部署详解+代码

    文章目录 引言 基础概念 onnx:跨框架的模型表达标准 onnxruntime:部署模型的推理引擎 示例代码 0)安装onnx和onnxruntime 1)pytorch模型转onnx模型 2)on ...

  6. OpenCV转换PyTorch分类模型并使用OpenCV Python启动

    OpenCV转换PyTorch分类模型并使用OpenCV Python启动 转换PyTorch分类模型并使用OpenCV Python启动 目标 介绍 要求 实践 模型转换管道 模型评估 评估模式 测 ...

  7. python做的游戏可以导出吗_Python for RenderDoc批量导出模型和贴图

    故事背景: 美术那里有需求,需要别人游戏的模型,来借鉴一下,问我是否有工具可以一键导出模型.我就搜索了一下RenderDoc批量导出图片,结果搜到了用C++改RenderDoc源码的文章.让Rende ...

  8. pytorch1.0 用torch script导出模型

    python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...

  9. PyTorch 1.0 中文官方教程:ONNX 现场演示教程

    译者:冯宝宝 本教程将向您展示如何使用ONNX将已从PyTorch导出的神经模型传输模型转换为Apple CoreML格式.这将允许您在Apple设备上轻松运行深度学习模型,在这种情况下,可以从摄像机 ...

最新文章

  1. mxnet 配置gpu
  2. ping32终端安全管理方案_RFID固定资产管理解决方案,RFID资产管理,RFID手持终端
  3. 今年双11,阿里业务100%上云
  4. JavaScript状态2018
  5. cp105b linux 驱动,富士施乐 DocuPrint CP105b驱动
  6. Java导入导出Excel控件简介
  7. Android集成讯飞语音、百度语音、阿里语音识别
  8. Python第八课:函数(def)
  9. 三、 CSS3流星雨划过动画特效
  10. python打印不换行_python打印后如何不换行
  11. 基于SpringBoot的社区综合治理系统设计与实现
  12. 如何成为有效学习的高手 学习笔记
  13. Ubuntu系统垃圾清理
  14. Linux 流量控制TC
  15. SNIPER—— SNIP的实战版本 (目标检测)(two-stage)(深度学习)(Arvix 2018)
  16. 记录Google被和谐的日子
  17. java.lang.reflect.InvocationTargetException异常处理方法
  18. 嵌入式软件学习路线(入门)
  19. Android 11 踩雷之 App无法唤起相机
  20. agilent仪表的GPIB接口

热门文章

  1. oracle增量脚本(记录)创建触发器监控对一张表的增删改
  2. ZAM 3D 制作3D动画字幕 用于Xaml导出
  3. 实验5 —— 编写、调试具有多个段的程序
  4. django实现搜索功能
  5. 关于数论【莫比乌斯反演】
  6. C++中的未定义的行为
  7. 本地邮件系统的安装及配置
  8. linux之service命令
  9. dedecms调用日期格式化形式大全
  10. Python网路请求(GET示例)