Deformable-DETR variants:Two-stage Deformable DETR

前言

  • two stage Deformable DETR

上图为论文中关于two-stage的部分,介绍较少,DETR及其变体分为:one-stage\two-stage,其中one-stage的decoder部分queries的初始化是由随机初始化的content queries(initially set zero and unlearnable) + position embeding(set randomly and learnable)。two-stage类似于RCNN,把encoder输出的memory作为shared feature map用于ROI Proposal,并将proposals用于后面decoder的queries初始化。这样可以加快decoder部分的收敛和稳定性。

  • DINO中对于目前的初始化方法分为三类:
  1. 第一类以DETR为首的static anchors
  2. 第二类以deformable detr为首的dynamic anchors and contents
  3. 第三类作者提出的dynamic anchors and static contents

源码部分

  • gen_encoder_output_proposals(mmdetection\mmdet\models\utils\transformer.py)
   # get proposals
# get proposalsdef gen_encoder_output_proposals(self, memory, memory_padding_mask,spatial_shapes):"""Generate proposals from encoded memory.Args:memory (Tensor) : The output of encoder,has shape (bs, num_key, embed_dim).  num_key isequal the number of points on feature map fromall level.memory_padding_mask (Tensor): Padding mask for memory.has shape (bs, num_key).spatial_shapes (Tensor): The shape of all feature maps.has shape (num_level, 2).Returns:tuple: A tuple of feature map and bbox prediction.- output_memory (Tensor): The input of decoder,  \has shape (bs, num_key, embed_dim).  num_key is \equal the number of points on feature map from \all levels.- output_proposals (Tensor): The normalized proposal \after a inverse sigmoid, has shape \(bs, num_keys, 4)."""N, S, C = memory.shapeproposals = []_cur = 0for lvl, (H, W) in enumerate(spatial_shapes):mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view(N, H, W, 1)valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)grid_y, grid_x = torch.meshgrid(torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device),torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device))grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)scale = torch.cat([valid_W.unsqueeze(-1),valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scalewh = torch.ones_like(grid) * 0.05 * (2.0**lvl)proposal = torch.cat((grid, wh), -1).view(N, -1, 4)proposals.append(proposal)_cur += (H * W)output_proposals = torch.cat(proposals, 1)output_proposals_valid = ((output_proposals > 0.01) &(output_proposals < 0.99)).all(-1, keepdim=True)output_proposals = torch.log(output_proposals / (1 - output_proposals))output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))output_memory = memoryoutput_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))output_memory = output_memory.masked_fill(~output_proposals_valid,float(0))output_memory = self.enc_output_norm(self.enc_output(output_memory))return output_memory, output_proposals
  • class DeformableDetrTransformer(Transformer):
def forward(self,mlvl_feats,mlvl_masks,query_embed,mlvl_pos_embeds,reg_branches=None,cls_branches=None,**kwargs):assert self.as_two_stage or query_embed is not Nonefeat_flatten = []mask_flatten = []lvl_pos_embed_flatten = []spatial_shapes = []for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):bs, c, h, w = feat.shapespatial_shape = (h, w)spatial_shapes.append(spatial_shape)feat = feat.flatten(2).transpose(1, 2)mask = mask.flatten(1)pos_embed = pos_embed.flatten(2).transpose(1, 2)lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)lvl_pos_embed_flatten.append(lvl_pos_embed)feat_flatten.append(feat)mask_flatten.append(mask)feat_flatten = torch.cat(feat_flatten, 1)mask_flatten = torch.cat(mask_flatten, 1)lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device)level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)reference_points = \self.get_reference_points(spatial_shapes,valid_ratios,device=feat.device)feat_flatten = feat_flatten.permute(1, 0, 2)  # (H*W, bs, embed_dims)lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2)  # (H*W, bs, embed_dims)memory = self.encoder(query=feat_flatten,key=None,value=None,query_pos=lvl_pos_embed_flatten,query_key_padding_mask=mask_flatten,spatial_shapes=spatial_shapes,reference_points=reference_points,level_start_index=level_start_index,valid_ratios=valid_ratios,**kwargs)memory = memory.permute(1, 0, 2)bs, _, c = memory.shapeif self.as_two_stage:output_memory, output_proposals = \self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory)enc_outputs_coord_unact = \reg_branches[self.decoder.num_layers](output_memory) + output_proposalstopk = self.two_stage_num_proposalstopk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1,topk_proposals.unsqueeze(-1).repeat(1, 1, 4))topk_coords_unact = topk_coords_unact.detach()reference_points = topk_coords_unact.sigmoid()init_reference_out = reference_pointspos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))query_pos, query = torch.split(pos_trans_out, c, dim=2)else:query_pos, query = torch.split(query_embed, c, dim=1)query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)query = query.unsqueeze(0).expand(bs, -1, -1)reference_points = self.reference_points(query_pos).sigmoid()init_reference_out = reference_points# decoderquery = query.permute(1, 0, 2)memory = memory.permute(1, 0, 2)query_pos = query_pos.permute(1, 0, 2)inter_states, inter_references = self.decoder(query=query,key=None,value=memory,query_pos=query_pos,key_padding_mask=mask_flatten,reference_points=reference_points,spatial_shapes=spatial_shapes,level_start_index=level_start_index,valid_ratios=valid_ratios,reg_branches=reg_branches,**kwargs)inter_references_out = inter_referencesif self.as_two_stage:return inter_states, init_reference_out,\inter_references_out, enc_outputs_class,\enc_outputs_coord_unactreturn inter_states, init_reference_out, \inter_references_out, None, None

Deformable-DETR(two-stage version)中Encoder Proposal相关推荐

  1. Deformable DETR要点解读

    最近整理Transformer和set prediction相关的检测&实例分割文章,感兴趣的可以跟一下: DETR: End-to-End Object Detection with Tra ...

  2. Deformable Detr代码阅读

    前言 本文主要是自己在阅读mmdet中Deformable Detr的源码时的一个记录,如有错误或者问题,欢迎指正 deformable attention的流程 首先zq即为object query ...

  3. Deformable DETR论文翻译

    Deformable DETR论文翻译 摘要 1.介绍 2.相关工作 3. 回顾transformer和DETR 4.方法 4.1 端到端目标检测中的可形变transformer 4.2 DEFORM ...

  4. 计算机视觉算法——基于Transformer的目标检测(DETR / Deformable DETR / DETR 3D)

    计算机视觉算法--基于Transformer的目标检测(DETR / Deformable DETR / DETR 3D) 计算机视觉算法--基于Transformer的目标检测(DETR / Def ...

  5. 【Deformable DETR 论文+源码解读】Deformable Transformers for End-to-End Object Detection

    目录 前言 一.背景和改进思路 二.细节原理和源码讲解 2.1.多尺度特征 2.1.1.backbone生成多尺度特征 2.1.2.多尺度位置编码 2.2.多尺度可变形注意力 2.2.1.普通多头注意 ...

  6. 【Transformer】Deformable DETR: deformable transformers for end-to-end object detection

    文章目录 一.背景和动机 二.方法 三.效果 四.可视化 论文链接:https://arxiv.org/pdf/2010.04159.pdf 代码链接:https://github.com/funda ...

  7. Deformable DETR

    目录 一.Deformable Convolution原理分析 Deformable DETR 原理分析 Deformable Attention Module Multi-scale Deforma ...

  8. Deformable detr源码分析

    D:\Anaconda\envs\pytorch\Lib\site-packages\mmcv\cnn\bricks\transformer.py 1.backbone Deformable detr ...

  9. DETR、conditional DETR、Deformable DETR

    DETR是将transformer机制应用到目标检测领域的算法模型.其主要思想是利用transformer的encoder-decoder架构,利用注意力机制来实现端到端获得目标检测的结果. DETR ...

最新文章

  1. Vue.js 状态过渡
  2. Sonar与jenkins集成
  3. mysql 5.6 利用gtid 同步数据遇到的问题记录
  4. 基于JAVA+Servlet+JSP+MYSQL的失物招领系统
  5. c语言求浮点数矩阵的逆程序,逆矩阵求程序!!!谢谢
  6. python线程池如何调度,python线程池控制
  7. ajax异步加载网页爬虫
  8. C/C++ QT图形开发高级组件 [空]
  9. Cisco路由器密码重置
  10. Mac Android Studio连接MuMu模拟器
  11. 常见大数据应用有哪些?
  12. 虚拟服务器伪静态怎么设置,虚拟主机如何设置伪静态
  13. 总结2018,规划2019
  14. Ubuntu防火墙ufw规则配置
  15. outguess 使用方法
  16. 一年推出四款社交产品,百度社交难在哪?
  17. 虚幻4 UE4 绑定按键操作及切换视角
  18. CSS学习笔记(跟随b站pink老师)
  19. 艾灵网络完成战略轮融资
  20. 服务器修改字体,Win10 1909默认字体怎么修改?Win10 1909默认字体修改教程

热门文章

  1. thinkpad S3笔记本无线网速慢的解决方案
  2. 最后谁剩下来了就返回哪个阵营 Dota2 Senate
  3. 【论文阅读】正则表达式也可以被当成神经网络训练吗?[附项目代码与代码详细说明]
  4. 论文阅读笔记:A Scalable Exemplar-based Subspace ClusteringAlgorithm for Class-Imbalanced Data
  5. React有哪些性能优化的手段?
  6. GenBank序列名称解析
  7. CSDN 如何设置博客名、博客简介及描述?
  8. Debian安装docker全流程
  9. oss :Request has expired.
  10. android拼图游戏报告,拼图游戏报告分析报告.doc