2021SC@SDUSC
CenterNet之loss计算代码解析(接上文)
代码解析
来自train.py中第173行开始进行loss计算:

# 得到heat map, reg, wh 三个变量
hmap, regs, w_h_ = zip(*outputs)regs = [
_tranpose_and_gather_feature(r, batch['inds']) for r in regs
]
w_h_ = [
_tranpose_and_gather_feature(r, batch['inds']) for r in w_h_
]# 分别计算loss
hmap_loss = _neg_loss(hmap, batch['hmap'])
reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])# 进行loss加权,得到最终loss
loss = hmap_loss + 1 * reg_loss + 0.1 * w_h_loss

上述transpose_and_gather_feature函数具体实现如下,主要功能是将ground truth中计算得到的对应中心点的值获取。

def _tranpose_and_gather_feature(feat, ind):# ind代表的是ground truth中设置的存在目标点的下角标feat = feat.permute(0, 2, 3, 1).contiguous()# from [bs c h w] to [bs, h, w, c] feat = feat.view(feat.size(0), -1, feat.size(3)) # to [bs, wxh, c]feat = _gather_feature(feat, ind)return featdef _gather_feature(feat, ind, mask=None):# feat : [bs, wxh, c]dim = feat.size(2)# ind : [bs, index, c]ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)feat = feat.gather(1, ind) # 按照dim=1获取indif mask is not None:mask = mask.unsqueeze(2).expand_as(feat)feat = feat[mask]feat = feat.view(-1, dim)return feat

hmap loss代码
调用:hmap_loss = _neg_loss(hmap, batch[‘hmap’])

def _neg_loss(preds, targets):''' Modified focal loss. Exactly the same as CornerNet.Runs faster and costs a little bit more memoryArguments:preds (B x c x h x w)gt_regr (B x c x h x w)'''pos_inds = targets.eq(1).float()# heatmap为1的部分是正样本neg_inds = targets.lt(1).float()# 其他部分为负样本neg_weights = torch.pow(1 - targets, 4)# 对应(1-Yxyc)^4loss = 0for pred in preds: # 预测值# 约束在0-1之间pred = torch.clamp(torch.sigmoid(pred), min=1e-4, max=1 - 1e-4)pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_indsneg_loss = torch.log(1 - pred) * torch.pow(pred,2) * neg_weights * neg_indsnum_pos = pos_inds.float().sum()pos_loss = pos_loss.sum()neg_loss = neg_loss.sum()if num_pos == 0:loss = loss - neg_loss # 只有负样本else:loss = loss - (pos_loss + neg_loss) / num_posreturn loss / len(preds)


代码和以上公式一一对应,pos代表正样本,neg代表负样本。
reg & wh loss代码
调用:reg_loss = reg_loss(regs, batch[‘regs’], batch[‘ind_masks’])
调用:w_h_loss = reg_loss(w_h, batch['w_h
’], batch[‘ind_masks’])

def _reg_loss(regs, gt_regs, mask):mask = mask[:, :, None].expand_as(gt_regs).float()loss = sum(F.l1_loss(r * mask, gt_regs * mask, reduction='sum') /(mask.sum() + 1e-4) for r in regs)return loss / len(regs)

Anchor Free系列模型11相关推荐

  1. 目标检测 YOLO 系列模型

    前言 YOLO (You Only Look Once) 系列模型追求实时目标检测,因此会在一定程度上牺牲精度,以实现更高的检测速度. 如果你对这篇文章感兴趣,可以点击「[访客必读 - 指引页]一文囊 ...

  2. 清华p-tuning | GPT也能做NLU?清华推出p-tuning方法解决GPT系列模型fine-tuning效果比BERT差问题

    一.概述 title:GPT Understands, Too 论文地址:https://arxiv.org/abs/2103.10385 代码:https://github.com/THUDM/P- ...

  3. 今晚直播 |重磅!Anchor Free系列算法强势来袭!

    日常一问:风口上的"猪"--"Anchor Free"算法,飞起来了吗? 有没有飞起来另说,但是"Anchor Free"算法模型层出不穷, ...

  4. GPT系列模型技术路径演进

    目录 前言 Transformer GPT-1 BERT GPT-2 GPT-3 InstructGPT/ChatGPT GPT-4 类ChatGPT产品 Google Bard(诗人) facebo ...

  5. android11测试版下载,Find X2系列 Android 11 Beta1 测试版发布,你要尝试一下吗?

    原标题:Find X2系列 Android 11 Beta1 测试版发布,你要尝试一下吗? 今年的特殊情况导致安卓11发布日期的一波三折,然而最后谷歌还是在6月11日发布了安卓Beta1版本.有趣的是 ...

  6. Android 2.2 r1 API 中文文档系列(11) —— RadioButton

    一.结构 public class RadioButton extends CompoundButton java.lang.Object    android.view.View          ...

  7. bert中文预训练模型_HFL中文预训练系列模型已接入Transformers平台

    哈工大讯飞联合实验室(HFL)在前期陆续发布了多个中文预训练模型,目前已成为最受欢迎的中文预训练资源之一.为了进一步方便广大用户的使用,借助Transformers平台可以更加便捷地调用已发布的中文预 ...

  8. STM32 基础系列教程 11 – ADC 轮询

    前言 学习stm32 adc模数转换接口使用,学会用STM32对模拟信号时行采样,通过轮询模式得到ADC结果. 示例详解 基于硬件平台: STM32F10C8T6最小系统板, MCU 的型号是 STM ...

  9. mdkstc系列器件支持包下载_Find X2系列 Android 11 Beta1 测试版发布

    OPPO Find X2系列 Android 11 Beta1 测试版已经发布了,诚邀软件开发者下载使用体验!该版本仅提供给开发者提前适配Android 11 Beta 1版本.版本存在已知风险,不保 ...

最新文章

  1. NAR:UNITE真菌鉴定ITS数据库——处理未分类和并行分类(数据库文章阅读笔记Markdown模板)...
  2. python 天气预报
  3. Angular - - ngHref、ngSrc、ngCopy/ngCut/ngPaste
  4. linux关机正确方法
  5. 【Eclipse提高开发速度-插件篇】安装VJET插件,JS等提示开发插件
  6. CTF-window和linux下命令执行的知识
  7. linux http 访问限制,51CTO博客-专业IT技术博客创作平台-技术成就梦想
  8. 堆排序-java实现
  9. 魔兽争霸3地图加密,支持重制版-魔兽争霸3地图加密实操,魔兽地图加密工具
  10. 树莓派简单摄像头录像并保存视频文件
  11. 高通820系列(apq8098平台,androido系统),mmm external/ethtool-5.2/,报错
  12. java 实心圆,liststyletype实心圆小图标大小颜色属性设置
  13. Jupyter Notebook使用的快捷键
  14. jquery 报错提示Uncaught TypeError: $ is not a function
  15. 表格(Excel)分列,拆分文本怎么用
  16. Unity调试Android安装包
  17. java原生开发是什么意思,深入剖析
  18. 计算机连接网络被限制,wifi连接被限制怎么办,教您wifi显示网络受限如何解决
  19. 解决方法 curl: (35) OpenSSL SSL_connect: SSL_ERROR_SYSCALL in connection to bit.ly:443
  20. keras:神经网络的中间层输出

热门文章

  1. undefined symbol nvic 报错 undefined symbol TIM_Cmd报错
  2. lol老是闪退到桌面_lol闪退到桌面怎么解决
  3. vue 使用fetch 出现问题解决以及 相应知识学习
  4. 利用VBA在Excel中批量画图
  5. 4 基于matplotlib的python数据可视化——导入Excel数据批量制作柱形图
  6. Xcode14 build WebDriverAgent提示“Cannot link directly with dylib/framework“的解决方法
  7. C# 模拟PrintScreen 和 Alt+PrintScreen截取屏幕图片
  8. 神经网络中的Epoch、Iteration、Batchsize
  9. 【Java】图片 base64
  10. 相亲网站平台制作建设,第九篇