NMS方法的总结可以参考我之前的文章:
https://blog.csdn.net/qq_34919792/article/details/108186234

非极大值抑制(Non-Maximum Suppression,NMS),顾名思义就是抑制不是极大值的元素。在检测中,我们通过将IOU大于一定阈值的框做一个筛选,只保留置信度最高的框。

网上比较经典的实现思路

def py_cpu_nms(dets, thresh): """Pure Python NMS baseline.""" #x1、y1、x2、y2、以及score赋值 x1 = dets[:, 0] y1 = dets[:, 1] x2 = dets[:, 2] y2 = dets[:, 3] scores = dets[:, 4] #每一个检测框的面积 areas = (x2 - x1 + 1) * (y2 - y1 + 1) #按照score置信度降序排序 order = scores.argsort()[::-1] keep = [] #保留的结果框集合while order.size > 0: i = order[0] keep.append(i) #保留该类剩余box中得分最高的一个 #得到相交区域,左上及右下 xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) #计算相交的面积,不重叠时面积为0 w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h #计算IoU:重叠面积 /(面积1+面积2-重叠面积) ovr = inter / (areas[i] + areas[order[1:]] - inter) #保留IoU小于阈值的box inds = np.where(ovr <= thresh)[0] order = order[inds + 1] #因为ovr数组的长度比order数组少一个,所以这里要将所有下标后移一位 return keep

其实在NMS可以做pytorch版本,速度和cpython加速后一样,即和pytorch自带的库一样,但是更适合改。

def IOU(box_a, box_b):inter = intersect(box_a, box_b)area_a = ((box_a[:, 2]-box_a[:, 0]) *(box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]area_b = ((box_b[:, 2]-box_b[:, 0]) *(box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]union = area_a + area_b - interreturn inter / union  # [A,B]def torch_nms(boxes, scores, iou_threshold):_, idx = scores.sort(0, descending=True) # descending表示降序boxes_idx = boxes[idx]iou = IOU(boxes_idx, boxes_idx).triu_(diagonal=1) #取上三角矩阵,不包含对角线B = iouwhile 1:A = BmaxA, _ = torch.max(A, dim=0)E = (maxA <= iou_threshold).float().unsqueeze(1).expand_as(A)B = iou.mul(E)if A.equal(B) == True:breakkeep= idx[maxA <= iou_threshold]return keep

这样的代码其实很好改进,比如可以把IOU去改成GIOU或者DIOU或者自己魔改魔改,也可以去改进NMS,提供个DIOU的改进示例。

def DIOU(box_a, box_b, delta = 0.9):inter = intersect(box_a, box_b)area_a = ((box_a[:, :, 2]-box_a[:, :, 0]) *(box_a[:, :, 3]-box_a[:, :, 1])).unsqueeze(2).expand_as(inter)  # [A,B]area_b = ((box_b[:, :, 2]-box_b[:, :, 0]) *(box_b[:, :, 3]-box_b[:, :, 1])).unsqueeze(1).expand_as(inter)  # [A,B]union = area_a + area_b - interx1 = ((box_a[:, :, 2]+box_a[:, :, 0]) / 2).unsqueeze(2).expand_as(inter)y1 = ((box_a[:, :, 3]+box_a[:, :, 1]) / 2).unsqueeze(2).expand_as(inter)x2 = ((box_b[:, :, 2]+box_b[:, :, 0]) / 2).unsqueeze(1).expand_as(inter)y2 = ((box_b[:, :, 3]+box_b[:, :, 1]) / 2).unsqueeze(1).expand_as(inter)t1 = box_a[:, :, 1].unsqueeze(2).expand_as(inter)b1 = box_a[:, :, 3].unsqueeze(2).expand_as(inter)l1 = box_a[:, :, 0].unsqueeze(2).expand_as(inter)r1 = box_a[:, :, 2].unsqueeze(2).expand_as(inter)t2 = box_b[:, :, 1].unsqueeze(1).expand_as(inter)b2 = box_b[:, :, 3].unsqueeze(1).expand_as(inter)l2 = box_b[:, :, 0].unsqueeze(1).expand_as(inter)r2 = box_b[:, :, 2].unsqueeze(1).expand_as(inter)cr = torch.max(r1, r2)cl = torch.min(l1, l2)ct = torch.min(t1, t2)cb = torch.max(b1, b2)D = (((x2 - x1)**2 + (y2 - y1)**2) / ((cr-cl)**2 + (cb-ct)**2 + 1e-7))out = inter / union - D ** deltareturn out if use_batch else out.squeeze(0)def torch_nms(boxes, scores, iou_threshold):_, idx = scores.sort(0, descending=True) # descending表示降序boxes_idx = boxes[idx]iou = DIOU(boxes_idx, boxes_idx).triu_(diagonal=1) #取上三角矩阵,不包含对角线B = iouwhile 1:A = BmaxA, _ = torch.max(A, dim=0)E = (maxA <= iou_threshold).float().unsqueeze(1).expand_as(A)B = iou.mul(E)if A.equal(B) == True:breakkeep= idx[maxA <= iou_threshold]return keep

手写NMS和魔改(Pytorch版本)相关推荐

  1. 实例:手写 CUDA 算子,让 Pytorch 提速 20 倍

    作者丨PENG Bo@知乎(已授权) 来源丨https://zhuanlan.zhihu.com/p/476297195 编辑丨极市平台 本文的代码,在 win10 和 linux 均可直接编译运行: ...

  2. mybatis 无法初始化类_从零开始手写 mybatis(一)MVP 版本

    什么是 MyBatis ? MyBatis 是一款优秀的持久层框架,它支持定制化 SQL.存储过程以及高级映射. MyBatis 避免了几乎所有的 JDBC 代码和手动设置参数以及获取结果集. MyB ...

  3. 基于手写字体数据集MNIST的pytorch图像分类实例

    logistic回归 logistic回归可用于的是做二分类问题,使用sigmoid函数将所有的正数和负数都变成0-1之间的数,这样就可以用这个数来确定到底属于哪一类,可以简单的认为概率大于0.5即为 ...

  4. 卷积神经网络 手写数字识别(包含Pytorch实现代码)

    Hello!欢迎来到六个核桃Lu! 运用卷积神经网络 实现手写数字识别 1 算法分析及设计 卷积神经网络: 图1-2 如图1-2,卷积神经网络由若干个方块盒子构成,盒子从左到右仿佛越来越小,但却越来越 ...

  5. pytorch lstm 写诗文的魔改,测试,猜想

    首先目前自然语言处理的网络基本都是transformers的变体. 我们就不从热闹了,就使用简单的FC层设计一个,首先一般自然语言都是一个概率问题, 所以就是一个分类问题,一般都是有多少的字就分为多少 ...

  6. pytorch实现手写数字识别_Paddle和Pytorch实现MNIST手写数字集识别对比

    一.简介 1. Paddle PaddlePaddle是百度自主研发的集深度学习核心框架.工具组件和服务平台为一体的技术领先.功能完备的开源深度学习平台,有全面的官方支持的工业级应用模型,涵盖自然语言 ...

  7. 手写数字识别环境安装(pytorch)(window10)

    Anaconda TODO Pytorch 必须使用conda install 进行安装 在cmd中nvidia-smi查看本机CUDA版本 在官网复制下载命令,选取的CUDA版本不能高于本机CUDA ...

  8. 手写数字代码识别(pytorch)实现

    数据预览: import pandas df=pandas.read_csv('C:\\Users\\HP\\Desktop\\mnist_train.csv',header=None) df.hea ...

  9. 实验3 手写字体识别【机器学习】

    推荐 python实现手写数字识别(小白入门) 原文MNIST Handwritten Digit Recognition in PyTorch 翻译用PyTorch实现MNIST手写数字识别(非常详 ...

  10. Git手写笔记(简单秒懂)详细讲解

    Git手写笔记 Git是什么:版本控制系统 1.常见的版本控制工具 :1.集中式版本控制工具 2. 分布式版本控制工具 集中式版本控制工具:集中式三个人每个人都有一台,电脑合并到一起,三台链接到一起, ...

最新文章

  1. linux 用户态 隐藏进程 简介
  2. 余数相同问题(信息学奥赛一本通-T1080)
  3. JS: 浅拷贝vs深拷贝 | 刷题打卡
  4. 《2021新青年生长力报告》:水果青年、农货青年、设计青年,哪个最潮?
  5. 使用base标签后图片无法加载_Spring 源码学习(二)-默认标签解析
  6. qt下调用win32api 修改分辨率
  7. Centos 7 telnet 详解
  8. Spark学习笔记6:Spark调优与调试
  9. android自动计步_Android计步模块实例代码(类似微信运动)
  10. MATLAB 2017a 下载及安装
  11. 清理c盘、c盘哪些文件可以删、图形显示文件大小软件
  12. qgis二次开发环境
  13. (中英)作文 —— 标题与小标题
  14. 不要用 Mounty,一次惊险的数据恢复记录
  15. 教你2种常用的电商高并发处理解决方案
  16. ASP.NET 快递管理管理系统毕业设计实现步骤
  17. 高效记忆/形象记忆(10)英语单词记忆-音标法
  18. 如何取消EXCEL文件的“受保护的视图“
  19. 跳转指令JMP(04)和跳转结束指令JME(05)
  20. Gradle学习之Android-DSL AppExtension篇

热门文章

  1. iOS 逆向编程(二)越狱入门知识
  2. 单片机驱动程序是什么,驱动文件组成。
  3. OAuth2通过token访问资源服务器
  4. java开发微信公众号退款_微信公众号退款开发
  5. 微信小程序滑动切换选项卡
  6. Java命名和java图标来由
  7. 使用Patch激活CleanMyPC时报错找不到文件
  8. 大型央企云边协同建设方案及其借鉴意义分析
  9. jsmind 线条_jsmind实例扩展(思维导图)
  10. python环境下,PIP卸载、重装、升级