目录

  • 前言
  • 一、Backbone整体结构
  • 一、CNN-Backbone
  • 二、Positional Encoding
  • Reference

前言

最近在看DETR的源码,断断续续看了一星期左右,把主要的模型代码理清了。一直在考虑以什么样的形式写一写DETR的源码解析。考虑的一种形式是像之前写的YOLOv5那样的按文件逐行写,一种是想把源码按功能模块串起来。考虑了很久还是决定按第二种方式,一是因为这种方式可能会更省时间,另外就是也方便我整体再理解一下吧。

我觉得看代码就是要看到能把整个模型分功能拆开,最后再把所有模块串起来,这样才能达到事半功倍。

另外一点我觉得很重要的是:拿到一个开源项目代码,要有马上配置环境能够正常运行Debug,并且通过解析train.py马上找到主要模型相关的内容,然后着重关注模型方面的解析,像一些日志、计算mAP、画图等等代码,完全可以不看,可以省很多时间,所以以后我讲解源码都会把无关的代码完全剥离,不再讲解,全部精力关注模型、改进、损失等内容。

这一节主要讲一下DETR的Backbone部分,包括CNN和位置编码两个模块的代码。主要涉及models/backbone.py和models/position_encoding.py两个文件。

Github注释版源码:HuKai97/detr-annotations

一、Backbone整体结构

整个Backbone主要包括CNN特征提取和位置编码两个部分。代码还是比较简单的,下面开始解析源码。

首先是调用models/Backbone.py中的build_backbone函数创建Backbone:

def build_backbone(args):# 搭建backbone# 位置编码  PositionEmbeddingSine()position_embedding = build_position_encoding(args)train_backbone = args.lr_backbone > 0   # 是否需要训练backbone  Truereturn_interm_layers = args.masks       # 是否需要返回中间层结果 目标检测False  分割True# 生成backbone  resnet50backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)# 将backbone输出与位置编码相加   0: backbone   1: PositionEmbeddingSine()model = Joiner(backbone, position_embedding)model.num_channels = backbone.num_channels   # 512return model

这里首先调用build_position_encoding函数生成正余弦位置编码position_embedding:[bs,256,H/32, W/32],其中256前128是y方向位置编码,后128是x方向位置编码;再调用Backbone类生成ResNet50对输入数据进行特征提取得到特征图[bs,2048,H/32, W/32]。最后Joiner将两者合并存储起来,方便后续使用。

一、CNN-Backbone

创建ResNet50,先调用Backbone类:

class Backbone(BackboneBase):"""ResNet backbone with frozen BatchNorm."""def __init__(self, name: str,train_backbone: bool,return_interm_layers: bool,dilation: bool):# 直接掉包 调用torchvision.models中的backbonebackbone = getattr(torchvision.models, name)(replace_stride_with_dilation=[False, False, dilation],pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)# resnet50  2048num_channels = 512 if name in ('resnet18', 'resnet34') else 2048super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

这个类是继承自BackboneBase类的,而且CNN直接调用的就是torchvision.models中的模型,所以直接看BackboneBase类:

class BackboneBase(nn.Module):def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):super().__init__()for name, parameter in backbone.named_parameters():# layer0 layer1不需要训练 因为前面层提取的信息其实很有限 都是差不多的 不需要训练if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:parameter.requires_grad_(False)# False 检测任务不需要返回中间层if return_interm_layers:return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}else:return_layers = {'layer4': "0"}# 检测任务直接返回layer4即可  执行torchvision.models._utils.IntermediateLayerGetter这个函数可以直接返回对应层的输出结果self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)self.num_channels = num_channelsdef forward(self, tensor_list: NestedTensor):"""tensor_list: pad预处理之后的图像信息tensor_list.tensors: [bs, 3, 608, 810]预处理后的图片数据 对于小图片而言多余部分用0填充tensor_list.mask: [bs, 608, 810] 用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)"""# 取出预处理后的图片数据 [bs, 3, 608, 810] 输入模型中  输出layer4的输出结果 dict '0'=[bs, 2048, 19, 26]xs = self.body(tensor_list.tensors)# 保存输出数据out: Dict[str, NestedTensor] = {}for name, x in xs.items():m = tensor_list.mask  # 取出图片的mask [bs, 608, 810] 知道图片哪些区域是有效的 哪些位置是pad之后的无效的assert m is not None# 通过插值函数知道卷积后的特征的mask  知道卷积后的特征哪些是有效的  哪些是无效的# 因为之前图片输入网络是整个图片都卷积计算的 生成的新特征其中有很多区域都是无效的mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]# out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]out[name] = NestedTensor(x, mask)# out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]return out

这个类还是在调用torchvision.models中的模型,然后再把预处理后的图片数据[bs, 3, 608, 810]和mask数据[bs, 608, 810]输入到模型中(这个图片数据是经过pad填充的数据,而mask数据就是记录这些图片哪些像素位置是pad的,为True,没用pad的真实有效数据就为False)。经过前向传播,再调用IntermediateLayerGetter函数把对应层特征图提取出来,得到原图32倍下采样的特征图[bs, 2048, 19, 26],以及这张特征图对应的mask[bs, 19, 26]。

二、Positional Encoding

Positional Encoding 就是位置编码。这里主要是调用models/position_encoding.py中的build_position_encoding函数创建位置编码:

def build_position_encoding(args):"""创建位置编码args: 一系列参数  args.hidden_dim: transformer中隐藏层的维度   args.position_embedding: 位置编码类型 正余弦sine or 可学习learned"""# N_steps = 128 = 256 // 2  backbone输出[bs,256,25,34]  256维度的特征# 而传统的位置编码应该也是256维度的, 但是detr用的是一个x方向和y方向的位置编码concat的位置编码方式  这里和ViT有所不同# 二维位置编码   前128维代表x方向位置编码  后128维代表y方向位置编码N_steps = args.hidden_dim // 2if args.position_embedding in ('v2', 'sine'):# TODO find a better way of exposing other arguments# [bs,256,19,26]  dim=1时  前128个是y方向位置编码  后128个是x方向位置编码position_embedding = PositionEmbeddingSine(N_steps, normalize=True)elif args.position_embedding in ('v3', 'learned'):position_embedding = PositionEmbeddingLearned(N_steps)else:raise ValueError(f"not supported{args.position_embedding}")return position_embedding

可以看到,源码是实现了两种位置编码,一种是正余弦绝对位置编码,不需要额外的参数学习,另一种是可学习绝对位置编码。原论文用的是正余弦绝对位置编码,而且代码也是默认使用这个的,所以这里主要介绍PositionEmbeddingSine类:

class PositionEmbeddingSine(nn.Module):"""Absolute pos embedding, Sine.  没用可学习参数  不可学习  定义好了就固定了This is a more standard version of the position embedding, very similar to the oneused by the Attention is all you need paper, generalized to work on images."""def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):super().__init__()self.num_pos_feats = num_pos_feats    # 128维度  x/y  = d_model/2self.temperature = temperature        # 常数 正余弦位置编码公式里面的10000self.normalize = normalize            # 是否对向量进行max规范化   Trueif scale is not None and normalize is False:raise ValueError("normalize should be True if scale is passed")if scale is None:# 这里之所以规范化到2*pi  因为位置编码函数的周期是[2pi, 20000pi]scale = 2 * math.pi  # 规范化参数 2*piself.scale = scaledef forward(self, tensor_list: NestedTensor):x = tensor_list.tensors   # [bs, 2048, 19, 26]  预处理后的 经过backbone 32倍下采样之后的数据  对于小图片而言多余部分用0填充mask = tensor_list.mask   # [bs, 19, 26]  用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)assert mask is not Nonenot_mask = ~mask   # True的位置才是真实有效的位置# 考虑到图像本身是2维的 所以这里使用的是2维的正余弦位置编码# 这样各行/列都映射到不同的值 当然有效位置是正常值 无效位置会有重复值 但是后续计算注意力权重会忽略这部分的# 而且最后一个数字就是有效位置的总和,方便max规范化# 计算此时y方向上的坐标  [bs, 19, 26]y_embed = not_mask.cumsum(1, dtype=torch.float32)# 计算此时x方向的坐标    [bs, 19, 26]x_embed = not_mask.cumsum(2, dtype=torch.float32)# 最大值规范化 除以最大值 再乘以2*pi 最终把坐标规范化到0-2pi之间if self.normalize:eps = 1e-6y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scalex_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scaledim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)   # 0 1 2 .. 127# 2i/2i+1: 2 * (dim_t // 2)  self.temperature=10000   self.num_pos_feats = d/2dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)   # 分母pos_x = x_embed[:, :, :, None] / dim_t   # 正余弦括号里面的公式pos_y = y_embed[:, :, :, None] / dim_t   # 正余弦括号里面的公式# x方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)# y方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)# concat: [bs,19,26,128][bs,19,26,128] -> [bs,19,26,256] -> [bs,256,19,26]pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)# [bs,256,19,26]  dim=1时  前128个是y方向位置编码  后128个是x方向位置编码return pos

对照公式:

我的几个关键点的理解:

  1. 这里是通过mask来构建位置编码的,mask中记录了特征图中每个像素位置是否是pad的,只有在为False的位置,才是有效的位置,才需要构建位置编码;
  2. 关于最大值规范化 :因为正余弦编码方式,思想就是将各个位置的通过公式映射到 0~2Π 这个范围内(也可以是4Π,6Π,8Π…,因为它是一个周期函数,不过我们一般默认为2Π),所以这里在带入公式之前需要对x_embed、y_embed先进行规范化;
  3. 关于位置编码方式:这里之所以是把x和y分别进行位置编码(二维位置编码),而不是像transformer那样的一维位置编码。主要考虑的是transformer是应用在语言模型中的,天然就是一维的,所以一维可能更适合,而DETR是应用在图像任务中的一个目标检测框架,在图像任务中,当然二维位置编码效果可能会更好点;
  4. 这样,对于每个位置(x,y),其所在列对应的编码值就在通道维度的前128维,其所在行的编码值就在通道这个维度的后128维。这样这个特征图上各个位置就都对应到不同的维度的编码值了。

当然作为学习,也可以看看第二种绝对位置编码方式:可学习位置编码:

class PositionEmbeddingLearned(nn.Module):"""Absolute pos embedding, learned.可以发现整个类其实就是初始化了相应shape的位置编码参数,让后通过可学习的方式学习这些位置编码参数"""def __init__(self, num_pos_feats=256):super().__init__()# nn.Embedding  相当于 nn.Parameter  其实就是初始化函数self.row_embed = nn.Embedding(50, num_pos_feats)self.col_embed = nn.Embedding(50, num_pos_feats)self.reset_parameters()def reset_parameters(self):nn.init.uniform_(self.row_embed.weight)nn.init.uniform_(self.col_embed.weight)def forward(self, tensor_list: NestedTensor):x = tensor_list.tensorsh, w = x.shape[-2:]   # 特征图h wi = torch.arange(w, device=x.device)j = torch.arange(h, device=x.device)x_emb = self.col_embed(i)   # 初始化x方向位置编码y_emb = self.row_embed(j)   # 初始化y方向位置编码# concat x y 方向位置编码pos = torch.cat([x_emb.unsqueeze(0).repeat(h, 1, 1),y_emb.unsqueeze(1).repeat(1, w, 1),], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)return pos

可以发现整个类其实就是初始化了相应shape的位置编码参数,然后通过可学习的方式自己学习这些位置编码参数,代码比较简答。

Reference

官方源码: https://github.com/facebookresearch/detr

b站源码讲解: 铁打的流水线工人

知乎【布尔佛洛哥哥】: DETR 源码解读

CSDN【在努力的松鼠】源码讲解: DETR源码笔记(一)

CSDN【在努力的松鼠】源码讲解: DETR源码笔记(二)

CSDN: Transformer中的position encoding(位置编码一)

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(一)、概述与模型推断】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(三)、Backbone与位置编码】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(四)、Detection with Transformer】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(五)、loss函数与匈牙利匹配算法】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(六)、模型输出与预测生成】

【DETR源码解析】二、Backbone模块相关推荐

  1. Android xUtils3源码解析之图片模块

    本文已授权微信公众号<非著名程序员>原创首发,转载请务必注明出处. xUtils3源码解析系列 一. Android xUtils3源码解析之网络模块 二. Android xUtils3 ...

  2. Android xUtils3源码解析之注解模块

    本文已授权微信公众号<非著名程序员>原创首发,转载请务必注明出处. xUtils3源码解析系列 一. Android xUtils3源码解析之网络模块 二. Android xUtils3 ...

  3. Android xUtils3源码解析之数据库模块

    本文已授权微信公众号<非著名程序员>原创首发,转载请务必注明出处. xUtils3源码解析系列 一. Android xUtils3源码解析之网络模块 二. Android xUtils3 ...

  4. 【深度学习模型】智云视图中文车牌识别源码解析(二)

    [深度学习模型]智云视图中文车牌识别源码解析(二) 感受 HyperLPR可以识别多种中文车牌包括白牌,新能源车牌,使馆车牌,教练车牌,武警车牌等. 代码不可谓不混乱(别忘了这是职业公司的准产品级代码 ...

  5. 【DETR源码解析】三、Transformer模块

    目录 前言 一.Transformer整体结构 二.TransformerEncoder 2.1.TransformerEncoderLayer 三.TransformerDecoder 3.1.Tr ...

  6. erlang下lists模块sort(排序)方法源码解析(二)

    上接erlang下lists模块sort(排序)方法源码解析(一),到目前为止,list列表已经被分割成N个列表,而且每个列表的元素是有序的(从大到小) 下面我们重点来看看mergel和rmergel ...

  7. Kubernetes学习笔记之Calico CNI Plugin源码解析(二)

    女主宣言 今天小编继续为大家分享Kubernetes Calico CNI Plugin学习笔记,希望能对大家有所帮助. PS:丰富的一线技术.多元化的表现形式,尽在"360云计算" ...

  8. Mobx 源码解析 二(autorun)

    前言 我们在Mobx 源码解析 一(observable)已经知道了observable 做的事情了, 但是我们的还是没有讲解明白在我们的Demo中,我们在Button 的Click 事件中只是对ba ...

  9. android网络框架retrofit源码解析二

    注:源码解析文章参考了该博客:http://www.2cto.com/kf/201405/305248.html 前一篇文章讲解了retrofit的annotation,既然定义了,那么就应该有解析的 ...

最新文章

  1. Flask-admin 使用总结
  2. 趴在门口的云计算,盯上了屋内狂奔的CDN
  3. C++实现静态顺序表的增删查改以及初始化
  4. 前端学习(3238):react生命周期4
  5. 大家身边极度聪明的人是什么样子?
  6. python如何写代码_如何写出优雅的Python代码?
  7. 【转】C51中断函数的写法
  8. win11快捷键怎么使用 Windows11快捷键的使用方法
  9. 从零开始部署基于阿里容器云的微服务(consul+registrator+template)(一)
  10. iOS开发UI篇—控制器的创建
  11. 「手把手带你学算法」本周小结!(动态规划系列七)
  12. 11.GitLab webhooks
  13. RAID磁盘阵列配置
  14. VMWare虚拟机Ubantu20.10添加中文智能拼音输入法
  15. selenium不定位元素直接操作键盘之Keys.CONTROL
  16. 《跟小智一起学网络》教程目录
  17. C++读取通达信shm.tnf文件股票代码/名称
  18. html页面虚化,css实现背景虚化效果的示例代码
  19. iframe标签全屏
  20. 求税后收入及个人所得税

热门文章

  1. ArcGIS的ArcToolbox执行任务时没反应或图层上有小锁的解决方法
  2. 学习了编程之后,是不是就可以进行APP开发了?
  3. 前端 “一键换肤“ 的几种方案
  4. 挑战杯、互联网+大学生创新创业大赛项目计划书《多功能智能化无人机》
  5. 哔哩哔哩弹幕html,哔哩哔哩bilibili默认关闭弹幕
  6. MTK 平台sensor arch 介绍-hal
  7. 用java实现凯撒加密系统,JAVA如何实现caesar凯撒加密算法
  8. @WebFilter怎么控制多个filter的执行顺序
  9. 张飞老师硬件第十五部视频整理——硬件基础1-2-3
  10. 什么是BFC,如何触发BFC,BFC的作用