直接放入nms yolov5源码,以一个二分类的模型举例说明,对nms的理解

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,labels=(), max_det=100000,return_index = False):"""Runs Non-Maximum Suppression (NMS) on inference resultsReturns:list of detections, on (n,6) tensor per image [xyxy, conf, cls]"""nc = prediction.shape[2] - 5  # number of classesxc = prediction[..., 4] > conf_thres  # candidates# Checksassert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'# Settingsmin_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and heightmax_nms = 200000  # maximum number of boxes into torchvision.ops.nms()time_limit = 60.0  # seconds to quit afterredundant = True  # require redundant detectionsmulti_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)merge = False  # use merge-NMSt = time.time()i = Noneoutput = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]for xi, x in enumerate(prediction):  # image index, image inference# Apply constraints# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-heightx = x[xc[xi]]  # confidence# Cat apriori labels if autolabellingif labels and len(labels[xi]):l = labels[xi]v = torch.zeros((len(l), nc + 5), device=x.device)v[:, :4] = l[:, 1:5]  # boxv[:, 4] = 1.0  # confv[range(len(l)), l[:, 0].long() + 5] = 1.0  # clsx = torch.cat((x, v), 0)# If none remain process next imageif not x.shape[0]:continue# Compute confx[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf# Box (center x, center y, width, height) to (x1, y1, x2, y2)box = xywh2xyxy(x[:, :4])# Detections matrix nx6 (xyxy, conf, cls)if multi_label:i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).Tx = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)else:  # best class onlyconf, j = x[:, 5:].max(1, keepdim=True)x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]# Filter by classif classes is not None:x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]# Apply finite constraint# if not torch.isfinite(x).all():#     x = x[torch.isfinite(x).all(1)]# Check shapen = x.shape[0]  # number of boxesif not n:  # no boxescontinueelif n > max_nms:  # excess boxesx = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence# Batched NMSc = x[:, 5:6] * (0 if agnostic else max_wh)  # classesboxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scoresi = torchvision.ops.nms(boxes, scores, iou_thres)  # NMSif i.shape[0] > max_det:  # limit detectionsi = i[:max_det]if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrixweights = iou * scores[None]  # box weightsx[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxesif redundant:i = i[iou.sum(1) > 1]  # require redundancyoutput[xi] = x[i]if (time.time() - t) > time_limit:print(f'WARNING: NMS time limit {time_limit}s exceeded')break  # time limit exceededif return_index == True:return output, ielse:return output
  • 最开始两个变量:
    nc = prediction.shape[2] - 5  # number of classesxc = prediction[..., 4] > conf_thres  # candidates
  • nc = prediction.shape[2] - 5 # number of classes

prediction是网络模型的直接输出
输出其shape是(1, 50000, 7), 1表示的是图片的个数,50000表示是网络预测的候选框的个数,7表示一组数其意义如下:

由这个可以看出来,nc可以得到网络预测的类别个数。

  • xc = prediction[…, 4] > conf_thres # candidates
    prediction[…, 4]是shape 为torch.Size([1, 50000]) tensor,prediction[…, 4]意义是取所有预测值的第5个值,表示目标框含有目标的概率值,整个表达式prediction[…, 4] > conf_thres,返回的是一个与prediction[…, 4]具有相同shape(1, 50000)的tensor,其每个值是True或者False,然后把这个tensor赋值给xc,所以xc的shape是(1, 50000),每个值是True或者False,用其值表达每个box的置信度是否大于或者小于conf_thres值。

  • 输入参数conf_thres和iou_thres的check

assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

这两行很简单,是check conf和iou是否在0和1之间,属于对参数的check。

  • 变量定义
    变量定义直接参考变量后面注释,对其中文翻译
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height(像素值)最小和最大的box的宽和高max_nms = 200000  # maximum number of boxes into torchvision.ops.nms()输送到torchvision.ops.nms()接口中最大box总数量time_limit = 60.0  # seconds to quit afternms函数执行超时设置redundant = True  # require redundant detections需要额外的检测multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)merge = False  # use merge-NMS
  • 输出tensor定义
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]

output被定义成一个list,list的长度等于预测图片的个数prediction.shape[0],也就是1,每个list的元素是一个包含6个字段的tensor。

  • 用for循环依次遍历每一张图片,对每一张图片的推理结果进行处理
for xi, x in enumerate(prediction):  # image index, image inference

xi是图片的index,其值是0,表示第0张图片,x是图片的推理结果,其shape是torch.Size([50000, 7]),这里巧妙的利用了enumerate将图片的index和图片的推理结果分离开,分别存放在了xi和x里面。本来prediction是一个shape为torch.Size([1, 50000, 7])的三维张量,张量的轴1是图片的index,张量的另外两个轴对应图片的推理结果,这里利用了enumerate将prediction的这两部分给分离开了。这行代码很巧妙。

x = x[xc[xi]]  # confidence

xc是之前计算出的shape为(1, 50000)的tensor,每个值是True或者False,用其值表达每个box的置信度是否大于或者小于conf_thres值。xc[xi],[xi]这个是对xc tensor的取值方式,xi的值是0,则表示取第0张图片的所有True或者False值。xc[xi]的shape是torch.Size([50000]),与x的第0个维度一致,x[xc[xi]]也是对x tensor的取值方式,x[xc[xi]]整个表达式则表示对x第0个维度上xc[xi]所有为True的保留,为False的则舍弃,也就是对第i(这里i是0)张图片所有confidence值大于conf_thres的box取出,然后重新赋值给x。

# Cat apriori labels if autolabellingif labels and len(labels[xi]):l = labels[xi]v = torch.zeros((len(l), nc + 5), device=x.device)v[:, :4] = l[:, 1:5]  # boxv[:, 4] = 1.0  # confv[range(len(l)), l[:, 0].long() + 5] = 1.0  # clsx = torch.cat((x, v), 0)

这段代码不确定意思,暂时忽略,后面弄明白再补充。

# If none remain process next imageif not x.shape[0]:continue

这里是检查通过conf_thres条件过滤后是否x.shape[0]大于0,如果有存在大于conf_thres的box,那么进一步处理。如果没有大于conf_thres的box,那么就处理下一张图片。

# Compute confx[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

这行是计算每个框所预测的类别的概率,yolo论文的公式是
Pr⁡(Class i∣Object )∗Pr⁡(Object )∗IOUpred truth =Pr⁡(Class i)∗IOUpred truth \operatorname{Pr}\left(\text { Class }_{i} \mid \text { Object }\right) * \operatorname{Pr}(\text { Object }) * \mathrm{IOU}_{\text {pred }}^{\text {truth }}=\operatorname{Pr}\left(\text { Class }_{i}\right) * \mathrm{IOU}_{\text {pred }}^{\text {truth }} Pr( Class i​∣ Object )∗Pr( Object )∗IOUpred truth ​=Pr( Class i​)∗IOUpred truth ​
这个引自于yolo论文gives us class-specific confidence scores for each
box. These scores encode both the probability of that class
appearing in the box and how well the predicted box fits the
object.
从这句话可以看出来,这个公式计算得到的scores值将每个类别的概率值以及预测框对目标框定范围的准确程度,而代码没有使用IOU这一部分,仅仅只做了类别的概率计算。x[:, 5:]对应的是Pr(Class i∣Object ){Pr}\left(\text { Class }_{i} \mid \text { Object }\right)Pr( Class i​∣ Object ),x[:, 4:5]对应的是Pr(Object ){Pr}(\text { Object })Pr( Object )。
而概率的计算也很巧妙,x[:, 5:]对应5000个2列tensor,x[:, 4:5],对应5000个1列tensor,x[:, 5:] *= x[:, 4:5]则表示为2列的每个元素和一列的每个元素分别相乘得到2列的元素, 元素再赋值给x[:, 5:]的两列。经过这一步计算x的后面两列(第6列,第7列)的值就是代表了目标的所在类别的confidence值了。

# Box (center x, center y, width, height) to (x1, y1, x2, y2)box = xywh2xyxy(x[:, :4])

原来yolo的输出box的格式是center x, center y, width, height,经过这一步后将box的四个值表达为(x1, y1, x2, y2)

# Detections matrix nx6 (xyxy, conf, cls)if multi_label:i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).Tx = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)else:  # best class onlyconf, j = x[:, 5:].max(1, keepdim=True)x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

以上代码构建nx6的检测矩阵,对于multi_label,还没有明白,等明白了再补充。

conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

这里使用了tensor的max函数,这个函数的详细用法可以参考https://www.jianshu.com/p/3ed11362b54f,max的输入参数是1表示对每行求最大值,函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。keepdim可以参考https://blog.csdn.net/zylooooooooong/article/details/112576268,表示输出维度和输入维度是否一致,True则表示一致。
函数返回的conf的shape是torch.Size([50000, 1]),每一行是x[:, 5:]的最大值,j是每一行最大值的索引,这个索引最终用来表示成nms输出的每个类别的id。

x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

这一行将box,conf,以及j 按照列cat成一个tensor作为网络的输出。[conf.view(-1) > conf_thres]则是筛选出出confidence值大于conf_thres所有box。
代码执行到这一行,基本就算是把置信度大于conf_thres的所有box给筛选出来了,筛选结果存放在x tensor里面。x的0到3列存放box,4列存放conf,5列存放class种类id。

# Filter by classif classes is not None:x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

这一行是利用class进行过滤,筛选出指定的class,nms仅仅对指定的class进行nms。

# Check shapen = x.shape[0]  # number of boxesif not n:  # no boxescontinueelif n > max_nms:  # excess boxesx = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

判断x中是否含有box,如果没有box,则进行下一张图片的 nms,如果box的个数超过最大nms个数则按照confidence值降序排列,取出最大nms个数的box做nms。

# Batched NMSc = x[:, 5:6] * (0 if agnostic else max_wh)  # classesboxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scoresi = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS

c = x[:, 5:6] * (0 if agnostic else max_wh)这行代码是多类别中应用NMS具体意义可以参考https://blog.csdn.net/flyfish1986/article/details/119177472。
多类别NMS(非极大值抑制)的处理策略是为了让每个类都能独立执行NMS,在所有的边框上添加一个偏移量。偏移量仅取决于类的ID(也就是x[:, 5:6]),并且足够大,以便来自不同类的框不会重叠。
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
这行取出boxes和scores,boxes添加了偏移量c,不通过类别的偏移量大小不一致。
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
调用torch自带的nms接口实现重叠框的抑制,函数返回的是一个tensor i

Tensor: int64 tensor with the indices of the elements that have been kept
by NMS, sorted in decreasing order of scores

i的意义是,整型64张量,指示被保留的框的index,另外是按照得分(置信度)从高到低排列。

        if i.shape[0] > max_det:  # limit detectionsi = i[:max_det]

判断是否超过最大nms检测个数,如果超过,则去掉置信度低的。

        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrixweights = iou * scores[None]  # box weightsx[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxesif redundant:i = i[iou.sum(1) > 1]  # require redundancy

这部分暂时没懂,后继懂了再更新

output[xi] = x[i]

x[i]是利用nms的结果i,取出所有nms 结果i对应的box,然后将结果保存到xi张图片对应的output里面。

        if (time.time() - t) > time_limit:print(f'WARNING: NMS time limit {time_limit}s exceeded')break  # time limit exceeded

这行是判断nms是否超时,如果超时,则直接跳出for循环,不进行下一张图片的nms

return output

返回结果,结果是存放在output里面了,output 对应的是一个list,list一个元素对应一张图片的nms结果,对于本文举例的二分类,只有一张图片,所以可以查看output[0],确认nms的输出结果,其shape是torch.Size([1892, 6]),如果打印其中一行,则对应的格式如下,也就是nms结果,每一行的意义:

yolov5 nms 源码理解相关推荐

  1. YoloV3-SPP NMS源码详解

    前言 该文链接至:YOLO-V3-SPP详细解析 本文主要讲解在推理阶段YoloV3-SPP的处理,分以下几点: 预处理数据 推理的NMS处理 关于map的计算,ultralytic版本的源码调用了p ...

  2. STM32 USMART源码理解

    本文是基于正点原子USMART源码理解. 该组件由 ALIENTEK 开发提供,功能类似 linux 的 shell(RTT 的 finsh 也属于此类).USMART 最主要 的功能就是通过串口调用 ...

  3. 从hotspot底层对象结构理解锁膨胀升级过程||深入jdk源码理解longadder的分段cas优化机制——分段CAS优化

    深入jdk源码理解longadder的分段cas优化机制 longadder

  4. faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data)

    转载自:faster rcnn源码理解(二)之AnchorTargetLayer(网络中的rpn_data) - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.n ...

  5. faster rcnn的源码理解(一)SmoothL1LossLayer论文与代码的结合理解

    转载自:faster rcnn的源码理解(一)SmoothL1LossLayer论文与代码的结合理解 - 野孩子的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/u ...

  6. TLD(Tracking-Learning-Detection)学习与源码理解之(六)

    TLD(Tracking-Learning-Detection)学习与源码理解之(六) zouxy09@qq.com http://blog.csdn.net/zouxy09 下面是自己在看论文和这些 ...

  7. TLD(Tracking-Learning-Detection)学习与源码理解之(五)

    TLD(Tracking-Learning-Detection)学习与源码理解之(五)   zouxy09@qq.com http://blog.csdn.net/zouxy09 下面是自己在看论文和 ...

  8. TLD(Tracking-Learning-Detection)学习与源码理解之(四)

    TLD(Tracking-Learning-Detection)学习与源码理解之(四) zouxy09@qq.com http://blog.csdn.net/zouxy09 下面是自己在看论文和这些 ...

  9. TLD(Tracking-Learning-Detection)学习与源码理解之(三)

    TLD(Tracking-Learning-Detection)学习与源码理解之(三) zouxy09@qq.com http://blog.csdn.net/zouxy09 下面是自己在看论文和这些 ...

最新文章

  1. 设计模式 之美 -- 简单工厂模式
  2. quartz mysql 初始化_quartz2.3.0(十五)执行、暂停、继续执行、清除,花式操作数据库中持久化的job任务...
  3. PGA Usage Larger than PGA_AGGREGATE_TARGET setting?
  4. Jenkins持续集成环境之tomcat的安装和配置
  5. ZZ flex preloader
  6. composer查看当前镜像取消_Docker image镜像常用命令
  7. 默认轮播,鼠标移入停止,移出继续
  8. 【嵌入式】C语言高级编程-container_of宏(04)
  9. 每天一道LeetCode----从数组中选择若干不连续元素使得总和最大
  10. 有趣 IOS 开展 - block 使用具体解释
  11. 复杂度及圈复杂介绍介绍
  12. 如何在iOS和WatchKit中更改图像tintColor
  13. 重庆理工大学图像处理实验二:图像增强
  14. JDK1.8源码(十一)——java.util.TreeMap类
  15. 2020计算机应用模拟题,2020年函授本科计算机应用基础课后作业、模拟题及答案...
  16. java实现飞机大战小游戏(源码+注释)
  17. UVa 12261 High Score (贪心“向左走,向右走”)
  18. 做读书笔记好用的软件分享
  19. 物联网Wi-Fi,SIP 时代即将来临
  20. 你有哪些“相见恨晚”的UE4学习资料?

热门文章

  1. 名帖10 邓石如 篆书《弟子职》
  2. Python生成动态二维码,运用神库:qrcode
  3. cannot import name ‘DistanceMetric‘ from ‘sklearn.metrics‘
  4. 2019考研上海交通大学823计算机通信网真题回忆
  5. 深信服2018年实习生校园招聘总结
  6. 快码加编——代码生成神器
  7. Selenium学习 - 简介
  8. 深圳市专精特新企业申报条件及各区奖励政策重点介绍,补贴20-200万
  9. SQL一对多,子表关联查询
  10. 百度飞桨AI达人创造营 | Day01:创意是如何诞生的