pytorch模型转ONNX转TensorRT,模型转换和推理部署
一、pth模型转ONNX
import os
import sys
import torch
import numpy as npfrom feat.model import ResNet # 导入自己的模型类def load_checkpoint(checkpoint_file, model):"""Loads the checkpoint from the given file."""err_str = "Checkpoint '{}' not found"assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)checkpoint = torch.load(checkpoint_file, map_location="cpu")return checkpoint["epoch"]if __name__ == '__main__':os.environ['CUDA_VISIBLE_DEVICES']='0' # 设置运行显卡号model_filename='resnet_epoch_17.pyth'# init modelmodel = ResNet()load_checkpoint(model_filename, model)model = model.cuda()model.eval()onnx_name = 'resnet.onnx' # 输出onnx文件example = torch.randn((1,3,224,224)) # 模型输入大小example = example.cuda()input_names = ["input"]output_names = ["outputs"]dynamic_axes = {"input": {0: "batch_size"}, "outputs": {0: "batch_size"}}# 模型转换并保存torch.onnx.export(model, example,onnx_name, opset_version=12, input_names=input_names, output_names=output_names, dynamic_axes=None)
二、测试ONNX模型精度
import os
import sys
import torch
import numpy as np
import onnxruntime
import timeif __name__ == '__main__':os.environ['CUDA_VISIBLE_DEVICES']='0' # 设置运行显卡号model_filename='resnet_epoch_17.pyth'# init modelmodel = ResNet()load_checkpoint(model_filename, model)model = model.cuda()model.eval()session = onnxruntime.InferenceSession(onnx_name,providers=['CUDAExecutionProvider'])img = np.random.randn(1,3,224,224).astype(np.float32) # 随机输出t1 = time.time()onnx_preds = session.run(None, {"input": img})print("onnx preds result: ", onnx_preds)t2 = time.time()pth_preds = model(torch.from_numpy(img).cuda())print("pth preds result: ", pth_preds)t3 = time.time()
对比打印结果,确认结果保持一致
onnx preds res: [array([[-0.13128008, 0.04037811, 0.0529038 , 0.101323 , -0.03352938, [43/1903]0.03099938, 0.06380229, -0.03544223, -0.03368076, 0.06361518, -0.00668521, -0.01996843, -0.0132075 , -0.03448019, 0.17793381, 0.08131739, 0.10232763, -0.09122676, 0.01173838, 0.03181053, -0.05899123, 0.01569226, -0.04734752, -0.12551421, 0.00686131, -0.00749457, -0.03729884, 0.05349742, 0.0304895 , 0.02956274, 0.00393172, 0.00196273, 0.01296113, -0.03985897, -0.06289426, -0.0825834 , -0.28903952, 0.02842386, -0.1718263 , -0.05555207, -0.03707219, 0.10904352, 0.06582819, 0.04960179, 0.01508415, 0.05469472, 0.28663486, 0.1183752 , -0.06070469, -0.05200525, -0.03477468, -0.06193898, -0.04432139, 0.0843045 , -0.12080704, 0.00163073, -0.08544722, 0.11994477, 0.02619292, 0.05066012, -0.00332941, -0.1488586 , 0.07936171, 0.06203181, -0.0645356 , -0.07661135, -0.05883927, -0.00459472, -0.06721105, -0.02880175, -0.00337263, -0.00927516, 0.03289868, 0.10054352, -0.09545278, -0.0216963 , 0.11413048, -0.04580398, 0.02614305, -0.08269466, 0.01835637, 0.17654261, 0.0573773 , -0.06440263, 0.01176349, 0.00998674, 0.02840159, 0.14086637, -0.02473863, 0.05228964, -0.03329878, -0.02751228, -0.04788758, 0.1546051 , 0.05838795, -0.02351469, -0.01315547, -0.13732813, -0.08146078, 0.01943143, -0.08991284, 0.14222968, -0.14729632, 0.24547395, -0.05293949, 0.04446511, 0.05436133, -0.09403729, -0.0900671 , 0.04516568, 0.10035874, -0.03281724, 0.19480802, -0.11344203, -0.02487336, -0.08126407, -0.00491623, 0.04313428, -0.10474856, -0.11427435, -0.01765379, -0.04613522, 0.08338863, 0.00564523, 0.14067101, 0.05428562, 0.12530491, -0.2503076 ]], dtype=float32)]
pth preds res: tensor([[-0.1313, 0.0404, 0.0529, 0.1013, -0.0335, 0.0310, 0.0638, -0.0354, -0.0337, 0.0636, -0.0067, -0.0200, -0.0132, -0.0345, 0.1779, 0.0813, 0.1023, -0.0912, 0.0117, 0.0318, -0.0590, 0.0157, -0.0473, -0.1255, 0.0069, -0.0075, -0.0373, 0.0535, 0.0305, 0.0296, 0.0039, 0.0020, 0.0130, -0.0399, -0.0629, -0.0826, -0.2890, 0.0284, -0.1718, -0.0556, -0.0371, 0.1090, 0.0658, 0.0496, 0.0151, 0.0547, 0.2866, 0.1184, -0.0607, -0.0520, -0.0348, -0.0619, -0.0443, 0.0843, -0.1208, 0.0016, -0.0854, 0.1199, 0.0262, 0.0507, -0.0033, -0.1489, 0.0794, 0.0620, -0.0645, -0.0766, -0.0588, -0.0046, -0.0672, -0.0288, -0.0034, -0.0093, 0.0329, 0.1005, -0.0955, -0.0217, 0.1141, -0.0458, 0.0261, -0.0827, 0.0184, 0.1765, 0.0574, -0.0644, 0.0118, 0.0100, 0.0284, 0.1409, -0.0247, 0.0523, -0.0333, -0.0275, -0.0479, 0.1546, 0.0584, -0.0235, -0.0132, -0.1373, -0.0815, 0.0194, -0.0899, 0.1422, -0.1473, 0.2455, -0.0529, 0.0445, 0.0544, -0.0940, -0.0901, 0.0452, 0.1004, -0.0328, 0.1948, -0.1134, -0.0249, -0.0813, -0.0049, 0.0431, -0.1047, -0.1143, -0.0177, -0.0461, 0.0834, 0.0056, 0.1407, 0.0543, 0.1253, -0.2503]], device='cuda:0', grad_fn=<DivBackward0>)
onnx cost time: 0.0062367916107177734 pth cost time: 0.030622243881225586
三、ONNX转TensorRT
import os
import tensorrt as trtTRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt_runtime = trt.Runtime(TRT_LOGGER)BASE_DIR = os.path.dirname(os.path.abspath(__file__))EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)os.environ['CUDA_VISIBLE_DEVICES'] = '2'def get_engine(input_shape, onnx_file_path = "", engine_file_path=""):"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""def build_engine():"""Takes an ONNX file and creates a TensorRT engine to run inference with"""with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as config:# builder.max_workspace_size = 1 << 32 # 256MiBsconfig.max_workspace_size = 1 << 33 # 1024MB# config.set_flag(trt.BuilderFlag.FP16) # 使用Fp16精度,如果使用FP32需要屏蔽这一句。builder.max_batch_size = 1# Parse model fileif not os.path.exists(onnx_file_path):print('ONNX file {} not found, please run torch2onnx first to generate it.'.format(onnx_file_path))exit(0)print('Loading ONNX file from path {}...'.format(onnx_file_path))with open(onnx_file_path, 'rb') as model:print('Beginning ONNX file parsing')if not parser.parse(model.read()):print ('ERROR: Failed to parse the ONNX file.')for error in range(parser.num_errors):print (parser.get_error(error))return None# The actual yolov3.onnx is generated with batch size 64. Reshape input to batch size 1network.get_input(0).shape = input_shapeprint('Completed parsing of ONNX file')print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))# config = trt.IBuilderConfig(max_workspace_size = 1 << 32)# config.engine = builder.build_engine(network, config)print("Completed creating Engine")with open(engine_file_path, "wb") as f:f.write(engine.serialize())return engineif os.path.exists(engine_file_path):# If a serialized engine exists, use it instead of building an engine.print("Reading engine from file {}".format(engine_file_path))with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:return runtime.deserialize_cuda_engine(f.read())else:return build_engine()if __name__ == '__main__':onnx_file = 'resnet.onnx'engin_file = 'resnet.engine'input_shape = [1, 3, 224, 224]get_engine(input_shape, onnx_file, engin_file)
四、测试TensorRT模型精度
import os
import sys
import cv2
import copy
import torch
import numpy as np
import time
import onnxruntime
import pycuda.driver as cuda
import tensorrt as trtos.environ['CUDA_VISIBLE_DEVICES']='3'
TRT_LOGGER = trt.Logger()
import trt_common
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
if sys.getdefaultencoding() != 'utf-8':reload(sys)sys.setdefaultencoding('utf-8')# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem(object):def __init__(self, host_mem, device_mem):self.host = host_memself.device = device_memdef __str__(self):return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)def __repr__(self):return self.__str__()def get_engine(engine_file_path):with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:return runtime.deserialize_cuda_engine(f.read())# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
def allocate_buffers(engine):inputs = []outputs = []bindings = []stream = cuda.Stream()for binding in engine:size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_sizedtype = trt.nptype(engine.get_binding_dtype(binding))# Allocate host and device buffershost_mem = cuda.pagelocked_empty(size, dtype)device_mem = cuda.mem_alloc(host_mem.nbytes)# Append the device buffer to device bindings.bindings.append(int(device_mem))# Append to the appropriate list.if engine.binding_is_input(binding):inputs.append(HostDeviceMem(host_mem, device_mem))else:outputs.append(HostDeviceMem(host_mem, device_mem))return inputs, outputs, bindings, stream# This function is generalized for multiple inputs/outputs for full dimension networks.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference_v2(context, bindings, inputs, outputs, stream):# Transfer input data to the GPU.[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]# Run inference.context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)# Transfer predictions back from the GPU.[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]# Synchronize the streamstream.synchronize()# Return only the host outputs.return [out.host for out in outputs]if __name__ == '__main__':os.environ['CUDA_VISIBLE_DEVICES']='3'onnx_name = 'resnet.onnx'trt_name = 'resnet.engine'session = onnxruntime.InferenceSession(onnx_name,providers=['CUDAExecutionProvider'])import pycuda.autoprimaryctxengine = get_engine(trt_name)context = engine.create_execution_context()inputs, outputs, bindings, stream = allocate_buffers(engine)img = cv2.imread('test.jpg')img = cv2.resize(img, (224,224))img = img.transpose([2,0,1]).astype(np.float32)img = np.expand_dims(img, axis=0)t1 = time.time()onnx_preds = session.run(None, {"input": img})#print("onnx_preds: ", onnx_preds)t2 = time.time()inputs[0].host = np.ascontiguousarray(img)trt_outputs = do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)data = copy.deepcopy(trt_outputs[0])#print("preds: ", data)t3 = time.time()print("onnx: ", t2-t1, " trt: ", t3-t2)
五、ERROR
error1
ERROR: Failed to parse the ONNX file.
In node 84 (importConv): UNSUPPORTED_NODE: Assertion failed: inputs.at(2).is_weights() && "The bias tensor is required to be an initializer for the Conv operator."
solution:
pip install onnx-simplifier
通过simplify重新保存ONNX模型
import onnx
from onnxsim import simplifyonnx_model = onnx.load('resnet.onnx')
model_simp, check = simplify(onnx_model)onnx.save(model_simp, 'resnet_sim.onnx')
error2
ValueError: ndarray is not contiguous
solution:
数组不连续,使用np.ascontiguousarray(img) 处理数组
inputs[0].host = np.ascontiguousarray(img)
error3
Error Code 1: Myelin (Compiled against cuBLASLt 11.11.3.0 but running against cuBLASLt 11.4.1.0.)
solution:
tensorrt 和 torch同时使用调用了不同版本的libmyelin.so,不同同时使用。tensorrt和onnxruntime同时使用也会发生。
pytorch模型转ONNX转TensorRT,模型转换和推理部署相关推荐
- 【TensorRT】PyTorch模型转换为ONNX及TensorRT模型
文章目录 1. PyTorch模型转TensorRT模型流程 2. PyTorch模型转ONNX模型 3. ONNX模型转TensorRT模型 3.1 TensorRT安装 3.2 将ONNX模型转换 ...
- 1、pth转onnx模型、onnx转tensorrt模型、python中使用tensorrt进行加速推理(全网最全,不信你打我)
本文向所有亲们介绍在python当中配置tensorrt环境.使用tensorrt环境进行推理的教程,主要分为两大部分,第一部分环境配置,第二部分前向推理. 第一部分 环境配置 第一步:检查你的系统类 ...
- yolov5 pt->onnx->om yolov5模型转onnx转om模型转换
yolov5 pt->onnx->om yolov5-6.1版本 models/yolo.py Detect函数修改 class Detect(nn.Module):def forward ...
- 【地平线开发板 模型转换】将pytorch生成的onnx模型转换成.bin模型
文章目录 1 获取onnx模型 2 启动docker容器 3 onnx模型检查 3.1 为什么要检查? 3.2 如何操作 4 图像数据预处理 4.1 一些问题的思考 4.2 图片挑选与放置 4.2 使 ...
- 【Pytorch基础教程33】算法模型部署(MLFlow/ONNX/tf serving)
内容概况 服务器上训练好模型后,需要将模型部署到线上,接受请求.完成推理并且返回结果. 保存模型结构和参数最简单的是torch.save保存为checkpoint,但一般用于训练时记录过程,训练中断可 ...
- ONNX系列四 --- 使用ONNX使TensorFlow模型可移植
目录 TensorFlow简介 安装和导入转换器 快速浏览模型 将TensorFlow模型转换为ONNX 摘要和后续步骤 参考文献 下载源547.1 KB 系列文章列表如下: ONNX系列一 --- ...
- ONNX系列二 --- 使用ONNX使Keras模型可移植
目录 Keras简介 快速浏览模型 安装和导入转换器 将Keras模型转换为ONNX 摘要和后续步骤 参考文献 下载源547.1 KB 系列文章列表如下: ONNX系列一 --- 带有ONNX的便携式 ...
- 【模型加速】自定义TensorRT NMS3D插件(1)
需求是这样的,在做PointPillars模型的加速的时候我注意到网络的检测头部分小型操作很多,加速效果不明显.此外,3D检测模型的NMS部分通常是作为后处理的一部分来单独实现,TensorRT并没有 ...
- 模型转换:pytorch模型转onnx, onnx转tensorflow, tensorflow转tflite
文章目录 软件版本: pytorch模型转onnx onnx模型转tensorflow tensorflow模型转tflite 软件版本: tensorflow 2.3.1 pytorch 1.6.0 ...
最新文章
- 九度-1463-招聘会
- python +java 用socket在局域网进行图片上传给springboot后端并进行前端访问
- HDbaseT 高清传输更简单——只需一根网线
- .NET Core请求控制器Action方法正确匹配,但为何404?
- C++设计模式-AbstractFactory抽象工厂模式
- 平扫加重建什么意思_在这款“奸商模拟器”里,帮助战场老兵重建家园吧!
- Kudu : kudu 主键相关
- WPF DataGrid 对行中单元格的访问
- 科技无障碍盛会举办,人工智能和创新成为高频词!
- JavaScript:正则表达式 分组
- Android编码规范05
- 2021年实验中学高考成绩查询,2021年北京高中排名,高中高考成绩排名一览表
- AutoCAD2011,2020安装教程
- 笔记本计算机的功率一般多少,笔记本电脑功率是多少 怎么看笔记本功率多大...
- 正点原子IIC例程讲解笔记(三)——24cxx.c中函数理解
- python统计词频瓦尔登湖_点评《瓦尔登湖》
- 软件工程(数据流图例题详解)
- (八十一)探索hidl-gen使用及IWifi.hal 实现
- xz2p更新android 9,索尼 XZ2 迎来安卓 9.0 更新,但少了全面屏手势
- 游戏开发论坛_国内游戏开发站点与论坛
热门文章
- pandas中数据的选取
- 信捷XC PLC与3台施耐德ATV12变频器通讯程序 信捷XC PLC与3台施耐德ATV12变频器通讯
- 使用express框架时,用MongoDB存放session时,出现错误,解决方法
- http://noi.openjudge.cn/ch0107/13/
- 光伏速进时代下的“建维”一体化路线
- 计算机技术中的图像融合,刘少稳:图像融合技术(CARTO-Merge)的应用方法_复旦大学附属中山医院_ 刘少稳 林佳雄 _365心血管网...
- plsql配置连接oracle,不安装oracle客户端
- [RK3399][Android7.1] 4通道ADC芯片ES7210驱动源码
- TRON COIN智能合约审计报告
- Sonar Qube连续代码质量管理(三)sonar-scanner-3.3.0.1492在Windows环境下安装部署和代码检查使用