书接上回,上篇博客中我们学习到了Encoder模块,接下来我们来学习Decoder模块其代码是如何实现的。
其实Deformable-DETR最大的创新在于其提出了可变形注意力模型以及多尺度融合模块:
其主要表现在Backbone模块以及self-attention核cross-attention的计算上。这些方法都在DINO-DETR中得到继承,此外DAB-DETR中的Anchor Query设计与bounding box强化机制也有涉及。

Encoder模块

首先经过Encoder后的输出结果为 memory:torch.Size([2, 9620, 256]),其分别代表不同level的特征信息:tensor([ 0, 7220, 9044, 9500], device=‘cuda:0’)

memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)

Two-Stage

核心思想:
Encoder会生成特征memory,再自己生成初步proposals(其实就是特征图上的点坐标 xywh)。
然后分别使用非共享检测头的分类分支对memory进行分类预测,得到对每个类别的分类结果;
再用回归分支进行回归预测,得到proposals的偏移量(xywh)。再用初步proposals偏移量 得到第一个阶段的预测proposals。
然后选取top-k个分数最高的那批预测proposals作为Decoder的参考点。
并且,Decoder的object query和 query pos都是由参考点通过位置嵌入(position embedding)再接上一个全连接层 + LN层处理生成的。

Two-Stage主要是应用在初始化参考点坐标上。
one-stage的参考点是get_reference_points函数生成的,而two-stage参考点是通过gen_encoder_output_proposals函数生成的。

one-stage初始化方法

def get_reference_points(spatial_shapes, valid_ratios, device):"""生成参考点   reference points  为什么参考点是中心点?  为什么要归一化?spatial_shapes: 4个特征图的shape [4, 2]valid_ratios: 4个特征图中非padding部分的边长占其边长的比例  [bs, 4, 2]  如全是1device: cuda:0"""reference_points_list = []# 遍历4个特征图的shape  比如 H_=100  W_=150for lvl, (H_, W_) in enumerate(spatial_shapes):# 0.5 -> 99.5 取100个点  0.5 1.5 2.5 ... 99.5# 0.5 -> 149.5 取150个点 0.5 1.5 2.5 ... 149.5# ref_y: [100, 150]  第一行:150个0.5  第二行:150个1.5 ... 第100行:150个99.5# ref_x: [100, 150]  第一行:0.5 1.5...149.5   100行全部相同ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))# [100, 150] -> [bs, 15000]  150个0.5 + 150个1.5 + ... + 150个99.5 -> 除以100 归一化ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)# [100, 150] -> [bs, 15000]  100个: 0.5 1.5 ... 149.5  -> 除以150 归一化ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)# [bs, 15000, 2] 每一项都是xyref = torch.stack((ref_x, ref_y), -1)reference_points_list.append(ref)# list4: [bs, H/8*W/8, 2] + [bs, H/16*W/16, 2] + [bs, H/32*W/32, 2] + [bs, H/64*W/64, 2] -># [bs, H/8*W/8+H/16*W/16+H/32*W/32+H/64*W/64, 2]reference_points = torch.cat(reference_points_list, 1)# reference_points: [bs, H/8*W/8+H/16*W/16+H/32*W/32+H/64*W/64, 2] -> [bs, H/8*W/8+H/16*W/16+H/32*W/32+H/64*W/64, 1, 2]# valid_ratios: [1, 4, 2] -> [1, 1, 4, 2]# 复制4份 每个特征点都有4个归一化参考点 -> [bs, H/8*W/8+H/16*W/16+H/32*W/32+H/64*W/64, 4, 2]reference_points = reference_points[:, :, None] * valid_ratios[:, None]# 4个flatten后特征图的归一化参考点坐标return reference_points

Two-Stage参考点初始化方法

def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):"""得到第一阶段预测的所有proposal box output_proposals和处理后的Encoder输出output_memorymemory: Encoder输出特征  [bs, H/8 * W/8 + ... + H/64 * W/64, 256]memory_padding_mask: Encoder输出特征对应的mask [bs, H/8 * W/8 + H/16 * W/16 + H/32 * W/32 + H/64 * W/64]spatial_shapes: [4, 2] backbone输出的4个特征图的shape"""N_, S_, C_ = memory.shape  # bs  H/8 * W/8 + ... + H/64 * W/64  256base_scale = 4.0proposals = []_cur = 0   # 帮助找到mask中每个特征图的初始indexfor lvl, (H_, W_) in enumerate(spatial_shapes):  # 如H_=76  W_=112# 1、生成所有proposal box的中心点坐标xy# 展平后的mask [bs, 76, 112, 1]mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)# grid_y = [76, 112]   76行112列  第一行全是0  第二行全是1 ... 第76行全是75# grid_x = [76, 112]   76行112列  76行全是 0 1 2 ... 111grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))# grid = [76, 112, 2(xy)]   这个特征图上的所有坐标点x,ygrid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)  # [bs, 1, 1, 2(xy)]# [76, 112, 2(xy)] -> [1, 76, 112, 2] + 0.5 得到所有网格中心点坐标  这里和one-stage的get_reference_points函数原理是一样的grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale# 2、生成所有proposal box的宽高wh  第i层特征默认wh = 0.05 * (2**i)wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)# 3、concat xy+wh -> proposal xywh [bs, 76x112, 4(xywh)]proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)proposals.append(proposal)_cur += (H_ * W_)# concat 4 feature map proposals [bs, H/8 x W/8 + ... + H/64 x W/64] = [bs, 11312, 4]output_proposals = torch.cat(proposals, 1)# 筛选一下 xywh 都要处于(0.01,0.99)之间output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)#用log(x/1-x)output_proposals = torch.log(output_proposals / (1 - output_proposals))# mask的地方是无效的 直接用inf代替output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))# 再按条件筛选一下 不符合的用用inf代替output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))output_memory = memoryoutput_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))# 对encoder输出进行处理:全连接层 + LayerNormoutput_memory = self.enc_output_norm(self.enc_output(output_memory))return output_memory, output_proposals

for循环里是对不同level的所有格点创建不同尺寸的anchor框,scale其实是对有效区域的处理,后续对output_proposals的处理是筛选掉边界附近的候选,输出是对应位置的特征和编码后的proposal, 对应位置的特征用于映射proposal的类别score以及校正偏差。值得注意的是proposal并没有直接使用原始坐标,而是进行了log的编码, 在forward中的two_stage情况提取reference_points是使用sigmoid函数进行了解码,我们假设偏置量为0,可以发现:

所谓的双阶段其实就是在Encoder后不是将数据直接送入Decoder,而是送入MLP与全连接层进行分类与回归后再送入Decoder。

enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) #torch.Size([2, 9620, 91])
enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals#torch.Size([2, 9620, 4])

随后选择topk

topk = self.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]#torch.Size([2, 300])
#torch.Size([2, 300])
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
#torch.Size([2, 300, 4])
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
#torch.Size([2, 300, 4])
topk_coords_unact = topk_coords_unact.detach()

将其进行sigmoid,由于gen_encoder_output_proposals进行了log,此时sigmoid刚好可以变回初始值

reference_points = topk_coords_unact.sigmoid() #torch.Size([2, 300, 4])

随后得到初始化参考点坐标信息:
层归一化定义:

self.pos_trans_norm = nn.LayerNorm(d_model * 2)
#torch.Size([2, 300, 4])
init_reference_out = reference_points
#pos_trans_norm是层归一化,得到结果torch.Size([2, 300, 512])
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)

最终得到:query_embed torch.Size([2, 300, 256]),tgt torch.Size([2, 300, 256])

Decoder模块

终于,进入了Decoder模块,我们首先来看其传入的参数:
tgt:torch.Size([2, 300, 256])
reference_points:torch.Size([2, 300, 4])
memory:torch.Size([2, 9620, 256])
spatial_shapes:

tensor([[76, 95],[38, 48],[19, 24],[10, 12]], device='cuda:0')

level_start_index:tensor([ 0, 7220, 9044, 9500], device=‘cuda:0’)
query_embed:torch.Size([2, 300, 256])
mask_flatten:torch.Size([2, 9620])

hs, inter_references = self.decoder(tgt, reference_points, memory,spatial_shapes, level_start_index, valid_ratios, query_pos=query_embed if not self.use_dab else None, src_padding_mask=mask_flatten)

进入Decoder层:
其后就与DAB-DETR一致了,只是将cross_attention替换为可变形注意力。

DAB-Deformable-DETR源码学习记录之模型构建(二)相关推荐

  1. Spark-Core源码学习记录 3 SparkContext、SchedulerBackend、TaskScheduler初始化及应用的注册流程

    Spark-Core源码学习记录 该系列作为Spark源码回顾学习的记录,旨在捋清Spark分发程序运行的机制和流程,对部分关键源码进行追踪,争取做到知其所以然,对枝节部分源码仅进行文字说明,不深入下 ...

  2. Deformable DETR源码解读

    文章目录 一:网络创新点 二:流程详解 [part 1]deformable_detr模块 [part 2]deformable_transformer模块 [part3]Encoder模块 [par ...

  3. microsoft 的gpt2模型源码学习记录

    相关链接: gpt2论文传送门 microsoft Deepspeed gpt2源码传送 微软 Deepspeed 中集成的 gpt2 代码感觉比 huggingface 的代码可读性要强很多,这里只 ...

  4. Deformable detr源码分析

    D:\Anaconda\envs\pytorch\Lib\site-packages\mmcv\cnn\bricks\transformer.py 1.backbone Deformable detr ...

  5. Swoole源码学习记录(五)——锁和信号(二)

    Swoole版本:1.7.4-stable Github地址: https://github.com/LinkedDestiny/swoole-src-analysis 二.Mutex互斥锁 接下来是 ...

  6. aqs clh java_Java并发包源码学习之AQS框架(二)CLH lock queue和自旋锁

    上一篇文章提到AQS是基于CLH lock queue,那么什么是CLH lock queue,说复杂很复杂说简单也简单, 所谓大道至简: CLH lock queue其实就是一个FIFO的队列,队列 ...

  7. Spring 源码学习一: 使用Gradle 构建Spring 源码环境

    Gradle安装 下载Gradle: https://gradle.org/releases/ 选择安装的版本: 6.x 以上 选择版本后,点击下载. 配置环境变量: unzip gradle-6.8 ...

  8. DAB-Deformable-DETR代码学习记录之模型构建

    DAB-DETR的作者在Deformable-DETR基础上,将DAB-DETR的思想融入到了Deformable-DETR中,取得了不错的成绩.今天博主通过源码来学习下DAB-Deformable- ...

  9. 【博学谷学习记录】超强总结,用心分享 | 架构师 Mybatis源码学习总结

    Mybatis源码学习 文章目录 Mybatis源码学习 一.Mybatis架构设计 二.源码剖析 1.如何解析的全局配置文件 解析配置文件源码流程 2.如何解析的映射配置文件 Select inse ...

最新文章

  1. 回归评估+解释方差分
  2. Python学习(四) —— 编码
  3. linux 上下文切换监控,[Linux] 查看进程的上下文切换pidstat
  4. 种子接近,随机数也接近吗_接近代码,接近爸爸
  5. 解密Oracle备份工具-exp/imp
  6. 调用百度API实现人像动漫化(C++)
  7. 路长全讲座免费在线学习 免费下载
  8. 计算机grand,The Grand
  9. 热式气体质量流量计检定规程_新品发布:西尼尔ST51/54热式质量流量计
  10. 三节串联锂电池充电管理芯片,IC电路图,BOM表
  11. graphpad两组t检验_GraphPad prism -- t检验操作步骤解析~
  12. webdriver中的截图截图方法
  13. @guardedby同步注解
  14. 传奇服务器的角色文件在,传奇版本等一些软件放到服务器里的方法
  15. 打峡谷之巅有眼缘 那不如我们自己写个猜数字 C语言
  16. 已有企业认证的微信公众号快速创建一个企业小程序
  17. Android培训武汉,武汉安卓培训之Android如何使用样式创建半透明窗体
  18. TCP和UDP编程的区别,步骤
  19. Swift5.1 语言指南(二十) 类型转换
  20. Python 水果出库

热门文章

  1. 用python画漂亮的图案-使用 Python Turtle 设计简单而又美丽的图形
  2. OSChina 周日乱弹 —— 怎样判别你是她的男神
  3. guid linux 识别的分区表_GUID分区与MBR分区有什么区别?
  4. 简单教学管理系统画E-R关系图
  5. linux检查邮件命令,linux下mail 邮件查看命令
  6. 聚宝加油卡,2022年独一无二的翻身机会
  7. 如何修改别人的神经网络,人工神经网络通过调整
  8. Scrapy中对xpath使用re
  9. [PC] 微软账号连接不上
  10. Golang如何实现排序