[SOLO ]SOLO: Segmenting Objects by Locations代码解读笔记(ECCV. 2020)
Segmenting Objects by Locations
如果对你帮助的话,希望给我个赞~
文章目录
- SOLO head网络结构
- 损失函数
- 正样本的选取
- 1. SOLO/mmdect/models/detectors/single_stage_ins.py
- 2. SOLO/mmdet/models/anchor_heads/solo_head.py
- 3. SOLO/mmdetect/core/post_processing/matrix_nms.py
- 4. SOLO/configs/solo/solo_r50_fpn_8gpu_1x.py
- 5. SOLO/mmdet/models/anchor_heads/_ _init_ _.py
SOLO head网络结构
损失函数
正样本的选取
论文原话:
起初看完后,并不是很理解。但我认为看完代码后,是我对于正样本选取的一个新的领悟与体会,如何与全卷积网络结合,很好的一个实践与理论相结合,通过代码来反思与加深与论文思想的理解。
其中FCOS、polarmask也是采用了一种中心采样的结构。这些文中都有提到,全卷积网络可以采用gt_box内的所有点为positive example,但是这样子计算量肯定很大,并且其他靠近bbox的点回归的效果肯定是很差的,因此围绕质心(solo以质心为中心)进行正样本采样是非常合理的。
引用一篇特别棒的转载博客里的图片:博客链接
如图所示,在原图中,蓝色框表示图片等分的格子,这里设置分为5X5个格子。绿色框为目标物体的gt box,黄色框表示缩小到0.2倍数的box,红色框表示负责预测该实例的格子。
下方黑白图为mask分支的target可视化,为了便于显示,这里对不同通道进行了拼接。左边的第一幅图,图中有一个实例,其gt box缩小到0.2倍占据两个格子,因此这两个格子负责预测该实例。
下方的mask分支,只有两个FPN的输出匹配到了该实例,因此在红色格子对应的channel负责预测该实例的mask。第二幅图,图中分布大小不同的实例,可见在FPN输出的mask分支上,从小到大负责不同尺度的实例。
下图是原图的,也很清晰的表达了FPN如何根据不同的gt_areas 以及 实例所处在的网格位置放入对于的channel上预测。首先根据gt_areas将不同的gt放入不同的FPN层。然后再相同层中,如果有多个实例,就会根据设置好的网格,按照某个GT的质心的0.2 * gt_areas(这时候的gt_areas缩小到对应的FPN层输出的feature map的大小)的大小缩放。
1. SOLO/mmdect/models/detectors/single_stage_ins.py
single_stage_ins中实现了backbone(resnet),neck(fpn)以及head(solo_head)的连接以及forward。
import torch.nn as nnfrom mmdet.core import bbox2result
from .. import builder
from ..registry import DETECTORS
from .base import BaseDetector
import pdb@DETECTORS.register_module
class SingleStageInsDetector(BaseDetector):def __init__(self,backbone,neck=None,bbox_head=None,mask_feat_head=None,train_cfg=None,test_cfg=None,pretrained=None):super(SingleStageInsDetector, self).__init__()self.backbone = builder.build_backbone(backbone) # 1.build_backbone --> resnetif neck is not None:self.neck = builder.build_neck(neck) # 2.build_neck --> fpnif mask_feat_head is not None:self.mask_feat_head = builder.build_head(mask_feat_head)#pdb.set_trace()self.bbox_head = builder.build_head(bbox_head) # 3.build_head --> solo headself.train_cfg = train_cfgself.test_cfg = test_cfgself.init_weights(pretrained=pretrained) # 'torchvision://resnet50'def init_weights(self, pretrained=None):super(SingleStageInsDetector, self).init_weights(pretrained)self.backbone.init_weights(pretrained=pretrained)if self.with_neck:if isinstance(self.neck, nn.Sequential):for m in self.neck:m.init_weights()else:self.neck.init_weights()if self.with_mask_feat_head:if isinstance(self.mask_feat_head, nn.Sequential):for m in self.mask_feat_head:m.init_weights()else:self.mask_feat_head.init_weights()#pdb.set_trace()self.bbox_head.init_weights()# forward提取 backbone 和 neck的特征 def extract_feat(self, img):x = self.backbone(img) # resnet forward if self.with_neck:x = self.neck(x) # fpn forwardreturn x'''after neck feature map:x(Pdb) x[0].shapetorch.Size([2, 256, 200, 304])(Pdb) x[1].shapetorch.Size([2, 256, 100, 152])(Pdb) x[2].shapetorch.Size([2, 256, 50, 76])(Pdb) x[3].shapetorch.Size([2, 256, 25, 38])(Pdb) x[4].shapetorch.Size([2, 256, 13, 19])'''def forward_dummy(self, img):x = self.extract_feat(img)outs = self.bbox_head(x)return outsdef forward_train(self,img,img_metas,gt_bboxes,gt_labels,gt_bboxes_ignore=None,gt_masks=None):# 1. img # eg. [torch.Size([2, 3, 800, 1216]) represents the max size of h and w in the img batch_size# 2. img_metas# eg.#[# {'filename': 'data/coco2017/train2017/000000559012.jpg', # 'ori_shape': (508, 640, 3), # 'img_shape': (800, 1008, 3), # 'pad_shape': (800, 1216, 3), # 'scale_factor': 1.8823529411764706, # 'flip': False, # 'img_norm_cfg': {'mean': array([123.675, 116.28 , 103.53 ], dtype=float32), # 'std': array([58.395, 57.12 , 57.375], dtype=float32), # 'to_rgb': True}}, ## {'filename': 'data/coco2017/train2017/000000532426.jpg', # 'ori_shape': (333, 640, 3), 'img_shape': (753, 1333, 3), # 'pad_shape': (800, 1088, 3), 'scale_factor': 2.4024024024024024,# 'flip': False, # 'img_norm_cfg': {'mean': array([123.675, 116.28 , 103.53 ], dtype=float32), # 'std': array([58.395, 57.12 , 57.375], dtype=float32), # 'to_rgb': True}}# ]# 3. gt_bboxes# eg.# gt_bboxes represents 'bbox' of coco datasets# type(gt_bboxes) --> list # len(gt_bboxes) --> batch_size(ie. img per gpu) eg. 2# type(gt_bboxes[idx]) --> tensor# gt_bboxes[idx].size() --> [instances, 4] '4' represents [x1, y1, x2, y2]# [6, 4] [9, 4]# 4. gt_labels# eg.# gt_labels represents 'category_id' of coco datasets# type(gt_labels) --> list # len(gt_labels) --> batch_size(img per gpu) eg. 2# type(gt_labels[idx]) --> tensor# gt_labels[idx].size() --> instances eg. how many categories gt_bboxes[7 or 13, 4] --> gt_labels[7 or 13]# 6 , 9# 5. gt_masks# eg.# type(gt_masks) --> list# len(gt_masks) --> batch_size(img per gpu) eg. 2# type(gt_bboxes[idx]) --> list# (6, 800, 1216) (9, 800, 1088) represents (instances of pad_shape, w, h)x = self.extract_feat(img) # forward backbone and fpn# solo_head forwardouts = self.bbox_head(x) # forward solo_head# outs eg. 各五层# 1.ins_pred:# outs[0][0].size() --> torch.Size([2, 1600, 200, 336])# outs[0][1].size() --> torch.Size([2, 1296, 200, 336]) # outs[0][2].size() --> torch.Size([2, 1024, 100, 168])# outs[0][3].size() --> torch.Size([2, 256, 50, 84])# outs[0][4].size() --> torch.Size([2, 144, 50, 84])# # 2.cate_pred:# outs[1][0].size() --> torch.Size([2, 80, 40, 40])# outs[1][1].size() --> torch.Size([2, 80, 36, 36])# outs[1][2].size() --> torch.Size([2, 80, 24, 24])# outs[1][3].size() --> torch.Size([2, 80, 24, 24])# outs[1][4].size() --> torch.Size([2, 80, 12, 12])# if self.with_mask_feat_head:mask_feat_pred = self.mask_feat_head(x[self.mask_feat_head.start_level:self.mask_feat_head.end_level + 1])loss_inputs = outs + (mask_feat_pred, gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)else:loss_inputs = outs + (gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg) # tuple len(outs) = 2 len(loss_inputs) = 7# compute SOLO losslosses = self.bbox_head.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)return lossesdef simple_test(self, img, img_meta, rescale=False):x = self.extract_feat(img)outs = self.bbox_head(x, eval=True) # when testing , eval = True rescale=Trueif self.with_mask_feat_head: # Falsemask_feat_pred = self.mask_feat_head(x[self.mask_feat_head.start_level:self.mask_feat_head.end_level + 1])seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)else:seg_inputs = outs + (img_meta, self.test_cfg, rescale) # forward backbone fpn and solo_head seg_result = self.bbox_head.get_seg(*seg_inputs) # get_seg()return seg_result def aug_test(self, imgs, img_metas, rescale=False):raise NotImplementedError
2. SOLO/mmdet/models/anchor_heads/solo_head.py
注:一次输入的数据打印在最下方。
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, bbox2roi, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
import pdb
import math
INF = 1e8from scipy import ndimagedef points_nms(heat, kernel=2):# kernel must be 2hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=1)keep = (hmax[:, :, :-1, :-1] == heat).float() # 在tensor相等(a==b) 是返回一个bool类型的矩阵,T or F; 如果加上float(),则返回1 or 0。 可以使用(hmax[:, :, :-1, :-1] == heat).bool()修正回去。return heat * keep # 通过max_pool2d操作后, 返回一个 2*2 中只有一个值非0def dice_loss(input, target):input = input.contiguous().view(input.size()[0], -1) # [instances , w * h]target = target.contiguous().view(target.size()[0], -1).float() # [instances , w * h]a = torch.sum(input * target, 1)b = torch.sum(input * input, 1) + 0.001c = torch.sum(target * target, 1) + 0.001e = (2 * a) / (b + c)print('dice_loss:', 1-e)#pdb.set_trace() # [24]return 1-e@HEADS.register_module
class SOLOHead(nn.Module):def __init__(self,num_classes,in_channels,seg_feat_channels=256,stacked_convs=4,strides=(4, 8, 16, 32, 64),base_edge_list=(16, 32, 64, 128, 256),scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),sigma=0.4,num_grids=None,cate_down_pos=0,with_deform=False,loss_ins=None,loss_cate=None,conv_cfg=None,norm_cfg=None):super(SOLOHead, self).__init__()self.num_classes = num_classes # 81self.seg_num_grids = num_grids # [40, 36, 24, 16, 12]self.cate_out_channels = self.num_classes - 1 # 80self.in_channels = in_channels #256self.seg_feat_channels = seg_feat_channels # 256self.stacked_convs = stacked_convs # 7self.strides = strides # [8, 8, 16, 32, 32]self.sigma = sigma # 0.2self.cate_down_pos = cate_down_pos # 0self.base_edge_list = base_edge_list # (16, 32, 64, 128, 256)self.scale_ranges = scale_ranges # ((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048))self.with_deform = with_deform #False#loss_cate: {'type': 'FocalLoss', 'use_sigmoid': True, 'gamma': 2.0, 'alpha': 0.25, 'loss_weight': 1.0}self.loss_cate = build_loss(loss_cate) # FocalLoss() <class 'mmdet.models.losses.focal_loss.FocalLoss'>self.ins_loss_weight = loss_ins['loss_weight'] # 3self.conv_cfg = conv_cfgself.norm_cfg = norm_cfgself._init_layers()#pdb.set_trace()# init ins_convs, cate_convs, solo_ins_list, solo_catedef _init_layers(self):norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)self.ins_convs = nn.ModuleList()self.cate_convs = nn.ModuleList()for i in range(self.stacked_convs):# coorconv要加x y 2维chn = self.in_channels + 2 if i == 0 else self.seg_feat_channelsself.ins_convs.append(ConvModule(chn,self.seg_feat_channels,3,stride=1,padding=1,norm_cfg=norm_cfg,bias=norm_cfg is None))chn = self.in_channels if i == 0 else self.seg_feat_channelsself.cate_convs.append(ConvModule(chn,self.seg_feat_channels,3,stride=1,padding=1,norm_cfg=norm_cfg,bias=norm_cfg is None))self.solo_ins_list = nn.ModuleList()# 修改 [h, w, 256] --> [h, w, min(h/s, w/s)^2] self.solo_sa_module = nn.ModuleList()# [h, w , 256] ---> [h, w, s*s]# 修改'''for seg_num_grid in self.seg_num_grids:self.solo_ins_list.append(nn.Conv2d(self.seg_feat_channels, seg_num_grid**2, 1))'''for seg_num_grid in self.seg_num_grids:self.solo_ins_list.append(nn.Conv2d(seg_num_grid**2, seg_num_grid**2, 1))# [h, w, 256] --> [h, w, s]self.solo_cate = nn.Conv2d(self.seg_feat_channels, self.cate_out_channels, 3, padding=1)#pdb.set_trace()#初始化权重def init_weights(self):for m in self.ins_convs:normal_init(m.conv, std=0.01)for m in self.cate_convs:normal_init(m.conv, std=0.01)bias_ins = bias_init_with_prob(0.01) # bias_insfor m in self.solo_ins_list: normal_init(m, std=0.01, bias=bias_ins)bias_cate = bias_init_with_prob(0.01) # -4.59511985013459normal_init(self.solo_cate, std=0.01, bias=bias_cate)#pdb.set_trace()def forward(self, feats, eval=False):new_feats = self.split_feats(feats) # 先对feats[0] 以及 feats[4]进行插值 进行缩放# feats:# (Pdb) feats[0].size()# torch.Size([2, 256, 200, 304]) ---> new_feats[0] [2, 256, 100, 152] 缩小# (Pdb) feats[1].size()# torch.Size([2, 256, 100, 152])# (Pdb) feats[3].size()# torch.Size([2, 256, 25, 38])# (Pdb) feats[4].size()# torch.Size([2, 256, 13, 19]) ---> new_feats[4] [2, 256, 25, 38] 放大featmap_sizes = [featmap.size()[-2:] for featmap in new_feats] # h, w# featmap_sizes = [# torch.Size([100, 152]), # torch.Size([100, 152]), # torch.Size([50, 76]), # torch.Size([25, 38]), # torch.Size([25, 38]# )]upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2) # upsampled_size表示原来的最大的fpn层上的 feature map的siz: eg. [320, 200]ins_pred, cate_pred = multi_apply(self.forward_single, new_feats, list(range(len(self.seg_num_grids))),eval=eval, upsampled_size=upsampled_size)return ins_pred, cate_preddef split_feats(self, feats):#len(feats) = 5 (tuple)#pdb.set_trace()# 缩小的插值 scale_factor=0.5# {'P2': 8, 'P3': 8, 'P4': 16, 'P5': 32, 'P6': 32} ---> 可以推出这次输入的图片 [, ] --> fpn缩放return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'), # torch.Size([2, 256, 160, 100])feats[1], # torch.Size([2, 256, 160, 100])feats[2], # torch.Size([2, 256, 80, 50])feats[3], # torch.Size([2, 256, 40, 25])F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))# torch.Size([2, 256, 40, 25])def forward_single(self, x, idx, eval=False, upsampled_size=None):# 执行5次 对应FPN的5层 分别构造head# x = torch.Size([2, 256, 160, 100]) # idx = 0# upsampled_size = (320, 200)#pdb.set_trace()ins_feat = xdevice = ins_feat.deviceprint(device)cate_feat = x# ins branch# concat CoordConvx_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)y, x = torch.meshgrid(y_range, x_range)y = y.expand([ins_feat.shape[0], 1, -1, -1]) # N, 1, h/strides, w/stridesx = x.expand([ins_feat.shape[0], 1, -1, -1]) # N, 1, h/strides, w/stridescoord_feat = torch.cat([x, y], 1) # [N, 2, w, h]# channels: 256 --> 258 [N, 256, w, h] --> [N, 258, w, h]ins_feat = torch.cat([ins_feat, coord_feat], 1)# in_convs 7个conv forwardfor i, ins_layer in enumerate(self.ins_convs):ins_feat = ins_layer(ins_feat)#pdb.set_trace()# 第一次修改sa_feat = []# [152, 100] --> [160, 120]sa_h = math.ceil(ins_feat.size()[2] / self.seg_num_grids[idx]) #if (ins_feat.size()[2] % self.seg_num_grids[idx]) != 0:# sa_h = sa_h + 1sa_w = math.ceil(ins_feat.size()[3] / self.seg_num_grids[idx])#if (ins_feat.size()[3] % self.seg_num_grids[idx]) != 0:#sa_w = sa_w + 1# interpolate# 插值后: ins_feat [2, 256, 160, 120]ins_feat = F.interpolate(ins_feat, size=(self.seg_num_grids[idx] * sa_h, self.seg_num_grids[idx] * sa_w), mode='bilinear') # ins_sa_feat [2, 40*40, 160, 120]#ins_sa_feat = torch.zeros(ins_feat.size()[0], self.seg_num_grids[idx] * self.seg_num_grids[idx], ins_feat.size()[2], ins_feat.size()[3],device=device)seg_num_grids = self.seg_num_grids[idx]abc = []for i in range(seg_num_grids):for j in range(seg_num_grids):weight = ins_feat[:, :, i * sa_h : (i + 1) * sa_h, j * sa_w : (j + 1) * sa_w].repeat(1, 1, seg_num_grids, seg_num_grids)abc.append((weight * ins_feat).sum(1))ins_pred = torch.stack(abc, dim=1)#print(ins_pred.shape)'''基于boss方法的改进,此部分可以直接跳过~# 第一次修改速度太慢for i in range(seg_num_grids * seg_num_grids):grid_in_row = i % seg_num_gridsrow = i // seg_num_grids sa = ins_feat[:, :, row*sa_h : row*sa_h + sa_h, grid_in_row*sa_w : grid_in_row*sa_w + sa_w].cuda()for j in range(seg_num_grids):for k in range(seg_num_grids):ins_sa_feat[:, i, j*sa_h : j*sa_h + sa_h, k*sa_w : k*sa_w + sa_w] = (sa * ins_feat[:, :, j*sa_h : j*sa_h + sa_h, k*sa_w : k*sa_w + sa_w]).sum(dim = 1) ins_sa_feat = ins_sa_feat.cuda()''''''# 第二次修改# --------------------------------------------------------------------------------------------------------------------## 1. 分成 sa_h * sa_w 个 seg_num_grids * seg_num_grids的mask特征图 # --------------------------------------------------------------------------------------------------------------------#mask_list =[]for i in range(sa_h):for j in range(sa_w):mask_list.append(ins_feat[:, :, i::sa_h, j::sa_w]) # mask_list[i].size() = [n, 256, seg_num_grids, seg_num_grids]#print(len(mask_list)) # len = sa_h * sa_w #pdb.set_trace()# --------------------------------------------------------------------------------------------------------------------## 2. sa_h * sa_w 的self-attention# --------------------------------------------------------------------------------------------------------------------#all_sa_feat = []per_sa_feat = []for i in range(sa_h * sa_w):ori_n = mask_list[i].size()[0]ori_c = mask_list[i].size()[1]n_c_hw = mask_list[i].reshape(ori_n, ori_c, -1) # [n, c, hw]#tmp_n_c_hw = n_c_hw.clone()n_c_hw_T = n_c_hw.permute(0, 2, 1) #[n, hw, c]tmp = torch.matmul(n_c_hw_T, n_c_hw) # [n, hw, c] x [n, c, hw] == [n, hw, hw]stack_sa_feat = tmp.reshape(ori_n, seg_num_grids * seg_num_grids, seg_num_grids, -1) # [n, s*s, s, s] all_sa_feat.append(stack_sa_feat)# --------------------------------------------------------------------------------------------------------------------## 3. 将同一行的seg_num_grids个元素矩阵先拼接 eg: xxxxyyyyzzzzccccc --> xyzc xyzc xyzc# --------------------------------------------------------------------------------------------------------------------#cat_all_row_feat = []for i in range(0, sa_w * sa_h, sa_w):cat_row_feat = torch.cat([feat for feat in all_sa_feat[i : i + sa_w]], dim = 3)cat_all_row_feat.append(cat_row_feat) #print(len(cat_all_row_feat))#pdb.set_trace()# --------------------------------------------------------------------------------------------------------------------## 4. 先交换cat_all_row_feat中的每一列# --------------------------------------------------------------------------------------------------------------------#all_new_row_feat_list = [] #交换好后的4个tensor的新行 xyxy abab cdcd fgfgfor i in range(0, len(cat_all_row_feat)):per_new_row_feat_list = [] # eg. xyxy or abab for j in range(0, seg_num_grids):per_row_feat = cat_all_row_feat[i][:, :, :, j::seg_num_grids] # Tensorper_new_row_feat_list.append(per_row_feat)all_new_row_feat_list.append(torch.cat(per_new_row_feat_list, dim = 3)) # 交换好后#print('len(all_new_row_feat_list):', len(all_new_row_feat_list))#pdb.set_trace()# --------------------------------------------------------------------------------------------------------------------## 5. 在此基础上继续在列上拼接 # --------------------------------------------------------------------------------------------------------------------##for feat in all_new_row_feat_list:#print(feat.size())cat_all_col_feat = torch.cat([feat for feat in all_new_row_feat_list], dim = 2)#print('cat_all_col_feat.size():', cat_all_col_feat.size())#pdb.set_trace()# --------------------------------------------------------------------------------------------------------------------## 6. 交换行# --------------------------------------------------------------------------------------------------------------------#per_new_col_feat_list = [] #交换好后的4个tensor的新行 xyxy abab cdcd fgfgfor i in range(0, seg_num_grids):# eg. xyxy # abab per_col_feat = cat_all_col_feat[:, :, i::seg_num_grids, :] # Tensorper_new_col_feat_list.append(per_col_feat)all_new_col_feat = torch.cat(per_new_col_feat_list, dim = 2) # 交换好后ins_sa_feat = all_new_col_feat.to(device)#print('ins_sa_feat.size(): ', ins_sa_feat.size())#print(ins_sa_feat)#pdb.set_trace()'''# --------------------------------------------------------------------------------------------------------------------## 修改截止# --------------------------------------------------------------------------------------------------------------------## w x h x 256 --> 2w x 2h x 256#ins_feat = F.interpolate(ins_feat, scale_factor=2, mode='bilinear')ins_pred = F.interpolate(ins_pred, scale_factor=2, mode='bilinear')# eg. torch.Size([2, 1600 or 1296 or 576 or 256 or 144, 2H/strides, 2W/strides])# 新的修改ins_pred = self.solo_ins_list[idx](ins_pred) # [N, 256, 2w, 2h] --> [N, S*S, 2w, 2h] eg. torch.Size([2, 1600, 200, 304])# cate branchfor i, cate_layer in enumerate(self.cate_convs):if i == self.cate_down_pos: # when i == 0seg_num_grid = self.seg_num_grids[idx] # [40, 36, 24, 16, 12]cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear') # 缩放cate_feat = cate_layer(cate_feat)# channels: 256 --> 80cate_pred = self.solo_cate(cate_feat)if eval:ins_pred = F.interpolate(ins_pred.sigmoid(), size=upsampled_size, mode='bilinear') # 注意:把5个fpn层全部插值成同一个尺寸!根据upsampled_size, eval时放大到原图的1/4 eg. [1, 1600, 200, 304]cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1) # [N, h, w, c] eg. [1, 40, 40, 80]# 返回 分类和实例的最后一层结果。return ins_pred, cate_preddef loss(self,ins_preds,cate_preds,gt_bbox_list,gt_label_list,gt_mask_list,img_metas,cfg,gt_bboxes_ignore=None):featmap_sizes = [featmap.size()[-2:] for featmap inins_preds]ins_label_list, cate_label_list, ins_ind_label_list = multi_apply(self.solo_target_single,gt_bbox_list,gt_label_list,gt_mask_list,featmap_sizes=featmap_sizes)#testins_labels = []temp_2 = []#ins_labels_2 =[] # 循环 5次# ins_labels_level :# eg. ins_labels_level[0].size() torch.Size([1296, 200, 272]) # ins_labels_level[1].size() torch.Size([1296, 200, 272])for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list),zip(*ins_ind_label_list)):temp = []#pdb.set_trace()for ins_labels_level_img, ins_ind_labels_level_img in zip(ins_labels_level, ins_ind_labels_level):temp.append(ins_labels_level_img[ins_ind_labels_level_img, ...]) # [instances, 200, 304]#pdb.set_trace()temp_2 = torch.cat(temp, 0) # batch_size的每个图片的每一层ins_labels.append(temp_2)# ins'''# zip() 与 zip(*)相反ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]for ins_labels_level_img, ins_ind_labels_level_img inzip(ins_labels_level, ins_ind_labels_level)], 0)for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))] # len(ins_label_list) = batchsize''''''temp_2 = [] for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list)):temp = []for ins_preds_level_img, ins_ind_labels_level_img in zip(ins_preds_level, ins_ind_labels_level):temp.append(ins_preds_level_img[ins_ind_labels_level_img, ...])temp_2 = torch.cat(temp, 0)ins_preds.append(temp_2)pdb.set_trace()'''ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]for ins_preds_level_img, ins_ind_labels_level_img inzip(ins_preds_level, ins_ind_labels_level)], 0)for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]#pdb.set_trace()ins_ind_labels = [] temp_2 = []for ins_ind_labels_level in zip(*ins_ind_label_list):temp = [] for ins_ind_labels_level_img in ins_ind_labels_level:temp.append(ins_ind_labels_level_img.flatten())temp_2 = torch.cat(temp)ins_ind_labels.append(temp_2)#pdb.set_trace()'''ins_ind_labels = [torch.cat([ins_ind_labels_level_img.flatten()for ins_ind_labels_level_img in ins_ind_labels_level])for ins_ind_labels_level in zip(*ins_ind_label_list)]'''flatten_ins_ind_labels = torch.cat(ins_ind_labels) # 3872 * batch_sizenum_ins = flatten_ins_ind_labels.sum() # 计算有多少正样本 相当于把元素是True的加起来#pdb.set_trace()# dice lossloss_ins = []# 对于ins 使用 gt ins_labels 与 pre ins_preds 求lossfor input, target in zip(ins_preds, ins_labels): # ins_preds 与 ins_labels维度一样, ins_preds[0]数值, ins_labels[0]是0,1if input.size()[0] == 0: # no inscontinueinput = torch.sigmoid(input) # sigmoidloss_ins.append(dice_loss(input, target))loss_ins = torch.cat(loss_ins).mean()loss_ins = loss_ins * self.ins_loss_weightprint('loss_ins: ', loss_ins)# catecate_labels = [torch.cat([cate_labels_level_img.flatten()for cate_labels_level_img in cate_labels_level])for cate_labels_level in zip(*cate_label_list)]flatten_cate_labels = torch.cat(cate_labels) # 3872 * batch_size# 对于cate 同样使用gt cate_labels 与 pre cate_preds求losscate_preds = [cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels) # [s*s , C]for cate_pred in cate_preds]'''(Pdb) cate_preds[0].size()torch.Size([3200, 80]) 3200 = 1600 *2 --> [40, 40, 80](Pdb) cate_preds[1].size()torch.Size([2592, 80])(Pdb) cate_preds[2].size()torch.Size([1152, 80])(Pdb) cate_preds[3].size()torch.Size([512, 80])(Pdb) cate_preds[4].size()torch.Size([288, 80])'''flatten_cate_preds = torch.cat(cate_preds) # [3782 * instance, 80] 5个fpn最后的feature map的channel相加loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1) # num_ins表示的是ins_preds[0:4]上的的第一维度相加, 表示一共实例的个数。return dict(loss_ins=loss_ins,loss_cate=loss_cate)def solo_target_single(self,gt_bboxes_raw,gt_labels_raw,gt_masks_raw,featmap_sizes=None):# 每次读取一张图片,根据gt_areas算图中的每一个实例在FPN的哪一层# gt_bboxes_raw.size() --> [7, 4]# gt_labels_raw --> 7# gt_masks_raw --> [7, 800, 1024]# featmap_sizes --> [torch.Size([200, 336]), torch.Size([200, 336]), torch.Size([100, 168]), torch.Size([50, 84]), torch.Size([50, 84])]device = gt_labels_raw[0].device # cuda# ins# compute the gt_areas of per gt in one img.# gt_areas.size() --> [instances]gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))ins_label_list = []cate_label_list = []ins_ind_label_list = []for (lower_bound, upper_bound), stride, featmap_size, num_grid \in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device) # eg. [40 * 40, 200, 336]cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device) # [40, 40]ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device) # [1600]# nonzero()返回非0索引的位置。# flatten()展平操作hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten() # 代表在这一层 预测的实例的gt索引 也就是哪一个示例会出现在这层#pdb.set_trace()if len(hit_indices) == 0:ins_label_list.append(ins_label)cate_label_list.append(cate_label)ins_ind_label_list.append(ins_ind_label)continuegt_bboxes = gt_bboxes_raw[hit_indices] # store gt_bboxes[x1,y1,x2,y2] when gt_areas belong to [lower_bound , upper_bound] ---> eg.[1, 4]gt_labels = gt_labels_raw[hit_indices] # [instances] when gt_areas belong to [lower_bound , upper_bound ---> eg.[57]gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...] # [instances , w, h] --> eg. [1, 800, 1216]half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma # self.sigma = 0.2half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigmaoutput_stride = stride / 2 # 每次只挑出一个instancefor seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):if seg_mask.sum() < 10:continue# mass centerupsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)center_h, center_w = ndimage.measurements.center_of_mass(seg_mask) # 算质心coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid)) # 将质心 转化为 num_grid的坐标 eg. [659, 398] --> [29, 11] when num_grid = 36coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))# left, top, right, downtop_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))top = max(top_box, coord_h-1) # 6down = min(down_box, coord_h+1) # 8left = max(coord_w-1, left_box) # 6right = min(right_box, coord_w+1) # 8# catecate_label[top:(down+1), left:(right+1)] = gt_label # eg. 将[6,8]# insseg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride) # [800, 1088] --> [50, 68] 因为是[2h, 2w] 因此少缩小2倍seg_mask = torch.Tensor(seg_mask)for i in range(top, down+1):for j in range(left, right+1):label = int(i * num_grid + j)ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask # 存储在 s*s的某个通道上ins_ind_label[label] = True # s*s 中哪一个网格有实例ins_label_list.append(ins_label)cate_label_list.append(cate_label)ins_ind_label_list.append(ins_ind_label)#pdb.set_trace()return ins_label_list, cate_label_list, ins_ind_label_listdef get_seg(self, seg_preds, cate_preds, img_metas, cfg, rescale=None): # len(seg_preds):5 len(cate_preds):5#pdb.set_trace()assert len(seg_preds) == len(cate_preds)num_levels = len(cate_preds) # 5featmap_size = seg_preds[0].size()[-2:] # max fpn feature map size : [200, 304]result_list = []for img_id in range(len(img_metas)):cate_pred_list = [cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)]seg_pred_list = [seg_preds[i][img_id].detach() for i in range(num_levels)]img_shape = img_metas[img_id]['img_shape']scale_factor = img_metas[img_id]['scale_factor']ori_shape = img_metas[img_id]['ori_shape']cate_pred_list = torch.cat(cate_pred_list, dim=0) #每次读取one img, 因此cate_pred_list.size() --> [3872, 80]seg_pred_list = torch.cat(seg_pred_list, dim=0)result = self.get_seg_single(cate_pred_list, seg_pred_list,featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)result_list.append(result)#pdb.set_trace()#pdb.set_trace()return result_list# 对于每一个图片。def get_seg_single(self,cate_preds, # [3872, 80]seg_preds, # eg. [3872, 200, 304]featmap_size, # eg. [200, 304] max feature map in FPNimg_shape, # eg. [800, 1199, 3]ori_shape, # eg. [427, 640, 3]scale_factor,cfg,rescale=False, debug=False):assert len(cate_preds) == len(seg_preds)#pdb.set_trace()# overall info.h, w, _ = img_shapeupsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) # eg. [800, 1216]# process.inds = (cate_preds > cfg.score_thr) # 第一次筛选 eg. [3872, 80] score_thr = 0.1 inds 是 bool类型 # category scores.cate_scores = cate_preds[inds] # eg.[507] cate_scores是数值,维度是[num[True]](我认为还降维了), 根据cate_preds[inds] 在对于true的地方输出if len(cate_scores) == 0: return None# category labels.inds = inds.nonzero() # 返回inds[i]为True的索引 inds.nonzero().size() --> [507, 2]cate_labels = inds[:, 1] # inds的第二列是代表的[80]中的类别。 cate_labels --> [507]# strides.size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0) # tensor([1600, 2896, 3472, 3728, 3872], device='cuda:0')strides = cate_scores.new_ones(size_trans[-1]) # [3872] 全为1n_stage = len(self.seg_num_grids) # 5strides[:size_trans[0]] *= self.strides[0] # 前1600个元素由 1 变成 8for ind_ in range(1, n_stage):strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.strides[ind_] # eg. 为1600 ~ 2896的1296个元素 赋值strides = strides[inds[:, 0]] # strides.size() --> [507] inds[:, 0] 表示第几个grid_cell# masks. seg_preds = seg_preds[inds[:, 0]] # [3872, 200, 304] --> [507, 200, 304]seg_masks = seg_preds > cfg.mask_thr # mask_thr = 0.5 bool [507, 200, 304] --> binary mask 二值化的作用!sum_masks = seg_masks.sum((1, 2)).float() # [507, 200, 304] ---> [507] sum(1,2)表示对每一个channcel内的[H * W]的每个元素求和# filter.keep = sum_masks > strides #bool [507]if keep.sum() == 0:return None#过滤seg_masks = seg_masks[keep, ...] # bool [keep.size(), 200, 304] seg_mask[True]的位置保持原来的seg_mask的值(T or F), seg_mask[False]的位置直接取舍不记录。seg_preds = seg_preds[keep, ...]sum_masks = sum_masks[keep]cate_scores = cate_scores[keep]cate_labels = cate_labels[keep]# mask scoring.seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks # eg, [507] 每一个channel上的对应元素相乘再求和最后除以cate_scores *= seg_scores # why?# sort and keep top nms_presort_inds = torch.argsort(cate_scores, descending=True)if len(sort_inds) > cfg.nms_pre: # 筛选前500sort_inds = sort_inds[:cfg.nms_pre]seg_masks = seg_masks[sort_inds, :, :]seg_preds = seg_preds[sort_inds, :, :]sum_masks = sum_masks[sort_inds]cate_scores = cate_scores[sort_inds]cate_labels = cate_labels[sort_inds]#pdb.set_trace()# Matrix NMScate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)# filter.keep = cate_scores >= cfg.update_thrif keep.sum() == 0:return Noneseg_preds = seg_preds[keep, :, :]cate_scores = cate_scores[keep]cate_labels = cate_labels[keep]# sort and keep top_ksort_inds = torch.argsort(cate_scores, descending=True)if len(sort_inds) > cfg.max_per_img:sort_inds = sort_inds[:cfg.max_per_img]seg_preds = seg_preds[sort_inds, :, :]cate_scores = cate_scores[sort_inds]cate_labels = cate_labels[sort_inds]seg_preds = F.interpolate(seg_preds.unsqueeze(0),size=upsampled_size_out,mode='bilinear')[:, :, :h, :w]seg_masks = F.interpolate(seg_preds,size=ori_shape[:2],mode='bilinear').squeeze(0)seg_masks = seg_masks > cfg.mask_thr#pdb.set_trace()return seg_masks, cate_labels, cate_scores#----------------------------------------------------------------------------------------#
#self.ins_convs:
'''
ModuleList((0): ConvModule((conv): Conv2d(258, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(1): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(2): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(3): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(4): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(5): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(6): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))
)'''#----------------------------------------------------------------------------------------#
#self.cate_convs
'''ModuleList((0): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(1): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(2): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(3): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(4): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(5): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))(6): ConvModule((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gn): GroupNorm(32, 256, eps=1e-05, affine=True)(activate): ReLU(inplace=True))'''
#----------------------------------------------------------------------------------------#
# self.solo_ins_list
'''
ModuleList((0): Conv2d(256, 1600, kernel_size=(1, 1), stride=(1, 1))(1): Conv2d(256, 1296, kernel_size=(1, 1), stride=(1, 1))(2): Conv2d(256, 576, kernel_size=(1, 1), stride=(1, 1))(3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))(4): Conv2d(256, 144, kernel_size=(1, 1), stride=(1, 1))
)
''''''
ins_pred:(Pdb) ins_pred[0].size()torch.Size([2, 1600, 200, 304])(Pdb) ins_pred[1].size()torch.Size([2, 1296, 200, 304])(Pdb) ins_pred[2].size()torch.Size([2, 576, 100, 152])(Pdb) ins_pred[3].size()torch.Size([2, 256, 50, 76])(Pdb) ins_pred[4].size()torch.Size([2, 144, 50, 76])''''''
cate_pred:(Pdb) cate_pred[0].size()torch.Size([2, 80, 40, 40])(Pdb) cate_pred[1].size()torch.Size([2, 80, 36, 36])(Pdb) cate_pred[2].size()torch.Size([2, 80, 24, 24])(Pdb) cate_pred[3].size()torch.Size([2, 80, 16, 16])(Pdb) cate_pred[4].size()torch.Size([2, 80, 12, 12])'''
#----------------------------------------------------------------------------------------##def loss'''
ins_labels(Pdb) ins_labels[0].size()torch.Size([1, 200, 272])(Pdb) ins_labels[1].size()torch.Size([0, 200, 272])(Pdb) ins_labels[2].size()torch.Size([16, 100, 136])(Pdb) ins_labels[3].size()torch.Size([39, 50, 68])(Pdb) ins_labels[4].size()torch.Size([18, 50, 68])
''''''
ins_preds:(Pdb) ins_preds[0].size()torch.Size([1, 200, 272])(Pdb) ins_preds[1].size()torch.Size([0, 200, 272])(Pdb) ins_preds[2].size()torch.Size([6, 100, 136])(Pdb) ins_preds[3].size()torch.Size([10, 50, 68])(Pdb) ins_preds[4].size()torch.Size([6, 50, 68])
''''''
ins_ind_labels:(Pdb) ins_ind_labels[0].size()torch.Size([1600])(Pdb) ins_ind_labels[1].size()torch.Size([1296])(Pdb) ins_ind_labels[2].size()torch.Size([576])(Pdb) ins_ind_labels[3].size()torch.Size([256])(Pdb) ins_ind_labels[4].size()torch.Size([144])''''''
cate_labels:(Pdb) cate_labels[0].size()torch.Size([1600])(Pdb) cate_labels[1].size()torch.Size([1296])(Pdb) cate_labels[2].size()torch.Size([576])(Pdb) cate_labels[3].size()torch.Size([256])(Pdb) cate_labels[4].size()torch.Size([144])''''''
get_seg
cfg:{'nms_pre': 500, 'score_thr': 0.1,'mask_thr': 0.5, 'update_thr': 0.05, 'kernel': 'gaussian', 'sigma': 2.0, 'max_per_img': 100}''''''
sum_masks
tensor([ 96., 96., 82., 82., 82., 108., 108., 108., 86.,86., 86., 208., 227., 227., 227., 134., 134., 88.,28., 79., 79., 231., 231., 231., 189., 189., 31.,31., 125., 125., 125., 158., 158., 194., 99., 99.,74., 159., 37., 37., 37., 39., 39., 275., 50.,31., 64., 64., 64., 64., 66., 66., 66., 66.,91., 91., 91., 93., 192., 192., 192., 46., 46.,46., 39., 39., 51., 51., 87., 140., 181., 199.,50., 50., 50., 50., 76., 20., 88., 88., 84.,84., 84., 236., 236., 94., 211., 211., 252., 85.,98., 56., 96., 96., 60., 60., 60., 53., 84.,84., 84., 84., 258., 267., 304., 90., 105., 105.,105., 75., 75., 75., 53., 53., 84., 84., 132.,274., 274., 259., 259., 296., 296., 296., 272., 272.,272., 272., 112., 117., 50., 87., 143., 143., 80.,88., 88., 273., 273., 320., 320., 294., 364., 313.,355., 302., 353., 353., 67., 67., 42., 32., 32.,61., 61., 61., 61., 68., 68., 68., 68., 168.,168., 168., 28., 28., 28., 67., 71., 139., 282.,304., 94., 169., 135., 135., 286., 331., 100., 100.,100., 95., 95., 172., 277., 277., 277., 371., 380.,92., 92., 160., 394., 394., 395., 132., 132., 157.,295., 282., 452., 468., 66., 66., 209., 73., 73.,73., 352., 360., 333., 25., 205., 229., 229., 229.,491., 491., 488., 488., 488., 449., 449., 234., 255.,255., 255., 255., 630., 514., 514., 514., 481., 481.,481., 871., 1029., 260., 260., 260., 260., 639., 514.,484., 168., 168., 415., 81., 1120., 1232., 418., 418.,128., 141., 242., 242., 91., 57., 57., 80., 80.,80., 621., 1248., 1315., 199., 304., 210., 78., 54.,54., 62., 62., 62., 622., 697., 697., 663., 663.,149., 118., 108., 109., 109., 202., 218., 218., 275.,275., 357., 357., 357., 361., 361., 102., 111., 111.,448., 279., 356., 347., 347., 271., 293., 288., 288.,288., 277., 277., 271., 271., 131., 131., 162., 162.,162., 132., 132., 107., 362., 452., 452., 571., 361.,360., 438., 714., 404., 427., 613., 395., 411., 438.,438., 471., 529., 546., 52., 52., 85., 85., 85.,181., 181., 336., 359., 183., 353., 370., 98., 98.,98., 191., 191., 268., 268., 340., 340., 736., 346.,380., 94., 94., 94., 179., 179., 412., 437., 437.,437., 1087., 560., 398., 925., 925., 802., 802., 802.,375., 834., 847., 512., 944., 508., 48., 274., 82.,82., 82., 482., 444., 491., 491., 1281., 679., 679.,571., 571., 571., 1403., 583., 647., 1429., 940., 721.,721., 313., 1953., 3322., 3694., 3694., 2245., 2187., 1180.,3924., 3924., 3963., 1622., 2566., 3506., 1246., 2082., 4032.,4067., 474., 567., 567., 1675., 2513., 3013., 1489., 709.,900., 900., 769., 2537., 689., 1485., 2476., 416., 1449.,706., 2477., 3185., 3221., 413., 2756., 3230., 3230., 3156.,424., 465., 2933., 2846., 474., 474., 940., 940., 851.,851., 851., 553., 1572., 5856., 3666., 4373., 3937., 2129.,4194., 4586., 2788., 2683., 4081., 3171., 3171., 3894., 4206.,1353., 1984., 3575., 3303., 2040., 3688., 3688., 7555., 8147.,9637., 10042., 7735., 9848., 10357., 6124., 10311., 10753., 5137.,4384., 6858., 4768., 4397., 6499., 10237., 10237., 9333., 9333.,9033., 9723., 9955.], device='cuda:0')''''''strides
tensor([ 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 16., 16.,16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32.,32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32.,32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32., 32.,32., 32., 32.], device='cuda:0')'''
3. SOLO/mmdetect/core/post_processing/matrix_nms.py
import torch
import pdbdef matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None):"""Matrix NMS for multi-class masks.Args:seg_masks (Tensor): shape (n, h, w) boolcate_labels (Tensor): shape (n), mask labels in descending ordercate_scores (Tensor): shape (n), mask scores in descending orderkernel (str): 'linear' or 'gauss' sigma (float): std in gaussian methodsum_masks (Tensor): The sum of seg_masksReturns:Tensor: cate_scores_update, tensors of shape (n)"""pdb.set_trace()n_samples = len(cate_labels) # 最多 500if n_samples == 0:return []if sum_masks is None:sum_masks = seg_masks.sum((1, 2)).float()seg_masks = seg_masks.reshape(n_samples, -1).float() # [500, 60800] 相当于把同一个实例的特征展平# inter. 注: 矩阵相乘就表示了每一个channel上某一个实例的掩码所在所在位置上的值(1or0)与其他通道的mask所在位置的值相乘# 2个特例:# 就算相同类别,如果位置不同,那么他们inter也是0,如果位置相同,就涉及到了NMS筛选的范畴# (1)如果他们位置不同,那么就必定是为0的,不能仅仅考虑类别相同! # (2)并且可能不同的实例一大一小,但是他们位置有相交,那么也有交集!不同实例相同位置的IOU排除方法见下面的label_matrix的使用。inter_matrix = torch.mm(seg_masks, seg_masks.transpose(1, 0)) # [500 , 60800] @ [60800 , 500] = [500, 500] # union.sum_masks_x = sum_masks.expand(n_samples, n_samples) # [500, 500]# iou.# 掩码值相加代表了union 取上三角(转置肯定有重复。)iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1)# label_specific matrix.cate_labels_x = cate_labels.expand(n_samples, n_samples) # [500, 500]# 每i行的元素(1 or 0),1表示和第i个mask类别一样的。 并且使用了triu方法,进一步的得到分数比他低的的mask(triu方法的妙用)# 因此在已经排除了同一种label不同位置的情况,这一步就是排除同一个位置,不同label,它们的iou也要置于0label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1) # [500, 500] # IoU decay # iou_matrix * label_matrix是为了保留同一种小于最大scores的label的iou。# 因为之前算的iou的inter部分有可能一大一小的实例,但是他们位置上有重叠,因此还有iou并不等于0,要进行惩罚# 而消除不同label的iou(因为nms就是对同一个类别的scores高低的mask/box进行筛选最后剩下一个)# 第一个式子排除结束。得到同种mask同一位置的IOU,每i行表示与第i个mask的iou。decay_iou = iou_matrix * label_matrix'''(Pdb) decay_iou = (iou_matrix * label_matrix) 上三角。tensor([[0.0000, 0.8036, 0.5017, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.4816, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],...,[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0127],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],device='cuda:0')'''# IoU compensationcompensate_iou, _ = (iou_matrix * label_matrix).max(0) # fast-nms 按列取最大值(不是同一类的mask直接就0不考虑了),第i列表示第i个mask与跟它同种mask最大的scores最大的iou值# 分析:# eg. # 前3列都是第一个mask的预测,按照scores排列第一个是最大的,所以第一列的max就是0;# 注意看第三列,max是0.5017,这个0.5是和第一个mask相比的,而不取0.47(如果thr是0.5就不会被排除)。# 这就是**fast-nms尽可能去掉更多的框的核心思想**。'''compensate_ioutensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],[0.8036, 0.8036, 0.8036, ..., 0.8036, 0.8036, 0.8036],[0.5017, 0.5017, 0.5017, ..., 0.5017, 0.5017, 0.5017],...,[0.0021, 0.0021, 0.0021, ..., 0.0021, 0.0021, 0.0021],[0.0054, 0.0054, 0.0054, ..., 0.0054, 0.0054, 0.0054],[0.0193, 0.0193, 0.0193, ..., 0.0193, 0.0193, 0.0193]],'''compensate_iou = compensate_iou.expand(n_samples, n_samples).transpose(1, 0)# matrix nmsif kernel == 'gaussian': decay_matrix = torch.exp(-1 * sigma * (decay_iou ** 2))compensate_matrix = torch.exp(-1 * sigma * (compensate_iou ** 2))decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) # 分析min(0)按列取最小的作用:# 如下面的eg, 因为经过了指数函数e, 原来为0的表示最大的score或者无iou缩减的值就要变为1。原来对于每一个mask,次大的得分的scores就会变小。# 按列取最小应该算出对每一个mask的scores抑制的大小。(这里的decay_iou只会算同label的mask了。)'''(Pdb) decay_matrix / compensate_matrixtensor([[1.0000, 0.2748, 0.6044, ..., 1.0000, 1.0000, 1.0000],[3.6388, 3.6388, 2.2883, ..., 3.6388, 3.6388, 3.6388],[1.6545, 1.6545, 1.6545, ..., 1.6545, 1.6545, 1.6545],...,[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 0.9997],[1.0001, 1.0001, 1.0001, ..., 1.0001, 1.0001, 1.0001],[1.0007, 1.0007, 1.0007, ..., 1.0007, 1.0007, 1.0007]],device='cuda:0')'''pdb.set_traceelif kernel == 'linear':decay_matrix = (1-decay_iou)/(1-compensate_iou)decay_coefficient, _ = decay_matrix.min(0)else:raise NotImplementedError# update the score.cate_scores_update = cate_scores * decay_coefficient # soft-nms的方法 让相同的label但是scores低与max的变小。pdb.set_trace()return cate_scores_updatedef multiclass_nms(multi_bboxes,multi_scores,score_thr,nms_cfg,max_num=-1,score_factors=None):"""NMS for multi-class bboxes.Args:multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)multi_scores (Tensor): shape (n, #class), where the 0th columncontains scores of the background class, but this will be ignored.score_thr (float): bbox threshold, bboxes with scores lower than itwill not be considered.nms_thr (float): NMS IoU thresholdmax_num (int): if there are more than max_num bboxes after NMS,only top max_num will be kept.score_factors (Tensor): The factors multiplied to scores beforeapplying NMSReturns:tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labelsare 0-based."""num_classes = multi_scores.shape[1]bboxes, labels = [], []nms_cfg_ = nms_cfg.copy()nms_type = nms_cfg_.pop('type', 'nms')nms_op = getattr(nms_wrapper, nms_type)for i in range(1, num_classes):cls_inds = multi_scores[:, i] > score_thrif not cls_inds.any():continue# get bboxes and scores of this classif multi_bboxes.shape[1] == 4:_bboxes = multi_bboxes[cls_inds, :]else:_bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4]_scores = multi_scores[cls_inds, i]if score_factors is not None:_scores *= score_factors[cls_inds]cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1)cls_dets, _ = nms_op(cls_dets, **nms_cfg_)cls_labels = multi_bboxes.new_full((cls_dets.shape[0], ),i - 1,dtype=torch.long)bboxes.append(cls_dets)labels.append(cls_labels)if bboxes:bboxes = torch.cat(bboxes)labels = torch.cat(labels)if bboxes.shape[0] > max_num:_, inds = bboxes[:, -1].sort(descending=True)inds = inds[:max_num]bboxes = bboxes[inds]labels = labels[inds]else:bboxes = multi_bboxes.new_zeros((0, 5))labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)return bboxes, labels'''
(Pdb) cate_scores * decay_coefficient
tensor([0.7593, 0.1010, 0.1081, 0.5393, 0.0926, 0.0885, 0.4901, 0.4664, 0.4540,0.0755, 0.4385, 0.3944, 0.0726, 0.0748, 0.0986, 0.0551, 0.0835, 0.0694,0.0822, 0.3194, 0.0600, 0.3141, 0.0594, 0.3115, 0.3114, 0.3086, 0.0771,0.0792, 0.0597, 0.0512, 0.0569, 0.3018, 0.0461, 0.0550, 0.0537, 0.0662,0.0580, 0.0644, 0.0503, 0.2881, 0.2839, 0.2830, 0.0561, 0.1310, 0.2692,0.0652, 0.0694, 0.0505, 0.0410, 0.0464, 0.0665, 0.0409, 0.2440, 0.0407,0.0464, 0.0410, 0.2291, 0.0447, 0.1051, 0.2260, 0.2241, 0.2236, 0.2233,0.0529, 0.1370, 0.2200, 0.0540, 0.0532, 0.0473, 0.0530, 0.2168, 0.2134,0.0678, 0.0478, 0.0384, 0.0407, 0.1161, 0.0320, 0.0619, 0.2025, 0.0388,0.0331, 0.0493, 0.0866, 0.0849, 0.0413, 0.0593, 0.0593, 0.0388, 0.0389,0.0738, 0.1875, 0.0674, 0.1145, 0.0588, 0.0806, 0.1797, 0.0382, 0.1776,0.1751, 0.0489, 0.0511, 0.1743, 0.0815, 0.1741, 0.0582, 0.0925, 0.0317,0.0318, 0.1661, 0.1645, 0.0297, 0.1634, 0.1629, 0.0446, 0.0389, 0.0318,0.1611, 0.1445, 0.0564, 0.0337, 0.1564, 0.1563, 0.0331, 0.1556, 0.0605,0.1533, 0.1526, 0.0254, 0.1477, 0.0477, 0.1507, 0.0379, 0.1504, 0.0312,0.0492, 0.1478, 0.0248, 0.1466, 0.0412, 0.0278, 0.0301, 0.0973, 0.0297,0.1449, 0.0219, 0.0616, 0.0348, 0.0274, 0.0721, 0.0425, 0.1388, 0.0409,0.0231, 0.0848, 0.1382, 0.0488, 0.0265, 0.0326, 0.1361, 0.0220, 0.0898,0.0259, 0.0259, 0.0268, 0.0563, 0.1345, 0.1344, 0.0220, 0.0319, 0.0512,0.1330, 0.0265, 0.0458, 0.0277, 0.0257, 0.0245, 0.0280, 0.1300, 0.0402,0.0307, 0.0460, 0.0315, 0.0277, 0.0173, 0.0657, 0.0251, 0.0230, 0.1267,0.1263, 0.0789, 0.0680, 0.0559, 0.0196, 0.0247, 0.0987, 0.1243, 0.0254,0.1033, 0.1235, 0.1234, 0.1233, 0.1232, 0.0211, 0.0351, 0.1230, 0.1225,0.0211, 0.1211, 0.0752, 0.1207, 0.0759, 0.1200, 0.0432, 0.1198, 0.1191,0.0215, 0.0458, 0.1184, 0.0221, 0.1175, 0.0706, 0.0312, 0.1170, 0.1169,0.0257, 0.1167, 0.1166, 0.0193, 0.0641, 0.1151, 0.0692, 0.0873, 0.0289,0.0330, 0.1137, 0.0447, 0.0257, 0.0675, 0.1123, 0.0252, 0.0519, 0.0219,0.0188, 0.0327, 0.1117, 0.1117, 0.0921, 0.0403, 0.0270, 0.0230, 0.0641,0.0273, 0.1099, 0.0201, 0.0322, 0.1091, 0.1090, 0.0229, 0.1089, 0.0187,0.0216, 0.0307, 0.0513, 0.1080, 0.0260, 0.0855, 0.0441, 0.0188, 0.0972,0.1068, 0.0417, 0.0206, 0.0394, 0.0214, 0.0427, 0.0170, 0.0311, 0.0481,0.0196, 0.1049, 0.1051, 0.1049, 0.0295, 0.0347, 0.0226, 0.0667, 0.0199,0.1041, 0.0246, 0.1038, 0.0241, 0.1033, 0.1028, 0.0212, 0.1021, 0.1022,0.1019, 0.0413, 0.0388, 0.0343, 0.0967, 0.0925, 0.0654, 0.1009, 0.0301,0.1007, 0.0986, 0.0474, 0.0583, 0.0990, 0.0273, 0.0989, 0.0737, 0.0689,0.0187, 0.0231, 0.0982, 0.0522, 0.0132, 0.0973, 0.0387, 0.0971, 0.0937,0.0968, 0.0189, 0.0218, 0.0933, 0.0219, 0.0199, 0.0957, 0.0475, 0.0266,0.0950, 0.0389, 0.0454, 0.0262, 0.0641, 0.0870, 0.0212, 0.0187, 0.0834,0.0931, 0.0431, 0.0929, 0.0929, 0.0703, 0.0193, 0.0459, 0.0211, 0.0926,0.0925, 0.0923, 0.0371, 0.0420, 0.0224, 0.0196, 0.0919, 0.0336, 0.0917,0.0894, 0.0569, 0.0832, 0.0328, 0.0249, 0.0263, 0.0181, 0.0410, 0.0906,0.0159, 0.0402, 0.0183, 0.0168, 0.0171, 0.0204, 0.0160, 0.0897, 0.0323,0.0173, 0.0240, 0.0708, 0.0894, 0.0892, 0.0892, 0.0283, 0.0186, 0.0172,0.0882, 0.0160, 0.0179, 0.0522, 0.0511, 0.0177, 0.0877, 0.0418, 0.0155,0.0606, 0.0868, 0.0867, 0.0485, 0.0258, 0.0143, 0.0359, 0.0804, 0.0457,0.0835, 0.0678, 0.0177, 0.0193, 0.0250, 0.0477, 0.0289, 0.0247, 0.0839,0.0836, 0.0680, 0.0423, 0.0147, 0.0649, 0.0824, 0.0178, 0.0299, 0.0219,0.0161, 0.0152, 0.0422, 0.0242, 0.0266, 0.0808, 0.0453, 0.0557, 0.0807,0.0222, 0.0154, 0.0217, 0.0134, 0.0600, 0.0447, 0.0231, 0.0162, 0.0759,0.0292, 0.0229, 0.0790, 0.0380, 0.0216, 0.0505, 0.0786, 0.0556, 0.0281,0.0469, 0.0556, 0.0233, 0.0726, 0.0175, 0.0303, 0.0774, 0.0770, 0.0462,0.0285, 0.0731, 0.0333, 0.0712, 0.0232, 0.0318, 0.0756, 0.0361, 0.0382,0.0751, 0.0627, 0.0749, 0.0565, 0.0470, 0.0228, 0.0193, 0.0294, 0.0442,0.0434, 0.0538, 0.0726, 0.0562, 0.0260, 0.0227, 0.0721, 0.0325, 0.0717,0.0604, 0.0696, 0.0700, 0.0588, 0.0234, 0.0229, 0.0195, 0.0683, 0.0350,0.0359, 0.0378, 0.0688, 0.0407, 0.0671], device='cuda:0')
(Pdb) cate_scores
tensor([0.7593, 0.6425, 0.6012, 0.5393, 0.5195, 0.4914, 0.4901, 0.4664, 0.4540,0.4468, 0.4385, 0.3944, 0.3913, 0.3701, 0.3569, 0.3558, 0.3473, 0.3448,0.3417, 0.3194, 0.3147, 0.3141, 0.3134, 0.3115, 0.3114, 0.3086, 0.3071,0.3065, 0.3050, 0.3035, 0.3025, 0.3018, 0.3017, 0.3003, 0.2977, 0.2969,0.2946, 0.2934, 0.2930, 0.2881, 0.2839, 0.2830, 0.2733, 0.2713, 0.2692,0.2640, 0.2634, 0.2615, 0.2544, 0.2502, 0.2472, 0.2443, 0.2440, 0.2430,0.2317, 0.2306, 0.2291, 0.2265, 0.2262, 0.2260, 0.2241, 0.2236, 0.2233,0.2222, 0.2207, 0.2202, 0.2192, 0.2191, 0.2173, 0.2169, 0.2168, 0.2134,0.2130, 0.2112, 0.2105, 0.2093, 0.2073, 0.2070, 0.2031, 0.2025, 0.2007,0.1998, 0.1989, 0.1978, 0.1951, 0.1939, 0.1920, 0.1917, 0.1895, 0.1893,0.1876, 0.1875, 0.1847, 0.1839, 0.1827, 0.1817, 0.1797, 0.1786, 0.1776,0.1751, 0.1748, 0.1746, 0.1743, 0.1743, 0.1741, 0.1723, 0.1704, 0.1701,0.1675, 0.1661, 0.1645, 0.1642, 0.1634, 0.1629, 0.1625, 0.1623, 0.1618,0.1611, 0.1607, 0.1599, 0.1583, 0.1564, 0.1563, 0.1557, 0.1556, 0.1541,0.1533, 0.1526, 0.1518, 0.1514, 0.1512, 0.1507, 0.1505, 0.1504, 0.1503,0.1499, 0.1478, 0.1476, 0.1466, 0.1461, 0.1458, 0.1453, 0.1452, 0.1452,0.1449, 0.1438, 0.1419, 0.1411, 0.1405, 0.1392, 0.1391, 0.1388, 0.1386,0.1385, 0.1383, 0.1382, 0.1379, 0.1367, 0.1363, 0.1361, 0.1357, 0.1355,0.1352, 0.1352, 0.1349, 0.1348, 0.1345, 0.1344, 0.1335, 0.1335, 0.1333,0.1330, 0.1326, 0.1323, 0.1317, 0.1313, 0.1312, 0.1309, 0.1301, 0.1298,0.1293, 0.1283, 0.1282, 0.1282, 0.1280, 0.1280, 0.1277, 0.1268, 0.1267,0.1263, 0.1261, 0.1259, 0.1259, 0.1255, 0.1253, 0.1245, 0.1243, 0.1238,0.1237, 0.1235, 0.1234, 0.1233, 0.1233, 0.1231, 0.1230, 0.1230, 0.1226,0.1215, 0.1211, 0.1211, 0.1207, 0.1201, 0.1200, 0.1198, 0.1198, 0.1191,0.1187, 0.1186, 0.1184, 0.1183, 0.1175, 0.1173, 0.1172, 0.1170, 0.1169,0.1168, 0.1167, 0.1166, 0.1164, 0.1153, 0.1151, 0.1150, 0.1145, 0.1141,0.1140, 0.1137, 0.1133, 0.1131, 0.1128, 0.1123, 0.1123, 0.1123, 0.1120,0.1119, 0.1117, 0.1117, 0.1117, 0.1112, 0.1111, 0.1111, 0.1107, 0.1104,0.1103, 0.1099, 0.1097, 0.1093, 0.1091, 0.1090, 0.1089, 0.1089, 0.1086,0.1082, 0.1082, 0.1082, 0.1080, 0.1080, 0.1076, 0.1074, 0.1074, 0.1071,0.1068, 0.1068, 0.1066, 0.1065, 0.1063, 0.1062, 0.1060, 0.1056, 0.1056,0.1054, 0.1053, 0.1051, 0.1049, 0.1049, 0.1044, 0.1044, 0.1041, 0.1041,0.1041, 0.1038, 0.1038, 0.1034, 0.1033, 0.1028, 0.1028, 0.1023, 0.1022,0.1022, 0.1021, 0.1019, 0.1017, 0.1015, 0.1015, 0.1011, 0.1009, 0.1007,0.1007, 0.0996, 0.0996, 0.0993, 0.0990, 0.0990, 0.0989, 0.0988, 0.0988,0.0987, 0.0983, 0.0982, 0.0978, 0.0978, 0.0973, 0.0972, 0.0971, 0.0969,0.0968, 0.0965, 0.0963, 0.0958, 0.0958, 0.0958, 0.0957, 0.0957, 0.0955,0.0950, 0.0947, 0.0946, 0.0942, 0.0940, 0.0940, 0.0938, 0.0935, 0.0933,0.0931, 0.0930, 0.0929, 0.0929, 0.0928, 0.0928, 0.0928, 0.0927, 0.0926,0.0925, 0.0923, 0.0923, 0.0923, 0.0922, 0.0919, 0.0919, 0.0919, 0.0917,0.0916, 0.0913, 0.0912, 0.0911, 0.0908, 0.0907, 0.0907, 0.0906, 0.0906,0.0905, 0.0905, 0.0904, 0.0902, 0.0902, 0.0901, 0.0901, 0.0897, 0.0896,0.0895, 0.0895, 0.0894, 0.0894, 0.0892, 0.0892, 0.0889, 0.0884, 0.0883,0.0882, 0.0881, 0.0879, 0.0878, 0.0878, 0.0877, 0.0877, 0.0875, 0.0873,0.0870, 0.0868, 0.0867, 0.0867, 0.0865, 0.0862, 0.0859, 0.0856, 0.0856,0.0856, 0.0856, 0.0852, 0.0852, 0.0852, 0.0849, 0.0847, 0.0843, 0.0839,0.0836, 0.0834, 0.0833, 0.0831, 0.0830, 0.0824, 0.0824, 0.0822, 0.0818,0.0818, 0.0815, 0.0815, 0.0813, 0.0811, 0.0808, 0.0808, 0.0807, 0.0807,0.0807, 0.0806, 0.0802, 0.0800, 0.0800, 0.0798, 0.0796, 0.0793, 0.0792,0.0791, 0.0790, 0.0790, 0.0790, 0.0789, 0.0787, 0.0786, 0.0784, 0.0783,0.0779, 0.0778, 0.0778, 0.0777, 0.0775, 0.0774, 0.0774, 0.0770, 0.0768,0.0767, 0.0766, 0.0763, 0.0760, 0.0759, 0.0756, 0.0756, 0.0755, 0.0754,0.0752, 0.0751, 0.0749, 0.0749, 0.0740, 0.0738, 0.0736, 0.0736, 0.0733,0.0729, 0.0729, 0.0726, 0.0723, 0.0722, 0.0721, 0.0721, 0.0721, 0.0717,0.0715, 0.0714, 0.0714, 0.0713, 0.0713, 0.0712, 0.0707, 0.0704, 0.0694,0.0690, 0.0689, 0.0688, 0.0678, 0.0671], device='cuda:0')''''''
(Pdb) ans[0].sum()
tensor(26, device='cuda:0')
(Pdb) ans[1].sum()
tensor(26, device='cuda:0')
(Pdb) ans[2].sum()
tensor(26, device='cuda:0')
(Pdb) label_matrix[0].sum()
tensor(25., device='cuda:0')
(Pdb) label_matrix[1].sum()
tensor(24., device='cuda:0')
(Pdb) label_matrix[2].sum()
tensor(23., device='cuda:0')''''''(Pdb) iou_matrix
tensor([[0.0000, 0.9618, 0.9262, ..., 0.5556, 0.0000, 0.0000],[0.0000, 0.0000, 0.9157, ..., 0.5608, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.5495, 0.0000, 0.0000],...,[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0082],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],device='cuda:0')
(Pdb) label_matrix
tensor([[0., 1., 1., ..., 0., 0., 0.],[0., 0., 1., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')
(Pdb) iou_matrix * label_matrix
tensor([[0.0000, 0.9618, 0.9262, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.9157, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],...,[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],device='cuda:0')''''''
(Pdb) decay_iou
tensor([[0.0000, 0.9618, 0.9262, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.9157, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],...,[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],device='cuda:0')
(Pdb) compensate_iou
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],[0.9618, 0.9618, 0.9618, ..., 0.9618, 0.9618, 0.9618],[0.9262, 0.9262, 0.9262, ..., 0.9262, 0.9262, 0.9262],...,[0.1814, 0.1814, 0.1814, ..., 0.1814, 0.1814, 0.1814],[0.5750, 0.5750, 0.5750, ..., 0.5750, 0.5750, 0.5750],[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],device='cuda:0')''''''
(Pdb) decay_matrix
tensor([[1.0000, 0.1572, 0.1798, ..., 1.0000, 1.0000, 1.0000],[1.0000, 1.0000, 0.1869, ..., 1.0000, 1.0000, 1.0000],[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],...,[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000]],device='cuda:0')
(Pdb) compensate_matrix
tensor([[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],[0.1572, 0.1572, 0.1572, ..., 0.1572, 0.1572, 0.1572],[0.1798, 0.1798, 0.1798, ..., 0.1798, 0.1798, 0.1798],...,[0.9363, 0.9363, 0.9363, ..., 0.9363, 0.9363, 0.9363],[0.5162, 0.5162, 0.5162, ..., 0.5162, 0.5162, 0.5162],[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000]],device='cuda:0')
(Pdb) decay_coefficient
tensor([1.0000, 0.1572, 0.1798, 1.0000, 0.1783, 0.1802, 1.0000, 1.0000, 1.0000,0.1689, 1.0000, 1.0000, 0.1855, 0.2022, 0.2764, 0.1547, 0.2404, 0.2013,0.2406, 1.0000, 0.1905, 1.0000, 0.1897, 1.0000, 1.0000, 1.0000, 0.2510,0.2585, 0.1958, 0.1687, 0.1882, 1.0000, 0.1527, 0.1832, 0.1805, 0.2230,0.1968, 0.2195, 0.1718, 1.0000, 1.0000, 1.0000, 0.2053, 0.4830, 1.0000,0.2468, 0.2634, 0.1930, 0.1613, 0.1854, 0.2688, 0.1672, 1.0000, 0.1676,0.2001, 0.1778, 0.9999, 0.1972, 0.4646, 0.9999, 1.0000, 1.0000, 1.0000,0.2382, 0.6206, 0.9990, 0.2464, 0.2426, 0.2175, 0.2442, 1.0000, 1.0000,0.3181, 0.2265, 0.1825, 0.1945, 0.5600, 0.1545, 0.3049, 1.0000, 0.1934,0.1655, 0.2479, 0.4379, 0.4354, 0.2128, 0.3090, 0.3096, 0.2048, 0.2055,0.3932, 0.9999, 0.3649, 0.6227, 0.3221, 0.4436, 1.0000, 0.2139, 1.0000,1.0000, 0.2798, 0.2927, 1.0000, 0.4676, 1.0000, 0.3377, 0.5426, 0.1867,0.1898, 1.0000, 1.0000, 0.1810, 1.0000, 1.0000, 0.2745, 0.2394, 0.1964,1.0000, 0.8990, 0.3527, 0.2131, 0.9999, 1.0000, 0.2128, 1.0000, 0.3925,1.0000, 1.0000, 0.1672, 0.9756, 0.3158, 1.0000, 0.2516, 1.0000, 0.2073,0.3283, 1.0000, 0.1679, 1.0000, 0.2823, 0.1908, 0.2075, 0.6698, 0.2049,1.0000, 0.1524, 0.4344, 0.2468, 0.1950, 0.5180, 0.3053, 1.0000, 0.2947,0.1668, 0.6135, 1.0000, 0.3543, 0.1941, 0.2391, 1.0000, 0.1622, 0.6628,0.1915, 0.1915, 0.1984, 0.4174, 1.0000, 1.0000, 0.1648, 0.2391, 0.3841,1.0000, 0.2000, 0.3462, 0.2100, 0.1955, 0.1864, 0.2139, 0.9999, 0.3096,0.2372, 0.3589, 0.2454, 0.2163, 0.1353, 0.5128, 0.1966, 0.1813, 1.0000,1.0000, 0.6258, 0.5399, 0.4436, 0.1565, 0.1968, 0.7929, 1.0000, 0.2052,0.8349, 1.0000, 1.0000, 1.0000, 0.9999, 0.1713, 0.2855, 1.0000, 0.9990,0.1734, 1.0000, 0.6206, 1.0000, 0.6314, 1.0000, 0.3608, 1.0000, 1.0000,0.1814, 0.3864, 0.9998, 0.1868, 1.0000, 0.6018, 0.2664, 1.0000, 1.0000,0.2204, 1.0000, 1.0000, 0.1655, 0.5560, 1.0000, 0.6018, 0.7625, 0.2531,0.2891, 1.0000, 0.3947, 0.2269, 0.5983, 1.0000, 0.2240, 0.4622, 0.1954,0.1679, 0.2926, 0.9999, 1.0000, 0.8288, 0.3631, 0.2429, 0.2077, 0.5807,0.2477, 1.0000, 0.1835, 0.2947, 1.0000, 1.0000, 0.2103, 1.0000, 0.1724,0.2000, 0.2840, 0.4740, 1.0000, 0.2410, 0.7940, 0.4111, 0.1751, 0.9077,1.0000, 0.3905, 0.1929, 0.3705, 0.2016, 0.4020, 0.1601, 0.2947, 0.4558,0.1863, 0.9965, 1.0000, 1.0000, 0.2813, 0.3324, 0.2163, 0.6408, 0.1911,1.0000, 0.2366, 1.0000, 0.2331, 1.0000, 1.0000, 0.2059, 0.9982, 1.0000,0.9972, 0.4041, 0.3809, 0.3376, 0.9527, 0.9110, 0.6471, 1.0000, 0.2990,1.0000, 0.9895, 0.4757, 0.5870, 1.0000, 0.2757, 1.0000, 0.7463, 0.6971,0.1898, 0.2349, 1.0000, 0.5335, 0.1353, 0.9996, 0.3981, 1.0000, 0.9676,1.0000, 0.1954, 0.2259, 0.9738, 0.2285, 0.2074, 1.0000, 0.4963, 0.2780,1.0000, 0.4111, 0.4801, 0.2780, 0.6819, 0.9255, 0.2259, 0.2002, 0.8939,1.0000, 0.4634, 1.0000, 1.0000, 0.7577, 0.2078, 0.4951, 0.2280, 1.0000,1.0000, 1.0000, 0.4014, 0.4548, 0.2429, 0.2128, 1.0000, 0.3658, 1.0000,0.9756, 0.6232, 0.9124, 0.3601, 0.2744, 0.2895, 0.2001, 0.4525, 1.0000,0.1758, 0.4439, 0.2022, 0.1865, 0.1894, 0.2269, 0.1781, 1.0000, 0.3609,0.1929, 0.2681, 0.7913, 0.9999, 1.0000, 1.0000, 0.3181, 0.2103, 0.1950,1.0000, 0.1819, 0.2036, 0.5941, 0.5819, 0.2022, 1.0000, 0.4777, 0.1774,0.6963, 1.0000, 1.0000, 0.5600, 0.2989, 0.1664, 0.4174, 0.9394, 0.5335,0.9756, 0.7929, 0.2073, 0.2270, 0.2930, 0.5621, 0.3410, 0.2926, 1.0000,1.0000, 0.8152, 0.5078, 0.1772, 0.7817, 1.0000, 0.2154, 0.3641, 0.2681,0.1963, 0.1870, 0.5180, 0.2982, 0.3277, 0.9999, 0.5600, 0.6903, 1.0000,0.2754, 0.1911, 0.2704, 0.1668, 0.7497, 0.5600, 0.2895, 0.2049, 0.9588,0.3695, 0.2894, 1.0000, 0.4810, 0.2742, 0.6411, 1.0000, 0.7090, 0.3589,0.6018, 0.7151, 0.3002, 0.9344, 0.2259, 0.3921, 1.0000, 1.0000, 0.6018,0.3710, 0.9543, 0.4373, 0.9361, 0.3053, 0.4208, 1.0000, 0.4777, 0.5065,0.9999, 0.8339, 1.0000, 0.7549, 0.6350, 0.3088, 0.2617, 0.3994, 0.6034,0.5947, 0.7377, 0.9990, 0.7771, 0.3594, 0.3155, 1.0000, 0.4505, 1.0000,0.8445, 0.9756, 0.9810, 0.8240, 0.3274, 0.3215, 0.2753, 0.9701, 0.5041,0.5205, 0.5485, 1.0000, 0.5994, 1.0000], device='cuda:0')'''
4. SOLO/configs/solo/solo_r50_fpn_8gpu_1x.py
# model settings
model = dict(type='SOLO',pretrained='torchvision://resnet50',backbone=dict(type='ResNet',depth=50,num_stages=4,out_indices=(0, 1, 2, 3), # C2, C3, C4, C5frozen_stages=1,style='pytorch'),neck=dict(type='FPN',in_channels=[256, 512, 1024, 2048],out_channels=256,start_level=0,num_outs=5),bbox_head=dict(type='SOLOHead', # SOLOHead对应同名 SOLOHead.py, 因此可以修改type对应相应自己修改的SOLOHead_xx.pynum_classes=81,in_channels=256,stacked_convs=7,seg_feat_channels=256,strides=[8, 8, 16, 32, 32],scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),sigma=0.2,num_grids=[40, 36, 24, 16, 12],cate_down_pos=0,with_deform=False,loss_ins=dict(type='DiceLoss',use_sigmoid=True,loss_weight=3.0),loss_cate=dict(type='FocalLoss',use_sigmoid=True,gamma=2.0,alpha=0.25,loss_weight=1.0),))
# training and testing settings
train_cfg = dict()
test_cfg = dict(nms_pre=500,score_thr=0.1,mask_thr=0.5,update_thr=0.05,kernel='gaussian', # gaussian/linearsigma=2.0,max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco2017/'
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='LoadAnnotations', with_bbox=True, with_mask=True),dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),dict(type='RandomFlip', flip_ratio=0.5),dict(type='Normalize', **img_norm_cfg),dict(type='Pad', size_divisor=32),dict(type='DefaultFormatBundle'),dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='MultiScaleFlipAug',img_scale=(1333, 800),flip=False,transforms=[dict(type='Resize', keep_ratio=True),dict(type='RandomFlip'),dict(type='Normalize', **img_norm_cfg),dict(type='Pad', size_divisor=32),dict(type='ImageToTensor', keys=['img']),dict(type='Collect', keys=['img']),])
]
data = dict(imgs_per_gpu=2,workers_per_gpu=2,train=dict(type=dataset_type,ann_file=data_root + 'annotations/instances_train2017.json',img_prefix=data_root + 'train2017/',pipeline=train_pipeline),val=dict(type=dataset_type,ann_file=data_root + 'annotations/instances_val2017.json',img_prefix=data_root + 'val2017/',pipeline=test_pipeline),test=dict(type=dataset_type,ann_file=data_root + 'annotations/instances_val2017.json',img_prefix=data_root + 'val2017/',pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(policy='step',warmup='linear',warmup_iters=500,warmup_ratio=1.0 / 3,step=[9, 11])
#save
checkpoint_config = dict(interval=1) # log文件里面
# yapf:disable
log_config = dict(interval=1, # 每interval次iter打印一次hooks=[dict(type='TextLoggerHook'),# dict(type='TensorboardLoggerHook')])
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/solo_release_r50_fpn_8gpu_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
5. SOLO/mmdet/models/anchor_heads/_ init _.py
from .anchor_head import AnchorHead
from .atss_head import ATSSHead
from .fcos_head import FCOSHead
from .fovea_head import FoveaHead
from .free_anchor_retina_head import FreeAnchorRetinaHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .reppoints_head import RepPointsHead
from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .ssd_head import SSDHead
from .solo_head import SOLOHead
from .solov2_head import SOLOv2Head
from .solov2_light_head import SOLOv2LightHead
from .decoupled_solo_head import DecoupledSOLOHead
from .decoupled_solo_light_head import DecoupledSOLOLightHead
from .solo_head_xx improt SOLOHead_xx # 注册文件名
__all__ = ['AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead','GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead','FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead','ATSSHead', 'SOLOHead','SOLOv2Head', 'SOLOv2LightHead', 'DecoupledSOLOHead', 'DecoupledSOLOLightHead'
'SOLOHead_xx'
]# 然后把 SOLOHead_xx.py实现以下, 对应的super函数更改下,就可以保留官方文件的同时,进行小更改了。
[SOLO ]SOLO: Segmenting Objects by Locations代码解读笔记(ECCV. 2020)相关推荐
- 【深度学习】【实例分割】SOLO:Segmenting Objects by Locations
[实例分割]SOLO:Segmenting Objects by Locations 相关工作 模型结构 branch 损失函数 Inference 实验 Decoupled SOLO head 实例 ...
- 【实例分割论文】 SOLO:Segmenting Objects by Locations(更新代码)
===========更新 2020/3/28========= 作者源代码已经开源,因此更新了结合作者源代码分析的网络实现部分: 此外,SOLO v2论文已经发布 https://arxiv.org ...
- SOLO:Segmenting Objects by Locations阅读笔记
论文下载地址: https://arxiv.org/abs/1912.04488 论文摘要: 当前实例分割算法大部分是先检测后分割的框架,例如Mask-RCNN算法.或者先进行语义分割,然后再通过聚类 ...
- SOLO: Segmenting Objects by Locations 论文学习
SOLO: Segmenting Objects by Locations Abstract 1. Introduction 2. Related Work 3. SOLO 3.1 Problem F ...
- [实例分割] SOLO: Segmenting Objects by Locations 论文阅读
转载请注明作者和出处: http://blog.csdn.net/john_bh/ 论文链接: SOLO: Segmenting Objects by Locations 作者及团队:阿德莱德大学 & ...
- 实例分割之SOLO: Segmenting Objects by Locations
SOLO: Segmenting Objects by Locations code 实例分割主要有两大类方法: 一种是"detect-then-segment",即先检测bbox ...
- SOLO:Segmenting Objects by Locations
SOLO SOLO的中心思想是把instance segmentation的问题分解为两个相似的分类问题,分别是category-aware预测和instance-aware mask的生成. 把输入 ...
- 论文解读《SOLO: Segmenting Objects by Locations》
实例分割属于比较challenging的任务,他相当于是object detection和semantic segmentation的结合体.在SOLO出现之前,有两种常用的paradigm:(1)t ...
- 转载系列【分割】:ECCV2020 | SOLO: Segmenting Objects by Locations
文章目录 一.背景 二.本文方法 三.本文方法的具体做法 3.1 问题定义 3.1.1 Semantic category 3.1.2 Instance Mask 3.2 Network Archit ...
最新文章
- [NHibernate]代码生成器的使用
- const 和 #define区别
- nimbus java_Java程序设置界面为Nimbus风格x
- IOS之AFNetworking,SDWebImage,Kingfisher,Alamofire,FMDB框架的使用
- 交友软件上的两种网友类型......
- 使用PropertyPlaceholderConfigurer读取属性文件
- 盘点云原生的5大特征
- 【已解决】请在位于当前 Web 应用程序根目录下的“web.config”配置文件中创建一个 <customErrors> 标记
- 字节数与字符数mysql_mysql 数值与字符类型 长度梳理
- qml 函数使用经验总结1(image中调用函数)
- MFC---CComboBox控件添加字符串函数InsertString
- Axure RP 9基础教程(3)——添加图标元件
- 微信调试弹出报错信息
- 【C#工具】后宫佳丽三千
- 【C51单片机】交通红绿灯设计(仿真)
- 【封面】华为解读“生态伙伴”
- 【历史上的今天】8 月 8 日:中国第一个校园 BBS 成立;网景通信上市;EarthLink 创始人出生
- 颜色列表 delphi中使用颜色
- VS2005如何加载Word组件(详细步骤)
- 互联网、因特网、万维网、广域网、局域网、以太网的区别