由于香港中文大学多媒体实验室在深度学习时代发表论文的数量以及质量都是值得称赞的,故此对于算法工程师好好学习下https://github.com/open-mmlab/mmdetection是很有必要的。这篇论文主要是引入了新的loss解决物体检测中positive example和negative example的问题以及easy example和hard example的问题。首先先贴出mmdetection中关于GHM loss的代码:

import torch
import torch.nn as nn
import torch.nn.functional as Ffrom ..registry import LOSSESdef _expand_binary_labels(labels, label_weights, label_channels):bin_labels = labels.new_full((labels.size(0), label_channels), 0)inds = torch.nonzero(labels >= 1).squeeze()if inds.numel() > 0:bin_labels[inds, labels[inds] - 1] = 1bin_label_weights = label_weights.view(-1, 1).expand(label_weights.size(0), label_channels)return bin_labels, bin_label_weights# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module
class GHMC(nn.Module):"""GHM Classification Loss.Details of the theorem can be viewed in the paper"Gradient Harmonized Single-stage Detector".https://arxiv.org/abs/1811.05181Args:bins (int): Number of the unit regions for distribution calculation.momentum (float): The parameter for moving average.use_sigmoid (bool): Can only be true for BCE based loss now.loss_weight (float): The weight of the total GHM-C loss."""def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0):super(GHMC, self).__init__()self.bins = binsself.momentum = momentumself.edges = torch.arange(bins + 1).float().cuda() / binsself.edges[-1] += 1e-6if momentum > 0:self.acc_sum = torch.zeros(bins).cuda()self.use_sigmoid = use_sigmoidif not self.use_sigmoid:raise NotImplementedErrorself.loss_weight = loss_weightdef forward(self, pred, target, label_weight, *args, **kwargs):"""Calculate the GHM-C loss.Args:pred (float tensor of size [batch_num, class_num]):The direct prediction of classification fc layer.target (float tensor of size [batch_num, class_num]):Binary class target for each sample.label_weight (float tensor of size [batch_num, class_num]):the value is 1 if the sample is valid and 0 if ignored.Returns:The gradient harmonized loss."""# the target should be binary class labelif pred.dim() != target.dim():target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1))target, label_weight = target.float(), label_weight.float()edges = self.edgesmmt = self.momentumweights = torch.zeros_like(pred)# gradient lengthg = torch.abs(pred.sigmoid().detach() - target)valid = label_weight > 0tot = max(valid.float().sum().item(), 1.0)n = 0  # n valid binsfor i in range(self.bins):inds = (g >= edges[i]) & (g < edges[i + 1]) & validnum_in_bin = inds.sum().item()if num_in_bin > 0:if mmt > 0:self.acc_sum[i] = mmt * self.acc_sum[i] \+ (1 - mmt) * num_in_binweights[inds] = tot / self.acc_sum[i]else:weights[inds] = tot / num_in_binn += 1if n > 0:weights = weights / nloss = F.binary_cross_entropy_with_logits(pred, target, weights, reduction='sum') / totreturn loss * self.loss_weight# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module
class GHMR(nn.Module):"""GHM Regression Loss.Details of the theorem can be viewed in the paper"Gradient Harmonized Single-stage Detector"https://arxiv.org/abs/1811.05181Args:mu (float): The parameter for the Authentic Smooth L1 loss.bins (int): Number of the unit regions for distribution calculation.momentum (float): The parameter for moving average.loss_weight (float): The weight of the total GHM-R loss."""def __init__(self, mu=0.02, bins=10, momentum=0, loss_weight=1.0):super(GHMR, self).__init__()self.mu = muself.bins = binsself.edges = torch.arange(bins + 1).float().cuda() / binsself.edges[-1] = 1e3self.momentum = momentumif momentum > 0:self.acc_sum = torch.zeros(bins).cuda()self.loss_weight = loss_weight# TODO: support reduction parameterdef forward(self, pred, target, label_weight, avg_factor=None):"""Calculate the GHM-R loss.Args:pred (float tensor of size [batch_num, 4 (* class_num)]):The prediction of box regression layer. Channel number can be 4or 4 * class_num depending on whether it is class-agnostic.target (float tensor of size [batch_num, 4 (* class_num)]):The target regression values with the same size of pred.label_weight (float tensor of size [batch_num, 4 (* class_num)]):The weight of each sample, 0 if ignored.Returns:The gradient harmonized loss."""mu = self.muedges = self.edgesmmt = self.momentum# ASL1 lossdiff = pred - targetloss = torch.sqrt(diff * diff + mu * mu) - mu# gradient lengthg = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()weights = torch.zeros_like(g)valid = label_weight > 0tot = max(label_weight.float().sum().item(), 1.0)n = 0  # n: valid binsfor i in range(self.bins):inds = (g >= edges[i]) & (g < edges[i + 1]) & validnum_in_bin = inds.sum().item()if num_in_bin > 0:n += 1if mmt > 0:self.acc_sum[i] = mmt * self.acc_sum[i] \+ (1 - mmt) * num_in_binweights[inds] = tot / self.acc_sum[i]else:weights[inds] = tot / num_in_binif n > 0:weights /= nloss = loss * weightsloss = loss.sum() / totreturn loss * self.loss_weight

to be continue

《Gradient Harmonized Single-stage Detector》论文阅读以及代码分析相关推荐

  1. DSSD : Deconvolutional Single Shot Detector论文阅读笔记

    文章目录 DSSD : Deconvolutional Single Shot Detector论文阅读笔记2017 Abstract 1. Introduction 2. Related Work ...

  2. 目标检测--Accurate Single Stage Detector Using Recurrent Rolling Convolution

    Accurate Single Stage Detector Using Recurrent Rolling Convolution CVPR 2017 商汤科技关于目标检测的文献 Code: htt ...

  3. EAST: An Efficient and Accurate Scene Text Detector 论文阅读

    EAST: An Efficient and Accurate Scene Text Detector 论文阅读 Reference 正文 摘要 引言 相关工作 方法 算法 网络设计 标签生成 损失函数 ...

  4. OpenCV图像处理算法——7(《Contrast image correction method》 论文阅读及代码实现)

    <Contrast image correction method> 论文阅读及代码实现 以下内容大部分引自:https://cloud.tencent.com/developer/art ...

  5. fishnet:论文阅读与代码理解

    fishnet:论文阅读与代码理解 一.论文概述 二.整体框架 三.代码理解 四.总结 fishnet论文地址:http://papers.nips.cc/paper/7356-fishnet-a-v ...

  6. CSI笔记【5】:Widar2.0: Passive Human Tracking with a Single Wi-Fi Link论文阅读

    CSI笔记[5]:Widar2.0: Passive Human Tracking with a Single Wi-Fi Link论文笔记 前言 Abstract 1 INTRODUCTION 2 ...

  7. Deep Depth Completion of a Single RGB-D Image论文阅读记录以及quicktest

    (一)论文简要说明 这是一篇2018年CVPR的最新论文,可以直接通过输入RGB图以及相对应的Depth图,然后可以直接补全任意形式深度图的缺失. 论文地址:https://arxiv.org/abs ...

  8. 九月学习笔记 (FM、一些论文阅读、代码)

    目录 2020.09.16 FM 因子分解机 2021.09.18 论文阅读 Interactive Recommender System via Knowledge Graph-enhanced R ...

  9. paperswithcode 论文阅读与代码复现

    Machine Learning论文阅读与复现 神奇宝贝 1.丰富的论文合集 2.丰富的数据集 3.方法合集 4.论文解析 要是有一个cs科研er不知道这个宝藏网站,我都会伤心的,OK?https:/ ...

最新文章

  1. Pyhton 操作MySQL数据库
  2. 联想Y510P安装windows 8.1
  3. java + selenium 种WebElement 定位到父元素 跟子元素
  4. git管理复杂项目代码
  5. 准备刺第一针了(飞秋官方下载)
  6. 利用数据缓存加速文件备份
  7. ELK下Kibana性能调优
  8. Mybatis编写初始化Dao代码
  9. 4.7 Spark SQL 数据分析流程
  10. RxJava 的基本使用
  11. 数据统计分析(SPSS)【1】
  12. 【嵌入式】---- 单片机常用单位
  13. HTTP 503 Service Temporarily Unavailable
  14. 新编php找工作常见面试笔试汇总
  15. 网页设计配色应用实例剖析——橙色系
  16. 又一次移植最新lvgl8到esp32的踩坑记录
  17. JS原型和原型链是什么?
  18. [Re]2022DASCTF Apr X FATE 防疫挑战赛
  19. html第二章课后选择题答案,心理学基础第二章 课后习题
  20. 梅科尔工作室-崔启凡-鸿蒙笔记4

热门文章

  1. CSS初始化(科普)
  2. Win7不激活会怎么样 Win7可以不用激活吗
  3. 前端工程师用代码制作特效,七夕情人节成功表白女神,终于摆脱单身
  4. SpringSecurity(一)
  5. 程序员的算法课(6)-最长公共子序列(LCS)
  6. TCP为什么需要3次握手与4次挥手
  7. 科学计数法转换为普通数字
  8. HuaWei ❀ Virtual Firewalld 虚拟防火墙
  9. python查看excel编码格式_[Python]实现处理读写xlsx xls excel文件格式(含中文处理方法)...
  10. Python快速搭建网站