预测流程

  1. 神经网络预测
  2. 剔除低分数值的框
  3. NMS处理

代码展示

根据yolo-v3的输出结构可以看出,最终网络有三个尺度的输出。以输入为416x416x3为例,输出分别为13x13,26x26,52x52,每个grid cell生成3个预测框,总共产生框的数量为 10647 = 3 × (13x13 + 26x26 + 52x52).

    def evaluate(self):predicted_dir_path = './mAP/predicted'ground_truth_dir_path = './mAP/ground-truth'if os.path.exists(predicted_dir_path):shutil.rmtree(predicted_dir_path)if os.path.exists(ground_truth_dir_path):shutil.rmtree(ground_truth_dir_path)if os.path.exists(self.write_image_path):shutil.rmtree(self.write_image_path)os.mkdir(predicted_dir_path)os.mkdir(ground_truth_dir_path)os.mkdir(self.write_image_path)with open(self.annotation_path, 'r') as annotation_file:for num, line in enumerate(annotation_file):annotation = line.strip().split()image_path = annotation[0]print(annotation[0])print(annotation)image_name = image_path.split('/')[-1]print('image_path: ', image_path)image = cv2.imread(image_path)bbox_data_gt = np.array([list(map(float, box.split(','))) for box in annotation[1:]])if len(bbox_data_gt) == 0:bboxes_gt = []classes_gt = []else:bboxes_gt, classes_gt = bbox_data_gt[:, :4], bbox_data_gt[:, 4]# print(bboxes_gt)# print(classes_gt)ground_truth_path = os.path.join(ground_truth_dir_path, str(num) + '.txt')# 处理标签print('=> ground truth of %s:' % image_name)num_bbox_gt = len(bboxes_gt)with open(ground_truth_path, 'w') as f:for i in range(num_bbox_gt):class_name = self.classes[classes_gt[i]]xmin, ymin, xmax, ymax = list(map(str, bboxes_gt[i]))# print(xmin, ymin, xmax, ymax)bbox_mess = ' '.join([class_name, xmin, ymin, xmax, ymax]) + '\n'# print(bbox_mess)f.write(bbox_mess)print('\t' + str(bbox_mess).strip())# 预测图片print('=> predict result of %s:' % image_name)predict_result_path = os.path.join(predicted_dir_path, str(num) + '.txt')bboxes_pr = self.predict(image)if self.write_image:image = utils.draw_bbox(image, bboxes_pr, show_label=self.show_label)cv2.imwrite(self.write_image_path + image_name, image)with open(predict_result_path, 'w') as f:for bbox in bboxes_pr:coor = np.array(bbox[:4], dtype=np.int32)score = bbox[4]class_ind = int(bbox[5])class_name = self.classes[class_ind]score = '%.4f' % scorexmin, ymin, xmax, ymax = list(map(str, coor))bbox_mess = ' '.join([class_name, score, xmin, ymin, xmax, ymax]) + '\n'f.write(bbox_mess)print('\t' + str(bbox_mess).strip())print('===============================================================')

PostProcess

如下面代码所示,主要剔除越界以及低分数的框。

def postprocess_boxes(pred_bbox, org_img_shape, input_size, score_threshold):valid_scale = [0, np.inf]pred_bbox = np.array(pred_bbox)# pred_bbox:[16128, 6]pred_xywh = pred_bbox[:, 0:4]  # 获取坐标pred_conf = pred_bbox[:, 4]  # 获取置信度pred_prob = pred_bbox[:, 5:]  # 获取分类概率# # (1) (x, y, w, h) --> (xmin, ymin, xmax, ymax)pred_coor = np.concatenate([pred_xywh[:, :2] - pred_xywh[:, 2:] * 0.5,pred_xywh[:, :2] + pred_xywh[:, 2:] * 0.5], axis=-1)# # (2) (xmin, ymin, xmax, ymax) -> (xmin_org, ymin_org, xmax_org, ymax_org)org_h, org_w = org_img_shaperesize_ratio = min(input_size / org_w, input_size / org_h)dw = (input_size - resize_ratio * org_w) / 2dh = (input_size - resize_ratio * org_h) / 2pred_coor[:, 0::2] = 1.0 * (pred_coor[:, 0::2] - dw) / resize_ratiopred_coor[:, 1::2] = 1.0 * (pred_coor[:, 1::2] - dh) / resize_ratioprint(pred_coor)print(pred_coor.shape)# # (3) clip some boxes those are out of range# 坐标越界:左上角坐标大于右下角坐标pred_coor = np.concatenate([np.maximum(pred_coor[:, :2], [0, 0]),np.minimum(pred_coor[:, 2:], [org_w - 1, org_h - 1])], axis=-1)invalid_mask = np.logical_or((pred_coor[:, 0] > pred_coor[:, 2]), (pred_coor[:, 1] > pred_coor[:, 3]))pred_coor[invalid_mask] = 0# # (4) discard some invalid boxesbboxes_scale = np.sqrt(np.multiply.reduce(pred_coor[:, 2:4] - pred_coor[:, 0:2], axis=-1))print(bboxes_scale)print(bboxes_scale.shape)# 保证边界框的宽和高在(0,inf)之间, inf为无限大的值scale_mask = np.logical_and((valid_scale[0] < bboxes_scale), (bboxes_scale < valid_scale[1]))# # (5) discard some boxes with low scoresclasses = np.argmax(pred_prob, axis=-1)s = pred_prob[np.arange(len(pred_coor)), classes]print(s.shape)scores = pred_conf * pred_prob[np.arange(len(pred_coor)), classes]score_mask = scores > score_thresholdmask = np.logical_and(scale_mask, score_mask)coors, scores, classes = pred_coor[mask], scores[mask], classes[mask]return np.concatenate([coors, scores[:, np.newaxis], classes[:, np.newaxis]], axis=-1)

NMS处理

具体流程如下:

  1. 判断边界框的数目是否大于0,如果不是则结束迭代;
  2. 按照置信度得分排序选出评分最大的边界框 A 并取出;
  3. 计算边界框 A 与剩下所有边界框的 iou 并剔除那些 iou 值高于阈值的边界框,重复上述步骤;
def nms(bboxes, iou_threshold, sigma=0.3, method='nms'):""":param bboxes: (xmin, ymin, xmax, ymax, score, class)Note: soft-nms, https://arxiv.org/pdf/1704.04503.pdfhttps://github.com/bharatsingh430/soft-nms"""# 判定图片中有多少类别classes_in_img = list(set(bboxes[:, 5]))print(classes_in_img)best_bboxes = []for cls in classes_in_img:cls_mask = (bboxes[:, 5] == cls)print(cls_mask)cls_bboxes = bboxes[cls_mask]# 判断边界框的数量是否大于0while len(cls_bboxes) > 0:# 选出对应类别的最大的评分框A,置信度最高的框max_ind = np.argmax(cls_bboxes[:, 4])# 将边界框 A 取出并剔除best_bbox = cls_bboxes[max_ind]best_bboxes.append(best_bbox)cls_bboxes = np.concatenate([cls_bboxes[: max_ind], cls_bboxes[max_ind + 1:]])# 计算剩余框与A框的IOU,并剔除高于阈值的框iou = bboxes_iou(best_bbox[np.newaxis, :4], cls_bboxes[:, :4])weight = np.ones((len(iou),), dtype=np.float32)assert method in ['nms', 'soft-nms']if method == 'nms':iou_mask = iou > iou_thresholdweight[iou_mask] = 0.0if method == 'soft-nms':weight = np.exp(-(1.0 * iou ** 2 / sigma))cls_bboxes[:, 4] = cls_bboxes[:, 4] * weightscore_mask = cls_bboxes[:, 4] > 0.cls_bboxes = cls_bboxes[score_mask]return best_bboxes

NMS举例

最后所有取出来的边界框 A 就是我们想要的。不妨举个简单的例子:假如5个边界框及评分为: A: 0.9,B: 0.08,C: 0.8, D: 0.6,E: 0.5,设定的评分阈值为 0.3,计算步骤如下。

  • 步骤1: 边界框的个数为5,满足迭代条件;
  • 步骤2: 按照 socre 排序选出评分最大的边界框 A 并取出;
  • 步骤3: 计算边界框 A 与其他 4 个边界框的 iou,假设得到的 iou 值为:B: 0.1,C: 0.7, D: 0.02, E: 0.09, 剔除边界框 C;
  • 步骤4: 现在只剩下边界框 B、D、E,满足迭代条件;
  • 步骤5: 按照 socre 排序选出评分最大的边界框 D 并取出;
  • 步骤6: 计算边界框 D 与其他 2 个边界框的 iou,假设得到的 iou 值为:B: 0.06,E: 0.8,剔除边界框 E;
  • 步骤7: 现在只剩下边界框 B,满足迭代条件;
  • 步骤8: 按照 socre 排序选出评分最大的边界框 B 并取出;
  • 步骤9: 此时边界框的个数为零,结束迭代。

最后我们得到了边界框 A、B、D,但其中边界框 B 的评分非常低,这表明该边界框是没有物体的,因此应当抛弃掉。

YOLO-V3代码解析系列(六) —— 网络预测(evaluate.py)相关推荐

  1. YOLO-V5 算法和代码解析系列 —— 学习路线规划综述

    目录标题 为什么学习 YOLO-V5 ? 博客文章列表 面向对象 开源项目学习方法 预备知识 项目目录结构 为什么学习 YOLO-V5 ? 算法性能:与YOLO系列(V1,V2,V3,V4)相比,YO ...

  2. 目标检测Tensorflow:Yolo v3代码详解 (2)

    目标检测Tensorflow:Yolo v3代码详解 (2) 三.解析Dataset()数据预处理部分 四. 模型训练 yolo_train.py 五. 模型冻结 model_freeze.py 六. ...

  3. 探索 YOLO v3 实现细节 - 第6篇 预测 (完结)

    YOLO,即You Only Look Once的缩写,是一个基于卷积神经网络(CNN)的物体检测算法.而YOLO v3是YOLO的第3个版本,即YOLO.YOLO 9000.YOLO v3,检测效果 ...

  4. YOLO v3算法解析

    论文:YOLOv3: An Incremental Improvement 论文地址:https://pjreddie.com/media/files/papers/YOLOv3.pdf YOLO系列 ...

  5. 对象检测目标小用什么模型好_自动驾驶目标检测- YOLO v3 深入解析

    从2016年 Joseph Redmon 发布第一代YOLO开始,YOLO已经更新四代了,凭借着在一个网络模型中完成对图像中所有对象边界框和类别预测的独特创新,成为现今使用最广泛也是最快的对象检测算法 ...

  6. YOLO v3代码学习

    本人使用的版本是https://github.com/AlexeyAB/darknet 源码在darknet-master\src目录下 还记得我们用YOLO v3训练自己的数据集的过程,控制台下使用 ...

  7. 经典卷积模型(四)GoogLeNet-Inception(V3)代码解析

    Inception-V3 网络主干依旧由Inception.辅助分类器构成,其中Inception有六类. BasicConv2d 基本卷积模块 BasicConv2d为带BN+relu的卷积. cl ...

  8. yolo系列之yolo v3【深度解析】——讲的挺好,原作者厉害的

    版权申明:转载和引用图片,都必须经过书面同意.获得留言同意即可 本文使用图片多为本人所画,需要高清图片可以留言联系我,先点赞后取图 这篇博文比较推荐的yolo v3代码是qwe的keras版本,复现比 ...

  9. Darknet53(YOLO V3骨干网络)

    YOLO V3算法使用的骨干网络是Darknet53.Darknet53网络的具体结构如图所示,在ImageNet图像分类任务上取得了很好的成绩.在检测任务中,将图中C0后面的平均池化.全连接层和So ...

最新文章

  1. SQL Server 数据库备份
  2. 【DocFX文档翻译】DocFX 入门 (Getting Started with DocFX)
  3. MySQL学习(四、子查询)
  4. 学习Vue.js实战(一)
  5. 修改场景默认pawn的方法
  6. rgss加密文件解包器_Unity AssetBundle高效加密案例分享
  7. SecureWatch和人工智能为疫情期间更安全有效地监控房地产开发提供助力
  8. 次世代的会话管理项目 Spring Session
  9. 怎样判断ajax请求,如何判断一个请求为ajax请求?
  10. UE4学习日记(十一)实现简单的御剑(板)飞行功能
  11. 干货分享|E-prime 3入门手册
  12. LCD12864 并口和串口通用程序
  13. 如何快速获取设备ip地址
  14. Python自动化操作word--批量替换word文档中的文字
  15. 分布式文件存储:FastDFS简单使用与原理分析
  16. nginx中配置不输入端口(指定地址)访问项目的方法
  17. 基于java的奖学金评定系统设计与实现
  18. 安装计算机主板时应注意的问题,电脑DIY:电脑主板的安装以及注意事项
  19. Matlab报错错误使用symengine
  20. 【软件测试】软件测试分类

热门文章

  1. Primary主类和Catagory分类都存在相同事件
  2. 【bzoj1050】 旅行comf
  3. Redis(设置失效时间,RedisDesktopManger远程管理工具)
  4. Eclipse插件Target Management (RSE)
  5. Origin图复制到Word后有大片空白
  6. webscraper多页爬取_爬虫工具实战篇(Web Scraper)- 京东商品信息爬取(原创)
  7. 定量分析双花(双重支付)问题
  8. 视频怎么做GIF表情包?教你一键生成gif动图
  9. 未明学院:12个惊艳的数据可视化经典案例
  10. pcap头文件位置 Linux,pcap文件头的组织格式