详细解读一下OHEM的实现代码:

def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0
):"""Arguments:batch_size (int): number of sampled rois for bbox head trainingloc_pred (FloatTensor): [R, 4], location of positive roisloc_target (FloatTensor): [R, 4], location of positive roispos_mask (FloatTensor): [R], binary mask for sampled positive roiscls_pred (FloatTensor): [R, C]cls_target (LongTensor): [R]Returns:cls_loss, loc_loss (FloatTensor)"""ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False)#这里先暂存下正常的分类loss和回归lossloss = ohem_cls_loss + ohem_loc_loss#然后对分类和回归loss求和sorted_ohem_loss, idx = torch.sort(loss, descending=True)#再对loss进行降序排列keep_num = min(sorted_ohem_loss.size()[0], batch_size)#得到需要保留的loss数量if keep_num < sorted_ohem_loss.size()[0]:#这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留keep_idx_cuda = idx[:keep_num]#保留到需要keep的数目ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]#分类和回归保留相同的数目cls_loss = ohem_cls_loss.sum() / keep_numloc_loss = ohem_loc_loss.sum() / keep_num#然后分别对分类和回归loss求均值return cls_loss, loc_loss

OHEM的pytorch代码实现细节相关推荐

  1. BNN Pytorch代码阅读笔记

    BNN Pytorch代码阅读笔记 这篇博客来写一下我对BNN(二值化神经网络)pytorch代码的理解,我是第一次阅读项目代码,所以想仔细的自己写一遍,把细节理解透彻,希望也能帮到大家! 论文链接: ...

  2. 论文学习笔记: Learning Multi-Scale Photo Exposure Correction(含pytorch代码复现)

    论文学习笔记: Learning Multi-Scale Photo Exposure Correction--含pytorch代码复现 本章工作: 论文摘要 训练数据集 网络设计原理 补充知识:拉普 ...

  3. YOLOv2---优图代码+实现细节

    ---- 参考链接: https://blog.csdn.net/lanran2/article/details/82826045 https://zhuanlan.zhihu.com/p/35325 ...

  4. VITAL Tracker Pytorch 代码阅读笔记

    VITAL Tracker Pytorch 代码阅读笔记 论文链接:https://arxiv.org/pdf/1804.04273.pdf 代码链接:https://github.com/abner ...

  5. Transformer Pytorch代码实现以及理解

    Transformer结构​​​​​​​ 论文:Attention is all you need Transformer模型是2017年Google公司在论文<Attention is All ...

  6. python实现胶囊网络_Capsule Network胶囊网络解读与pytorch代码实现

    本文是论文<Dynamic Routing between Capsules>的论文解读与pytorch代码实现. 如需转载本文或代码请联系作者 @Riroaki 并声明. 众所周知,卷积 ...

  7. 【3D计算机视觉】从PointNet到PointNet++理论及pytorch代码

    从PointNet到PointNet++理论及代码详解 1. 点云是什么 1.1 三维数据的表现形式 1.2 为什么使用点云 1.3 点云上以往的相关工作 2. PointNet 2.1 基于点云的置 ...

  8. 将卷积引入transformer中VcT(Introducing Convolutions to Vision Transformers)的pytorch代码详解

    文章目录 1. Motivation: 2. Method 2.1 Convolutional Token Embedding 模块 2.2 Convolutional Projection For ...

  9. PyTorch代码调试利器_TorchSnooper

    GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper 大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch ...

最新文章

  1. 使用模式创建一个面向服务的组件中间件
  2. 第04讲: 基础探究,Session 与 Cookies
  3. android定时循环,Android AlarmManager实现定时循环后台任务
  4. python创建sqlite3 unicode error_python/sqlite3:发生异常:sqlite3.operationalerror
  5. javascript教程:console.log 详解
  6. es6 模块的整体加载
  7. 矩池云安装gcc4.9和g++4.9简单教程
  8. diffpatch升级_Tinker资源补丁原理解析
  9. Flash MX本地保存数据的三种方法
  10. 音创服务器系统手动加歌,音创ktv点歌系统的教程
  11. 计算机质保试题及答案,质量体系、国军标体系试卷(质保部出)
  12. 电商数据分析项目总结!
  13. uC/OS-II任务调度之就绪表及最高优先级任务判定算法
  14. MX_Player_Pro_专业精简版AC3/DTS/EAC3 By.SOLDIER-就要应用网91apps.cn
  15. [Boston Legal][S02E02]Allan Shore在Kelly Nolan被控杀夫一案中的结案陈词
  16. 抖音SEO之关键词排名优化详解【从入门到精通】
  17. adlds文件服务器,Windows轻型目录(AD LDS)的备份恢复
  18. 转载 Package CJK Error: Invalid character code错误
  19. 疫情防控信息管理系统
  20. 接口测试面试题及参考答案,就等你来看~

热门文章

  1. caller和callee的使用方法
  2. 项目管理中职能型、矩阵型、项目型组织结构的优缺点
  3. 查找技术——折半查找(二分查找)
  4. 如何升级libc.so.6以及升级后引发的灾难
  5. 微软surface3开启硬件虚拟化
  6. (二)Linux物理内存初始化
  7. 【生信】KEGG数据库在线使用
  8. mac操作系统快捷键总结
  9. Filebeat+Kafka+ELK日志采集(一)
  10. java 对象锁和类锁的区别