"Learning Efficient Object Detection Models with Knowledge Distillation"这篇文章通过知识蒸馏(Knowledge Distillation)与Hint指导学习(Hint Learning),提升了主干精简的多分类目标检测网络的推理精度(文章以Faster RCNN为例),例如Faster RCNN-Alexnet、Faster-RCNN-VGGM等,具体框架如下图所示:

教师网络的暗知识提取分为三点:中间层Feature Maps的Hint;RPN/RCN中分类层的暗知识;以及RPN/RCN中回归层的暗知识。具体如下:

具体指导学生网络学习时,RPN与RCN的分类损失由分类层softmax输出与hard target的交叉熵loss、以及分类层softmax输出与soft target的交叉熵loss构成:

由于检测器需要鉴别的不同类别之间存在样本不均衡(imbalance),因此在L_soft中需要对不同类别的交叉熵分配不同的权重,其中背景类的权重为1.5(较大的比例),其他分类的权重均为1.0:

RPN与RCN的回归损失由正常的smooth L1 loss、以及文章所定义的teacher bounded regression loss构成:

其中Ls_L1表示正常的smooth L1 loss,Lb表示文章定义的teacher bounded regression loss。当学生网络的位置回归与ground truth的L2距离超过教师网络的位置回归与ground truth的L2距离、且大于某一阈值时,Lb取学生网络的位置回归与ground truth之间的L2距离,否则Lb置0。

Hint learning需要计算教师网络与学生网络中间层输出的Feature Maps之间的L2 loss,并且在学生网络中需要添加可学习的适配层(adaptation layer),以确保guided layer输出的Feature Maps与教师网络输出的Hint维度一致:

通过知识蒸馏、Hint指导学习,提升了精简网络的泛化性、并有助于加快收敛,最后取得了良好的实验结果,具体见文章实验部分。

以SSD为例,KD loss与Teacher bounded L2 loss设计如下:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..box_utils import match, log_sum_expeps = 1e-5def KL_div(p, q, pos_w, neg_w):p = p + epsq = q + epslog_p = p * torch.log(p / q)log_p[:,0] *= neg_wlog_p[:,1:] *= pos_wreturn torch.sum(log_p)class MultiBoxLoss(nn.Module):def __init__(self, num_classes, overlap_thresh, prior_for_matching,bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,cfg, use_gpu=True, neg_w=1.5, pos_w=1.0, Temp=1., reg_m=0.):super(MultiBoxLoss, self).__init__()self.use_gpu = use_gpuself.num_classes = num_classes                   # 21self.threshold = overlap_thresh                  # 0.5self.background_label = bkg_label                # 0self.encode_target = encode_target               # Falseself.use_prior_for_matching = prior_for_matching # Trueself.do_neg_mining = neg_mining                  # Trueself.negpos_ratio = neg_pos                      # 3self.neg_overlap = neg_overlap                   # 0.5self.variance = cfg['variance']# soft-target lossself.neg_w = neg_wself.pos_w = pos_wself.Temp  = Tempself.reg_m = reg_mdef forward(self, predictions, pred_t, targets):"""Multibox LossArgs:predictions (tuple): A tuple containing loc preds, conf preds,and prior boxes from SSD net.conf shape: torch.size(batch_size,num_priors,num_classes)loc shape: torch.size(batch_size,num_priors,4)priors shape: torch.size(num_priors,4)pred_t (tuple): teacher's predictionstargets (tensor): Ground truth boxes and labels for a batch,shape: [batch_size,num_objs,5] (last idx is the label)."""loc_data, conf_data, priors = predictionsnum = loc_data.size(0)priors = priors[:loc_data.size(1), :]num_priors = (priors.size(0))num_classes = self.num_classes# predictions of teachersloc_teach1, conf_teach1 = pred_t[0]# match priors (default boxes) and ground truth boxesloc_t = torch.Tensor(num, num_priors, 4)conf_t = torch.LongTensor(num, num_priors)for idx in range(num):truths = targets[idx][:, :-1].datalabels = targets[idx][:, -1].datadefaults = priors.datamatch(self.threshold, truths, defaults, self.variance, labels,loc_t, conf_t, idx)# wrap targetswith torch.no_grad():if self.use_gpu:loc_t = loc_t.cuda(non_blocking=True)conf_t = conf_t.cuda(non_blocking=True)pos = conf_t > 0 # (1, 0, 1, ...)num_pos = pos.sum(dim=1, keepdim=True) # [num, 1], number of positives# Localization Loss (Smooth L1)# Shape: [batch,num_priors,4]pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) # [batch,num_priors,1] before expand_asloc_p = loc_data[pos_idx].view(-1, 4)loc_t = loc_t[pos_idx].view(-1, 4)loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)# knowledge transfer for loc regression# teach1loc_teach1_p = loc_teach1[pos_idx].view(-1, 4)l2_dis_s = (loc_p - loc_t).pow(2).sum(1)l2_dis_s_m = l2_dis_s + self.reg_ml2_dis_t = (loc_teach1_p - loc_t).pow(2).sum(1)l2_num = l2_dis_s_m > l2_dis_tl2_loss_teach1 = l2_dis_s[l2_num].sum()l2_loss = l2_loss_teach1# Compute max conf across batch for hard negative miningbatch_conf = conf_data.view(-1, self.num_classes)loss_c = log_sum_exp(batch_conf.float()) - batch_conf.gather(1, conf_t.view(-1, 1)).float()# Hard Negative Miningloss_c[pos.view(-1, 1)] = 0loss_c = loss_c.view(num, -1)#loss_c[pos] = 0  # filter out pos boxes for now_, loss_idx = loss_c.sort(1, descending=True)_, idx_rank = loss_idx.sort(1)num_pos = pos.long().sum(1, keepdim=True)num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)neg = idx_rank < num_neg.expand_as(idx_rank)# Confidence Loss Including Positive and Negative Examples# CrossEntropy losspos_idx = pos.unsqueeze(2).expand_as(conf_data) # [batch,num_priors,cls]neg_idx = neg.unsqueeze(2).expand_as(conf_data)conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)targets_weighted = conf_t[(pos+neg).gt(0)]loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)# soft loss for Knowledge Distillation# teach1conf_p_teach = conf_teach1[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)pt = F.softmax(conf_p_teach/self.Temp, dim=1)if self.neg_w > 1.:ps = F.softmax(conf_p/self.Temp, dim=1)soft_loss1 = KL_div(pt, ps, self.pos_w, self.neg_w) * (self.Temp**2)else:ps = F.log_softmax(conf_p/self.Temp, dim=1)soft_loss1 = nn.KLDivLoss(size_average=False)(ps, pt) * (self.Temp**2)soft_loss = soft_loss1# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / NN = num_pos.data.sum().float()loss_l = loss_l.float()loss_c = loss_c.float()loss_l /= Nloss_c /= Nl2_loss /= Nsoft_loss /= Nreturn loss_l, loss_c, soft_loss, l2_loss

将Bounded Regression Loss进一步抽象出来:

def bounded_regress_loss(landmark_gt, landmarks_t, landmarks_s, reg_m=0.5, br_alpha=0.05
):""" Calculate the bounded_regress_loss for KD."""l2_dis_s = (landmark_gt - landmarks_s).pow(2).sum(1)l2_dis_s_m = l2_dis_s + reg_ml2_dis_t = (landmark_gt - landmarks_t).pow(2).sum(1)br_loss = l2_dis_s[l2_dis_s_m > l2_dis_t].sum()return br_loss * br_alpha

Paper地址:https://papers.nips.cc/paper/6676-learning-efficient-object-detection-models-with-knowledge-distillation.pdf

PyTorch版SSD:https://github.com/amdegroot/ssd.pytorch

目标检测网络的知识蒸馏相关推荐

  1. 目标检测中的知识蒸馏方法

    目标检测中的知识蒸馏方法 知识蒸馏 (Knowledge Distillation KD) 是模型压缩(轻量化)的一种有效的解决方案,这种方法可以使轻量级的学生模型获得繁琐的教师模型中的知识.知识蒸馏 ...

  2. 【文献阅读】结合对抗网络和知识蒸馏,对多模态源的遥感图像分类(S. Pande等人,ICCV,2019)

    一.背景 文章题目:<An Adversarial Approach to Discriminative Modality Distillation for Remote Sensing Ima ...

  3. 一种投影法的点云目标检测网络

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 文章导读 本文来源于早期的一篇基于投影法的三维目标检测文章<An Euler-Region-Pr ...

  4. NeurIPS 2020 | 基于协同集成与分发的协同显著性目标检测网络

    论文题目:CoADNet: Collaborative Aggregation-and-Distribution Networks for Co-Salient Object Detection 论文 ...

  5. retinanet 部署_RetinaNet: Focal loss在目标检测网络中的应用

    介绍 RetinaNet是2018年Facebook AI团队在目标检测领域新的贡献.它的重要作者名单中Ross Girshick与Kaiming He赫然在列.来自Microsoft的Sun Jia ...

  6. 丢弃Transformer!旷视和西安交大提出基于FCN的端到端目标检测网络

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 本文作者:王剑锋  | 编辑:Amusi https://zhuanlan.zhihu.com ...

  7. 计算机视觉算法——目标检测网络总结

    计算机视觉算法--目标检测网络总结 计算机视觉算法--目标检测网络总结 1. RCNN系列 1.1 RCNN 1.1.1 关键知识点--网络结构及特点 1.1.2 关键知识点--RCNN存在的问题 1 ...

  8. 两阶段3D目标检测网络 SIENet: Spatial Information Enhancement Network for 3D Object Detection from Point Cloud

    本文介绍一篇两阶段的3D目标检测网络:SIENet. 这里重点是理解本文提出的 Hybrid-Paradigm Region Proposal Network 和 Spatial Informatio ...

  9. 详解两阶段3D目标检测网络 Voxel R-CNN:Towards High Performance Voxel-based 3D Object Detection

    本文介绍一篇两阶段的3D目标检测网络:Voxel R-CNN,论文已收录于AAAI 2021. 这里重点是理解本文提出的 Voxel RoI pooling. 论文链接为:https://arxiv. ...

  10. 详解两阶段3D目标检测网络PVRCNN:Point-Voxel Feature Set Abstraction for 3D Object Detection

    在<动手学无人驾驶(4):基于激光雷达点云数据3D目标检测>一文中介绍了3D目标检测网络PointRCNN.今天介绍该作者新提出的3D检测模型:PVRCNN,论文已收录于CVPR2020. ...

最新文章

  1. Unity3dShader_边缘发光效果
  2. 了解 JavaScript (4)– 第一个 Web 应用程序
  3. 坚持c++,真正掌握c++(4)
  4. PHP 循环引用的问题
  5. c#之多线程之为所欲为
  6. [python] 线程锁
  7. [C#]方法示例:判断是否闰年
  8. 中南大学 科学计算和MATLAB 初级语言学习01_02
  9. 0017-Spark的HistoryServer不能查看到所有历史作业分析
  10. 为什么找不到使用rem的网站
  11. matplotlib 设置中文字体
  12. pdf转word工具内含注册码【pdf转word】
  13. 国科大学习资料--形式语言与自动机理论(姚刚)-2020期末考试题
  14. mac上利用openssl命令进行软件安装包的哈希校验
  15. 微信小程序云开发之云函数与本地数据库获取数据
  16. 导出表钩子之EAT HOOK解析
  17. 【浏览器修改请求头】该地址不支持在浏览器打开,如需访问, 请使用微信扫描下方二维码
  18. 机器学习、计算机视觉神犇/大牛主页
  19. STC12C5A60S2软件模式SPI读取DS1302时钟实时显示在1602
  20. 弱网测试(ios手机自带)

热门文章

  1. Scrum板与Kanban如何抉择?kdliihoap板与按照xhvrcr
  2. PIC单片机提示 No valid installed HI-TECH compiler drivers
  3. 排列组合 C语言函数,排列组合(C递归版)
  4. Linux中rps/rfs的原理及实现
  5. 【C++ Primer 第五版】序言+前言
  6. Axure RP 9格式刷使用说明【教程三】
  7. Unity使用VS2019打开代码出现不兼容的解决方法
  8. 图片默认底部3px缝隙
  9. 市场竞争力法则:以小博大,虽败犹荣
  10. 疫情之下,这10个技巧助你开启云面试的正确姿势