DETR: End-to-End Object Detection with Transformers 网络解析
说明:

  1. 个人理解,如有错误请及时提出。
  2. 由于自己电脑驱动较低不满足440及以上,所以目前网络中张量的具体维度不太清楚,后续如有条件再更新博客。
  3. 不得不感叹论文作者的数学功底扎实、知识涉猎广博。
  4. 太暴力了,依靠一些精巧的结构和强大的硬件支持替代了大家精心设计的anchor、nms等结构。
  5. 对于小目标、多目标场景仍然较差。
  6. 论文涉及的东西很多,本博客逐步添加,欢迎提出修改意见。

资源:

  1. 论文地址
  2. github地址
  3. 李宏毅老师Transformer视频B站版
  4. Yannic Kilcher大神的论文讲解,“你懂得”网站的地址。另外提供国内CSDN下载,欢迎支持

本博客目录

  • 总体结构
  • Transformer
    • Multi-Head Self-Attention
    • ADD & Norm
    • FFN
    • Spatial positional encoding
    • Object queries
  • Bipartite Matching
  • 损失函数

总体结构


  1. 整个网络思路非常清晰,首先是将图像[3, H, W]送入常用的卷积神经网络backbone来提取特征,以resnet为例,那么在backbone最后这里得到的特征图[2048, H/32, W/32]
  2. 然后将来自backbone的特征图展开,变为[C, HW]的形状,送入由编码器和解码器组成的Transformer结构中
  3. 最后是将预测得到的“集合”与真值进行匹配(匈牙利算法),通过最小化损失(代价)来一次性预测检测(实例分割)结果

Transformer


如上图所示,左图为DETR中的transformer结构,右图为文章Attention Is All You Need中的结构图,基本上还是一致的。
李宏毅老师讲到transformer实际上是seq2seq model with ‘self-attention’,所以下面着重来讲一下DETR里面涉及到的一些细节问题

Multi-Head Self-Attention

其实NLP中transformer需要处理的是一些序列数据,那么为了处理序列数据首先可以想到的就是RNN结构,这种结构考虑了上下文关系,相对于CNN(感受野问题)来讲具有优势。但是如下图所示的RCNN模块存在一个问题:任务难以并行

因为an to bn 的计算依赖 an-1 to bn-1的中间结果,这就意味着任务必须是串行的,这对于目前所说的大数据,并行计算,云计算来讲是不合适的。而相对于RNN来讲CNN则更加适合并行计算,那么self-attention模块就是一个典型的合理解。如下图所示:

那么这里的核心思想就是怎么去构建这个self-attention结构。(x1, x2, x3, x4)表示我们的输入,(b1, b2, b3, b4)表示我们的输出

这里q, k, v可以按照词袋模型的思路来理解,那么q是输入的矩阵,而k则是数据字典,那么v就是我们将将输入的矩阵q通过对应数据字典进行编码后形成的新的特征矩阵。因此他们的公式可以分别表示为
q i = W q a i q^i = W^qa^i qi=Wqai
k i = W k a i k^i = W^ka^i ki=Wkai
v i = W v a i v^i = W^va^i vi=Wvai
有了这个数学表示,那么接下来我们需要的是拿q1与(k1, k2, k3, k4)进行点乘,然后是q2, q3, q4,这个过程表示如下

其中
α 1 , i = q 1 ⋅ k i / d \alpha_{1,i} = q^1\cdot k^i/ \sqrt{d} α1,i=q1ki/d


α ^ 1 , i = e x p ( α 1 , i ) / ∑ j e x p ( α 1 , j ) \hat{\alpha}_{1,i} = exp(\alpha_{1,i})/\sum_j{exp(\alpha_{1,j})} α^1,i=exp(α1,i)/jexp(α1,j)
d是q和k的维度。到这里我们获得了输入x1与四个数据字典k相关的“权值”,那么为了获得最后的“词袋模型编码后的向量”,就需要alpha-head与v进行一定的操作,如下图所示

b 1 = ∑ i α ^ 1 , i v i b^1 = \sum_i{\hat{\alpha}_{1,i}v^i} b1=iα^1,ivi
同理很容易就可以计算出(b1, b2, b3, b4),具体的矩阵推导可以查看李宏毅老师的视频和PPT,很容易可以理解为什么说这个结构可以替代RNN进行并行加速。而multi-head self-attention就很好理解了,就是多层head的堆叠,这是深度学习中很常见的网络构建方式,下图是一个2 heads的例子,可以与上图进行对比,很明显可以明白差异。

ADD & Norm

这里Add就是矩阵的加法,Norm指的是Layer Norm,其主要是Batch Norm是在一个batch之间来正则化,而Layer Norm则只是考虑一个图上的正则化。具体可以阅读文献。

FFN

DETR中的FFN实质上就是FC+ReLu+FC这种形式

Spatial positional encoding

引入这个张量的原因是因为输入Transformer的张量被转换成了[c, HW],对于图像来说就失去了像素的空间分布信息,这不符合Transformer处理序列数据的初衷,那么就势必要引入位置编码。
这个张量作者做了两种尝试,不过由于实验效果基本一致,所以就采用了人工生成的方式。
一种是学习得到:

class PositionEmbeddingLearned(nn.Module):"""Absolute pos embedding, learned."""def __init__(self, num_pos_feats=256):super().__init__()self.row_embed = nn.Embedding(50, num_pos_feats)self.col_embed = nn.Embedding(50, num_pos_feats)self.reset_parameters()def reset_parameters(self):nn.init.uniform_(self.row_embed.weight)nn.init.uniform_(self.col_embed.weight)def forward(self, tensor_list: NestedTensor):x = tensor_list.tensorsh, w = x.shape[-2:]i = torch.arange(w, device=x.device)j = torch.arange(h, device=x.device)x_emb = self.col_embed(i)y_emb = self.row_embed(j)pos = torch.cat([x_emb.unsqueeze(0).repeat(h, 1, 1),y_emb.unsqueeze(1).repeat(1, w, 1),], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)return pos

第二种是人工生成:

class PositionEmbeddingSine(nn.Module):"""This is a more standard version of the position embedding, very similar to the oneused by the Attention is all you need paper, generalized to work on images."""def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):super().__init__()self.num_pos_feats = num_pos_featsself.temperature = temperatureself.normalize = normalizeif scale is not None and normalize is False:raise ValueError("normalize should be True if scale is passed")if scale is None:scale = 2 * math.piself.scale = scaledef forward(self, tensor_list: NestedTensor):x = tensor_list.tensorsmask = tensor_list.maskassert mask is not Nonenot_mask = ~masky_embed = not_mask.cumsum(1, dtype=torch.float32)x_embed = not_mask.cumsum(2, dtype=torch.float32)if self.normalize:eps = 1e-6y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scalex_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scaledim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)pos_x = x_embed[:, :, :, None] / dim_tpos_y = y_embed[:, :, :, None] / dim_tpos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)return pos

人工生成位置编码的方式还是延续了NLP中Transformer位置编码的生成方式,不同的是因为图像是三维[C, H, W],因此除了通道维之外DETR在H和W方向上分别进行了编码。核心公式表示为
P E p o s , 2 i = s i n ( p o s / 1000 0 2 i / d m o d e l ) PE_{pos,2i} = sin(pos/10000^{2i/d_{model}}) PEpos,2i=sin(pos/100002i/dmodel)
P E p o s , 2 i + 1 = c o s ( p o s / 1000 0 2 i / d m o d e l ) PE_{pos,2i+1} = cos(pos/10000^{2i/d_{model}}) PEpos,2i+1=cos(pos/100002i/dmodel)
至于说为什么这样子编码可以表示位置不同,可以参考这个博客

Object queries

这个张量是在解码过程中引入的,它的维度和输出的目标集合数量是一致的,可以大致理解为“向量表示的图像上的关注点”。DETR中是通过学习得到的,初始化代码如下所示

self.query_embed = nn.Embedding(num_queries, hidden_dim)

调用的时候

hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

Bipartite Matching

不同于常见的检测器,DETR没有使用NMS,通常来讲预测出来的目标集合为N,每个元素是(类别, 坐标)即(c, b)。而真值NGT的数目通常来讲每张图像上数目是不同的,这里就引入了几个问题:

  1. N的个数是一个超参,比如DETR设置的是100
  2. 论文中讲到NGT数目不等于N时就用(null,b)来填充
  3. N和NGT匹配过程使用的是匈牙利算法,代码中为了加快速度进行计算时NGT其实没有进行填充,N中元素没有成功与NGT配对的就被视作背景
  4. 不得不说构思非常巧妙,转化为了求取最小代价的过程,但是也不得不说这个是大力出奇迹,DETR训练的缓慢本人猜测与这个相关,等有环境可以运行代码的时候测试下

损失函数

看到网上很多分析DETR损失函数的,所以这里我就不介绍了,如果有需要再说吧。不过值得一提的是,DETR也用到了这个思想:从不同深度提取特征图进行损失计算,可以缩短损失反向传播到不同深度的路径,降低梯度消失造成的影响,加快网络收敛并在一定程度上提高精度。

暂时更新到这里…

DETR: End-to-End Object Detection with Transformers [暴力美学]相关推荐

  1. End-to-End Object Detection with Transformers[DETR]

    End-to-End Object Detection with Transformers[DETR] 背景 概述 相关技术 输入 提取特征 获取position_embedding transfor ...

  2. 论文解读:DETR 《End-to-end object detection with transformers》,ECCV 2020

    论文解读:DETR <End-to-end object detection with transformers>,ECCV 2020 0. 论文基本信息 1. 论文解决的问题 2. 论文 ...

  3. 论文阅读:DETR:End-to-End Object Detection with Transformers

    题目:End-to-End Object Detection with Transformers 来源:Facebook AI ECCV2020 论文链接:https://arxiv.org/abs/ ...

  4. End-to-End Object Detection with Transformers,DETR论文学习

    End-to-End Object Detection with Transformers,DETR论文学习 1. 引言 2. 本论文发表前的目标检测策略(非端到端的目标检测策略) 2.1 目标检测的 ...

  5. End-to-End Object Detection with Transformers的部分解读

    Transformer+Detection:引入视觉领域的首创DETR 也没有精力看原文了,直接看了博客: https://mp.weixin.qq.com/s?__biz=MzI5MDUyMDIxN ...

  6. 论文阅读:DEFORMABLE DETR:DEFORMABLE DETR: DEFORMABLE TRANSFORMERSFOR END-TO-END OBJECT DETECTION

    题目:DEFORMABLE DETR:DEFORMABLE DETR: DEFORMABLE TRANSFORMERSFOR END-TO-END OBJECT DETECTION 来源:ICLA 是 ...

  7. Deformable DETR: DEFORMABLE TRANSFORMERSFOR END-TO-END OBJECT DETECTION(论文阅读)

    Deformable DETR 是商汤Jifeng Dai 团队于2021年发表在ICLR 上的文章,是针对Detr 的改进. 论文:<DEFORMABLE DETR: DEFORMABLE T ...

  8. 【论文阅读】【3d目标检测】Group-Free 3D Object Detection via Transformers

    论文标题:Group-Free 3D Object Detection via Transformers iccv2021 本文主要是针对votenet等网络中采用手工group的问题提出的改进 我们 ...

  9. DETR:End-to-End Object Detection with Transformers

    最前面是论文翻译,中间是背景+问题+方法步骤+实验过程,最后是文中的部分专业名词介绍(水平线分开,能力有限,部分翻译可能不太准确) 使用Transformers进行端到端目标检测 摘要 我们提出了一种 ...

最新文章

  1. 【TensorFlow】理解tf.nn.conv2d方法 ( 附代码详解注释 )
  2. SQL中distinct的用法
  3. 转:字体集选择font-family
  4. 基于Unity的弹幕游戏多人联机尝试
  5. SQL优化之列裁剪和投影消除
  6. NOIP模拟测试5「星际旅行·砍树·超级树」
  7. git 提交代码到新的库,不保留原来的提交历史记录
  8. 计算机考试考什么二级,计算机国家二级考试会考什么内容?怎么考?
  9. 微信小程序 API-转发(Share)
  10. POJ 1053 Set Me G++
  11. 关于苹果开发者账号(appleid)的问题修复
  12. vue-element-admin之修改登录页面背景
  13. 原型工具axure7.0
  14. 自然语言处理3 -- 词性标注
  15. 移动支付服务Dwolla宣布10美元以下交易不收费
  16. 深入浅出Flask PIN
  17. 华师大 OJ 3040
  18. squirrelmail(小松鼠web邮件系统)
  19. 战地4网页怎么换服务器地址,战地4设置服务器地址
  20. 203、商城业务-商品详情-环境搭建

热门文章

  1. 【vxe-table】和【el-table】调整列(单元格)背景色,指定列背景色设置,两层或多层表头也适用
  2. 低压铜排、电缆接头温度监测的应用场景及解决方案有哪些?
  3. 【迅为iMX6Q】开发板:擦灰后再次开箱上电
  4. 国民技术对比STM移植参考------N32G45X系列对比STM32F10X
  5. C语言打印心图案----真好玩
  6. c++const指针与函数调用
  7. C++ const对象与非const对象的相互调用、const成员函数与非const成员函数的相互调用
  8. 【wp】ZJCTF-我的方便面没有调料包战队
  9. [230516] TPO71 | 2022年托福阅读真题第4/36篇 | Electrical Energy from the Ocean | 11:50
  10. Java知识体系最强总结(2020版)