文章目录

  • Anchor free?Anchor base?
  • FCOS的ground truth分配
  • loss计算
  • 完整loss代码

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

Anchor free?Anchor base?

首先要明确的是,FCOS确实没有像RetinaNet那样使用了显式的Anchor(先验框)。FCOS把每一级FPN level的feature map上的每一个点作为一个样本,然后,根据样本在标注框内还是标注框外决定该样本是正样本还是负样本(注意FCOS中没有被忽略的样本)。从这一点上来说,FCOS确实是Anchor free的。但是,在FCOS进行ground trurh分配和测试计算时仍然要使用feature map上每个点倒推到输入图片上的(x,y)坐标,从这一点上来说,FCOS并不是完全free的,更准确地来说,FCOS是一个"point based"目标检测器。我们可以把FCOS看成是feature map上每个点只有一个隐式Anchor的目标检测器。

2020年新发布的DETR目标检测器(https://arxiv.org/pdf/2005.12872.pdf)把目标检测任务检测看成集合预测问题,使用了Transformer来预测box集合,完全不需要使用NMS和Anchor/Point的先验坐标,使得检测器真正做到了"free",感兴趣的同学可以自行了解。

FCOS的ground truth分配

对于一张输入图片上标注的多个框,首先把FPN上每一级FPN的feature map上的所有点都做判断,如果某个点在所有的标注框之外,那么这个点就作为负样本。此时,剩下的点中有些点可能同时在多个标注框内。然后取每个点对每个标注框的l,t,r,b(该点距离框左、上、右、下的距离)中的最大值,根据下面的值域范围,当最大值落在哪个范围内,就把该框分配给这个范围对应的FPN level的feature map上的对应点。

# 从左到右为分配给P3、P4、P5、P6、P7的值域范围
INF=100000000
mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]]

经过上面一步以后,绝大部分点都会只分配给一个框。但是仍然有些点会同时在两个框内(当有两个标注框的大小差不多的时候)。对于这些点,我们计算其与重叠框的面积,然后总是把这些点分配给面积最小的标注框。在下面的实现代码中,对于这部分样本我使用了矩阵计算的形式进行标签分配。虽然每张图上正样本一般只有几十到两三百左右,但是如果对这部分正样本使用for循环来分配标签,训练速度会变得非常慢,这一点需要注意。
对于分类标签,以0为负样本,1到80为80个正类;l,t,r,b和centerness标签完全按照FCOS论文中公式计算,没有修改。

ground truth分配代码实现如下:

    def get_batch_position_annotations(self, cls_heads, reg_heads,center_heads, batch_positions,annotations):"""Assign a ground truth target for each position on feature map"""device = annotations.devicebatch_mi = []for reg_head, mi in zip(reg_heads, self.mi):mi = torch.tensor(mi).to(device)B, H, W, _ = reg_head.shapeper_level_mi = torch.zeros(B, H, W, 2).to(device)per_level_mi = per_level_mi + mibatch_mi.append(per_level_mi)cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(cls_heads, reg_heads, center_heads, batch_positions, batch_mi):cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])center_pred = center_pred.view(center_pred.shape[0], -1,center_pred.shape[-1])per_level_position = per_level_position.view(per_level_position.shape[0], -1, per_level_position.shape[-1])per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,per_level_mi.shape[-1])cls_preds.append(cls_pred)reg_preds.append(reg_pred)center_preds.append(center_pred)all_points_position.append(per_level_position)all_points_mi.append(per_level_mi)cls_preds = torch.cat(cls_preds, axis=1)reg_preds = torch.cat(reg_preds, axis=1)center_preds = torch.cat(center_preds, axis=1)all_points_position = torch.cat(all_points_position, axis=1)all_points_mi = torch.cat(all_points_mi, axis=1)batch_targets = []for per_image_position, per_image_mi, per_image_annotations in zip(all_points_position, all_points_mi, annotations):per_image_annotations = per_image_annotations[per_image_annotations[:, 4] >= 0]points_num = per_image_position.shape[0]if per_image_annotations.shape[0] == 0:# 6:l,t,r,b,class_index,center-ness_gtper_image_targets = torch.zeros([points_num, 6], device=device)else:annotaion_num = per_image_annotations.shape[0]per_image_gt_bboxes = per_image_annotations[:, 0:4]candidates = torch.zeros([points_num, annotaion_num, 4],device=device)candidates = candidates + per_image_gt_bboxes.unsqueeze(0)per_image_position = per_image_position.unsqueeze(1).repeat(1, annotaion_num, 2)candidates[:, :,0:2] = per_image_position[:, :,0:2] - candidates[:, :,0:2]candidates[:, :,2:4] = candidates[:, :,2:4] - per_image_position[:, :,2:4]candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)sample_flag = (candidates_min_value[:, :, 0] >0).int().unsqueeze(-1)# get all negative reg targets which points ctr out of gt boxcandidates = candidates * sample_flag# get all negative reg targets which assign ground turth not in range of micandidates_max_value, _ = candidates.max(axis=-1, keepdim=True)per_image_mi = per_image_mi.unsqueeze(1).repeat(1, annotaion_num, 1)m1_negative_flag = (candidates_max_value[:, :, 0] >per_image_mi[:, :, 0]).int().unsqueeze(-1)candidates = candidates * m1_negative_flagm2_negative_flag = (candidates_max_value[:, :, 0] <per_image_mi[:, :, 1]).int().unsqueeze(-1)candidates = candidates * m2_negative_flagfinal_sample_flag = candidates.sum(axis=-1).sum(axis=-1)final_sample_flag = final_sample_flag > 0positive_index = (final_sample_flag == True).nonzero().squeeze(dim=-1)# if no assign positive sampleif len(positive_index) == 0:del candidates# 6:l,t,r,b,class_index,center-ness_gtper_image_targets = torch.zeros([points_num, 6],device=device)else:positive_candidates = candidates[positive_index]del candidatessample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)sample_box_gts = sample_box_gts.repeat(positive_candidates.shape[0], 1, 1)sample_class_gts = per_image_annotations[:, 4].unsqueeze(-1).unsqueeze(0)sample_class_gts = sample_class_gts.repeat(positive_candidates.shape[0], 1, 1)# 6:l,t,r,b,class_index,center-ness_gtper_image_targets = torch.zeros([points_num, 6],device=device)if positive_candidates.shape[1] == 1:# if only one candidate for each positive sample# assign l,t,r,b,class_index,center_ness_gt ground truth# class_index value from 1 to 80 represent 80 positive classes# class_index value 0 represenet negative classpositive_candidates = positive_candidates.squeeze(1)sample_class_gts = sample_class_gts.squeeze(1)per_image_targets[positive_index,0:4] = positive_candidatesper_image_targets[positive_index,4:5] = sample_class_gts + 1l, t, r, b = per_image_targets[positive_index, 0:1], per_image_targets[positive_index, 1:2], per_image_targets[positive_index,2:3], per_image_targets[positive_index,3:4]per_image_targets[positive_index, 5:6] = torch.sqrt((torch.min(l, r) / torch.max(l, r)) *(torch.min(t, b) / torch.max(t, b)))else:# if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point samplegts_w_h = sample_box_gts[:, :,2:4] - sample_box_gts[:, :,0:2]gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]positive_candidates_value = positive_candidates.sum(axis=2)# make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidatesINF = 100000000inf_tensor = torch.ones_like(gts_area) * INFgts_area = torch.where(torch.eq(positive_candidates_value, 0.),inf_tensor, gts_area)# get the smallest object candidate index_, min_index = gts_area.min(axis=1)candidate_indexes = (torch.linspace(1, positive_candidates.shape[0],positive_candidates.shape[0]) -1).long()final_candidate_reg_gts = positive_candidates[candidate_indexes, min_index, :]final_candidate_cls_gts = sample_class_gts[candidate_indexes, min_index]# assign l,t,r,b,class_index,center_ness_gt ground truthper_image_targets[positive_index,0:4] = final_candidate_reg_gtsper_image_targets[positive_index,4:5] = final_candidate_cls_gts + 1l, t, r, b = per_image_targets[positive_index, 0:1], per_image_targets[positive_index, 1:2], per_image_targets[positive_index,2:3], per_image_targets[positive_index,3:4]per_image_targets[positive_index, 5:6] = torch.sqrt((torch.min(l, r) / torch.max(l, r)) *(torch.min(t, b) / torch.max(t, b)))per_image_targets = per_image_targets.unsqueeze(0)batch_targets.append(per_image_targets)batch_targets = torch.cat(batch_targets, axis=0)batch_targets = torch.cat([batch_targets, all_points_position], axis=2)# batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_yreturn cls_preds, reg_preds, center_preds, batch_targets

loss计算

分类loss采用focal loss,计算过程与RetinaNet完全一样,只是样本由Anchor变成了Point。

分类loss代码实现如下:

    def compute_one_image_focal_loss(self, per_image_cls_preds,per_image_targets):"""compute one image focal loss(cls loss)per_image_cls_preds:[points_num,num_classes]per_image_targets:[points_num,8]"""per_image_cls_preds = torch.clamp(per_image_cls_preds,min=self.epsilon,max=1. - self.epsilon)num_classes = per_image_cls_preds.shape[1]# generate 80 binary ground truth classes for each anchorloss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),num_classes=num_classes + 1)loss_ground_truth = loss_ground_truth[:, 1:]loss_ground_truth = loss_ground_truth.float()alpha_factor = torch.ones_like(per_image_cls_preds) * self.alphaalpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),alpha_factor, 1. - alpha_factor)pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,1. - per_image_cls_preds)focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)bce_loss = -(loss_ground_truth * torch.log(per_image_cls_preds) +(1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))one_image_focal_loss = focal_weight * bce_lossone_image_focal_loss = one_image_focal_loss.sum()positive_points_num = per_image_targets[per_image_targets[:, 4] > 0].shape[0]# according to the original paper,We divide the focal loss by the number of positive sample anchorsone_image_focal_loss = one_image_focal_loss / positive_points_numreturn one_image_focal_loss

在FCOS论文中,回归loss采用IoU loss。这里我直接使用GIoU loss。由于回归loss仍然只对正样本进行计算,所以不存在预测框与真实框不相交的情况,此时GIoU loss和IoU loss是完全等同的。

回归loss代码实现如下:

    def compute_one_image_giou_loss(self, per_image_reg_preds,per_image_targets):"""compute one image giou loss(reg loss)per_image_reg_preds:[points_num,4]per_image_targets:[anchor_num,8]"""# only use positive points sample to compute reg lossdevice = per_image_reg_preds.deviceper_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]positive_points_num = per_image_targets.shape[0]if positive_points_num == 0:return torch.tensor(0.).to(device)center_ness_targets = per_image_targets[:, 5]pred_bboxes_xy_min = per_image_targets[:,6:8] - per_image_reg_preds[:,0:2]pred_bboxes_xy_max = per_image_targets[:,6:8] + per_image_reg_preds[:,2:4]gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,0:2]gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,2:4]pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],axis=1)gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,0:2])overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,2:4])overlap_area_sizes = torch.clamp(overlap_area_bot_right -overlap_area_top_left,min=0)overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]# anchors and annotations convert format to [x1,y1,w,h]pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1# compute anchors_area and annotations_areapred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]# compute union_areaunion_area = pred_bboxes_area + gt_bboxes_area - overlap_areaunion_area = torch.clamp(union_area, min=1e-4)# compute ious between one image anchors and one image annotationsious = overlap_area / union_areaenclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,0:2])enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,2:4])enclose_area_sizes = torch.clamp(enclose_area_bot_right -enclose_area_top_left,min=0)enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]enclose_area = torch.clamp(enclose_area, min=1e-4)gious_loss = 1. - ious + (enclose_area - union_area) / enclose_areagious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)# use center_ness_targets as the weight of gious lossgious_loss = gious_loss * center_ness_targetsgious_loss = gious_loss.sum() / positive_points_numgious_loss = 2. * gious_lossreturn gious_loss

最后乘以2是为了平衡回归loss与其他loss的数量级。

centerness使用bce loss进行优化。由于centerness loss的优化目标是不稳定的,在实际训练时会出现loss初期下降一点之后长期不再下降的情况,这个是正常的,不必担心。
centerness loss代码实现如下:

    def compute_one_image_center_ness_loss(self, per_image_center_preds,per_image_targets):"""compute one image center_ness loss(center ness loss)per_image_center_preds:[points_num,4]per_image_targets:[anchor_num,8]"""# only use positive points sample to compute center_ness lossdevice = per_image_center_preds.deviceper_image_center_preds = per_image_center_preds[per_image_targets[:, 4] > 0]per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]positive_points_num = per_image_targets.shape[0]if positive_points_num == 0:return torch.tensor(0.).to(device)center_ness_targets = per_image_targets[:, 5:6]center_ness_loss = -(center_ness_targets * torch.log(per_image_center_preds) +(1. - center_ness_targets) *torch.log(1. - per_image_center_preds))center_ness_loss = center_ness_loss.sum() / positive_points_numreturn center_ness_loss

完整loss代码

import torch
import torch.nn as nn
import torch.nn.functional as FINF = 100000000class FCOSLoss(nn.Module):def __init__(self,image_w,image_h,strides=[8, 16, 32, 64, 128],mi=[[-1, 64], [64, 128], [128, 256], [256, 512], [512, INF]],alpha=0.25,gamma=2.,epsilon=1e-4):super(FCOSLoss, self).__init__()self.alpha = alphaself.gamma = gammaself.epsilon = epsilonself.image_w = image_wself.image_h = image_hself.strides = stridesself.mi = midef forward(self, cls_heads, reg_heads, center_heads, batch_positions,annotations):"""compute cls loss, reg loss and center-ness loss in one batch"""cls_preds, reg_preds, center_preds, batch_targets = self.get_batch_position_annotations(cls_heads, reg_heads, center_heads, batch_positions, annotations)cls_preds = torch.sigmoid(cls_preds)reg_preds = torch.exp(reg_preds)center_preds = torch.sigmoid(center_preds)batch_targets[:, :, 5:6] = torch.sigmoid(batch_targets[:, :, 5:6])device = annotations.devicecls_loss, reg_loss, center_ness_loss = [], [], []valid_image_num = 0for per_image_cls_preds, per_image_reg_preds, per_image_center_preds, per_image_targets in zip(cls_preds, reg_preds, center_preds, batch_targets):positive_points_num = (per_image_targets[per_image_targets[:, 4] > 0]).shape[0]if positive_points_num == 0:cls_loss.append(torch.tensor(0.).to(device))reg_loss.append(torch.tensor(0.).to(device))center_ness_loss.append(torch.tensor(0.).to(device))else:valid_image_num += 1one_image_cls_loss = self.compute_one_image_focal_loss(per_image_cls_preds, per_image_targets)one_image_reg_loss = self.compute_one_image_giou_loss(per_image_reg_preds, per_image_targets)one_image_center_ness_loss = self.compute_one_image_center_ness_loss(per_image_center_preds, per_image_targets)cls_loss.append(one_image_cls_loss)reg_loss.append(one_image_reg_loss)center_ness_loss.append(one_image_center_ness_loss)cls_loss = sum(cls_loss) / valid_image_numreg_loss = sum(reg_loss) / valid_image_numcenter_ness_loss = sum(center_ness_loss) / valid_image_numreturn cls_loss, reg_loss, center_ness_lossdef compute_one_image_focal_loss(self, per_image_cls_preds,per_image_targets):"""compute one image focal loss(cls loss)per_image_cls_preds:[points_num,num_classes]per_image_targets:[points_num,8]"""per_image_cls_preds = torch.clamp(per_image_cls_preds,min=self.epsilon,max=1. - self.epsilon)num_classes = per_image_cls_preds.shape[1]# generate 80 binary ground truth classes for each anchorloss_ground_truth = F.one_hot(per_image_targets[:, 4].long(),num_classes=num_classes + 1)loss_ground_truth = loss_ground_truth[:, 1:]loss_ground_truth = loss_ground_truth.float()alpha_factor = torch.ones_like(per_image_cls_preds) * self.alphaalpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),alpha_factor, 1. - alpha_factor)pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_preds,1. - per_image_cls_preds)focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)bce_loss = -(loss_ground_truth * torch.log(per_image_cls_preds) +(1. - loss_ground_truth) * torch.log(1. - per_image_cls_preds))one_image_focal_loss = focal_weight * bce_lossone_image_focal_loss = one_image_focal_loss.sum()positive_points_num = per_image_targets[per_image_targets[:, 4] > 0].shape[0]# according to the original paper,We divide the focal loss by the number of positive sample anchorsone_image_focal_loss = one_image_focal_loss / positive_points_numreturn one_image_focal_lossdef compute_one_image_giou_loss(self, per_image_reg_preds,per_image_targets):"""compute one image giou loss(reg loss)per_image_reg_preds:[points_num,4]per_image_targets:[anchor_num,8]"""# only use positive points sample to compute reg lossdevice = per_image_reg_preds.deviceper_image_reg_preds = per_image_reg_preds[per_image_targets[:, 4] > 0]per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]positive_points_num = per_image_targets.shape[0]if positive_points_num == 0:return torch.tensor(0.).to(device)center_ness_targets = per_image_targets[:, 5]pred_bboxes_xy_min = per_image_targets[:,6:8] - per_image_reg_preds[:,0:2]pred_bboxes_xy_max = per_image_targets[:,6:8] + per_image_reg_preds[:,2:4]gt_bboxes_xy_min = per_image_targets[:, 6:8] - per_image_targets[:,0:2]gt_bboxes_xy_max = per_image_targets[:, 6:8] + per_image_targets[:,2:4]pred_bboxes = torch.cat([pred_bboxes_xy_min, pred_bboxes_xy_max],axis=1)gt_bboxes = torch.cat([gt_bboxes_xy_min, gt_bboxes_xy_max], axis=1)overlap_area_top_left = torch.max(pred_bboxes[:, 0:2], gt_bboxes[:,0:2])overlap_area_bot_right = torch.min(pred_bboxes[:, 2:4], gt_bboxes[:,2:4])overlap_area_sizes = torch.clamp(overlap_area_bot_right -overlap_area_top_left,min=0)overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]# anchors and annotations convert format to [x1,y1,w,h]pred_bboxes_w_h = pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1gt_bboxes_w_h = gt_bboxes[:, 2:4] - gt_bboxes[:, 0:2] + 1# compute anchors_area and annotations_areapred_bboxes_area = pred_bboxes_w_h[:, 0] * pred_bboxes_w_h[:, 1]gt_bboxes_area = gt_bboxes_w_h[:, 0] * gt_bboxes_w_h[:, 1]# compute union_areaunion_area = pred_bboxes_area + gt_bboxes_area - overlap_areaunion_area = torch.clamp(union_area, min=1e-4)# compute ious between one image anchors and one image annotationsious = overlap_area / union_areaenclose_area_top_left = torch.min(pred_bboxes[:, 0:2], gt_bboxes[:,0:2])enclose_area_bot_right = torch.max(pred_bboxes[:, 2:4], gt_bboxes[:,2:4])enclose_area_sizes = torch.clamp(enclose_area_bot_right -enclose_area_top_left,min=0)enclose_area = enclose_area_sizes[:, 0] * enclose_area_sizes[:, 1]enclose_area = torch.clamp(enclose_area, min=1e-4)gious_loss = 1. - ious + (enclose_area - union_area) / enclose_areagious_loss = torch.clamp(gious_loss, min=-1.0, max=1.0)# use center_ness_targets as the weight of gious lossgious_loss = gious_loss * center_ness_targetsgious_loss = gious_loss.sum() / positive_points_numgious_loss = 2. * gious_lossreturn gious_lossdef compute_one_image_center_ness_loss(self, per_image_center_preds,per_image_targets):"""compute one image center_ness loss(center ness loss)per_image_center_preds:[points_num,4]per_image_targets:[anchor_num,8]"""# only use positive points sample to compute center_ness lossdevice = per_image_center_preds.deviceper_image_center_preds = per_image_center_preds[per_image_targets[:, 4] > 0]per_image_targets = per_image_targets[per_image_targets[:, 4] > 0]positive_points_num = per_image_targets.shape[0]if positive_points_num == 0:return torch.tensor(0.).to(device)center_ness_targets = per_image_targets[:, 5:6]center_ness_loss = -(center_ness_targets * torch.log(per_image_center_preds) +(1. - center_ness_targets) *torch.log(1. - per_image_center_preds))center_ness_loss = center_ness_loss.sum() / positive_points_numreturn center_ness_lossdef get_batch_position_annotations(self, cls_heads, reg_heads,center_heads, batch_positions,annotations):"""Assign a ground truth target for each position on feature map"""device = annotations.devicebatch_mi = []for reg_head, mi in zip(reg_heads, self.mi):mi = torch.tensor(mi).to(device)B, H, W, _ = reg_head.shapeper_level_mi = torch.zeros(B, H, W, 2).to(device)per_level_mi = per_level_mi + mibatch_mi.append(per_level_mi)cls_preds,reg_preds,center_preds,all_points_position,all_points_mi=[],[],[],[],[]for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi in zip(cls_heads, reg_heads, center_heads, batch_positions, batch_mi):cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])center_pred = center_pred.view(center_pred.shape[0], -1,center_pred.shape[-1])per_level_position = per_level_position.view(per_level_position.shape[0], -1, per_level_position.shape[-1])per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,per_level_mi.shape[-1])cls_preds.append(cls_pred)reg_preds.append(reg_pred)center_preds.append(center_pred)all_points_position.append(per_level_position)all_points_mi.append(per_level_mi)cls_preds = torch.cat(cls_preds, axis=1)reg_preds = torch.cat(reg_preds, axis=1)center_preds = torch.cat(center_preds, axis=1)all_points_position = torch.cat(all_points_position, axis=1)all_points_mi = torch.cat(all_points_mi, axis=1)batch_targets = []for per_image_position, per_image_mi, per_image_annotations in zip(all_points_position, all_points_mi, annotations):per_image_annotations = per_image_annotations[per_image_annotations[:, 4] >= 0]points_num = per_image_position.shape[0]if per_image_annotations.shape[0] == 0:# 6:l,t,r,b,class_index,center-ness_gtper_image_targets = torch.zeros([points_num, 6], device=device)else:annotaion_num = per_image_annotations.shape[0]per_image_gt_bboxes = per_image_annotations[:, 0:4]candidates = torch.zeros([points_num, annotaion_num, 4],device=device)candidates = candidates + per_image_gt_bboxes.unsqueeze(0)per_image_position = per_image_position.unsqueeze(1).repeat(1, annotaion_num, 2)candidates[:, :,0:2] = per_image_position[:, :,0:2] - candidates[:, :,0:2]candidates[:, :,2:4] = candidates[:, :,2:4] - per_image_position[:, :,2:4]candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)sample_flag = (candidates_min_value[:, :, 0] >0).int().unsqueeze(-1)# get all negative reg targets which points ctr out of gt boxcandidates = candidates * sample_flag# get all negative reg targets which assign ground turth not in range of micandidates_max_value, _ = candidates.max(axis=-1, keepdim=True)per_image_mi = per_image_mi.unsqueeze(1).repeat(1, annotaion_num, 1)m1_negative_flag = (candidates_max_value[:, :, 0] >per_image_mi[:, :, 0]).int().unsqueeze(-1)candidates = candidates * m1_negative_flagm2_negative_flag = (candidates_max_value[:, :, 0] <per_image_mi[:, :, 1]).int().unsqueeze(-1)candidates = candidates * m2_negative_flagfinal_sample_flag = candidates.sum(axis=-1).sum(axis=-1)final_sample_flag = final_sample_flag > 0positive_index = (final_sample_flag == True).nonzero().squeeze(dim=-1)# if no assign positive sampleif len(positive_index) == 0:del candidates# 6:l,t,r,b,class_index,center-ness_gtper_image_targets = torch.zeros([points_num, 6],device=device)else:positive_candidates = candidates[positive_index]del candidatessample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)sample_box_gts = sample_box_gts.repeat(positive_candidates.shape[0], 1, 1)sample_class_gts = per_image_annotations[:, 4].unsqueeze(-1).unsqueeze(0)sample_class_gts = sample_class_gts.repeat(positive_candidates.shape[0], 1, 1)# 6:l,t,r,b,class_index,center-ness_gtper_image_targets = torch.zeros([points_num, 6],device=device)if positive_candidates.shape[1] == 1:# if only one candidate for each positive sample# assign l,t,r,b,class_index,center_ness_gt ground truth# class_index value from 1 to 80 represent 80 positive classes# class_index value 0 represenet negative classpositive_candidates = positive_candidates.squeeze(1)sample_class_gts = sample_class_gts.squeeze(1)per_image_targets[positive_index,0:4] = positive_candidatesper_image_targets[positive_index,4:5] = sample_class_gts + 1l, t, r, b = per_image_targets[positive_index, 0:1], per_image_targets[positive_index, 1:2], per_image_targets[positive_index,2:3], per_image_targets[positive_index,3:4]per_image_targets[positive_index, 5:6] = torch.sqrt((torch.min(l, r) / torch.max(l, r)) *(torch.min(t, b) / torch.max(t, b)))else:# if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point samplegts_w_h = sample_box_gts[:, :,2:4] - sample_box_gts[:, :,0:2]gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]positive_candidates_value = positive_candidates.sum(axis=2)# make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidatesINF = 100000000inf_tensor = torch.ones_like(gts_area) * INFgts_area = torch.where(torch.eq(positive_candidates_value, 0.),inf_tensor, gts_area)# get the smallest object candidate index_, min_index = gts_area.min(axis=1)candidate_indexes = (torch.linspace(1, positive_candidates.shape[0],positive_candidates.shape[0]) -1).long()final_candidate_reg_gts = positive_candidates[candidate_indexes, min_index, :]final_candidate_cls_gts = sample_class_gts[candidate_indexes, min_index]# assign l,t,r,b,class_index,center_ness_gt ground truthper_image_targets[positive_index,0:4] = final_candidate_reg_gtsper_image_targets[positive_index,4:5] = final_candidate_cls_gts + 1l, t, r, b = per_image_targets[positive_index, 0:1], per_image_targets[positive_index, 1:2], per_image_targets[positive_index,2:3], per_image_targets[positive_index,3:4]per_image_targets[positive_index, 5:6] = torch.sqrt((torch.min(l, r) / torch.max(l, r)) *(torch.min(t, b) / torch.max(t, b)))per_image_targets = per_image_targets.unsqueeze(0)batch_targets.append(per_image_targets)batch_targets = torch.cat(batch_targets, axis=0)batch_targets = torch.cat([batch_targets, all_points_position], axis=2)# batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_yreturn cls_preds, reg_preds, center_preds, batch_targetsif __name__ == '__main__':from fcos import FCOSnet = FCOS(resnet_type="resnet50")image_h, image_w = 600, 600cls_heads, reg_heads, center_heads, batch_positions = net(torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],[13, 45, 175, 210, 2]],[[11, 18, 223, 225, 1],[-1, -1, -1, -1, -1]],[[-1, -1, -1, -1, -1],[-1, -1, -1, -1, -1]]])loss = FCOSLoss(image_w, image_h)cls_loss, reg_loss, center_loss = loss(cls_heads, reg_heads, center_heads,batch_positions, annotations)print("2222", cls_loss, reg_loss, center_loss)

【庖丁解牛】从零实现FCOS(二):ground truth分配与loss计算相关推荐

  1. 【庖丁解牛】从零实现FCOS(终):CenterSample的重要性

    文章目录 什么是CenterSample 新的heads层 loss类修改 模型训练和测试结果 所有代码已上传到本人github repository:https://github.com/zgcr/ ...

  2. 图像中里面的Ground Truth是什么意思

    图像中里面的Ground Truth是什么意思 文章目录: 一.Ground True的介绍 二.图像中的Ground True的介绍 经常会在一些项目中,遇到Ground Truth,例如把其中的某 ...

  3. 图像中里面的Ground Truth

    经常会在一些项目中,遇到Ground Truth,例如把其中的某种图像类型叫做Ground True, 下面介绍一下: 一.Ground True的介绍 下从机器学习说起,什么叫做Ground Tru ...

  4. 机器学习中ground truth的解释

    作者:lee philip 链接:https://www.zhihu.com/question/22464082/answer/21443035 来源:知乎 著作权归作者所有.商业转载请联系作者获得授 ...

  5. 关于Ground truth

    在有监督学习中,每个数据都对应一个label的,以(x,t)的形式出现,其中x是待输入的数据,t是对应的label. label是正确的话,则被称为ground truth,错误的话则不是. 我们输入 ...

  6. 深度学习: ground truth 解释

    一.解释一 就是参考标准,一般用来做error quantification 比方说要根据历史数据预测某一时间的温度,ground truth就是那个时间的真实温度. error就是(predicte ...

  7. 【目标检测】概念理解:region proposal、bounding box、anchor box、ground truth、IoU、NMS、RoI Pooling

    最近刚接触图像识别,理解一些概念十分困难,尤其是动不动就冒出个看不懂的英语,让人抓狂.查了不少资料后做一个总结并加上一些自己的理解,理解若有误,烦请大家指出,相互学习. 本文主要对region pro ...

  8. Ground Truth是什么意思

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一.Ground Truth是什么? 二.理解 版权声明: 前言 提示:这里可以添加本文要记录的大概内容: 提示:以下 ...

  9. 【深度学习笔记】关键点检测——标签(Ground Truth)构建

    首先介绍一下关键点回归的Ground Truth的构建问题,主要有两种思路,Coordinate和Heatmap,Coordinate即直接将关键点坐标作为最后网络需要回归的目标,这种情况下可以直接得 ...

最新文章

  1. eclipselink mysql_Eclipse连接MySQL数据库(傻瓜篇)
  2. [01]关于TDD、BDD和DDD的一些看法
  3. laravel mysql rand_laravel如何从mysql数据库中随机抽取n条数据(高性能) - Laravel学习网...
  4. SonarQube代码质量管理平台安装与配置
  5. (3)nginx的虚拟主机配置
  6. PropertyPlaceholderConfigurer实现配置文件读取
  7. thinkphp5.0学习(九):TP5.0视图和模板
  8. (转载)mysql 用drop和delete方法删除用户的区别
  9. 《强化学习》中的第11章:基于函数逼近的离轨策略方法
  10. 装箱和拆箱、类型比较
  11. java - 操作系统 Linux
  12. VS2010 SP1安装失败之”此计算机的状态不支持此安装“
  13. 10 款富有创意的博客名片设计
  14. 离婚协议中的几个重点
  15. XCTF-mobile app3
  16. 119全国消防日,我们要注意用火安全
  17. 关于面试总结3-SQL查询
  18. 数据结构(三)打印二叉树中结点层次遍历序列的实现
  19. C语言%p与%x的区别
  20. Spring Cloud入门 -- Eureka服务注册与发现(Hoxton.SR5版)

热门文章

  1. 2019年12月电子学会图形化三级题目解析含答案:合作画画
  2. Blender 烘焙贴图 导入Unity
  3. 【坤坤讲师--图】Dinic
  4. 转:阿里CEO张勇:领导者要善于“从后排把人往前拨”
  5. Unit elasticsearch.service entered failed state
  6. seo (百度百科 仅截取小部分)
  7. 使用docker搭建gitlab版本控制系统
  8. echarts设置柱状图柱体渐变色
  9. 说到做到,贴个70后男程序员的成长经历
  10. 程序员的职业发展规划