本文只有 tensorrt python部分涉动态分辨率设置,没有c++的。

目录

pytorch转onnx:

onnx转tensorrt:

python tensorrt推理:


知乎博客也可以参考:

tensorrt动态输入(Dynamic shapes) - 知乎

记录此贴的原因有两个:1.肯定也有很多人需要 。2.就我搜索的帖子没一个讲的明明白白的,官方文档也不利索,需要连蒙带猜。话不多少,直接上代码。

以pytorch转onnx转tensorrt为例,动态shape是图像的长宽。

pytorch转onnx:

def export_onnx(model,image_shape,onnx_path, batch_size=1):x,y=image_shapeimg = torch.zeros((batch_size, 3, x, y))dynamic_onnx=Trueif dynamic_onnx:dynamic_ax = {'input_1' : {2 : 'image_height',3:'image_wdith'},   'output_1' : {2 : 'image_height',3:'image_wdith'}}torch.onnx.export(model, (img), onnx_path, input_names=["input_1"], output_names=["output_1"], verbose=False, opset_version=11,dynamic_axes=dynamic_ax)else:torch.onnx.export(model, (img), onnx_path, input_names=["input_1"], output_names=["output_1"], verbose=False, opset_version=11)

onnx转tensorrt:

按照nvidia官方文档对dynamic shape的定义,所谓动态,无非是定义engine的时候不指定,用-1代替,在推理的时候再确定,因此建立engine 和推理部分的代码都需要修改。

建立engine时,从onnx读取的network,本身的输入输出就是dynamic shapes,只需要增加optimization_profile来确定一下输入的尺寸范围。

def build_engine(onnx_path, using_half,engine_file,dynamic_input=True):trt.init_libnvinfer_plugins(None, '')with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:builder.max_batch_size = 1 # always 1 for explicit batchconfig = builder.create_builder_config()config.max_workspace_size = GiB(1)if using_half:config.set_flag(trt.BuilderFlag.FP16)# Load the Onnx model and parse it in order to populate the TensorRT network.with open(onnx_path, 'rb') as model:if not parser.parse(model.read()):print ('ERROR: Failed to parse the ONNX file.')for error in range(parser.num_errors):print (parser.get_error(error))return None##增加部分if dynamic_input:profile = builder.create_optimization_profile();profile.set_shape("input_1", (1,3,512,512), (1,3,1024,1024), (1,3,1600,1600)) config.add_optimization_profile(profile)#加上一个sigmoid层previous_output = network.get_output(0)network.unmark_output(previous_output)sigmoid_layer=network.add_activation(previous_output,trt.ActivationType.SIGMOID)network.mark_output(sigmoid_layer.get_output(0))return builder.build_engine(network, config) 

python tensorrt推理:

进行推理时,有个不小的暗坑,按照我之前的理解,既然动态输入,我只需要在给输入分配合适的缓存,然后不管什么尺寸直接推理就行了呗,事实证明还是年轻了。按照官方文档的提示,在推理的时候一定要增加这么一行,context.active_optimization_profile = 0,来选择对应的optimization_profile,ok,我加了,但是还是报错了,原因是我们既然在定义engine的时候没有定义输入尺寸,那么在推理的时候就需要根据实际的输入定义好输入尺寸。

def profile_trt(engine, imagepath,batch_size):assert(engine is not None)  input_image,input_shape=preprocess_image(imagepath)segment_inputs, segment_outputs, segment_bindings = allocate_buffers(engine, True,input_shape)stream = cuda.Stream()    with engine.create_execution_context() as context:context.active_optimization_profile = 0#增加部分origin_inputshape=context.get_binding_shape(0)#增加部分if (origin_inputshape[-1]==-1):origin_inputshape[-2],origin_inputshape[-1]=(input_shape)context.set_binding_shape(0,(origin_inputshape))input_img_array = np.array([input_image] * batch_size)img = torch.from_numpy(input_img_array).float().numpy()segment_inputs[0].host = img[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in segment_inputs]#Copy from the Python buffer src to the device pointer dest (an int or a DeviceAllocation) asynchronously,stream.synchronize()#Wait for all activity on this stream to cease, then return.context.execute_async(bindings=segment_bindings, stream_handle=stream.handle)#Asynchronously execute inference on a batch. stream.synchronize()[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in segment_outputs]#Copy from the device pointer src (an int or a DeviceAllocation) to the Python buffer dest asynchronouslystream.synchronize()results = np.array(segment_outputs[0].host).reshape(batch_size, input_shape[0],input_shape[1])    return results.transpose(1,2,0)

只是短短几行代码,结果折腾了一整天,不过好在解决了动态输入的问题,不需要再写一堆乱七八糟的代码,希望让有缘人少走一点弯路。

原文链接:https://blog.csdn.net/weixin_42365510/article/details/112088887

tensorrt动态输入分辨率尺寸相关推荐

  1. 取input 输入_tensorRT动态输入(python)

    关于tensorRT动态输入的例子大多数都是c++版本的,python版本的较少,这里简单总结下python处理tensorRT动态输入时,遇到的一些问题及解决方案. 这里的动态输入是指batch,w ...

  2. 一个小改动,CNN输入固定尺寸图像改为任意尺寸图像

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文小白将和大家一起学习如何在不使用计算量很大的滑动窗口的情况下对 ...

  3. 安卓开发屏幕分辨率尺寸适配问题【原创】

    2019独角兽企业重金招聘Python工程师标准>>> 屏幕分辨率尺寸适配是安卓开发中的难题之一,我开发中的解决办法是: 1.多使用相对布局,即RelativeLayout,或者Li ...

  4. MutualNet:一种“宽度-输入分辨率”互相学习的网络轻量化方法

    本文分享一篇来自 ECCV'20 Oral 的论文『MutualNet: Adaptive ConvNet via Mutual Learning from Network Width and Res ...

  5. OpenVINO整活(一) 输入分辨率

    OpenVINO整活(一) 输入分辨率 OpenVINO分为转换与部署两个部分,如下图所示 在转换步中,需要将输入模型序列化后传入OpenVINO的MOModel Optimizer工具对模型进行优化 ...

  6. 为什么有全连接层的卷积网络输入图片尺寸需要固定的

    一句话: 全连接层的一个神经元对应一个输入. 换句话说, 全连接层要求固定的输入维度. 数学推导: 大家都知道, z=wx+b,全连接神经网络结构一旦固定,需要学习的参数w是固定的,例如 输入图像是 ...

  7. SPP-Net 是怎么让 CNN 实现输入任意尺寸图像的?

    ECCV2014 Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition 解决的问题: there ...

  8. 壁纸最佳分辨率尺寸表

    标签: 壁纸, 最佳分辨率 本文链接: 壁纸最佳分辨率尺寸表 版权所有: hiued – 用户体验数据分析中心, 转载请注明本文出处. 壁纸最佳壁纸尺寸表,部分数据来源于网络.均测试过,数据有效.以下 ...

  9. PostgreSQL查询 动态输入参数

    工作中会碰到临时几天查询相关报表的情况,sql写好后每天只需改日期再执行一次就可以,但是一个个替换或者批量替换不仅耗时耗力,还有可能出错,所以想要能够动态输入查询参数并且同时改变成这个参数的功能. 之 ...

最新文章

  1. java内部类人打电话依赖手机_Java内部类及反射类面试问题,90%的人都不知道
  2. C#.NET如何判断是否有缺少的using
  3. java.awt.action 命令模式_java设计模式之命令模式
  4. ORA-29275:部分多字节字符
  5. x86已安装该产品 剑灵vcredist_MySQL Server v5.7正式版(附安装和配置数据库教程)
  6. 关于去苹果服务器验证充值的一些看法
  7. ai逻辑回归_人工智能中的逻辑是什么?
  8. 修改oracle 的dbname,在oracle 10g上修改dbname的实验
  9. 5.7 Components — Sending Actions From Components to Your Application
  10. 经纪xx系统节点VIP案例介绍和深入分析异常
  11. 《流浪地球》高赞好评被收买改差评?豆瓣如此回应...
  12. [Java] 蓝桥杯 BEGIN-3 入门训练 圆的面积
  13. HMI设计RGB配色表
  14. 总结:form中使用onSubmit=return false防止表单自动提交,以及s...
  15. 康佳LED55K55U电视板砖的拯救历程
  16. java 休眠_Java中 休眠(sleep)
  17. 光功率 博科交换机_交换机查看光模块型号及收发光功率命令
  18. ttl低电平接大电阻_电压不稳定?那是你不懂上拉/下拉电阻原理,5分钟教你应用!...
  19. 字符集详解(一看就懂系列)
  20. “青少年编程能力等级”第一、第二部分:图形化编程 Python编程 含测试样题

热门文章

  1. linux awk 多分隔符
  2. STL中的list详解
  3. ASP.NETcompilation debug=false targetFramework=4.0/错误
  4. VC6.0生成文件的种类和作用
  5. 调用 fork() 两次以避免僵死进程
  6. linux 统一设备模型 pci,Linux设备驱动模型摘抄
  7. oracle表行列权限,Oracle行列互换 横表和纵表
  8. html读取url中文件,HTML5基础知识 - JavaScript API - File - 读取文件为DataURL
  9. android xml pid vid,增加属性标识摄像头的vid与pid,以便知道摄像头与设备文件的对应关系...
  10. 计算机病毒不可能侵入rom,2008年职称计算机考试计算机基础试题7