需求是这样的,在做PointPillars模型的加速的时候我注意到网络的检测头部分小型操作很多,加速效果不明显。此外,3D检测模型的NMS部分通常是作为后处理的一部分来单独实现,TensorRT并没有直接支持3D NMS的导出。本着学习的目的,我将PointPillars模型中的检测头(单头)和3D NMS两部分合并到一个TensorRT Plugin,实现端到端的推理。其最终效果如下右图所示,自定义的NMS3D Plugin包含了整个后处理部分。

如何在onnx的输出后面增加NMS3D节点?

这一步涉及到修改onnx模型,可借助TensorRT自带的小工具ONNX GraphSurgeon来完成。它可以增加或者移除某些onnx节点,修改名字或者维度等等。ONNX GraphSurgeon工具的安装也很简单,先安装nvidia-pyindex,然后再安装onnx-graphsurgeon。

pip install nvidia-pyindex
pip install onnx-graphsurgeon

然后再是修改计算图的操作,我这里给出两种实现方式仅供参考。

方法一:

# Here we'll register a function to do all the subgraph-replacement heavy-lifting.
# NOTE: Since registered functions are entirely reusable, it may be a good idea to
# refactor them into a separate module so you can use them across all your models.
@gs.Graph.register()
def add_nms3d(self, inputs, outputs):# Disconnect output nodes of all input tensorsfor inp in inputs:inp.outputs.clear()### Disconnet input nodes of all output tensorsfor out in outputs:out.inputs.clear()attrs = collections.OrderedDict()attrs['anchor_sizes'] = anchor_sizesattrs['anchor_bottom_heights'] = anchor_bottom_heights# Insert the new node.return self.layer(op="NMS3D", inputs=inputs, outputs=outputs, name="nms3d", attrs=attrs)def simplify_onnx():model = onnx.load("pointpillar_raw.onnx")graph = gs.import_onnx(onnx_model)tmap = graph.tensors()inputs = [tmap['cls_preds'],tmap['box_preds'],tmap['dir_cls_preds']]outputs = [gs.Variable(name="nms3d_output", dtype=np.float32, shape=(1,100,9))]graph.add_nms3d(inputs, outputs)graph.outputs = outputsgraph.cleanup()graph.toposort()onnx.save(model_simplify, "pointpillar_simplify.onnx")print("export ok...")

方法二:

def simplify_onnx():#model = onnx.load("pointpillar_raw.onnx")model = onnx.load("pointpillar_fcn_max_nchw_cudapp.onnx")while len(model.graph.output):model.graph.output.remove(model.graph.output[0])model.graph.output.extend([onnx.helper.make_tensor_value_info('nms3d_output', onnx.TensorProto.FLOAT, [1,100,9]),])  attrs = collections.OrderedDict()attrs['anchor_sizes'] = anchor_sizesattrs['anchor_bottom_heights'] = anchor_bottom_heightsgraph = gs.import_onnx(model)tmap = graph.tensors()inputs = [tmap['cls_preds'],tmap['box_preds'],tmap['dir_cls_preds']]outputs = [tmap['nms3d_output']]nms3d_layer = graph.layer(op="NMS3D", inputs=inputs, outputs=outputs, name="nms3d", attrs=attrs)graph.cleanup()graph.toposort()onnx_module = gs.export_onnx(graph)onnx.save(onnx_module, "pointpillar_simplify.onnx")print("export ok...")

【参考文献】

TensorRT详细入门指北,如果你还不了解TensorRT,过来看看吧! - 知乎

安装onnx-graphsurgeon_人类高质量算法工程师的博客-CSDN博客

如何修改已有的ONNX模型 - 知乎

Polygraphy逐层对比onnx和tensorrt模型的输出 - 知乎

深度学习系列4:onnx_IE06的博客-CSDN博客_onnx学习

【模型加速】自定义TensorRT NMS3D插件(1)相关推荐

  1. 【模型加速】TensorRT详解

    ■ TensorRT概述 NVIDIA®TensorRT™的核心是一个C++库,可以促进在NVIDIA图形处理单元(GPU)上的高性能推断.它旨在与Tensorflow.Caffe.Pytorch.M ...

  2. 【模型加速】TensorRT安装、测试及常见问题

    ■ 安装过程 一.安装依赖环境 ● Ubuntu 20.04 ● CUDA 11.1 ● cuDNN 8.0.4 ● python 3.8.5 – 可以通过命令查看cuda.cudnn.python版 ...

  3. yolo模型部署——tensorRT模型加速+triton服务器模型部署

    将最近的工作做个记录,方便日后学习回顾: 1.针对项目需求开发满足任务的模型,拿到任务就要去选相应的算法,由于是工程应用型,必须找填坑多的算法,这样遇到问题可参考的资料多. 2.做好以后,还要将开发的 ...

  4. Ultralytics公司YOLOv8来了(训练自己的数据集并基于NVIDIA TensorRT和华为昇腾端到端模型加速)--跟不上“卷“的节奏

    Official YOLOv8 训练自己的数据集并基于NVIDIA TensorRT和华为昇腾端到端模型加速 说明: 本项目支持YOLOv8的对应的package的版本是:ultralytics-8. ...

  5. php多选筛选,DEDECMS自定义模型筛选多选版插件

    DEDECMS自定义模型筛选多选版插件,像分类信息网站一样的筛选功能. 一.文件夹说明: incluede         核心函数目录 二.安装说明 1.把这些文件夹全部复制到根目录粘贴,或者按文件 ...

  6. win10下 yolov8 tensorrt模型加速部署【实战】

    Windows10下yolov8 tensorrt模型加速部署[实战] TensorRT-Alpha基于tensorrt+cuda c++实现模型end2end的gpu加速,支持win10.linux ...

  7. 【模型加速】PointPillars模型TensorRT加速实验(7)

    按照[模型加速]PointPillars模型TensorRT加速实验(7)中给出的思路对已有的推理代码进行优化,简而言之就是保持数据在GPU显存中流动,尽量避免内存和显存之间的流动. PFN推理v2 ...

  8. 模型加速之INT8量化原理及实践(基于TensorRT)

    一.模型量化: 1.量化的定义是将网络参数从Float-32量化到更低位数,如Float-16.INT8.1bit等. 2.量化的作用:更小的模型尺寸.更低的功耗.更快的计算速度.下图是不同数据结构比 ...

  9. tensorrt轻松部署高性能dnn推理_部署环境之:tensorRT的插件

    TensorRT是一个高性能的深度学习推理(Inference)优化器,可以为深度学习应用提供低延迟.高吞吐率的部署推理.TensorRT可用于对超大规模数据中心.嵌入式平台或自动驾驶平台进行推理加速 ...

最新文章

  1. 【转】推荐两款富文本编辑器:NicEdit和Kindeditor
  2. iview template模式_使用Iview Menu 导航菜单(非 template/render 模式)
  3. 科大星云诗社动态20210827
  4. 【C语言】switch…case无break情况(2)
  5. PHP SOCKET编程详解
  6. C++类实例以及子类在内存中的分配
  7. [react] 什么是React.forwardRef?它有什么作用?
  8. ARM linux的启动部分源代码简略分析【转】
  9. Tomcat学习总结(18)—— Tomcat启动时org.apache.catalina.util.SessionIdGenerator产生安全随机类SecureRandom的实例慢问题解决
  10. Jmeter中的几种协议
  11. 【采访】腾讯社交广告高校算法大赛第一周周冠军——郭达雅 比赛经验及心得分享
  12. 苹果三星手机被诉辐射超标;淘集集启动破产清算;Drupal 8.8.0 发布 | 极客头条...
  13. 关于php 调用接口 微信云支付 HmacSha256 加密 request_content 生成 authen_code
  14. redis分布式锁实现(以抢红包为例)
  15. springboot项目启动遇到问题:AopAutoConfiguration matched: - @ConditionalOnProperty (spring.aop.auto=true)
  16. 8421码转16进制的c语言,将8421BCD码转换为十进制数(转)
  17. 换个视角!那么用户到底想要怎么样的产品?
  18. 幻方加密代码——自动生成幻方密钥方法,罗伯法单偶数阶的解法代码基于python
  19. Excel 查重小技巧,适用于office2003
  20. Linus Torvalds:最庆幸的是 30 年后,Linux 不是一个“死”项目

热门文章

  1. 金华免费服务器_金华云主机
  2. 操作系统实验(linux内核编译,添加系统调用,windows进程创建,脚本程序编写)
  3. HDMI转 toMIPI DSI驱动板1080P 2K 4K TC358870 东芝IC LCD 3D打印机 VR 永星电子 Yongxing
  4. 如何优雅的写UI——(1)MFC六大核心机制-程序初始化
  5. 外汇超短线交易中的“剥头皮”
  6. 网络编程_5(超时检测+UNIX域套接字+抓包工具+包头分析)
  7. HDU 4899 Hero meet devil
  8. 山体滑坡动画用什么软件制作_3d动画都是使用什么软件制作的
  9. 万能声卡驱动(Alsa)的安装方法
  10. php 配置 memcache,php如何配置memcache