OHEM的pytorch代码实现细节
详细解读一下OHEM的实现代码:
def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0
):"""Arguments:batch_size (int): number of sampled rois for bbox head trainingloc_pred (FloatTensor): [R, 4], location of positive roisloc_target (FloatTensor): [R, 4], location of positive roispos_mask (FloatTensor): [R], binary mask for sampled positive roiscls_pred (FloatTensor): [R, C]cls_target (LongTensor): [R]Returns:cls_loss, loc_loss (FloatTensor)"""ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False)#这里先暂存下正常的分类loss和回归lossloss = ohem_cls_loss + ohem_loc_loss#然后对分类和回归loss求和sorted_ohem_loss, idx = torch.sort(loss, descending=True)#再对loss进行降序排列keep_num = min(sorted_ohem_loss.size()[0], batch_size)#得到需要保留的loss数量if keep_num < sorted_ohem_loss.size()[0]:#这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留keep_idx_cuda = idx[:keep_num]#保留到需要keep的数目ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]#分类和回归保留相同的数目cls_loss = ohem_cls_loss.sum() / keep_numloc_loss = ohem_loc_loss.sum() / keep_num#然后分别对分类和回归loss求均值return cls_loss, loc_loss
OHEM的pytorch代码实现细节相关推荐
- BNN Pytorch代码阅读笔记
BNN Pytorch代码阅读笔记 这篇博客来写一下我对BNN(二值化神经网络)pytorch代码的理解,我是第一次阅读项目代码,所以想仔细的自己写一遍,把细节理解透彻,希望也能帮到大家! 论文链接: ...
- 论文学习笔记: Learning Multi-Scale Photo Exposure Correction(含pytorch代码复现)
论文学习笔记: Learning Multi-Scale Photo Exposure Correction--含pytorch代码复现 本章工作: 论文摘要 训练数据集 网络设计原理 补充知识:拉普 ...
- YOLOv2---优图代码+实现细节
---- 参考链接: https://blog.csdn.net/lanran2/article/details/82826045 https://zhuanlan.zhihu.com/p/35325 ...
- VITAL Tracker Pytorch 代码阅读笔记
VITAL Tracker Pytorch 代码阅读笔记 论文链接:https://arxiv.org/pdf/1804.04273.pdf 代码链接:https://github.com/abner ...
- Transformer Pytorch代码实现以及理解
Transformer结构 论文:Attention is all you need Transformer模型是2017年Google公司在论文<Attention is All ...
- python实现胶囊网络_Capsule Network胶囊网络解读与pytorch代码实现
本文是论文<Dynamic Routing between Capsules>的论文解读与pytorch代码实现. 如需转载本文或代码请联系作者 @Riroaki 并声明. 众所周知,卷积 ...
- 【3D计算机视觉】从PointNet到PointNet++理论及pytorch代码
从PointNet到PointNet++理论及代码详解 1. 点云是什么 1.1 三维数据的表现形式 1.2 为什么使用点云 1.3 点云上以往的相关工作 2. PointNet 2.1 基于点云的置 ...
- 将卷积引入transformer中VcT(Introducing Convolutions to Vision Transformers)的pytorch代码详解
文章目录 1. Motivation: 2. Method 2.1 Convolutional Token Embedding 模块 2.2 Convolutional Projection For ...
- PyTorch代码调试利器_TorchSnooper
GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper 大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch ...
最新文章
- 使用模式创建一个面向服务的组件中间件
- 第04讲: 基础探究,Session 与 Cookies
- android定时循环,Android AlarmManager实现定时循环后台任务
- python创建sqlite3 unicode error_python/sqlite3:发生异常:sqlite3.operationalerror
- javascript教程:console.log 详解
- es6 模块的整体加载
- 矩池云安装gcc4.9和g++4.9简单教程
- diffpatch升级_Tinker资源补丁原理解析
- Flash MX本地保存数据的三种方法
- 音创服务器系统手动加歌,音创ktv点歌系统的教程
- 计算机质保试题及答案,质量体系、国军标体系试卷(质保部出)
- 电商数据分析项目总结!
- uC/OS-II任务调度之就绪表及最高优先级任务判定算法
- MX_Player_Pro_专业精简版AC3/DTS/EAC3 By.SOLDIER-就要应用网91apps.cn
- [Boston Legal][S02E02]Allan Shore在Kelly Nolan被控杀夫一案中的结案陈词
- 抖音SEO之关键词排名优化详解【从入门到精通】
- adlds文件服务器,Windows轻型目录(AD LDS)的备份恢复
- 转载 Package CJK Error: Invalid character code错误
- 疫情防控信息管理系统
- 接口测试面试题及参考答案,就等你来看~