参考链接: 睿智的目标检测23——Pytorch搭建SSD目标检测平台
参考链接: 参考源代码: ssd_layers.py

from __future__ import division
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Function
from torch.autograd import Variable
from math import sqrt as sqrt
from itertools import product as product
import numpy as np
from utils.box_utils import decode, nms
from utils.config import Configclass Detect(Function):def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):self.num_classes = num_classesself.background_label = bkg_labelself.top_k = top_kself.nms_thresh = nms_threshif nms_thresh <= 0:raise ValueError('nms_threshold must be non negative.')self.conf_thresh = conf_threshself.variance = Config['variance']def forward(self, loc_data, conf_data, prior_data):loc_data = loc_data.cpu()conf_data = conf_data.cpu()num = loc_data.size(0)  # batch sizenum_priors = prior_data.size(0)output = torch.zeros(num, self.num_classes, self.top_k, 5)conf_preds = conf_data.view(num, num_priors,self.num_classes).transpose(2, 1)# 对每一张图片进行处理for i in range(num):# 对先验框解码获得预测框decoded_boxes = decode(loc_data[i], prior_data, self.variance)conf_scores = conf_preds[i].clone()for cl in range(1, self.num_classes):# 对每一类进行非极大抑制c_mask = conf_scores[cl].gt(self.conf_thresh)scores = conf_scores[cl][c_mask]if scores.size(0) == 0:continuel_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)boxes = decoded_boxes[l_mask].view(-1, 4)# 进行非极大抑制ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)output[i, cl, :count] = \torch.cat((scores[ids[:count]].unsqueeze(1),boxes[ids[:count]]), 1)flt = output.contiguous().view(num, -1, 5)_, idx = flt[:, :, 0].sort(1, descending=True)_, rank = idx.sort(1)flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)return output# Config = {#     'num_classes': 3, # 'num_classes': 21,
#     'feature_maps': [38, 19, 10, 5, 3, 1],
#     'min_dim': 300,
#     'steps': [8, 16, 32, 64, 100, 300],
#     'min_sizes': [30, 60, 111, 162, 213, 264],
#     'max_sizes': [60, 111, 162, 213, 264, 315],
#     'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
#     'variance': [0.1, 0.2],
#     'clip': True,
#     'name': 'VOC',
# }class PriorBox(object):def __init__(self, cfg):super(PriorBox, self).__init__()self.image_size = cfg['min_dim']  # 300self.num_priors = len(cfg['aspect_ratios'])  # 6self.variance = cfg['variance'] or [0.1]  # [0.1, 0.2]self.feature_maps = cfg['feature_maps']  # [38, 19, 10, 5, 3, 1]self.min_sizes = cfg['min_sizes']  # [30, 60, 111, 162, 213, 264]self.max_sizes = cfg['max_sizes']  # [60, 111, 162, 213, 264, 315]self.steps = cfg['steps']  # [8, 16, 32, 64, 100, 300]self.aspect_ratios = cfg['aspect_ratios']  # [[2], [2, 3], [2, 3], [2, 3], [2], [2]]self.clip = cfg['clip']  # Trueself.version = cfg['name']  # VOCfor v in self.variance:if v <= 0:raise ValueError('Variances must be greater than 0')def forward(self):mean = []for k, f in enumerate(self.feature_maps):  # [38, 19, 10, 5, 3, 1]x,y = np.meshgrid(np.arange(f),np.arange(f))  # 笛卡尔坐标形式 38 x 38x = x.reshape(-1)y = y.reshape(-1)for i, j in zip(y,x):f_k = self.image_size / self.steps[k]  # 300 / [8,16,32,64,100,300] 计算每个网格的像素宽度# 计算网格的中心cx = (j + 0.5) / f_k  # 中心点相对于特征图网格单位的横坐标位置cy = (i + 0.5) / f_k  # 中心点相对于特征图网格单位的纵坐标位置# 求短边s_k = self.min_sizes[k]/self.image_sizemean += [cx, cy, s_k, s_k]# 求长边s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size))mean += [cx, cy, s_k_prime, s_k_prime]# 获得长方形for ar in self.aspect_ratios[k]:  # [[2], [2, 3], [2, 3], [2, 3], [2], [2]]mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)]  # 获得不同宽高比的先验框mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)]  # 获得不同宽高比的先验框# 获得所有的先验框output = torch.Tensor(mean).view(-1, 4)if self.clip:output.clamp_(max=1, min=0)return outputclass L2Norm(nn.Module):def __init__(self,n_channels, scale):super(L2Norm,self).__init__()self.n_channels = n_channelsself.gamma = scale or Noneself.eps = 1e-10self.weight = nn.Parameter(torch.Tensor(self.n_channels))  # 长度是512的权重 torch.Size([512]) self.reset_parameters()def reset_parameters(self):init.constant_(self.weight,self.gamma)def forward(self, x):norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps  # torch.Size([4, 1, 38, 38])#x /= normx = torch.div(x,norm)  # torch.Size([4, 512, 38, 38])out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * xreturn out  # torch.Size([4, 512, 38, 38])

SSD目标检测算法生成8732个先验框相关推荐

  1. 基于Grad-CAM与KL损失的SSD目标检测算法

    基于Grad-CAM与KL损失的SSD目标检测算法 人工智能技术与咨询 来源:<电子学报>,作者侯庆山等 摘 要: 鉴于Single Shot Multibox Detector (SSD ...

  2. 基于神经网络的目标检测论文之目标检测方法:改进的SSD目标检测算法

    4.2 改进的SSD 上一章我们了解到,物体识别检测算法是在传统CNN算法基础上加上目标区域建议策略和边框回归算法得到的.前辈们的工作主要体现在目标区域建议的改进策略上,从最开始的穷举建议框,到划分图 ...

  3. (20)目标检测算法之YOLOv5计算预选框、详解anchor计算

    目标检测算法之YOLOv5计算预选框.详解anchor计算 单节段目标检测算法中:预选框的设定直接影响最终的检测精度 众所周知,yolov5中采用自适应调整预选框anchor的大小,但万事开头难,配置 ...

  4. 层与特征融合_【计算机系统应用】(第122期)感受野特征增强的 SSD 目标检测算法...

    点击上方"蓝字",关注我们吧! 目标检测是计算机视觉领域的一项重要任务, 是 生活中如实例分割[1] , 面部分析[2] , 汽车自动驾驶[3].视 频分析[4] 等各种视觉应用的 ...

  5. [RCNN]-[YOLO]-[SSD]目标检测算法

    原文链接:http ://chuansong.me/n/353443351445 转载自深度学习大讲堂公众号    开始本文内容之前,我们先来看一下上边左侧的这张图,从图中你看到了什么物体?他们在什么 ...

  6. SSD目标检测算法原理(上)

    目录 一.目标检测概述 1.1 项目演示介绍 1.2 图片识别背景 1.3 目标检测定义 二.目标检测算法原理 2.1 任务描述 2.2 目标检测算法必备基础 2.3目标检测算法模型输出 目标检测 - ...

  7. 一文弄懂SSD目标检测算法

    SSD是YOLO之后又一个引人注目的目标检测结构,它沿用了YOLO中直接回归 bbox和分类概率的方法,同时又参考了Faster R-CNN,大量使用anchor来提升识别准确度.通过把这两种结构相结 ...

  8. 睿智的目标检测23——Pytorch搭建SSD目标检测平台

    睿智的目标检测23--Pytorch搭建SSD目标检测平台 学习前言 什么是SSD目标检测算法 源码下载 SSD实现思路 一.预测部分 1.主干网络介绍 2.从特征获取预测结果 3.预测结果的解码 4 ...

  9. Pytorch搭建SSD目标检测平台

    学习前言 什么是SSD目标检测算法 源码下载 SSD实现思路 一.预测部分 1.主干网络介绍 2.从特征获取预测结果 3.预测结果的解码 4.在原图上进行绘制 二.训练部分 1.真实框的处理 2.利用 ...

  10. OpenMMLab 实战营打卡 - 第 四 课 目标检测算法基础

    (四)计算机视觉之目标检测算法基础 目录 前言 一.目标检测是什么? 1.目标检测VS图像分类 2.检测最朴素方法--滑窗 Sliding Window (1)滑窗的效率问题 (2)改进 3.目标检测 ...

最新文章

  1. 3d打印主要的切片参数类型_3D打印机切片参数详情说明
  2. 并发队列、线程池、锁
  3. eclipse 使用 maven 无法编译 jsp 文件的问题
  4. 企业ERP选型的两难困惑
  5. Jenkins持续集成项目搭建与实践——基于Python Selenium自动化测试(自由风格)
  6. 普华永道报告:三波自动化浪潮将依次出现,人类工作将显著受到影响
  7. Apache Nutch 1.3 学习笔记十一(页面评分机制 LinkRank 介绍)
  8. 洛谷P1156 垃圾陷阱【线性dp】
  9. STM32F1系列与STM32F4系列的GPIO
  10. 虚拟机实现远程桌面连接
  11. shopex4.8.5 php5.6,ShopEx(网上商店系统)
  12. Unity webGl 鼠标手指触屏控制相机围绕物体 360度旋转
  13. 010-flutter dart代码后台执行,没有界面的情况下
  14. 【2021版】想要专升本你不得不看的全干货_吐血整理_专升本_计算机文化基础(二)
  15. 日语学习的实用网址大全
  16. 输入3×4的矩阵 将值为负的位置和值输出
  17. modelmapper属性匹配问题分析
  18. 想进BAT一线互联网大厂,该怎么准备技术面试?一位6年老Android的面经总结(附300+面试题)
  19. 数仓搭建——DWD层
  20. 计算机原理与应用 第二章——ARM处理器

热门文章

  1. [乐意黎转载]从零开始学习jQuery (十) jQueryUI常用功能实战
  2. 卷积编码verilog实现
  3. 清除电脑多余垃圾--清除垃圾.bat文件 附保姆级操作步骤
  4. 没有他的帅气,也要像他那般努力!(转载)
  5. SQLServer中sp_Who、sp_Who2和sp_WhoIsActive介绍和查看监视运行
  6. android 去掉google搜索,Android 7.1 去掉桌面上的谷歌搜索框
  7. c语言中%d %%d %%%d和\\%d的区别
  8. 14种最佳的PHP帮助台脚本和5种免费选项
  9. 轻巧和实用并存——360安全卫士极速版试用报告
  10. vue+echarts绘制中国地图,动态配置省份颜色和城市标点