非极大值抑制(PyTorch-YOLOv3代码解析一)
非极大值抑制之代码解析(utils.py)
代码github地址:https://github.com/eriklindernoren/PyTorch-YOLOv3
1. 非极大值抑制函数代码
# reference: https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/f917503ffe4a21d2b1148d8cb13b89b834517d76/utils/utils.pydef non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):"""剔除目标置信度小于conf_thres,以及通过非极大值抑制筛选预测的信息Return detections(预测):(x1, y1, x2, y2, object_conf, cls_conf, cls_idx)"""# 把(center x, center y, width, height)转换成(x1, y1, x2, y2)prediction[..., :4] = xywh2xyxy(prediction[..., :4])output = [None for _ in range(len(prediction))] # batch_size# prediction: [batch_size x 10647 x 85]for image_i, image_pred in enumerate(prediction):# 筛选置信度大于conf_thres的行image_pred = image_pred[image_pred[:, 4] >= conf_thres]# 过滤没有大于conf_thres的img_predif not image_pred.size(0):continue# image_pred[:, 5:].max(1)[0]:取出每一行类别概率最大的值, 与筛选的目标置信度相乘score = image_pred[:, 4] * image_pred[:, 5:].max(1)[0]# image_pred按照score降序排列image_pred = image_pred[(-score).argsort()]# 得到每一行类别最大的概率与类别代号,keepdim: 保持与image_pred的维度相同class_confs, class_preds = image_pred[:, 5:].max(1, keepdim=True)# 获取detections:[x1,y1,x2,y2,object_conf,cls_conf,cls_idx],即预测的结果detections = torch.cat((image_pred[:, :5], class_confs.float(), class_preds.float()), 1)# 执行非极大值抑制keep_boxes = []while detections.size(0):# 其他框与score最大的框(第一个框)做交并比large_overlap = bbox_iou(detections[0, :4].unsqueeze(0), detections[:, :4]) > nms_thres# 第一行与每一行的类别代号相等的条件label_match = detections[0, -1] == detections[:, -1]# 与第一个框交并比大于阈值,且类别代号相同的条件invalid = large_overlap & label_match# 满足上面条件的行,通过目标置信度对box做加权平均处理weights = detections[invalid, 4:5]# 加权平均合并,赋值给第一行detections[0, :4] = (weights * detections[invalid, :4]).sum(0) / weights.sum()keep_boxes += [detections[0]]detections = detections[~invalid] # 取剩余部分的预测值if keep_boxes:output[image_i] = torch.stack(keep_boxes) # 非极大值抑制的结果return output
2. 目标置信度过滤
(1) 首先需要把预测的值[cx,cy,w,h]转换为左上角与右下角坐标[x1,y1,x2,y2],然后由预测的目标置信度大于某个阈值过滤预测的每一行,代码:
# [cx,cy,w,h]->[x1,y1,x2,y2]
prediction[...,:4] = xywh2xyxy(prediction[...,:4])# 通过目标置信项筛选
img_pred = img_pred[img_pred[:,4]>=conf_thres]
(2) 置信度阈值筛选后,数据输出如下:
3. 降序排列与预测重组
(1) 排序的值:每一行的最大类别概率与目标置信度概率的乘积,代码:
# image_pred[:, 5:].max(1)[0]:取出每一行类别概率最大的值, 与筛选的目标置信度相乘
score = img_pred[:,4] * img_pred[:,5:].max(1)[0]# 按照score降序排序
img_pred = img_pred[(-score).argsort()]
(2) 预测重组,在[x1,y1,x2,y2,object_conf]的后面添加类别概率、类别代号:cls_conf、cls_idx。添加的数据是类别概率最大的一个值。代码:
# 得到每一行类别最大的概率与类别代号,keepdim: 保持与image_pred的维度相同
cls_confs, cls_idx = img_pred[:,5:].max(1,keepdim=True)# 获取detections:[x1,y1,x2,y2,conf,cls_conf,cls_idx]
detections = torch.cat((img_pred[:, :5], cls_confs.float(), cls_idx.float()), 1)
(3) 执行结果:
4. 非极大值抑制:条件筛选
(1) 条件筛选:其他行与第一行交并比IoU大于阈值,且类别代号与第一行相等。代码:
# 其他行与第一行的交并比大于阈值,且类别代号与第一行相等的条件
large_overlap = bbox_iou(detections[0,:4].unsqueeze(0), detections[:,:4])>nms_thres
label_match = detections[0, -1] == detections[:, -1]
invalid = large_overlap & label_match
(2) 剔除满足invalid的行。
detections = detections[~invalid]
(3) 执行结果:
5. 非极大值抑制:目标置信度对box加权平均
(1) 目标置信度对box加权平均。
# 满足上面条件的行,通过目标置信度对box做加权平均处理
weights = detections[invalid, 4:5]
# 加权平均合并
detections[0, :4] = (weights * detections[invalid, :4]).sum(0) / weights.sum()
# 添加加权平均的detection
keep_boxes += [detections[0]]
(2) 执行结果:
(3) 最后output结果:
非极大值抑制(PyTorch-YOLOv3代码解析一)相关推荐
- array python 交集_NMS原理(非极大值抑制)+python实现
1.先解释什么叫IoU(intersection-over-union).IoU表示(A∩B)/(A∪B) 即交并比. 非极大值抑制:图一 --> 图二 ,剔除同一个目标上的重叠建议框,最终一个 ...
- 【YOLOv3 NMS】YOLOv3中的非极大值抑制
文章目录 1 NMS问题由来 2 NMS操作流程 2.1 进行NMS前要先有什么 2.2 NMS流程 3 NMS代码解读 4 感谢链接 1 NMS问题由来 利用YOLOv3网络结构提取到out0.ou ...
- NMS(Non-Maximum Suppression,非极大值抑制)解析
非极大值抑制,简称为NMS算法,英文为Non-Maximum Suppression.其思想是搜素局部最大值,抑制极大值.NMS算法在不同应用中的具体实现不太一样,但思想是一样的.非极大值抑制,在计算 ...
- PyTorch实现非极大值抑制(NMS)
NMS即non maximum suppression即非极大抑制,顾名思义就是抑制不是极大值的元素,搜索局部的极大值.在最近几年常见的物体检测算法(包括rcnn.sppnet.fast-rcnn.f ...
- 锚框、交并比和非极大值抑制(tf2.0源码解析)
锚框.交并比和非极大值抑制(tf2.0源码解析) 文章目录 锚框.交并比和非极大值抑制(tf2.0源码解析) 一.锚框生成 1.锚框的宽高 2.锚框的个数 3.注意点(★★★) 4.tf2.0代码 二 ...
- 一种非极大值抑制(non_max_suppression, nms)的代码实现方式
目录 1. 简介 2. 代码 2.1 坐标形式转换 2.2 iou计算 2.3 nms 1. 简介 非极大值抑制,non_max_suppression,简称nms,常用于目标检测的后处理,去除多余的 ...
- 目标检测中的LOU(交并比)和NMS(非极大值抑制)代码实现
1.LOU, 两个box框的交集比上并集,示意图如下所示: 代码如下所示: #假设box1的维度为[N,4] box2的维度为[M,4] def Lou(box1, box2):N = box1.si ...
- 手写非极大值抑制代码(NMS)
在物体检测领域当中,非极大值抑制应用十分广泛,目的是为了消除多余的框,找到最佳的物体检测的位置.那么具体如何操作呢?如下图所示,有三个boundingbox,其中第一个绿色boundingbox的置信 ...
- yolov3 NMS非极大值抑制
基本原理:对于Bounding Box的列表B及其对应的置信度S,采用下面的计算方式.选择具有最大score的检测框M,将其从B集合中移除并加入到最终的检测结果D中.通常将B中剩余检测框中与M的IoU ...
- sklearn逻辑回归 极大似然 损失_收藏!攻克目标检测难点秘籍二,非极大值抑制与回归损失优化之路...
点击上方"AI算法修炼营",选择加星标或"置顶" 标题以下,全是干货 前面的话 在前面的秘籍一中,我们主要关注了模型加速之轻量化网络,对目标检测模型的实时性难点 ...
最新文章
- 怎么将对象里面部分的属性放到一个空的对象里面去
- 【Groovy】Groovy 脚本调用 ( Java 类中调用 Groovy 脚本 )
- 当装了两个tomcat后,如何修改tomcat端口
- html鼠标点击伪类,CSS伪类:CSS3鼠标滑过按钮动画
- python print换行_Python中九九乘法表与古诗对话机器人及sep-end值
- think-in-java(17)容器深入研究
- Spring MVC AOP 初步学习
- SwiftUI实战一:从入门到精通
- 自己动手用麦咖啡(mcafee)打造自己的安全网站!安全系统(服务器)!
- 软件测试的职责描述,软件测试工程师的责任是什么?
- 视频编码解码(H264中的profile和level)
- Android吉他调音器,吉他调音器:GuitarTuna
- 安卓面试中高级安卓开发工程师总结之——大公司面试的方向和套路以及应对方法
- MNN Interpreter and Session
- CMD连接MySQL,本地phpAdmin登陆
- 主打“极致性价比”的酷玩7,能否让酷派再现辉煌?
- php 关注微信触发事件,微信api 关注事件
- 《浪潮之巅》读者热评
- 如何设置每天服务器定时重启?
- 设定了所有种子后每次结果还是不一样 pytorch可重复 可复现问题