导包

import io
import numpy as npfrom torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

定义模型

import torch.nn as nn
import torch.nn.init as initclass SuperResolutionNet(nn.Module):def __init__(self, upscale_factor, inplace=False):super(SuperResolutionNet, self).__init__()self.relu = nn.ReLU(inplace=inplace)self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))self.pixel_shuffle = nn.PixelShuffle(upscale_factor)self._initialize_weights()def forward(self, x):x = self.relu(self.conv1(x))x = self.relu(self.conv2(x))x = self.relu(self.conv3(x))x = self.pixel_shuffle(self.conv4(x))return xdef _initialize_weights(self):init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv4.weight)# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

如果出现

ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

conda install -n base -c conda-forge widgetsnbextension
conda install -c conda-forge ipywidgets

导入 pre-trained 权重

model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # just a random number# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))# set the model to inference mode
torch_model.eval()

导出 为 ONNX

# Input to the model
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)# Export the model
torch.onnx.export(torch_model,               # model being runx,                         # model input (or a tuple for multiple inputs)"super_resolution.onnx",   # where to save the model (can be a file or file-like object)export_params=True,        # store the trained parameter weights inside the model fileopset_version=10,          # the ONNX version to export the model todo_constant_folding=True,  # whether to execute constant folding for optimizationinput_names = ['input'],   # the model's input namesoutput_names = ['output'], # the model's output namesdynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes'output' : {0 : 'batch_size'}})

试着导入看看,(跟预测无关,可以跳过)

import onnxonnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)

重头戏来了,使用 onnxruntime 进行预测

import onnxruntimeort_session = onnxruntime.InferenceSession("super_resolution.onnx")def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

对比导出前与,导出后模型预测的结果,没有报错,就是表示onnx预测结果与之前基本一致。

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

参考

  1. Exporting a model from pytorch to onnx and running it using onnx runtime
  2. Load and predict with ONNX Runtime and a very simple model

使用 ONNX 模型做预测相关推荐

  1. tensorflow中保存模型、加载模型做预测(不需要再定义网络结构)

    下面用一个线下回归模型来记载保存模型.加载模型做预测 参考文章: http://blog.csdn.net/thriving_fcl/article/details/71423039 训练一个线下回归 ...

  2. 如何在Java应用里集成Spark MLlib训练好的模型做预测

    前言 昨天媛媛说,你是不是很久没写博客了.我说上一篇1.26号,昨天3.26号,刚好两个月,心中也略微有些愧疚.今天正好有个好朋友问,怎么在Java应用里集成Spark MLlib训练好的模型.在St ...

  3. c++调用mxnet模型做预测

    python在深度学习领域很火,做实验用python很舒服,但是生产环境下可能还是需要c/c++. 那么问题来了, mxnet训练出来的模型如何在c/c++下调用? 以下是一些填坑的经验分享一下 mx ...

  4. 如何用ARIMA模型做预测?

    1.作用 ARIMA模型的全称叫做自回归移动平均模型,是统计模型中最常见的一种用来进行时间序列预测的模型. 2.输入输出描述 输入:特征序列为1个时间序列数据定量变量 输出:未来N天的预测值 3.学习 ...

  5. 用lstm模型做预测_使用LSTM深度学习模型进行温度的时间序列单步和多步预测

    本文的目的是提供代码示例,并解释使用python和TensorFlow建模时间序列数据的思路. 本文展示了如何进行多步预测并在模型中使用多个特征. 本文的简单版本是,使用过去48小时的数据和对未来1小 ...

  6. 用lstm模型做预测_深度学习模型 CNN+LSTM 预测收盘价

    -- 本篇文章 by HeartBearting 上一篇浏览量很大,感谢各位的关注! 能够在这里分享一些实验,一起领略 数据科学之美,也很开心. 以后,这个实验的模型会不断深化. 之后,也会分享一些 ...

  7. joblib 读取模型后对单条数据做预测并解决Reshape your data either using array报错

    用joblib读取模型并对模型做预测: # 这里假设模型已经训练完成,就不训练了 import joblib model = joblib.load("kn_model.m") p ...

  8. garch dcc用matlab,用matlab工具箱怎么对garch模型做...

    对garch模型做预测可以用matlab自带的garchfit()函数,该函数主要用于估计ARMAX / GARCH模型参数. garchfit()函数使用格式: [Coeff,Errors,LLF, ...

  9. 阿里天池竞赛 A股上市公司营收预测 使用LSTM模型做时序预测

    参赛结束了,最后结果一百多名,先把清洗好的数据和预测算法文件记录下来. 使用的完全代码和数据 https://download.csdn.net/download/infent/10693927 代码 ...

最新文章

  1. python控制git版本库
  2. Android性能优化之运算篇(二)
  3. python数组改变维数
  4. (翻译) MongoDB(7) 安装MongoDB
  5. 淡入淡出效果 || 高亮显示案例
  6. java 获取微信公众号code为空
  7. Linux版本配置环境变量,如何linux环境下配置环境变量过程图解
  8. SQL中in参数在存储过程中传递及使用的方法
  9. Eclipse快速创建Bottom Up类型的Web Service服务端
  10. 文件的属性 计算机知识,计算机基础知识文件的属性(二)
  11. 1030: [JSOI2007]文本生成器 ac自动机+dp
  12. Arrays.asList( ) 返回一个特殊的“ArrayList”
  13. 【BZOJ1150】数据备份(堆/优先队列)
  14. 数据库概述之数据库设计实例分析
  15. coreseek mysql_coreseek,php,mysql全文检索部署(一)-阿里云开发者社区
  16. 远程服务器连接计算机和用户名填写,windos系统服务器:添加远程连接用户名方法...
  17. python去除图片背景(透明色)
  18. 公司新来了个P8员工,然后内卷了...
  19. 计算机电脑无法充电,iphone连接电脑无法充电怎么办
  20. Redis大厂面试20题

热门文章

  1. Javascript - 面向对象
  2. 【加密解密】密码学学习
  3. 第二章 Flask——Flask中的request
  4. win7远程连接开启方法
  5. 最少拦截系统,简单dp,(学长说这是贪心?!。。。。。。也是醉了)
  6. javascript函数执行前期变量环境初始化过程
  7. java 线程间的通讯(升级版)
  8. 《设计模式详解》创建型模式 - 单例模式
  9. 软件设计师18-系统开发和运行01
  10. [转]关于Python里的类型注解