Deformable-DETR(two-stage version)中Encoder Proposal
Deformable-DETR variants:Two-stage Deformable DETR
前言
- two stage Deformable DETR
上图为论文中关于two-stage的部分,介绍较少,DETR及其变体分为:one-stage\two-stage,其中one-stage的decoder部分queries的初始化是由随机初始化的content queries(initially set zero and unlearnable) + position embeding(set randomly and learnable)。two-stage类似于RCNN,把encoder输出的memory作为shared feature map用于ROI Proposal,并将proposals用于后面decoder的queries初始化。这样可以加快decoder部分的收敛和稳定性。
- DINO中对于目前的初始化方法分为三类:
- 第一类以DETR为首的static anchors
- 第二类以deformable detr为首的dynamic anchors and contents
- 第三类作者提出的dynamic anchors and static contents
源码部分
- gen_encoder_output_proposals(mmdetection\mmdet\models\utils\transformer.py)
# get proposals
# get proposalsdef gen_encoder_output_proposals(self, memory, memory_padding_mask,spatial_shapes):"""Generate proposals from encoded memory.Args:memory (Tensor) : The output of encoder,has shape (bs, num_key, embed_dim). num_key isequal the number of points on feature map fromall level.memory_padding_mask (Tensor): Padding mask for memory.has shape (bs, num_key).spatial_shapes (Tensor): The shape of all feature maps.has shape (num_level, 2).Returns:tuple: A tuple of feature map and bbox prediction.- output_memory (Tensor): The input of decoder, \has shape (bs, num_key, embed_dim). num_key is \equal the number of points on feature map from \all levels.- output_proposals (Tensor): The normalized proposal \after a inverse sigmoid, has shape \(bs, num_keys, 4)."""N, S, C = memory.shapeproposals = []_cur = 0for lvl, (H, W) in enumerate(spatial_shapes):mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view(N, H, W, 1)valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)grid_y, grid_x = torch.meshgrid(torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device),torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device))grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)scale = torch.cat([valid_W.unsqueeze(-1),valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scalewh = torch.ones_like(grid) * 0.05 * (2.0**lvl)proposal = torch.cat((grid, wh), -1).view(N, -1, 4)proposals.append(proposal)_cur += (H * W)output_proposals = torch.cat(proposals, 1)output_proposals_valid = ((output_proposals > 0.01) &(output_proposals < 0.99)).all(-1, keepdim=True)output_proposals = torch.log(output_proposals / (1 - output_proposals))output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))output_memory = memoryoutput_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))output_memory = output_memory.masked_fill(~output_proposals_valid,float(0))output_memory = self.enc_output_norm(self.enc_output(output_memory))return output_memory, output_proposals
- class DeformableDetrTransformer(Transformer):
def forward(self,mlvl_feats,mlvl_masks,query_embed,mlvl_pos_embeds,reg_branches=None,cls_branches=None,**kwargs):assert self.as_two_stage or query_embed is not Nonefeat_flatten = []mask_flatten = []lvl_pos_embed_flatten = []spatial_shapes = []for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):bs, c, h, w = feat.shapespatial_shape = (h, w)spatial_shapes.append(spatial_shape)feat = feat.flatten(2).transpose(1, 2)mask = mask.flatten(1)pos_embed = pos_embed.flatten(2).transpose(1, 2)lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)lvl_pos_embed_flatten.append(lvl_pos_embed)feat_flatten.append(feat)mask_flatten.append(mask)feat_flatten = torch.cat(feat_flatten, 1)mask_flatten = torch.cat(mask_flatten, 1)lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device)level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)reference_points = \self.get_reference_points(spatial_shapes,valid_ratios,device=feat.device)feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)memory = self.encoder(query=feat_flatten,key=None,value=None,query_pos=lvl_pos_embed_flatten,query_key_padding_mask=mask_flatten,spatial_shapes=spatial_shapes,reference_points=reference_points,level_start_index=level_start_index,valid_ratios=valid_ratios,**kwargs)memory = memory.permute(1, 0, 2)bs, _, c = memory.shapeif self.as_two_stage:output_memory, output_proposals = \self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory)enc_outputs_coord_unact = \reg_branches[self.decoder.num_layers](output_memory) + output_proposalstopk = self.two_stage_num_proposalstopk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1,topk_proposals.unsqueeze(-1).repeat(1, 1, 4))topk_coords_unact = topk_coords_unact.detach()reference_points = topk_coords_unact.sigmoid()init_reference_out = reference_pointspos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))query_pos, query = torch.split(pos_trans_out, c, dim=2)else:query_pos, query = torch.split(query_embed, c, dim=1)query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)query = query.unsqueeze(0).expand(bs, -1, -1)reference_points = self.reference_points(query_pos).sigmoid()init_reference_out = reference_points# decoderquery = query.permute(1, 0, 2)memory = memory.permute(1, 0, 2)query_pos = query_pos.permute(1, 0, 2)inter_states, inter_references = self.decoder(query=query,key=None,value=memory,query_pos=query_pos,key_padding_mask=mask_flatten,reference_points=reference_points,spatial_shapes=spatial_shapes,level_start_index=level_start_index,valid_ratios=valid_ratios,reg_branches=reg_branches,**kwargs)inter_references_out = inter_referencesif self.as_two_stage:return inter_states, init_reference_out,\inter_references_out, enc_outputs_class,\enc_outputs_coord_unactreturn inter_states, init_reference_out, \inter_references_out, None, None
Deformable-DETR(two-stage version)中Encoder Proposal相关推荐
- Deformable DETR要点解读
最近整理Transformer和set prediction相关的检测&实例分割文章,感兴趣的可以跟一下: DETR: End-to-End Object Detection with Tra ...
- Deformable Detr代码阅读
前言 本文主要是自己在阅读mmdet中Deformable Detr的源码时的一个记录,如有错误或者问题,欢迎指正 deformable attention的流程 首先zq即为object query ...
- Deformable DETR论文翻译
Deformable DETR论文翻译 摘要 1.介绍 2.相关工作 3. 回顾transformer和DETR 4.方法 4.1 端到端目标检测中的可形变transformer 4.2 DEFORM ...
- 计算机视觉算法——基于Transformer的目标检测(DETR / Deformable DETR / DETR 3D)
计算机视觉算法--基于Transformer的目标检测(DETR / Deformable DETR / DETR 3D) 计算机视觉算法--基于Transformer的目标检测(DETR / Def ...
- 【Deformable DETR 论文+源码解读】Deformable Transformers for End-to-End Object Detection
目录 前言 一.背景和改进思路 二.细节原理和源码讲解 2.1.多尺度特征 2.1.1.backbone生成多尺度特征 2.1.2.多尺度位置编码 2.2.多尺度可变形注意力 2.2.1.普通多头注意 ...
- 【Transformer】Deformable DETR: deformable transformers for end-to-end object detection
文章目录 一.背景和动机 二.方法 三.效果 四.可视化 论文链接:https://arxiv.org/pdf/2010.04159.pdf 代码链接:https://github.com/funda ...
- Deformable DETR
目录 一.Deformable Convolution原理分析 Deformable DETR 原理分析 Deformable Attention Module Multi-scale Deforma ...
- Deformable detr源码分析
D:\Anaconda\envs\pytorch\Lib\site-packages\mmcv\cnn\bricks\transformer.py 1.backbone Deformable detr ...
- DETR、conditional DETR、Deformable DETR
DETR是将transformer机制应用到目标检测领域的算法模型.其主要思想是利用transformer的encoder-decoder架构,利用注意力机制来实现端到端获得目标检测的结果. DETR ...
最新文章
- Vue.js 状态过渡
- Sonar与jenkins集成
- mysql 5.6 利用gtid 同步数据遇到的问题记录
- 基于JAVA+Servlet+JSP+MYSQL的失物招领系统
- c语言求浮点数矩阵的逆程序,逆矩阵求程序!!!谢谢
- python线程池如何调度,python线程池控制
- ajax异步加载网页爬虫
- C/C++ QT图形开发高级组件 [空]
- Cisco路由器密码重置
- Mac Android Studio连接MuMu模拟器
- 常见大数据应用有哪些?
- 虚拟服务器伪静态怎么设置,虚拟主机如何设置伪静态
- 总结2018,规划2019
- Ubuntu防火墙ufw规则配置
- outguess 使用方法
- 一年推出四款社交产品,百度社交难在哪?
- 虚幻4 UE4 绑定按键操作及切换视角
- CSS学习笔记(跟随b站pink老师)
- 艾灵网络完成战略轮融资
- 服务器修改字体,Win10 1909默认字体怎么修改?Win10 1909默认字体修改教程
热门文章
- thinkpad S3笔记本无线网速慢的解决方案
- 最后谁剩下来了就返回哪个阵营 Dota2 Senate
- 【论文阅读】正则表达式也可以被当成神经网络训练吗?[附项目代码与代码详细说明]
- 论文阅读笔记:A Scalable Exemplar-based Subspace ClusteringAlgorithm for Class-Imbalanced Data
- React有哪些性能优化的手段?
- GenBank序列名称解析
- CSDN 如何设置博客名、博客简介及描述?
- Debian安装docker全流程
- oss :Request has expired.
- android拼图游戏报告,拼图游戏报告分析报告.doc