文章目录

  • 什么是CenterSample
  • 新的heads层
  • loss类修改
  • 模型训练和测试结果

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

什么是CenterSample

在原始的FCOS实现中,对于处在标注框内部的锚点全部都会算成正样本。但是,标注框内部靠近边缘的部分往往仍然是背景部分,因此这部分正样本实际对应的点确是属于负样本的背景部分,这会对模型的学习造成困扰。在前面对原始版本的FCOS的训练过程中可以发现,FCOS模型在resize=667分辨率下需要24个epoch才能训练到接近RetinaNet训练12个epoch时的模型性能。而在更大的resize=1000分辨率下,模型性能在前10个epoch有所上升,但是之后开始下降,这说明在大分辨率图片输入情况下,前面所说的在标注框内边缘的正样本会对模型学习产生更大的负面作用(因为这种锚点在大分辨率情况下更多)。
CenterSample的作用是在标注框区域内,以标注框中心点为圆心,取一个比框更小的圆形部分,只有在这个圆形部分内的锚点才算成正样本。这样就会把大多数在标注框内边缘、实际落在背景部分的锚点标注为负样本,FCOS模型的收敛速度就会变快,最终性能也更好。具体来说,我们设置一个超参数center_sample_radius,即这个圆的基础半径。然后,根据这个标注框分配到FPN的哪一层,将center_sample_radius乘以该层的stride,就得到了这个标注框内圆形部分的真正半径。
同时,在代码中,我还在几个head层加入group nomlization层,这样最后得到的FCOS就和论文中所有改进都加上后的FCOS模型配置一样了。

新的heads层

分类heads、回归heads、centerness heads全部写在一个类中,centerness head与回归heads共用。

heads代码实现如下:

class FCOSClsRegCntHead(nn.Module):def __init__(self,inplanes,num_classes,num_layers=4,prior=0.01,use_gn=True,cnt_on_reg=True):super(FCOSClsRegCntHead, self).__init__()self.cnt_on_reg = cnt_on_regcls_layers = []for _ in range(num_layers):cls_layers.append(nn.Conv2d(inplanes,inplanes,kernel_size=3,stride=1,padding=1))if use_gn:cls_layers.append(nn.GroupNorm(32, inplanes))cls_layers.append(nn.ReLU(inplace=True))self.cls_head = nn.Sequential(*cls_layers)reg_layers = []for _ in range(num_layers):reg_layers.append(nn.Conv2d(inplanes,inplanes,kernel_size=3,stride=1,padding=1))if use_gn:reg_layers.append(nn.GroupNorm(32, inplanes))reg_layers.append(nn.ReLU(inplace=True))self.reg_head = nn.Sequential(*reg_layers)self.cls_out = nn.Conv2d(inplanes,num_classes,kernel_size=3,stride=1,padding=1)self.reg_out = nn.Conv2d(inplanes,4,kernel_size=3,stride=1,padding=1)self.center_out = nn.Conv2d(inplanes,1,kernel_size=3,stride=1,padding=1)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight, std=0.01)if m.bias is not None:nn.init.constant_(m.bias, val=0)prior = priorb = -math.log((1 - prior) / prior)self.cls_out.bias.data.fill_(b)def forward(self, x):cls_x = self.cls_head(x)reg_x = self.reg_head(x)del xcls_output = self.cls_out(cls_x)reg_output = self.reg_out(reg_x)if self.cnt_on_reg:center_output = self.center_out(reg_x)else:center_output = self.center_out(cls_x)return cls_output, reg_output, center_output

loss类修改

loss类只需要修改ground truth分配这部分添加centersample机制,即修改get_batch_position_annotations函数。修改后的get_batch_position_annotations实现如下:

   def get_batch_position_annotations(self, cls_heads, reg_heads,center_heads, batch_positions,annotations):"""Assign a ground truth target for each position on feature map"""device = annotations.devicebatch_mi, batch_stride = [], []for reg_head, mi, stride in zip(reg_heads, self.mi, self.strides):mi = torch.tensor(mi).to(device)B, H, W, _ = reg_head.shapeper_level_mi = torch.zeros(B, H, W, 2).to(device)per_level_mi = per_level_mi + mibatch_mi.append(per_level_mi)per_level_stride = torch.zeros(B, H, W, 1).to(device)per_level_stride = per_level_stride + stridebatch_stride.append(per_level_stride)cls_preds,reg_preds,center_preds,all_points_position,all_points_mi,all_points_stride=[],[],[],[],[],[]for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi, per_level_stride in zip(cls_heads, reg_heads, center_heads, batch_positions, batch_mi,batch_stride):cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])center_pred = center_pred.view(center_pred.shape[0], -1,center_pred.shape[-1])per_level_position = per_level_position.view(per_level_position.shape[0], -1, per_level_position.shape[-1])per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,per_level_mi.shape[-1])per_level_stride = per_level_stride.view(per_level_stride.shape[0], -1, per_level_stride.shape[-1])cls_preds.append(cls_pred)reg_preds.append(reg_pred)center_preds.append(center_pred)all_points_position.append(per_level_position)all_points_mi.append(per_level_mi)all_points_stride.append(per_level_stride)cls_preds = torch.cat(cls_preds, axis=1)reg_preds = torch.cat(reg_preds, axis=1)center_preds = torch.cat(center_preds, axis=1)all_points_position = torch.cat(all_points_position, axis=1)all_points_mi = torch.cat(all_points_mi, axis=1)all_points_stride = torch.cat(all_points_stride, axis=1)batch_targets = []for per_image_position, per_image_mi, per_image_stride, per_image_annotations in zip(all_points_position, all_points_mi, all_points_stride,annotations):per_image_annotations = per_image_annotations[per_image_annotations[:, 4] >= 0]points_num = per_image_position.shape[0]if per_image_annotations.shape[0] == 0:# 6:l,t,r,b,class_index,center-ness_gtper_image_targets = torch.zeros([points_num, 6], device=device)else:annotaion_num = per_image_annotations.shape[0]per_image_gt_bboxes = per_image_annotations[:, 0:4]candidates = torch.zeros([points_num, annotaion_num, 4],device=device)candidates = candidates + per_image_gt_bboxes.unsqueeze(0)per_image_position = per_image_position.unsqueeze(1).repeat(1, annotaion_num, 1)if self.use_center_sample:candidates_center = (candidates[:, :, 2:4] +candidates[:, :, 0:2]) / 2judge_distance = per_image_stride * self.center_sample_radiusjudge_distance = judge_distance.repeat(1, annotaion_num)candidates[:, :,0:2] = per_image_position[:, :,0:2] - candidates[:, :,0:2]candidates[:, :,2:4] = candidates[:, :,2:4] - per_image_position[:, :,0:2]candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)sample_flag = (candidates_min_value[:, :, 0] >0).int().unsqueeze(-1)# get all negative reg targets which points ctr out of gt boxcandidates = candidates * sample_flag# if use center sample get all negative reg targets which points not in center circleif self.use_center_sample:compute_distance = torch.sqrt((per_image_position[:, :, 0] -candidates_center[:, :, 0])**2 +(per_image_position[:, :, 1] -candidates_center[:, :, 1])**2)center_sample_flag = (compute_distance <judge_distance).int().unsqueeze(-1)candidates = candidates * center_sample_flag# get all negative reg targets which assign ground turth not in range of micandidates_max_value, _ = candidates.max(axis=-1, keepdim=True)per_image_mi = per_image_mi.unsqueeze(1).repeat(1, annotaion_num, 1)m1_negative_flag = (candidates_max_value[:, :, 0] >per_image_mi[:, :, 0]).int().unsqueeze(-1)candidates = candidates * m1_negative_flagm2_negative_flag = (candidates_max_value[:, :, 0] <per_image_mi[:, :, 1]).int().unsqueeze(-1)candidates = candidates * m2_negative_flagfinal_sample_flag = candidates.sum(axis=-1).sum(axis=-1)final_sample_flag = final_sample_flag > 0positive_index = (final_sample_flag == True).nonzero().squeeze(dim=-1)# if no assign positive sampleif len(positive_index) == 0:del candidates# 6:l,t,r,b,class_index,center-ness_gtper_image_targets = torch.zeros([points_num, 6],device=device)else:positive_candidates = candidates[positive_index]del candidatessample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)sample_box_gts = sample_box_gts.repeat(positive_candidates.shape[0], 1, 1)sample_class_gts = per_image_annotations[:, 4].unsqueeze(-1).unsqueeze(0)sample_class_gts = sample_class_gts.repeat(positive_candidates.shape[0], 1, 1)# 6:l,t,r,b,class_index,center-ness_gtper_image_targets = torch.zeros([points_num, 6],device=device)if positive_candidates.shape[1] == 1:# if only one candidate for each positive sample# assign l,t,r,b,class_index,center_ness_gt ground truth# class_index value from 1 to 80 represent 80 positive classes# class_index value 0 represenet negative classpositive_candidates = positive_candidates.squeeze(1)sample_class_gts = sample_class_gts.squeeze(1)per_image_targets[positive_index,0:4] = positive_candidatesper_image_targets[positive_index,4:5] = sample_class_gts + 1l, t, r, b = per_image_targets[positive_index, 0:1], per_image_targets[positive_index, 1:2], per_image_targets[positive_index,2:3], per_image_targets[positive_index,3:4]per_image_targets[positive_index, 5:6] = torch.sqrt((torch.min(l, r) / torch.max(l, r)) *(torch.min(t, b) / torch.max(t, b)))else:# if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point samplegts_w_h = sample_box_gts[:, :,2:4] - sample_box_gts[:, :,0:2]gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]positive_candidates_value = positive_candidates.sum(axis=2)# make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidatesINF = 100000000inf_tensor = torch.ones_like(gts_area) * INFgts_area = torch.where(torch.eq(positive_candidates_value, 0.),inf_tensor, gts_area)# get the smallest object candidate index_, min_index = gts_area.min(axis=1)candidate_indexes = (torch.linspace(1, positive_candidates.shape[0],positive_candidates.shape[0]) -1).long()final_candidate_reg_gts = positive_candidates[candidate_indexes, min_index, :]final_candidate_cls_gts = sample_class_gts[candidate_indexes, min_index]# assign l,t,r,b,class_index,center_ness_gt ground truthper_image_targets[positive_index,0:4] = final_candidate_reg_gtsper_image_targets[positive_index,4:5] = final_candidate_cls_gts + 1l, t, r, b = per_image_targets[positive_index, 0:1], per_image_targets[positive_index, 1:2], per_image_targets[positive_index,2:3], per_image_targets[positive_index,3:4]per_image_targets[positive_index, 5:6] = torch.sqrt((torch.min(l, r) / torch.max(l, r)) *(torch.min(t, b) / torch.max(t, b)))per_image_targets = per_image_targets.unsqueeze(0)batch_targets.append(per_image_targets)batch_targets = torch.cat(batch_targets, axis=0)batch_targets = torch.cat([batch_targets, all_points_position], axis=2)# batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_yreturn cls_preds, reg_preds, center_preds, batch_targets

模型训练和测试结果

重新训练FCOS模型,首先看看resize=667时的结果。

Network batch gpu-num apex syncbn epoch5-mAP-mAR-loss epoch10-mAP-mAR-loss epoch12-mAP-mAR-loss
ResNet50-FCOS-myresize667-fastdecode 24 2 yes no 0.272,0.399,1.15 0.293,0.422,1.07 0.312,0.445,1.06
ResNet101-FCOS-myresize667-fastdecode 16 2 yes no 0.261,0.390,1.14 0.307,0.438,1.06 0.325,0.455,1.05

使用CenterSample后,FCOS模型现在只需要训练12个epoch就比同分辨率下训练12个epoch后的RetinaNet模型性能要好。再使用resize=1000的输入训练ResNet50-FCOS,使用上表mAP=0.312的预训练模型参数初始化网络,训练24个epoch后模型性能如下:

Network batch gpu-num apex syncbn epoch12-mAP-mAR-loss epoch24-mAP-mAR-loss
ResNet50-FCOS-myresize1000-fastdecode 16 2 yes no 0.352,0.490,1.03 0.352,0.491,1.01

训练24个epoch后FCOS的mAP为35.2,比训练24个epoch后的同分辨率RetinaNet模型要高1.3个百分点。

【庖丁解牛】从零实现FCOS(终):CenterSample的重要性相关推荐

  1. 【庖丁解牛】从零实现FCOS(二):ground truth分配与loss计算

    文章目录 Anchor free?Anchor base? FCOS的ground truth分配 loss计算 完整loss代码 所有代码已上传到本人github repository:https: ...

  2. 《ArcGIS Runtime SDK for Android开发笔记》——问题集:如何解决ArcGIS Runtime SDK for Android中文标注无法显示的问题(转载)...

    Geodatabase中中文标注编码乱码一直是一个比较头疼的问题之前也不知道问题出在哪里?在百度后发现园子里的zssai已经对这个问题原因做了一个详细说明.这里将原文引用如下: 说明:此文转载自htt ...

  3. 第三章 线性系统的时域分析

    第三章 线性系统的时域分析 3.1 引言 经典控制理论中常用的工程分析方法有 3 种 时域分析法 根轨迹法 频率响应法 分析内容:动态性能:稳态性能:稳定性 时域分析法--在时间域内研究系统在典型输入 ...

  4. 强化学习(四) - 蒙特卡洛方法(Monte Carlo Methods)及实例

    强化学习(四) - 蒙特卡洛方法(Monte Carlo Methods)及实例 4. 蒙特卡洛方法 4.1 蒙特卡洛预测 例4.1:Blackjack(21点) 4.2 动作价值的蒙特卡洛估计 4. ...

  5. 数据集按类划分_大数据风控面试(五) 模型评估与优化

    目录: 1 简单介绍一下风控模型常用的评估指标 2 为什么ROC适合不平衡数据的评价? 3 如何处理样本不平衡的问题? 4 什么是模型的欠拟合和过拟合? 5 如何判断模型是否存在过拟合或欠拟合?对应的 ...

  6. 机器学习实战(五)——Logistic 回归

    文章目录 Logistic 回归 5.2 基于最优化方法的最佳回归系数确定 5.2.1 梯度上升法 5.3 python实战 5.3.1 查看数据集分布情况 5.3.2 训练 5.3.3 绘制决策边界 ...

  7. 深度学习基础知识整理

    版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/kwame211/article/details/81165381 本文是在七月的BAT机器学习面试1 ...

  8. BAT机器学习面试1000题系列(第1~305题

    1 请简要介绍下SVM,机器学习 ML模型 易SVM,全称是support vector machine,中文名叫支持向量机.SVM是一个面向数据的分类算法,它的目标是为确定一个分类超平面,从而将不同 ...

  9. 【转】BAT机器学习面试1000题系列(1~50)

    BAT机器学习面试1000题系列 整理:July.元超.立娜.德伟.贾茹.王剑.AntZ.孟莹等众人.本系列大部分题目来源于公开网络,取之分享,用之分享,且在撰写答案过程中若引用他人解析则必注明原作者 ...

最新文章

  1. python是一种跨平台开源免费的高级动态编程语言吗_第1章 管中窥豹:Python概述免费阅读_Python程序设计开发宝典免费全文_百度阅读...
  2. matlab循环数组里的数据库,用于在matlab的python中循环数组
  3. Groovy里的绕过getter方法直接访问类属性的办法
  4. ejb 2.0 3.0_EJB 3.0注入和查找简介
  5. H3C-WA2210升级
  6. 关于两个list深层遍历
  7. 超实用的微信图片转换工具
  8. cmake添加查找目录_cmake find_package路径详解
  9. Swift - 05 - 数值型字面量
  10. C++虚函数(多态性)
  11. Prescan:关于Prescan与Matlab联合仿真问题小总(不定时补充)
  12. python 小说分析_Python文章相关性分析---金庸武侠小说分析
  13. 最全jar包下载地址
  14. java 内置中介模式_详解Java设计模式编程中的中介者模式
  15. Word转PDF非常好用的软件——pdfFactory Pro
  16. arcgis里面怎么截图_怎么利用ARCGIS裁剪图像
  17. 如何快速进行十进制二进制转换
  18. Pandas RuntimeWarning: More than 20 figures have been opened. Figures created plt.close()也不起作用
  19. 联想Y430p win8.1装win7双系统
  20. MusicPlayerByService

热门文章

  1. 大一java实验课_JAVA实验课填空题集合.doc
  2. 选下拉框的的值对应上传相应的图片_excel表格下拉菜单调用对应数据,如何在excel中实现,选择下拉菜单某一项,该表格中就出现选项对应的数据?...
  3. Thomson.Reuters.EndNote.X6.v16.0.0.6348.Cracked-EAT
  4. 蹭个热度:我只希望孩子心中有爱,眼里有光……
  5. C# GDI画图系列(五) 添加文字和导出图片等附加功能
  6. 51nod Vote 善意的投票
  7. PS高效处理图片总结
  8. https://github.com/gnustep/
  9. 聚焦区块链应用,SegmentFault 黑客马拉松引爆珠三角
  10. xp计算机无法正常启动,xp系统启动修复_两种方法修复XP系统无法正常启动进入不了计算机_xp系统启动修复工具...