非极大值抑制之代码解析(utils.py)

代码github地址:https://github.com/eriklindernoren/PyTorch-YOLOv3

1. 非极大值抑制函数代码

# reference: https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/f917503ffe4a21d2b1148d8cb13b89b834517d76/utils/utils.pydef non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):"""剔除目标置信度小于conf_thres,以及通过非极大值抑制筛选预测的信息Return detections(预测):(x1, y1, x2, y2, object_conf, cls_conf, cls_idx)"""# 把(center x, center y, width, height)转换成(x1, y1, x2, y2)prediction[..., :4] = xywh2xyxy(prediction[..., :4])output = [None for _ in range(len(prediction))]  # batch_size# prediction: [batch_size x 10647 x 85]for image_i, image_pred in enumerate(prediction):# 筛选置信度大于conf_thres的行image_pred = image_pred[image_pred[:, 4] >= conf_thres]# 过滤没有大于conf_thres的img_predif not image_pred.size(0):continue# image_pred[:, 5:].max(1)[0]:取出每一行类别概率最大的值, 与筛选的目标置信度相乘score = image_pred[:, 4] * image_pred[:, 5:].max(1)[0]# image_pred按照score降序排列image_pred = image_pred[(-score).argsort()]# 得到每一行类别最大的概率与类别代号,keepdim: 保持与image_pred的维度相同class_confs, class_preds = image_pred[:, 5:].max(1, keepdim=True)# 获取detections:[x1,y1,x2,y2,object_conf,cls_conf,cls_idx],即预测的结果detections = torch.cat((image_pred[:, :5], class_confs.float(), class_preds.float()), 1)# 执行非极大值抑制keep_boxes = []while detections.size(0):# 其他框与score最大的框(第一个框)做交并比large_overlap = bbox_iou(detections[0, :4].unsqueeze(0), detections[:, :4]) > nms_thres# 第一行与每一行的类别代号相等的条件label_match = detections[0, -1] == detections[:, -1]# 与第一个框交并比大于阈值,且类别代号相同的条件invalid = large_overlap & label_match# 满足上面条件的行,通过目标置信度对box做加权平均处理weights = detections[invalid, 4:5]# 加权平均合并,赋值给第一行detections[0, :4] = (weights * detections[invalid, :4]).sum(0) / weights.sum()keep_boxes += [detections[0]]detections = detections[~invalid]  # 取剩余部分的预测值if keep_boxes:output[image_i] = torch.stack(keep_boxes)  # 非极大值抑制的结果return output

2. 目标置信度过滤

(1) 首先需要把预测的值[cx,cy,w,h]转换为左上角与右下角坐标[x1,y1,x2,y2],然后由预测的目标置信度大于某个阈值过滤预测的每一行,代码:

# [cx,cy,w,h]->[x1,y1,x2,y2]
prediction[...,:4] = xywh2xyxy(prediction[...,:4])# 通过目标置信项筛选
img_pred = img_pred[img_pred[:,4]>=conf_thres]

(2) 置信度阈值筛选后,数据输出如下:

3. 降序排列与预测重组

(1) 排序的值:每一行的最大类别概率与目标置信度概率的乘积,代码:

# image_pred[:, 5:].max(1)[0]:取出每一行类别概率最大的值, 与筛选的目标置信度相乘
score = img_pred[:,4] * img_pred[:,5:].max(1)[0]# 按照score降序排序
img_pred = img_pred[(-score).argsort()]

(2) 预测重组,在[x1,y1,x2,y2,object_conf]的后面添加类别概率、类别代号:cls_conf、cls_idx。添加的数据是类别概率最大的一个值。代码:

 # 得到每一行类别最大的概率与类别代号,keepdim: 保持与image_pred的维度相同
cls_confs, cls_idx = img_pred[:,5:].max(1,keepdim=True)# 获取detections:[x1,y1,x2,y2,conf,cls_conf,cls_idx]
detections = torch.cat((img_pred[:, :5], cls_confs.float(), cls_idx.float()), 1)

(3) 执行结果:

4. 非极大值抑制:条件筛选

(1) 条件筛选:其他行与第一行交并比IoU大于阈值,且类别代号与第一行相等。代码:

# 其他行与第一行的交并比大于阈值,且类别代号与第一行相等的条件
large_overlap = bbox_iou(detections[0,:4].unsqueeze(0), detections[:,:4])>nms_thres
label_match = detections[0, -1] == detections[:, -1]
invalid = large_overlap & label_match

(2) 剔除满足invalid的行。

detections = detections[~invalid]

(3) 执行结果:

5. 非极大值抑制:目标置信度对box加权平均

(1)  目标置信度对box加权平均。

# 满足上面条件的行,通过目标置信度对box做加权平均处理
weights = detections[invalid, 4:5]
# 加权平均合并
detections[0, :4] = (weights * detections[invalid, :4]).sum(0) / weights.sum()
# 添加加权平均的detection
keep_boxes += [detections[0]]

(2) 执行结果:

(3) 最后output结果:

非极大值抑制(PyTorch-YOLOv3代码解析一)相关推荐

  1. array python 交集_NMS原理(非极大值抑制)+python实现

    1.先解释什么叫IoU(intersection-over-union).IoU表示(A∩B)/(A∪B) 即交并比. 非极大值抑制:图一 --> 图二 ,剔除同一个目标上的重叠建议框,最终一个 ...

  2. 【YOLOv3 NMS】YOLOv3中的非极大值抑制

    文章目录 1 NMS问题由来 2 NMS操作流程 2.1 进行NMS前要先有什么 2.2 NMS流程 3 NMS代码解读 4 感谢链接 1 NMS问题由来 利用YOLOv3网络结构提取到out0.ou ...

  3. NMS(Non-Maximum Suppression,非极大值抑制)解析

    非极大值抑制,简称为NMS算法,英文为Non-Maximum Suppression.其思想是搜素局部最大值,抑制极大值.NMS算法在不同应用中的具体实现不太一样,但思想是一样的.非极大值抑制,在计算 ...

  4. PyTorch实现非极大值抑制(NMS)

    NMS即non maximum suppression即非极大抑制,顾名思义就是抑制不是极大值的元素,搜索局部的极大值.在最近几年常见的物体检测算法(包括rcnn.sppnet.fast-rcnn.f ...

  5. 锚框、交并比和非极大值抑制(tf2.0源码解析)

    锚框.交并比和非极大值抑制(tf2.0源码解析) 文章目录 锚框.交并比和非极大值抑制(tf2.0源码解析) 一.锚框生成 1.锚框的宽高 2.锚框的个数 3.注意点(★★★) 4.tf2.0代码 二 ...

  6. 一种非极大值抑制(non_max_suppression, nms)的代码实现方式

    目录 1. 简介 2. 代码 2.1 坐标形式转换 2.2 iou计算 2.3 nms 1. 简介 非极大值抑制,non_max_suppression,简称nms,常用于目标检测的后处理,去除多余的 ...

  7. 目标检测中的LOU(交并比)和NMS(非极大值抑制)代码实现

    1.LOU, 两个box框的交集比上并集,示意图如下所示: 代码如下所示: #假设box1的维度为[N,4] box2的维度为[M,4] def Lou(box1, box2):N = box1.si ...

  8. 手写非极大值抑制代码(NMS)

    在物体检测领域当中,非极大值抑制应用十分广泛,目的是为了消除多余的框,找到最佳的物体检测的位置.那么具体如何操作呢?如下图所示,有三个boundingbox,其中第一个绿色boundingbox的置信 ...

  9. yolov3 NMS非极大值抑制

    基本原理:对于Bounding Box的列表B及其对应的置信度S,采用下面的计算方式.选择具有最大score的检测框M,将其从B集合中移除并加入到最终的检测结果D中.通常将B中剩余检测框中与M的IoU ...

  10. sklearn逻辑回归 极大似然 损失_收藏!攻克目标检测难点秘籍二,非极大值抑制与回归损失优化之路...

    点击上方"AI算法修炼营",选择加星标或"置顶" 标题以下,全是干货 前面的话 在前面的秘籍一中,我们主要关注了模型加速之轻量化网络,对目标检测模型的实时性难点 ...

最新文章

  1. 怎么将对象里面部分的属性放到一个空的对象里面去
  2. 【Groovy】Groovy 脚本调用 ( Java 类中调用 Groovy 脚本 )
  3. 当装了两个tomcat后,如何修改tomcat端口
  4. html鼠标点击伪类,CSS伪类:CSS3鼠标滑过按钮动画
  5. python print换行_Python中九九乘法表与古诗对话机器人及sep-end值
  6. think-in-java(17)容器深入研究
  7. Spring MVC AOP 初步学习
  8. SwiftUI实战一:从入门到精通
  9. 自己动手用麦咖啡(mcafee)打造自己的安全网站!安全系统(服务器)!
  10. 软件测试的职责描述,软件测试工程师的责任是什么?
  11. 视频编码解码(H264中的profile和level)
  12. Android吉他调音器,吉他调音器:GuitarTuna
  13. 安卓面试中高级安卓开发工程师总结之——大公司面试的方向和套路以及应对方法
  14. MNN Interpreter and Session
  15. CMD连接MySQL,本地phpAdmin登陆
  16. 主打“极致性价比”的酷玩7,能否让酷派再现辉煌?
  17. php 关注微信触发事件,微信api 关注事件
  18. 《浪潮之巅》读者热评
  19. 如何设置每天服务器定时重启?
  20. 设定了所有种子后每次结果还是不一样 pytorch可重复 可复现问题

热门文章

  1. java代码整洁之道_代码整洁之道——我们是作者
  2. 按键精灵开发后台命令脚本的操作教程
  3. 电子商务网站建设规划方案
  4. 我行我素购物管理系统(面向对象)
  5. 百度云盘不限速 又一款百度网盘满速下载工具利器
  6. DBeaver——一款替代Navicat的数据库可视化工具
  7. lintcode java_Lintcode-java版本
  8. 新浪微博OAuth认证简介
  9. Java开发工具 IntelliJ IDEA(idea使用教程,手把手教学)内容很全,一篇管够!!!
  10. 振动试验条件及试验标准