教程 4: 自定义模型

我们通常把模型的各个组成成分分成6种类型:

  • 编码器(encoder):包括 voxel layer、voxel encoder 和 middle encoder 等进入 backbone 前所使用的基于 voxel 的方法,如 HardVFE 和 PointPillarsScatter。
  • 骨干网络(backbone):通常采用 FCN 网络来提取特征图,如 ResNet 和 SECOND。
  • 颈部网络(neck):位于 backbones 和 heads 之间的组成模块,如 FPN 和 SECONDFPN。
  • 检测头(head):用于特定任务的组成模块,如检测框的预测和掩码的预测。
  • RoI 提取器(RoI extractor):用于从特征图中提取 RoI 特征的组成模块,如 H3DRoIHead 和 PartAggregationROIHead。
  • 损失函数(loss):heads 中用于计算损失函数的组成模块,如 FocalLoss、L1Loss 和 GHMLoss。

开发新的组成模块

添加新建 encoder

接下来我们以 HardVFE 为例展示如何开发新的组成模块。

1. 定义一个新的 voxel encoder(如 HardVFE:即 DV-SECOND 中所提出的 Voxel 特征提取器)

创建一个新文件 mmdet3d/models/voxel_encoders/voxel_encoder.py

import torch.nn as nnfrom ..builder import VOXEL_ENCODERS@VOXEL_ENCODERS.register_module()
class HardVFE(nn.Module):def __init__(self, arg1, arg2):passdef forward(self, x):  # should return a tuplepass

2. 导入新建模块

用户可以通过添加下面这行代码到 mmdet3d/models/voxel_encoders/__init__.py

from .voxel_encoder import HardVFE

或者添加以下的代码到配置文件中,从而能够在避免修改源码的情况下导入新建模块。

custom_imports = dict(imports=['mmdet3d.models.voxel_encoders.HardVFE'],allow_failed_imports=False)

3. 在配置文件中使用 voxel encoder

model = dict(...voxel_encoder=dict(type='HardVFE',arg1=xxx,arg2=xxx),...

添加新建 backbone

接下来我们以 SECOND(Sparsely Embedded Convolutional Detection) 为例展示如何开发新的组成模块。

1. 定义一个新的 backbone(如 SECOND)

创建一个新文件 mmdet3d/models/backbones/second.py

import torch.nn as nnfrom ..builder import BACKBONES@BACKBONES.register_module()
class SECOND(BaseModule):def __init__(self, arg1, arg2):passdef forward(self, x):  # should return a tuplepass

2. 导入新建模块

用户可以通过添加下面这行代码到 mmdet3d/models/backbones/__init__.py

from .second import SECOND

或者添加以下的代码到配置文件中,从而能够在避免修改源码的情况下导入新建模块。

custom_imports = dict(imports=['mmdet3d.models.backbones.second'],allow_failed_imports=False)

3. 在配置文件中使用 backbone

model = dict(...backbone=dict(type='SECOND',arg1=xxx,arg2=xxx),...

添加新建 necks

1. 定义一个新的 neck(如 SECONDFPN)

创建一个新文件 mmdet3d/models/necks/second_fpn.py

from ..builder import NECKS@NECKS.register
class SECONDFPN(BaseModule):def __init__(self,in_channels=[128, 128, 256],out_channels=[256, 256, 256],upsample_strides=[1, 2, 4],norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),upsample_cfg=dict(type='deconv', bias=False),conv_cfg=dict(type='Conv2d', bias=False),use_conv_for_no_stride=False,init_cfg=None):passdef forward(self, X):# implementation is ignoredpass

2. 导入新建模块

用户可以通过添加下面这行代码到 mmdet3D/models/necks/__init__.py

from .second_fpn import SECONDFPN

或者添加以下的代码到配置文件中,从而能够在避免修改源码的情况下导入新建模块。

custom_imports = dict(imports=['mmdet3d.models.necks.second_fpn'],allow_failed_imports=False)

3. 在配置文件中使用 neck

model = dict(...neck=dict(type='SECONDFPN',in_channels=[64, 128, 256],upsample_strides=[1, 2, 4],out_channels=[128, 128, 128]),...

添加新建 heads

接下来我们以 PartA2 Head 为例展示如何开发新的组成模块。

注意:此处展示的 PartA2 RoI Head 将应用于双阶段检测器中,对于单阶段检测器,请参考 mmdet3d/models/dense_heads/ 中所展示的例子。由于这些 heads 简单高效,因此这些 heads 普遍应用在自动驾驶场景下的 3D 检测任务中。

首先,在 mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py 中创建一个新的 bbox head。
PartA2 RoI Head 实现一个新的 bbox head ,并用于目标检测的任务中。
为了实现一个新的 bbox head,通常需要在其中实现三个功能,如下所示,有时该模块还需要实现其他相关的功能,如 lossget_targets

from mmdet.models.builder import HEADS
from .bbox_head import BBoxHead@HEADS.register_module()
class PartA2BboxHead(BaseModule):"""PartA2 RoI head."""def __init__(self,num_classes,seg_in_channels,part_in_channels,seg_conv_channels=None,part_conv_channels=None,merge_conv_channels=None,down_conv_channels=None,shared_fc_channels=None,cls_channels=None,reg_channels=None,dropout_ratio=0.1,roi_feat_size=14,with_corner_loss=True,bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),conv_cfg=dict(type='Conv1d'),norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),loss_cls=dict(type='CrossEntropyLoss',use_sigmoid=True,reduction='none',loss_weight=1.0),init_cfg=None):super(PartA2BboxHead, self).__init__(init_cfg=init_cfg)def forward(self, seg_feats, part_feats):

其次,如果有必要的话,用户还需要实现一个新的 RoI Head,此处我们从 Base3DRoIHead 中继承得到一个新类 PartAggregationROIHead,此时我们就能发现 Base3DRoIHead 已经实现了下面的功能:

from abc import ABCMeta, abstractmethod
from torch import nn as nn@HEADS.register_module()
class Base3DRoIHead(BaseModule, metaclass=ABCMeta):"""Base class for 3d RoIHeads."""def __init__(self,bbox_head=None,mask_roi_extractor=None,mask_head=None,train_cfg=None,test_cfg=None,init_cfg=None):@propertydef with_bbox(self):@propertydef with_mask(self):@abstractmethoddef init_weights(self, pretrained):@abstractmethoddef init_bbox_head(self):@abstractmethoddef init_mask_head(self):@abstractmethoddef init_assigner_sampler(self):@abstractmethoddef forward_train(self,x,img_metas,proposal_list,gt_bboxes,gt_labels,gt_bboxes_ignore=None,**kwargs):def simple_test(self,x,proposal_list,img_metas,proposals=None,rescale=False,**kwargs):"""Test without augmentation."""passdef aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):"""Test with augmentations.If rescale is False, then returned bboxes and masks will fit the scaleof imgs[0]."""pass

接着将会对 bbox_forward 的逻辑进行修改,同时,bbox_forward 还会继承来自 Base3DRoIHead 的其他逻辑,在 mmdet3d/models/roi_heads/part_aggregation_roi_head.py 中,我们实现了新的 RoI Head,如下所示:

from torch.nn import functional as Ffrom mmdet3d.core import AssignResult
from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi
from mmdet.core import build_assigner, build_sampler
from mmdet.models import HEADS
from ..builder import build_head, build_roi_extractor
from .base_3droi_head import Base3DRoIHead@HEADS.register_module()
class PartAggregationROIHead(Base3DRoIHead):"""Part aggregation roi head for PartA2.Args:semantic_head (ConfigDict): Config of semantic head.num_classes (int): The number of classes.seg_roi_extractor (ConfigDict): Config of seg_roi_extractor.part_roi_extractor (ConfigDict): Config of part_roi_extractor.bbox_head (ConfigDict): Config of bbox_head.train_cfg (ConfigDict): Training config.test_cfg (ConfigDict): Testing config."""def __init__(self,semantic_head,num_classes=3,seg_roi_extractor=None,part_roi_extractor=None,bbox_head=None,train_cfg=None,test_cfg=None,init_cfg=None):super(PartAggregationROIHead, self).__init__(bbox_head=bbox_head,train_cfg=train_cfg,test_cfg=test_cfg,init_cfg=init_cfg)self.num_classes = num_classesassert semantic_head is not Noneself.semantic_head = build_head(semantic_head)if seg_roi_extractor is not None:self.seg_roi_extractor = build_roi_extractor(seg_roi_extractor)if part_roi_extractor is not None:self.part_roi_extractor = build_roi_extractor(part_roi_extractor)self.init_assigner_sampler()def _bbox_forward(self, seg_feats, part_feats, voxels_dict, rois):"""Forward function of roi_extractor and bbox_head used in bothtraining and testing.Args:seg_feats (torch.Tensor): Point-wise semantic features.part_feats (torch.Tensor): Point-wise part prediction features.voxels_dict (dict): Contains information of voxels.rois (Tensor): Roi boxes.Returns:dict: Contains predictions of bbox_head andfeatures of roi_extractor."""pooled_seg_feats = self.seg_roi_extractor(seg_feats,voxels_dict['voxel_centers'],voxels_dict['coors'][..., 0],rois)pooled_part_feats = self.part_roi_extractor(part_feats, voxels_dict['voxel_centers'],voxels_dict['coors'][..., 0], rois)cls_score, bbox_pred = self.bbox_head(pooled_seg_feats,pooled_part_feats)bbox_results = dict(cls_score=cls_score,bbox_pred=bbox_pred,pooled_seg_feats=pooled_seg_feats,pooled_part_feats=pooled_part_feats)return bbox_results

此处我们省略了与其他功能相关的细节,请参考 此处 获取更多细节。

最后,用户需要在 mmdet3d/models/bbox_heads/__init__.pymmdet3d/models/roi_heads/__init__.py 中添加新模块,使得对应的注册器能够发现并加载该模块。

此外,用户也可以添加以下的代码到配置文件中,从而实现相同的目标。

custom_imports=dict(imports=['mmdet3d.models.roi_heads.part_aggregation_roi_head', 'mmdet3d.models.roi_heads.bbox_heads.parta2_bbox_head'])

PartAggregationROIHead 的配置文件如下所示:

model = dict(...roi_head=dict(type='PartAggregationROIHead',num_classes=3,semantic_head=dict(type='PointwiseSemanticHead',in_channels=16,extra_width=0.2,seg_score_thr=0.3,num_classes=3,loss_seg=dict(type='FocalLoss',use_sigmoid=True,reduction='sum',gamma=2.0,alpha=0.25,loss_weight=1.0),loss_part=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),seg_roi_extractor=dict(type='Single3DRoIAwareExtractor',roi_layer=dict(type='RoIAwarePool3d',out_size=14,max_pts_per_voxel=128,mode='max')),part_roi_extractor=dict(type='Single3DRoIAwareExtractor',roi_layer=dict(type='RoIAwarePool3d',out_size=14,max_pts_per_voxel=128,mode='avg')),bbox_head=dict(type='PartA2BboxHead',num_classes=3,seg_in_channels=16,part_in_channels=4,seg_conv_channels=[64, 64],part_conv_channels=[64, 64],merge_conv_channels=[128, 128],down_conv_channels=[128, 256],bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),shared_fc_channels=[256, 512, 512, 512],cls_channels=[256, 256],reg_channels=[256, 256],dropout_ratio=0.1,roi_feat_size=14,with_corner_loss=True,loss_bbox=dict(type='SmoothL1Loss',beta=1.0 / 9.0,reduction='sum',loss_weight=1.0),loss_cls=dict(type='CrossEntropyLoss',use_sigmoid=True,reduction='sum',loss_weight=1.0)))...)

MMDetection 2.0 支持配置文件之间的继承,使得用户能够更加关注自己的配置文件的修改。
PartA2 Head 的第二阶段主要使用新建的 PartAggregationROIHeadPartA2BboxHead,需要根据对应模块的 __init__ 参数来设置对应的参数。

添加新建 loss

假定用户想要新添一个用于检测框回归的 loss,并命名为 MyLoss
为了添加一个新的 loss ,用于需要在 mmdet3d/models/losses/my_loss.py 中实现对应的逻辑。
装饰器 weighted_loss 能够保证对 batch 中每个样本的 loss 进行加权平均。

import torch
import torch.nn as nnfrom ..builder import LOSSES
from .utils import weighted_loss@weighted_loss
def my_loss(pred, target):assert pred.size() == target.size() and target.numel() > 0loss = torch.abs(pred - target)return loss@LOSSES.register_module()
class MyLoss(nn.Module):def __init__(self, reduction='mean', loss_weight=1.0):super(MyLoss, self).__init__()self.reduction = reductionself.loss_weight = loss_weightdef forward(self,pred,target,weight=None,avg_factor=None,reduction_override=None):assert reduction_override in (None, 'none', 'mean', 'sum')reduction = (reduction_override if reduction_override else self.reduction)loss_bbox = self.loss_weight * my_loss(pred, target, weight, reduction=reduction, avg_factor=avg_factor)return loss_bbox

接着,用户需要将 loss 添加到 mmdet3d/models/losses/__init__.py

from .my_loss import MyLoss, my_loss

此外,用户也可以添加以下的代码到配置文件中,从而实现相同的目标。

custom_imports=dict(imports=['mmdet3d.models.losses.my_loss'])

为了使用该 loss,需要对 loss_xxx 域进行修改。
因为 MyLoss 主要用于检测框的回归,因此需要在对应的 head 中修改 loss_bbox 域的值。

loss_bbox=dict(type='MyLoss', loss_weight=1.0))

【mmdetection3d】——04自定义模型相关推荐

  1. 关于DEDECMS自定义模型当中添加自定义字段后在后台添加内容后不显示解决方案...

    问题:我们自定义模型,添加自定义字段,比如单行文本(varchar)字段时,在后台添加内容,无法显示,但数据库里字段是有数据的. 解决办法:看看你的字段命名是否有大写,如果有全部改成小写就好了. 转载 ...

  2. Qt中的自定义模型类

    文章目录 1 Qt中的通用模型类 1.1 Qt中的通用模型类 1.2 Qt中的变体类型QVariant 2 自定义模型类 2.1 自定义模型类设计分析 2.2 自定义模型类数据层.数据表示层.数据组织 ...

  3. MyBatis-学习笔记04【04.自定义Mybatis框架基于注解开发】

    Java后端 学习路线 笔记汇总表[黑马程序员] MyBatis-学习笔记01[01.Mybatis课程介绍及环境搭建][day01] MyBatis-学习笔记02[02.Mybatis入门案例] M ...

  4. (四)Qt实现自定义模型基于QAbstractTableModel (一般)

    Qt实现自定义模型基于QAbstractTableModel 两个例子 例子1代码 Main.cpp #include <QtGui>#include "currencymode ...

  5. TensorFlow 2.0 - 自定义模型、训练过程

    文章目录 1. 自定义模型 2. 学习流程 学习于:简单粗暴 TensorFlow 2 1. 自定义模型 重载 call() 方法,pytorch 是重载 forward() 方法 import te ...

  6. 用于将带有查询字符串的复杂对象传递到Web API方法的自定义模型绑定器

    目录 介绍 查询复杂对象的字符串字段 使用和测试FieldValueModelBinder类 FieldValueModelBinder如何工作? 获取源字段和值 将字段部分与对象属性匹配 解析枚举类 ...

  7. (五)Qt实现自定义模型基于QAbstractItemModel

    目录: (一) Qt Model/View 的简单说明 .预定义模型 (二)使用预定义模型 QstringListModel例子 (三)使用预定义模型QDirModel的例子 (四)Qt实现自定义模型 ...

  8. (四)Qt实现自定义模型基于QAbstractTableModel

    目录: (一) Qt Model/View 的简单说明 .预定义模型 (二)使用预定义模型 QstringListModel例子 (三)使用预定义模型QDirModel的例子 (四)Qt实现自定义模型 ...

  9. Qt4_实现自定义模型

    实现自定义模型 Qt的预定义模型为数据的处理和查看提供了很好的方法.但是,有些数据源不能有效地和预定义模型一起工作,这时就需要创建自定义模型,以方便对底层数据源进行优化. 在介绍如何创建自定义模型之前 ...

最新文章

  1. ABAP 程序间的调用
  2. re模块与正则表达式
  3. Data source rejected establishment of connection, message from server: Too many connections
  4. 面试题4:二维数组中的查找
  5. oracle星形转换,Oracle数据仓库博客(转,学)
  6. 查看linux机器性能,Unix Linux 查看机器性能
  7. java根据文件路径读取文件_java根据路径读取文件
  8. 2021年兰州师大附中高考成绩查询,2021年兰州重点高中名单及排名,兰州高中高考成绩排名榜...
  9. DevExpress GridControl Gridview RepositoryItemCheckEdit复选框及获取选择行数据
  10. 深入Python字典的内部实现
  11. 开源 java CMS - FreeCMS2.8 数据对象 site
  12. Bailian2713 肿瘤面积【基础】
  13. java recv failed,java.sql.SQLException: I/O Error: Software caused connection abort: recv failed
  14. NTFS文件系统详细分析
  15. python第三方模块下载方法(最详最细)
  16. 计算机无法安装VC2015,win7系统vc++2015一个或多个问题导致了安装失败的处理步骤...
  17. subtype,supertype 与 subclass,superclass 的异同
  18. DaVinci:自定义常用剪辑快捷键
  19. ACM-ICPC 2018 沈阳赛区网络预赛 F. Fantastic Graph (有上下界可行流)
  20. dvwa下载及安装-图文详解+phpStudy配置

热门文章

  1. Dell Inspiron 灵越 5570 换硬盘
  2. 如何修复手机无服务器,技巧 | 手机无服务没信号?这样做就能修复!
  3. 台式计算机是否属于工装,工装是否属于劳保用品呢
  4. python-py文件在windows下乱码
  5. word插入公式(2):告别空格 ,公式居中,编号自动右对齐(适用于论文)
  6. revit二次开发调整三维视图的视角方向
  7. LayoutInflater.inflate的用法总结
  8. window下redis重启数据丢失(已解决)
  9. 完美世界校招算法题2017
  10. 软考:2021 中级软件评测师报考指南