【MOT】多目标追踪学习笔记之MOTR
1、前言
由于transformer的成功,现已经在多目标跟踪领域(MOTR)广泛应用,例如TransTrack和TrackFormer,但这两个工作严格来说不能算作端到端的模型,而MOTR的出现弥补了上面两个工作的缺点。本文将对MOTR的网络结构简单说明,然后对推理代码进行解析,如有不对的地方,还请各位大佬指出。
paper : https://arxiv.org/abs/2105.03247
repo :https://github.com/megvii-research/MOTR
可以加v:Rex1586662742或者q群:468713665一起讨论
学习链接:MOTR、MOTR
MOTR中涉及到的Deformerable DRET 可参考:【BEV】学习笔记之 DeformableDETR(原理+代码解析)
2、网络结构
下面根据论文中的图示来分步解析MOTR的网络结构
如上图左边为DETR的解码过程,利用Object queries和Image featue进行Decoder获得目标框。MOTR中利用了这种方法,利用多帧的特征与Track queries进行Decoder从而获得被跟踪的目标, 从代码里面发现这里的track queries其实是包括300个初始化的 detect queries + n个track queries(下文中假定每一帧中有n个track queries),得到了当前帧的目标后,即更新track queries。
针对上图(b)中的Iterative Updata过程,MOTR论文中提供了下图进行说明。
左右两边是一个对应的过程,在t1时刻,首先初始化300个detect queries ,并检测到了两个目标,即Object1和Object2,由于者两个目标在之前从未出现过,因此将其划分为track queries。在t2时刻,再次初始化300个object queries 并检测到了新的目标Object3,由于Object1和object2已经被跟踪上,因此不会再次被检测到。于是在t3时刻得到了3个track queries,在t4时刻,由于Obeject2消失,因此,track queries相应的也要将object3对应的object query删除。
下面将对track query的加入与删除模块进行说明,文中将这个模块命名为 query interaction module(QIM)。
在上图中,通过预测可以得到当前帧的 detect query 和 track query。1、detect 进入到Object Enterance模块, 从detect queries筛选出符合条件的 query作为新的track query。2、track query进入到Object Exit中,筛选出目标消失的track queries,然后在track queries中做self_attn,最终将detect queries以及track queries筛选后的query作为当前帧的track queries。
通过上面几个步骤,最终的网络结构如下图所示:
3、代码解析
在代码开始之前需要对Instances类进行说明,后续的代码中track_instances是 Instances实例化的对象,用于储存每帧中的 track queries,以及每个track queries对应的类别、boxes等信息。
3.1、models/structures/instances.py
class Instances:"""This class represents a list of instances in an image.It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".All fields must have the same ``__len__`` which is the number of instances.储存每一帧的检测以及跟踪信息,包含detection query + track query"""def __init__(...):self._fields: Dict[str, Any] = {}def set(self, name: str, value: Any) -> None:"""Set the field named `name` to `value`.储存检测结果"""self._fields[name] = value# 由于类初始的时候没有定义成员,后面给类添加成员变量的时候会默认调用这个函数def __setattr__(self, name: str, val: Any) -> None:if name.startswith("_"):super().__setattr__(name, val)else:self.set(name, val) # 将值存到 self._fields中def cat(..)"""合并两个 Instances"""...
3.2、demo.py
class Detector(...):def __init__(...):self.tr_tracker = MOTR()def run(...):for _, cur_img, ori_img in tqdm(self.dataloader):...res = self.model.inference_single_image(...) # -> models/motr.pytrack_instances = res['track_instances'] # 当前帧的 track_instances# 筛选出track_instances 得分大于阈值的 track querydt_instances = self.filter_dt_by_score(dt_instances, prob_threshold)# 筛选出 dt_instances 中 box面积大于阈值的track querydt_instances = self.filter_dt_by_area(dt_instances, area_threshold)if dump: tracker_outputs = self.tr_tracker.update(dt_instances) # [n, 6] 当前帧的 阈值与面积超过一定阈值的 [x1,y1,x2,y2,score,id]
3.3、models/motr.py
class MOTR(nn.Module):def __init__(...):self.post_process = TrackerPostProcess() # 将预测结果恢复到原图尺寸self.track_base = RuntimeTrackerBase() # 进入到QIM模块之前给track queries进行标记def _generate_empty_tracks(self):...def inference_single_image(...):if track_instances is None:track_instances = self._generate_empty_tracks() # 在第0帧初始化300个 detection queries# track_instances:上一帧的跟踪结果 len(track_instances) = 300 + n 300为空的detect queries的个数, n为track queries的个数# 利用上一帧的track qeries 和当前帧的图片特征进行decoder,来预测当前帧的检测结果,融合了时序信息。res = self._forward_single_image(img,track_instances=track_instances) # res:{"pred_logits":"pred_boxes","","hs":"",...}当前帧的目标检测结果res = self._post_process_single_image(res, track_instances, False) # 本帧检测到的目标,与上一帧的跟踪目标进行匹配track_instances = self.post_process(track_instances, ori_img_size) # 将当前帧的跟踪结果缩放到图片尺寸大小return ... # -> demo.pydef _forward_single_image(...):features, pos = self.backbone(samples) # 提取图片特征,参考deformable detr# track_instances.query_pos:[300 + n, 512]# hs 每个decode的中间结果 len(hs)==6# inter_references 每个decode的box# enc_outputs_class 每个decode 的类别hs, init_reference, inter_references, ... = self.transformer(...) # 2->models/deformable_transformer_plus.py# pred_logits:每个框的类别、pred_boxes:每个框的boxout = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'ref_pts': ref_pts_all[5]}return out # -> inference_single_imagedef _post_process_single_image(....):with torch.no_grad():if self.training:track_scores = frame_res['pred_logits'][0, :].sigmoid().max(dim=-1).valueselse:track_scores = frame_res['pred_logits'][0, :, 0].sigmoid() # 当前帧检测到所有目标的得分track_instances.scores = track_scores # [300 + n] detection queries + tracking queriestrack_instances.pred_logits = frame_res['pred_logits'][0]track_instances.pred_boxes = frame_res['pred_boxes'][0]track_instances.output_embedding = frame_res['hs'][0]if self.training:track_instances = self.criterion.match_for_single_frame(frame_res)else: self.track_base.update(track_instances) # 根据当前帧的检测结果更新当前帧的track_instances的ID信息tmp['init_track_instances'] = self._generate_empty_tracks() # 初始化300个detection querytmp['track_instances'] = track_instances # 当前帧的 track_instances [300 + n]if not is_last:out_track_instances = self.track_embed(tmp) # QIM网络 -> models/qim.pyelse:...return frame_res # -> inference_single_imageclass RuntimeTrackerBase(...):def __init__(...):...def update(...):track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0 # 得分大于阈值的设为0, 得分小于阈值超过一定次数后将被过滤for i in range(len(track_instances)):# 当前 query 没有目标 且 得分超过一定值if track_instances.obj_idxes[i] == -1 and track_instances.scores[i] >= self.score_thresh:# 为这个query 设定一个ID ,说明开始跟踪一个目标track_instances.obj_idxes[i] = self.max_obj_idself.max_obj_id += 1 # 最大ID+1# object_query存在目标,但是阈值小于一定值elif track_instances.obj_idxes[i] >= 0 and track_instances.scores[i] < self.filter_score_thresh:track_instances.disappear_time[i] += 1 # 目标消失的次数+1# 当消失的次数大于阈值,则将query设为无目标if track_instances.disappear_time[i] >= self.miss_tolerance:track_instances.obj_idxes[i] = -1
3.4、models/deformable_transformer_plus.py
class DeformableTransformer(nn.Module):def __init__(...):...def forward(...):# encoder 提取图片特征 memory:[1, 23674, 256]memory = self.encoder(...)if self.two_stage:...else:query_embed, tgt = torch.split(query_embed, c, dim=1) # query_embed:[300 + n, 256] tgt:[300 + n, 256]if ref_pts is None:...else:reference_points = ref_pts.unsqueeze(0).repeat(bs, 1, 1).sigmoid() # 归一化坐标# decoder len(hs)==6,每个decode的中间结果,inter_references [6, 1, 300+n, 4] 每个decode的每个object_queries的bboxhs, inter_references = self.decoder(...)return hs, init_reference_out, inter_references_out, None, None # -> models/motr.pyclass DeformableTransformerDecoder(nn.Module):def __init__:...def forward(...):"""tgt: [1,300 + n , 256]reference_points:[1,300+n,2]src:[1, 23674, 256]"""for lid, layer in enumerate(self.layers):if reference_points.shape[-1] == 4:...else:reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] # [1,300,4] 每个object_queries 在每个特征层上找一个对应点output = layer(...) # [1,300+n,256]if self.return_intermediate:return torch.stack(intermediate), torch.stack(intermediate_reference_points) #class DeformableTransformerDecoderLayer(...):def __init__(...):...def forward(...):if self.self_cross:return self._forward_self_cross(...)def _forward_self_cross(...):# self attentiontgt = self._forward_self_attn(tgt, query_pos, attn_mask) # [1,300+n,256]# cross attention# detect/track querise 与 图片特征图进行cross_attn 交叉注意力tgt2 = self.cross_attn(...) # [1,300+n,256]tgt = self.forward_ffn(tgt) # [1,300+n,256]return tgt def _forward_self_attn(...):if self.extra_track_attn:# tgt:[1,300+n,256] track queries 内部做self_attntgt = self._forward_track_attn(tgt, query_pos) q = k = self.with_pos_embed(tgt, query_pos)if attn_mask is not None:...else:# detect queries 与 track queries 一起做self-attntgt2 = self.self_attn(...)return self.norm2(tgt)def _forward_track_attn(...):q = k = self.with_pos_embed(tgt, query_pos) # [1,300+n,256]if q.shape[1] > 300:tgt2 = self.update_attn(...) # 即有目标的track_queries做self-attn,然后再将track queries与 detect queriestgt = torch.cat([tgt[:, :300],self.norm4(tgt[:, 300:]+self.dropout5(tgt2))], dim=1)return tgt
3.5、models/qim.py
class QueryInteractionModule(...):def __init__(...):...def forward(...):# data = track_instances# data:[300 + n] 在进入 QIM网络之前,已经通过RuntimeTrackerBase更新了data中每个queries的属性active_track_instances = self._select_active_tracks(data) # 挑选出 data里面跟踪到目标的queries [n]active_track_instances = self._update_track_embedding(active_track_instances) # track queries 之间进行self_attnmerged_track_instances = Instances.cat([init_track_instances, active_track_instances]) # 将300个空的detetion_query 和当前帧最终的track query合并return merged_track_instancesdef _update_track_embedding(...):# 当前帧跟踪到的 track querytgt2 = self.self_attn(q[:, None], k[:, None], value=tgt[:, None])[0][:, 0]if self.update_query_pos:...track_instances.query_pos[:, :dim // 2] = query_postrack_instances.query_pos[:, dim//2:] = query_feat
4、学习小结
由于数据集的问题,本文只说明了推理过程的部分,按照这个,训练部分的代码可参考推理部分的代码,只是增加了最后计算loss的部分。通过本文,学习了如何利用transformer进行端到端的多目标跟踪,可以看出transformer已经在跟多的领域发挥作用,有机会分享更多的论文学习记录。
【MOT】多目标追踪学习笔记之MOTR相关推荐
- Jarry的目标跟踪学习笔记一
Jarry的目标跟踪学习笔记一 目标跟踪是计算机视觉中的一个重要方向,已经由来已久,并且有着广泛的应用,如:视频监控,人机交互, 无人驾驶等.在我的想象中,自己研究的内容就是,将来钢铁侠头盔里追踪敌人 ...
- 《南溪的目标检测学习笔记》——模型预处理的学习笔记
1 介绍 在目标检测任务中,模型预处理分为两个步骤: 图像预处理:基于图像处理算法 数值预处理:基于机器学习理论 关于图像预处理,请参考<南溪的目标检测学习笔记>--图像预处理的学习笔记 ...
- 《南溪的目标检测学习笔记》——COCO数据集的学习笔记
1 COCO数据集 COCO数据集下载链接:COCO_download 1.1 数据概览 数据集大小 train: 118287张 train+val: 123287张 val: 5000张 目标数量 ...
- 《南溪的目标检测学习笔记》的笔记目录
1 前言 这是<南溪的目标检测学习笔记>的目录~ 2 学习目标检测的思路--"总纲" <南溪的目标检测学习笔记>--目标检测的学习笔记 我在这篇文章中介绍了 ...
- 《南溪的目标检测学习笔记》——目标检测模型的设计笔记
1 南溪学习的目标检测模型--DETR 南溪最赞赏的目标检测模型是DETR, 论文名称:End-to-End Object Detection with Transformers 1.2 decode ...
- 《南溪的目标检测学习笔记》——夏侯南溪的CNN调参笔记,加油
1 致谢 感谢赵老师的教导! 感谢张老师的指导! 2 调参目标 在COCO数据集上获得mAP>=10.0的模型,现在PaddleDetection上的Anchor-Free模型[TTFNet]的 ...
- [初窥目标检测]——《目标检测学习笔记(2):浅析Selective Search论文——“Selective Search for object recognition”》
[初窥目标检测]--<目标检测学习笔记(2):浅析Selective Search论文--Selective Search for object recognition> 本文介绍 前文我 ...
- 9月1日目标检测学习笔记——文本检测
文章目录 前言 一.类型 1.Top-Down 2.Bottom-up 二.基于深度学习的文本检测模型 1.CTPN 2.RRPN 3.FTSN 4.DMPNet 5.EAST 6.SegLink 7 ...
- 3D目标检测学习笔记
博主初学3D目标检测,此前没有相关学习背景,小白一枚-现阶段的学习重点是点云相关的3D检测. 本文是阅读文章:3D Object Detection for Autonomous Driving: A ...
最新文章
- 深入理解Linux中的文件权限
- JavaScript this指向相关内容
- Python数据类型判断常遇到的坑
- 开源纯C#工控网关+组态软件(四)上下位机通讯原理
- Linux下启动启动tomcat 服务器报错 The file is absent or does not have execute permission
- SWT实现Text输入自动提示
- 2017.10.8 志愿者招募 失败总结
- 各种排序算法总结及C#代码实现
- 张掖市职教中心计算机专业,张掖市职教中心参加2021年全市中等职业学校学生教师技能大赛成绩喜人...
- Caffe学习-手写数字识别
- 禁用word公式编辑器
- MTK6589双卡卡1或是卡2拨出电话
- 文本特征提取:词袋模型/词集模型,TF-IDF
- 华为al00的计算机在哪,(详细)华为畅享8 LDN-AL00的USB调试模式在哪里开启的流程...
- 键盘定位板图纸_DIY如何自制专属GH60机械键盘教程【步骤详解】
- 发现美,创造美,拥有美^_^.
- js-array数组-slice-splice
- DNA计算 与 肽展公式 推导 AOPM-A 变胸腺苷, AOPM-O尿胞变腺苷, AOPM-P尿胞变鸟苷, AOPM-M鸟腺苷的 S形螺旋纹 血氧峰 触发器分子式 严谨完整过程
- USRP工作流程及各部分功能
- 2021计算机技术考研非全日制,2021考研考非全日制还是全日制?盘点你不懂的非全日制深层含义~...