写在前面:在读了SSD的论文之后,完整的看了一遍SSD的代码,有了许多体会,以此记录自己的学习过程。
论文传送门:SSD: Single Shot MultiBox Detector
大佬复现的代码:link

一、SSD做了什么

Faster-rcnn有什么缺点:
1、由于是two-stage的网络,会先生成一些proposal,因此速度相对来说还是比较慢。而SSD直接作预测,可以达到50-60FPS的检测速度,达到了实时要求。
2、对小目标检测效果比较差。SSD在多尺度特征上做训练,因此不管是大目标,或者是小目标都有着比较好的检测效果。
图片来源此up主:link(up主的理论和代码讲解非常详细,给予了我很大帮助)

1、SSD结构图如下:


主干网络采用的VGG16:
深一些的特征图感受野比较大,因此用来预测大目标,浅一些的特征图感受野比较小,用来预测小目标。关于感受野的一些年内容可以阅读此博客:VGG网络和感受野的理解.

抽出不同特征图来做回归和预测。

2、Default Box的选择

论文中讲解如下:

直观来看就是下图中所示:

一共会在6个特征图层上生成8732个Default Box。
映射回原图

3、预测过程


和faster-rcnn不同的是,这里边界框回归参数预测是4k,而不是4num_classes*k

4、正负样本的选择

在仔细阅读了源代码后,先依据IOU值给每个Default Box分配对应的GT Box,然后里面IOU大于0.5的设置为正样本,同时,与GT Box最匹配的Default Box也设置为正样本(将正样本充分利用起来)。
负样本数是正样本数量的3倍,先将剩余的Default Box依据置信度排序,取排名靠前的对应数量的Default Box设置为负样本即可。

5、损失计算

和Faster-rcnn一样,类别损失利用正负样本,回归损失利用正样本即可。
具体可参考博客:link.

二、代码部分

大佬复现时,Backbone采用的时Resnet50
代码框架如下:

代码相对比较简单,看的过程中主要有两个部分有点疑惑

1、target部分

在读到损失计算部分时,这里target传入的数据比较疑惑
target_box(这里第一个4为batch_size数)

target_label


原来是在transform部分对target进行了操作,给出注释如下:

import random
import torchvision.transforms as t
from torchvision.transforms import functional as F
from src.utils import dboxes300_coco, calc_iou_tensor, Encoder
import torchclass Compose(object):def __init__(self, transforms):self.transforms = transformsdef __call__(self, image, target=None):for trans in self.transforms:image, target = trans(image, target)return image, target# 有些tensor并不是占用一整块内存,而是由不同的数据块组成,
# 而tensor的view()操作依赖于内存是整块的,这时只需要执行contiguous()这个函数,把tensor变成在内存中连续分布的形式。
class ToTensor(object):def __call__(self, image, target):image = F.to_tensor(image).contiguous()return image, targetclass RandomHorizontalFlip(object):def __init__(self, prob=0.5):self.prob = probdef __call__(self, image, target):if random.random() < self.prob:# 水平翻转图片image = image.flip(-1)bbox = target["boxes"]# 水平翻转boxes信息bbox[:, [0, 2]] = 1.0 - bbox[:, [2, 0]]target["boxes"] = bboxreturn image, targetclass SSDCropping(object):# 对图像进行裁减,该方法放在ToTensor之前def __init__(self):self.sample_options = (None,(0.1, None),(0.3, None),(0.5, None),(0.7, None),(0.9, None),(None, None),)# 8732*4,且都是相对坐标,想转化为绝对坐标,只用乘以原图大小即可# 一次传入batch_size张图片self.dboxes = dboxes300_coco()def __call__(self, image, target):while True:mode = random.choice(self.sample_options)if mode is None:return image, target# return跳出def函数# 这里的高宽是图像大小htot, wtot = target['height_width']min_iou, max_iou = modemin_iou = float('-inf') if min_iou is None else min_ioumax_iou = float('+inf') if max_iou is None else max_ioufor _ in range(5):w = random.uniform(0.3, 1.0)h = random.uniform(0.3, 1.0)# 如果一致,则跳过这次循环if w/h < 0.5 or w/h > 2:continue# left 0 ~ wtot - w, top 0 ~ htot - hleft = random.uniform(0, 1.0 - w)top = random.uniform(0, 1.0 - h)right = left + wbottom = top + h# boxes的坐标是在0-1之间的# 应该是裁减出有效目标bboxes = target["boxes"]ious = calc_iou_tensor(bboxes, torch.tensor([[left, top, right, bottom]]))if not ((ious > min_iou) & (ious < max_iou)).all():continue# 这是算出中心坐标xc = 0.5 * (bboxes[:, 0] + bboxes[:, 2])yc = 0.5 * (bboxes[:, 1] + bboxes[:, 3])masks = (xc > left) & (xc < right) & (yc > top) & (yc < bottom)if not masks.any():continuebboxes[bboxes[:, 0] < left, 0] = leftbboxes[bboxes[:, 1] < top, 1] = topbboxes[bboxes[:, 2] < right, 2] = rightbboxes[bboxes[:, 3] < bottom, 3] = bottom# 去除掉中心不在采样范围的GT Boxbboxes = bboxes[masks, :]# 取出GT Box的标签labels = target['labels']labels = labels[masks]# 计算裁减之后的图像大小left_idx = int(left * wtot)top_idx = int(top * htot)right_idx = int(right * wtot)bottom_idx = int(bottom * htot)image = image.crop((left_idx, top_idx, right_idx, bottom_idx))# 调整之后的bboxes坐标信息,也是在0-1之间bboxes[:, 0] = (bboxes[:, 0] - left) / wbboxes[:, 1] = (bboxes[:, 1] - top) / hbboxes[:, 2] = (bboxes[:, 2] - left) / wbboxes[:, 3] = (bboxes[:, 3] - top) / h# 更新crop之后的GT Box坐标信息和标签信息target['boxes'] = bboxestarget['labels'] = labelsreturn image, targetclass Resize(object):def __init__(self, size=(300, 300)):self.resize = t.Resize(size)def __call__(self, image, target):image = self.resize(image)return image, targetclass ColorJitter(object):"""对图像颜色进行随机调整,该方法应该放在ToTensor之前"""def __init__(self, brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05):self.trans = t.ColorJitter(brightness, contrast, saturation, hue)def __call__(self, image, target):image = self.tarns(image)return image, target# 对图像标准化的好处
# 1、提升模型的收敛速度
# 2、提高精度
# 3、防止梯度爆炸
class Normalization(object):def __init__(self, mean=None, std=None):if mean is None:mean = [0.485, 0.456, 0.406]if std is None:std = [0.229, 0.224, 0.225]self.normalize = t.Normalize(mean=mean, std=std)def __call__(self, image, target):image = self.normalize(image)return image, targetclass AssignGTtoDefaultBox(object):def __init__(self):self.default_box = dboxes300_coco()self.encoder = Encoder(self.default_box)# 记录一下单引号和双引号没有区别,不过可以交互使用def __call__(self, image, target):boxes = target['boxes']labels = target["labels"]bboxes_out, labels_out = self.encoder.encode(boxes, labels)target['boxes'] = bboxes_outtarget['labels'] = labels_outreturn image, target

2、读取pth文件的一些理解

    backbone = Backbone()model = SSD300(backbone=backbone, num_classes=num_classes)# model = nn.DataParallel(model)pre_ssd_path = "./src/nvidia_ssdpyt_fp32.pt"if os.path.exists(pre_ssd_path) is False:raise FileNotFoundError("nvidia_ssdpyt_fp32.pt not find in {}".format(pre_ssd_path))pre_model_dict = torch.load(pre_ssd_path, map_location=device)pre_weights_dict = pre_model_dict["model"]# 删除类别预测器权重,注意,回归预测器的权重可以重用,因为不涉及num_classesdel_conf_loc_dict = {}for k, v in pre_weights_dict.items():split_key = k.split(".")if "conf" in split_key:continuedel_conf_loc_dict.update({k: v})missing_keys, unexpected_keys = model.load_state_dict(del_conf_loc_dict, strict=False)# torch.save(model.state_dict(), './aaaaaaaaa.pth')# a = torch.load('./aaaaaaaaa.pth')# c, d = model.load_state_dict(a, strict=False)

由于加载的权重之前训练的类别数和我们现在要预测的类别数不一样,因此我们需要删除掉分类权部分重。
如上图,对于missing_keys, unexpected_keys参数不太理解,调试后发现missing_keys是model里面有的参数,del_conf_loc_dict里面没有的参数。unexpected_keys是del_conf_loc_dict里面有的,model里面没有的参数。

也就是说有14层参数model里有,而del_conf_loc_dict里没有

del_conf_loc_dict里有的,model里也都有

验证如下,我们先保存model的参数文件,再调用上面那个函数,理论上来说c和d都应该是空列表

    backbone = Backbone()model = SSD300(backbone=backbone, num_classes=num_classes)# model = nn.DataParallel(model)pre_ssd_path = "./src/nvidia_ssdpyt_fp32.pt"if os.path.exists(pre_ssd_path) is False:raise FileNotFoundError("nvidia_ssdpyt_fp32.pt not find in {}".format(pre_ssd_path))pre_model_dict = torch.load(pre_ssd_path, map_location=device)pre_weights_dict = pre_model_dict["model"]# 删除类别预测器权重,注意,回归预测器的权重可以重用,因为不涉及num_classesdel_conf_loc_dict = {}for k, v in pre_weights_dict.items():split_key = k.split(".")if "conf" in split_key:continuedel_conf_loc_dict.update({k: v})missing_keys, unexpected_keys = model.load_state_dict(del_conf_loc_dict, strict=False)torch.save(model.state_dict(), './aaaaaaaaa.pth')a = torch.load('./aaaaaaaaa.pth')c, d = model.load_state_dict(a, strict=False)

调试结果如下:

同时要保存或调用整个模型文件,使用函数:

#保存模型
torch.save(model_object,'resnet.pth')
#加载模型
model=torch.load('resnet.pth')

如果只用保存模型参数或则加载模型参数可用使用如下函数:

#将my_resnet模型存储为my_resnet.pth
torch.save(my_resnet.state_dict(),"my_resnet.pth")
#加载resnet,模型存放在my_resnet.pth
my_resnet.load_state_dict(torch.load("my_resnet.pth"))

SSD网络及代码理解相关推荐

  1. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  2. 『TensorFlow』SSD源码学习_其二:基于VGG的SSD网络前向架构

    Fork版本项目地址:SSD 参考自集智专栏 一.SSD基础 在分类器基础之上想要识别物体,实质就是 用分类器扫描整张图像,定位特征位置 .这里的关键就是用什么算法扫描,比如可以将图片分成若干网格,用 ...

  3. mmdetection学习系列(1)——SSD网络

    1. 概述 本文是本人自学mmdetection的第一篇文章,因为最近一段时间在做目标检测相关的内容,为了更好地研究领域内相关知识,特意花了不少时间熟悉mmdetection框架(https://gi ...

  4. 【神经网络】(11) 轻量化网络MobileNetV1代码复现、解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 复现轻量化神经网络模型 MobileNetV1.为了能将神经网络模型用于移动端(手机)和终端(安防监控.无人驾驶)的实时计算,通常这些设备 ...

  5. Deep Learning论文笔记之(五)CNN卷积神经网络代码理解

    Deep Learning论文笔记之(五)CNN卷积神经网络代码理解 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文,但 ...

  6. 3D人脸重建——PRNet网络输出的理解

    前言 之前有款换脸软件不是叫ZAO么,分析了一下,它的实现原理绝对是3D人脸重建,而非deepfake方法,找了一篇3D重建的论文和源码看看.这里对源码中的部分函数做了自己的理解和改写. 国际惯例,参 ...

  7. 【深度学习】SSD网络原理

    SSD网络backbone由VGG16网络的全部卷积层,即到conv5为止,去掉之后的全连接层.如下图: 然后是 conv6:3x3x1024; conv7:1x1x1024; conv8:1x1x2 ...

  8. 目标检测算法SSD用于行人检测(二):训练和测试SSD网络

    将Caltech数据集转化为caffe的输入数据格式LMDB请参考上一篇文章:https://blog.csdn.net/sunshine_zkf/article/details/86173247 前 ...

  9. fishnet:论文阅读与代码理解

    fishnet:论文阅读与代码理解 一.论文概述 二.整体框架 三.代码理解 四.总结 fishnet论文地址:http://papers.nips.cc/paper/7356-fishnet-a-v ...

  10. linux io的cfq代码理解一

    内核版本: 3.10内核. CFQ,即Completely Fair Queueing绝对公平调度器,原理是基于时间片的角度去保证公平,其实如果一台设备既有单队列,又有多队列,既有快速的NVME,又有 ...

最新文章

  1. 看完这篇,code review 谁敢喷你代码写的烂?怼回去!
  2. 鸟哥的Linux私房菜(基础篇)-第二章、 Linux 如何学习(二.5. 重点回顾)
  3. 【数据结构与算法】之深入解析“买卖股票的最好时机III”的求解思路与算法示例
  4. 聊聊Unity项目管理的那些事:Git-flow和Unity
  5. 面经——Linux相关
  6. 华为手机媒体音量自动静音_华为手机音量键隐藏着四个功能,80%的人只知道第一个!...
  7. python解常微分方程龙格库_求解常微分方程组初值问题的龙格库塔法分析及其C代码...
  8. c语言如何输出10个空格,新人提问:如何将输出时每行最后一个空格删除
  9. 文档隐写溯源技术分析
  10. 主动降噪ANC(Active Noise Control)
  11. 微信公众号吸粉8大策略,实战运营指南
  12. WPF动画——故事板(Storyboard)
  13. SpringBoot海景房出租管理系统+代码讲解
  14. 什么是代理(Proxy)?
  15. KVM(多电脑切换器)
  16. 一款英国折叠车如何在中国城市流行?
  17. 把Maven本地仓库修改为阿里云仓库
  18. 9、Vue自定义指令
  19. armbian 斐讯n1_记录一下斐讯N1盒子刷Armbian的各种坑
  20. java自动违例设计,java学习记录(二):java的违例控制机制

热门文章

  1. Linux安装MySQL5.7
  2. tomato(番茄)固件的简单设置截图
  3. QComboBox自定义设置
  4. 【数据库】SQL语句之修改语句(INSERT,UPDATE,DELETE)
  5. 微信小程序毕业设计 就餐预约点餐小程序毕业设计
  6. matlab 求虚数的反正切,matlab中的反正切函数
  7. 找不到该项目(无法删除文件)
  8. 保姆级教学——虚拟机器人平台vrep(coppeliaSim)的机器人平台搭建
  9. linux 内核代码阅读工具,linux内核源码阅读工具
  10. Linux: dnf