【DETR源码解析】二、Backbone模块
目录
- 前言
- 一、Backbone整体结构
- 一、CNN-Backbone
- 二、Positional Encoding
- Reference
前言
最近在看DETR的源码,断断续续看了一星期左右,把主要的模型代码理清了。一直在考虑以什么样的形式写一写DETR的源码解析。考虑的一种形式是像之前写的YOLOv5那样的按文件逐行写,一种是想把源码按功能模块串起来。考虑了很久还是决定按第二种方式,一是因为这种方式可能会更省时间,另外就是也方便我整体再理解一下吧。
我觉得看代码就是要看到能把整个模型分功能拆开,最后再把所有模块串起来,这样才能达到事半功倍。
另外一点我觉得很重要的是:拿到一个开源项目代码,要有马上配置环境能够正常运行Debug,并且通过解析train.py马上找到主要模型相关的内容,然后着重关注模型方面的解析,像一些日志、计算mAP、画图等等代码,完全可以不看,可以省很多时间,所以以后我讲解源码都会把无关的代码完全剥离,不再讲解,全部精力关注模型、改进、损失等内容。
这一节主要讲一下DETR的Backbone部分,包括CNN和位置编码两个模块的代码。主要涉及models/backbone.py和models/position_encoding.py两个文件。
Github注释版源码:HuKai97/detr-annotations
一、Backbone整体结构
整个Backbone主要包括CNN特征提取和位置编码两个部分。代码还是比较简单的,下面开始解析源码。
首先是调用models/Backbone.py中的build_backbone函数创建Backbone:
def build_backbone(args):# 搭建backbone# 位置编码 PositionEmbeddingSine()position_embedding = build_position_encoding(args)train_backbone = args.lr_backbone > 0 # 是否需要训练backbone Truereturn_interm_layers = args.masks # 是否需要返回中间层结果 目标检测False 分割True# 生成backbone resnet50backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)# 将backbone输出与位置编码相加 0: backbone 1: PositionEmbeddingSine()model = Joiner(backbone, position_embedding)model.num_channels = backbone.num_channels # 512return model
这里首先调用build_position_encoding函数生成正余弦位置编码position_embedding:[bs,256,H/32, W/32],其中256前128是y方向位置编码,后128是x方向位置编码;再调用Backbone类生成ResNet50对输入数据进行特征提取得到特征图[bs,2048,H/32, W/32]。最后Joiner将两者合并存储起来,方便后续使用。
一、CNN-Backbone
创建ResNet50,先调用Backbone类:
class Backbone(BackboneBase):"""ResNet backbone with frozen BatchNorm."""def __init__(self, name: str,train_backbone: bool,return_interm_layers: bool,dilation: bool):# 直接掉包 调用torchvision.models中的backbonebackbone = getattr(torchvision.models, name)(replace_stride_with_dilation=[False, False, dilation],pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)# resnet50 2048num_channels = 512 if name in ('resnet18', 'resnet34') else 2048super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
这个类是继承自BackboneBase类的,而且CNN直接调用的就是torchvision.models中的模型,所以直接看BackboneBase类:
class BackboneBase(nn.Module):def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):super().__init__()for name, parameter in backbone.named_parameters():# layer0 layer1不需要训练 因为前面层提取的信息其实很有限 都是差不多的 不需要训练if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:parameter.requires_grad_(False)# False 检测任务不需要返回中间层if return_interm_layers:return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}else:return_layers = {'layer4': "0"}# 检测任务直接返回layer4即可 执行torchvision.models._utils.IntermediateLayerGetter这个函数可以直接返回对应层的输出结果self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)self.num_channels = num_channelsdef forward(self, tensor_list: NestedTensor):"""tensor_list: pad预处理之后的图像信息tensor_list.tensors: [bs, 3, 608, 810]预处理后的图片数据 对于小图片而言多余部分用0填充tensor_list.mask: [bs, 608, 810] 用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)"""# 取出预处理后的图片数据 [bs, 3, 608, 810] 输入模型中 输出layer4的输出结果 dict '0'=[bs, 2048, 19, 26]xs = self.body(tensor_list.tensors)# 保存输出数据out: Dict[str, NestedTensor] = {}for name, x in xs.items():m = tensor_list.mask # 取出图片的mask [bs, 608, 810] 知道图片哪些区域是有效的 哪些位置是pad之后的无效的assert m is not None# 通过插值函数知道卷积后的特征的mask 知道卷积后的特征哪些是有效的 哪些是无效的# 因为之前图片输入网络是整个图片都卷积计算的 生成的新特征其中有很多区域都是无效的mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]# out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]out[name] = NestedTensor(x, mask)# out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]return out
这个类还是在调用torchvision.models中的模型,然后再把预处理后的图片数据[bs, 3, 608, 810]和mask数据[bs, 608, 810]输入到模型中(这个图片数据是经过pad填充的数据,而mask数据就是记录这些图片哪些像素位置是pad的,为True,没用pad的真实有效数据就为False)。经过前向传播,再调用IntermediateLayerGetter函数把对应层特征图提取出来,得到原图32倍下采样的特征图[bs, 2048, 19, 26],以及这张特征图对应的mask[bs, 19, 26]。
二、Positional Encoding
Positional Encoding 就是位置编码。这里主要是调用models/position_encoding.py中的build_position_encoding函数创建位置编码:
def build_position_encoding(args):"""创建位置编码args: 一系列参数 args.hidden_dim: transformer中隐藏层的维度 args.position_embedding: 位置编码类型 正余弦sine or 可学习learned"""# N_steps = 128 = 256 // 2 backbone输出[bs,256,25,34] 256维度的特征# 而传统的位置编码应该也是256维度的, 但是detr用的是一个x方向和y方向的位置编码concat的位置编码方式 这里和ViT有所不同# 二维位置编码 前128维代表x方向位置编码 后128维代表y方向位置编码N_steps = args.hidden_dim // 2if args.position_embedding in ('v2', 'sine'):# TODO find a better way of exposing other arguments# [bs,256,19,26] dim=1时 前128个是y方向位置编码 后128个是x方向位置编码position_embedding = PositionEmbeddingSine(N_steps, normalize=True)elif args.position_embedding in ('v3', 'learned'):position_embedding = PositionEmbeddingLearned(N_steps)else:raise ValueError(f"not supported{args.position_embedding}")return position_embedding
可以看到,源码是实现了两种位置编码,一种是正余弦绝对位置编码,不需要额外的参数学习,另一种是可学习绝对位置编码。原论文用的是正余弦绝对位置编码,而且代码也是默认使用这个的,所以这里主要介绍PositionEmbeddingSine类:
class PositionEmbeddingSine(nn.Module):"""Absolute pos embedding, Sine. 没用可学习参数 不可学习 定义好了就固定了This is a more standard version of the position embedding, very similar to the oneused by the Attention is all you need paper, generalized to work on images."""def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):super().__init__()self.num_pos_feats = num_pos_feats # 128维度 x/y = d_model/2self.temperature = temperature # 常数 正余弦位置编码公式里面的10000self.normalize = normalize # 是否对向量进行max规范化 Trueif scale is not None and normalize is False:raise ValueError("normalize should be True if scale is passed")if scale is None:# 这里之所以规范化到2*pi 因为位置编码函数的周期是[2pi, 20000pi]scale = 2 * math.pi # 规范化参数 2*piself.scale = scaledef forward(self, tensor_list: NestedTensor):x = tensor_list.tensors # [bs, 2048, 19, 26] 预处理后的 经过backbone 32倍下采样之后的数据 对于小图片而言多余部分用0填充mask = tensor_list.mask # [bs, 19, 26] 用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)assert mask is not Nonenot_mask = ~mask # True的位置才是真实有效的位置# 考虑到图像本身是2维的 所以这里使用的是2维的正余弦位置编码# 这样各行/列都映射到不同的值 当然有效位置是正常值 无效位置会有重复值 但是后续计算注意力权重会忽略这部分的# 而且最后一个数字就是有效位置的总和,方便max规范化# 计算此时y方向上的坐标 [bs, 19, 26]y_embed = not_mask.cumsum(1, dtype=torch.float32)# 计算此时x方向的坐标 [bs, 19, 26]x_embed = not_mask.cumsum(2, dtype=torch.float32)# 最大值规范化 除以最大值 再乘以2*pi 最终把坐标规范化到0-2pi之间if self.normalize:eps = 1e-6y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scalex_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scaledim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) # 0 1 2 .. 127# 2i/2i+1: 2 * (dim_t // 2) self.temperature=10000 self.num_pos_feats = d/2dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) # 分母pos_x = x_embed[:, :, :, None] / dim_t # 正余弦括号里面的公式pos_y = y_embed[:, :, :, None] / dim_t # 正余弦括号里面的公式# x方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)# y方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)# concat: [bs,19,26,128][bs,19,26,128] -> [bs,19,26,256] -> [bs,256,19,26]pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)# [bs,256,19,26] dim=1时 前128个是y方向位置编码 后128个是x方向位置编码return pos
对照公式:
我的几个关键点的理解:
- 这里是通过mask来构建位置编码的,mask中记录了特征图中每个像素位置是否是pad的,只有在为False的位置,才是有效的位置,才需要构建位置编码;
- 关于最大值规范化 :因为正余弦编码方式,思想就是将各个位置的通过公式映射到 0~2Π 这个范围内(也可以是4Π,6Π,8Π…,因为它是一个周期函数,不过我们一般默认为2Π),所以这里在带入公式之前需要对x_embed、y_embed先进行规范化;
- 关于位置编码方式:这里之所以是把x和y分别进行位置编码(二维位置编码),而不是像transformer那样的一维位置编码。主要考虑的是transformer是应用在语言模型中的,天然就是一维的,所以一维可能更适合,而DETR是应用在图像任务中的一个目标检测框架,在图像任务中,当然二维位置编码效果可能会更好点;
- 这样,对于每个位置(x,y),其所在列对应的编码值就在通道维度的前128维,其所在行的编码值就在通道这个维度的后128维。这样这个特征图上各个位置就都对应到不同的维度的编码值了。
当然作为学习,也可以看看第二种绝对位置编码方式:可学习位置编码:
class PositionEmbeddingLearned(nn.Module):"""Absolute pos embedding, learned.可以发现整个类其实就是初始化了相应shape的位置编码参数,让后通过可学习的方式学习这些位置编码参数"""def __init__(self, num_pos_feats=256):super().__init__()# nn.Embedding 相当于 nn.Parameter 其实就是初始化函数self.row_embed = nn.Embedding(50, num_pos_feats)self.col_embed = nn.Embedding(50, num_pos_feats)self.reset_parameters()def reset_parameters(self):nn.init.uniform_(self.row_embed.weight)nn.init.uniform_(self.col_embed.weight)def forward(self, tensor_list: NestedTensor):x = tensor_list.tensorsh, w = x.shape[-2:] # 特征图h wi = torch.arange(w, device=x.device)j = torch.arange(h, device=x.device)x_emb = self.col_embed(i) # 初始化x方向位置编码y_emb = self.row_embed(j) # 初始化y方向位置编码# concat x y 方向位置编码pos = torch.cat([x_emb.unsqueeze(0).repeat(h, 1, 1),y_emb.unsqueeze(1).repeat(1, w, 1),], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)return pos
可以发现整个类其实就是初始化了相应shape的位置编码参数,然后通过可学习的方式自己学习这些位置编码参数,代码比较简答。
Reference
官方源码: https://github.com/facebookresearch/detr
b站源码讲解: 铁打的流水线工人
知乎【布尔佛洛哥哥】: DETR 源码解读
CSDN【在努力的松鼠】源码讲解: DETR源码笔记(一)
CSDN【在努力的松鼠】源码讲解: DETR源码笔记(二)
CSDN: Transformer中的position encoding(位置编码一)
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(一)、概述与模型推断】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(三)、Backbone与位置编码】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(四)、Detection with Transformer】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(五)、loss函数与匈牙利匹配算法】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(六)、模型输出与预测生成】
【DETR源码解析】二、Backbone模块相关推荐
- Android xUtils3源码解析之图片模块
本文已授权微信公众号<非著名程序员>原创首发,转载请务必注明出处. xUtils3源码解析系列 一. Android xUtils3源码解析之网络模块 二. Android xUtils3 ...
- Android xUtils3源码解析之注解模块
本文已授权微信公众号<非著名程序员>原创首发,转载请务必注明出处. xUtils3源码解析系列 一. Android xUtils3源码解析之网络模块 二. Android xUtils3 ...
- Android xUtils3源码解析之数据库模块
本文已授权微信公众号<非著名程序员>原创首发,转载请务必注明出处. xUtils3源码解析系列 一. Android xUtils3源码解析之网络模块 二. Android xUtils3 ...
- 【深度学习模型】智云视图中文车牌识别源码解析(二)
[深度学习模型]智云视图中文车牌识别源码解析(二) 感受 HyperLPR可以识别多种中文车牌包括白牌,新能源车牌,使馆车牌,教练车牌,武警车牌等. 代码不可谓不混乱(别忘了这是职业公司的准产品级代码 ...
- 【DETR源码解析】三、Transformer模块
目录 前言 一.Transformer整体结构 二.TransformerEncoder 2.1.TransformerEncoderLayer 三.TransformerDecoder 3.1.Tr ...
- erlang下lists模块sort(排序)方法源码解析(二)
上接erlang下lists模块sort(排序)方法源码解析(一),到目前为止,list列表已经被分割成N个列表,而且每个列表的元素是有序的(从大到小) 下面我们重点来看看mergel和rmergel ...
- Kubernetes学习笔记之Calico CNI Plugin源码解析(二)
女主宣言 今天小编继续为大家分享Kubernetes Calico CNI Plugin学习笔记,希望能对大家有所帮助. PS:丰富的一线技术.多元化的表现形式,尽在"360云计算" ...
- Mobx 源码解析 二(autorun)
前言 我们在Mobx 源码解析 一(observable)已经知道了observable 做的事情了, 但是我们的还是没有讲解明白在我们的Demo中,我们在Button 的Click 事件中只是对ba ...
- android网络框架retrofit源码解析二
注:源码解析文章参考了该博客:http://www.2cto.com/kf/201405/305248.html 前一篇文章讲解了retrofit的annotation,既然定义了,那么就应该有解析的 ...
最新文章
- Flask-admin 使用总结
- 趴在门口的云计算,盯上了屋内狂奔的CDN
- C++实现静态顺序表的增删查改以及初始化
- 前端学习(3238):react生命周期4
- 大家身边极度聪明的人是什么样子?
- python如何写代码_如何写出优雅的Python代码?
- 【转】C51中断函数的写法
- win11快捷键怎么使用 Windows11快捷键的使用方法
- 从零开始部署基于阿里容器云的微服务(consul+registrator+template)(一)
- iOS开发UI篇—控制器的创建
- 「手把手带你学算法」本周小结!(动态规划系列七)
- 11.GitLab webhooks
- RAID磁盘阵列配置
- VMWare虚拟机Ubantu20.10添加中文智能拼音输入法
- selenium不定位元素直接操作键盘之Keys.CONTROL
- 《跟小智一起学网络》教程目录
- C++读取通达信shm.tnf文件股票代码/名称
- html页面虚化,css实现背景虚化效果的示例代码
- iframe标签全屏
- 求税后收入及个人所得税
热门文章
- ArcGIS的ArcToolbox执行任务时没反应或图层上有小锁的解决方法
- 学习了编程之后,是不是就可以进行APP开发了?
- 前端 “一键换肤“ 的几种方案
- 挑战杯、互联网+大学生创新创业大赛项目计划书《多功能智能化无人机》
- 哔哩哔哩弹幕html,哔哩哔哩bilibili默认关闭弹幕
- MTK 平台sensor arch 介绍-hal
- 用java实现凯撒加密系统,JAVA如何实现caesar凯撒加密算法
- @WebFilter怎么控制多个filter的执行顺序
- 张飞老师硬件第十五部视频整理——硬件基础1-2-3
- 什么是BFC,如何触发BFC,BFC的作用