使用 ONNX 模型做预测
导包
import io
import numpy as npfrom torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
定义模型
import torch.nn as nn
import torch.nn.init as initclass SuperResolutionNet(nn.Module):def __init__(self, upscale_factor, inplace=False):super(SuperResolutionNet, self).__init__()self.relu = nn.ReLU(inplace=inplace)self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))self.pixel_shuffle = nn.PixelShuffle(upscale_factor)self._initialize_weights()def forward(self, x):x = self.relu(self.conv1(x))x = self.relu(self.conv2(x))x = self.relu(self.conv3(x))x = self.pixel_shuffle(self.conv4(x))return xdef _initialize_weights(self):init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv4.weight)# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
如果出现
ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
则
conda install -n base -c conda-forge widgetsnbextension
conda install -c conda-forge ipywidgets
导入 pre-trained 权重
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1 # just a random number# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))# set the model to inference mode
torch_model.eval()
导出 为 ONNX
# Input to the model
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)# Export the model
torch.onnx.export(torch_model, # model being runx, # model input (or a tuple for multiple inputs)"super_resolution.onnx", # where to save the model (can be a file or file-like object)export_params=True, # store the trained parameter weights inside the model fileopset_version=10, # the ONNX version to export the model todo_constant_folding=True, # whether to execute constant folding for optimizationinput_names = ['input'], # the model's input namesoutput_names = ['output'], # the model's output namesdynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes'output' : {0 : 'batch_size'}})
试着导入看看,(跟预测无关,可以跳过)
import onnxonnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)
重头戏来了,使用 onnxruntime 进行预测
import onnxruntimeort_session = onnxruntime.InferenceSession("super_resolution.onnx")def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
对比导出前与,导出后模型预测的结果,没有报错,就是表示onnx预测结果与之前基本一致。
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
参考
- Exporting a model from pytorch to onnx and running it using onnx runtime
- Load and predict with ONNX Runtime and a very simple model
使用 ONNX 模型做预测相关推荐
- tensorflow中保存模型、加载模型做预测(不需要再定义网络结构)
下面用一个线下回归模型来记载保存模型.加载模型做预测 参考文章: http://blog.csdn.net/thriving_fcl/article/details/71423039 训练一个线下回归 ...
- 如何在Java应用里集成Spark MLlib训练好的模型做预测
前言 昨天媛媛说,你是不是很久没写博客了.我说上一篇1.26号,昨天3.26号,刚好两个月,心中也略微有些愧疚.今天正好有个好朋友问,怎么在Java应用里集成Spark MLlib训练好的模型.在St ...
- c++调用mxnet模型做预测
python在深度学习领域很火,做实验用python很舒服,但是生产环境下可能还是需要c/c++. 那么问题来了, mxnet训练出来的模型如何在c/c++下调用? 以下是一些填坑的经验分享一下 mx ...
- 如何用ARIMA模型做预测?
1.作用 ARIMA模型的全称叫做自回归移动平均模型,是统计模型中最常见的一种用来进行时间序列预测的模型. 2.输入输出描述 输入:特征序列为1个时间序列数据定量变量 输出:未来N天的预测值 3.学习 ...
- 用lstm模型做预测_使用LSTM深度学习模型进行温度的时间序列单步和多步预测
本文的目的是提供代码示例,并解释使用python和TensorFlow建模时间序列数据的思路. 本文展示了如何进行多步预测并在模型中使用多个特征. 本文的简单版本是,使用过去48小时的数据和对未来1小 ...
- 用lstm模型做预测_深度学习模型 CNN+LSTM 预测收盘价
-- 本篇文章 by HeartBearting 上一篇浏览量很大,感谢各位的关注! 能够在这里分享一些实验,一起领略 数据科学之美,也很开心. 以后,这个实验的模型会不断深化. 之后,也会分享一些 ...
- joblib 读取模型后对单条数据做预测并解决Reshape your data either using array报错
用joblib读取模型并对模型做预测: # 这里假设模型已经训练完成,就不训练了 import joblib model = joblib.load("kn_model.m") p ...
- garch dcc用matlab,用matlab工具箱怎么对garch模型做...
对garch模型做预测可以用matlab自带的garchfit()函数,该函数主要用于估计ARMAX / GARCH模型参数. garchfit()函数使用格式: [Coeff,Errors,LLF, ...
- 阿里天池竞赛 A股上市公司营收预测 使用LSTM模型做时序预测
参赛结束了,最后结果一百多名,先把清洗好的数据和预测算法文件记录下来. 使用的完全代码和数据 https://download.csdn.net/download/infent/10693927 代码 ...
最新文章
- python控制git版本库
- Android性能优化之运算篇(二)
- python数组改变维数
- (翻译) MongoDB(7) 安装MongoDB
- 淡入淡出效果 || 高亮显示案例
- java 获取微信公众号code为空
- Linux版本配置环境变量,如何linux环境下配置环境变量过程图解
- SQL中in参数在存储过程中传递及使用的方法
- Eclipse快速创建Bottom Up类型的Web Service服务端
- 文件的属性 计算机知识,计算机基础知识文件的属性(二)
- 1030: [JSOI2007]文本生成器 ac自动机+dp
- Arrays.asList( ) 返回一个特殊的“ArrayList”
- 【BZOJ1150】数据备份(堆/优先队列)
- 数据库概述之数据库设计实例分析
- coreseek mysql_coreseek,php,mysql全文检索部署(一)-阿里云开发者社区
- 远程服务器连接计算机和用户名填写,windos系统服务器:添加远程连接用户名方法...
- python去除图片背景(透明色)
- 公司新来了个P8员工,然后内卷了...
- 计算机电脑无法充电,iphone连接电脑无法充电怎么办
- Redis大厂面试20题