1、前言

由于transformer的成功,现已经在多目标跟踪领域(MOTR)广泛应用,例如TransTrack和TrackFormer,但这两个工作严格来说不能算作端到端的模型,而MOTR的出现弥补了上面两个工作的缺点。本文将对MOTR的网络结构简单说明,然后对推理代码进行解析,如有不对的地方,还请各位大佬指出。
paper : https://arxiv.org/abs/2105.03247
repo :https://github.com/megvii-research/MOTR
可以加v:Rex1586662742或者q群:468713665一起讨论
学习链接:MOTR、MOTR

MOTR中涉及到的Deformerable DRET 可参考:【BEV】学习笔记之 DeformableDETR(原理+代码解析)

2、网络结构

下面根据论文中的图示来分步解析MOTR的网络结构

如上图左边为DETR的解码过程,利用Object queries和Image featue进行Decoder获得目标框。MOTR中利用了这种方法,利用多帧的特征与Track queries进行Decoder从而获得被跟踪的目标, 从代码里面发现这里的track queries其实是包括300个初始化的 detect queries + n个track queries(下文中假定每一帧中有n个track queries),得到了当前帧的目标后,即更新track queries。

针对上图(b)中的Iterative Updata过程,MOTR论文中提供了下图进行说明。

左右两边是一个对应的过程,在t1时刻,首先初始化300个detect queries ,并检测到了两个目标,即Object1和Object2,由于者两个目标在之前从未出现过,因此将其划分为track queries。在t2时刻,再次初始化300个object queries 并检测到了新的目标Object3,由于Object1和object2已经被跟踪上,因此不会再次被检测到。于是在t3时刻得到了3个track queries,在t4时刻,由于Obeject2消失,因此,track queries相应的也要将object3对应的object query删除。

下面将对track query的加入与删除模块进行说明,文中将这个模块命名为 query interaction module(QIM)。

在上图中,通过预测可以得到当前帧的 detect query 和 track query。1、detect 进入到Object Enterance模块, 从detect queries筛选出符合条件的 query作为新的track query。2、track query进入到Object Exit中,筛选出目标消失的track queries,然后在track queries中做self_attn,最终将detect queries以及track queries筛选后的query作为当前帧的track queries。

通过上面几个步骤,最终的网络结构如下图所示:

3、代码解析

在代码开始之前需要对Instances类进行说明,后续的代码中track_instances是 Instances实例化的对象,用于储存每帧中的 track queries,以及每个track queries对应的类别、boxes等信息。

3.1、models/structures/instances.py

class Instances:"""This class represents a list of instances in an image.It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".All fields must have the same ``__len__`` which is the number of instances.储存每一帧的检测以及跟踪信息,包含detection query + track query"""def __init__(...):self._fields: Dict[str, Any] = {}def set(self, name: str, value: Any) -> None:"""Set the field named `name` to `value`.储存检测结果"""self._fields[name] = value# 由于类初始的时候没有定义成员,后面给类添加成员变量的时候会默认调用这个函数def __setattr__(self, name: str, val: Any) -> None:if name.startswith("_"):super().__setattr__(name, val)else:self.set(name, val)  # 将值存到 self._fields中def cat(..)"""合并两个 Instances"""...

3.2、demo.py

class Detector(...):def __init__(...):self.tr_tracker = MOTR()def run(...):for _, cur_img, ori_img in tqdm(self.dataloader):...res = self.model.inference_single_image(...)  # -> models/motr.pytrack_instances = res['track_instances'] # 当前帧的 track_instances# 筛选出track_instances 得分大于阈值的 track querydt_instances = self.filter_dt_by_score(dt_instances, prob_threshold)# 筛选出 dt_instances 中 box面积大于阈值的track querydt_instances = self.filter_dt_by_area(dt_instances, area_threshold)if dump: tracker_outputs = self.tr_tracker.update(dt_instances)  # [n, 6] 当前帧的 阈值与面积超过一定阈值的 [x1,y1,x2,y2,score,id]

3.3、models/motr.py

class MOTR(nn.Module):def __init__(...):self.post_process = TrackerPostProcess()  # 将预测结果恢复到原图尺寸self.track_base = RuntimeTrackerBase()    # 进入到QIM模块之前给track queries进行标记def _generate_empty_tracks(self):...def inference_single_image(...):if track_instances is None:track_instances = self._generate_empty_tracks()  # 在第0帧初始化300个 detection queries# track_instances:上一帧的跟踪结果  len(track_instances) = 300 + n   300为空的detect queries的个数, n为track queries的个数# 利用上一帧的track qeries 和当前帧的图片特征进行decoder,来预测当前帧的检测结果,融合了时序信息。res = self._forward_single_image(img,track_instances=track_instances)  # res:{"pred_logits":"pred_boxes","","hs":"",...}当前帧的目标检测结果res = self._post_process_single_image(res, track_instances, False)   # 本帧检测到的目标,与上一帧的跟踪目标进行匹配track_instances = self.post_process(track_instances, ori_img_size)  # 将当前帧的跟踪结果缩放到图片尺寸大小return ... # -> demo.pydef _forward_single_image(...):features, pos = self.backbone(samples) # 提取图片特征,参考deformable detr# track_instances.query_pos:[300 + n, 512]# hs 每个decode的中间结果 len(hs)==6# inter_references 每个decode的box# enc_outputs_class 每个decode 的类别hs, init_reference, inter_references, ... = self.transformer(...)  # 2->models/deformable_transformer_plus.py#  pred_logits:每个框的类别、pred_boxes:每个框的boxout = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'ref_pts': ref_pts_all[5]}return out # -> inference_single_imagedef _post_process_single_image(....):with torch.no_grad():if self.training:track_scores = frame_res['pred_logits'][0, :].sigmoid().max(dim=-1).valueselse:track_scores = frame_res['pred_logits'][0, :, 0].sigmoid()  # 当前帧检测到所有目标的得分track_instances.scores = track_scores  # [300 + n]   detection queries + tracking queriestrack_instances.pred_logits = frame_res['pred_logits'][0]track_instances.pred_boxes = frame_res['pred_boxes'][0]track_instances.output_embedding = frame_res['hs'][0]if self.training:track_instances = self.criterion.match_for_single_frame(frame_res)else: self.track_base.update(track_instances)   # 根据当前帧的检测结果更新当前帧的track_instances的ID信息tmp['init_track_instances'] = self._generate_empty_tracks()  # 初始化300个detection querytmp['track_instances'] = track_instances  # 当前帧的 track_instances [300 + n]if not is_last:out_track_instances = self.track_embed(tmp)   # QIM网络  -> models/qim.pyelse:...return frame_res  # -> inference_single_imageclass RuntimeTrackerBase(...):def __init__(...):...def update(...):track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0 # 得分大于阈值的设为0, 得分小于阈值超过一定次数后将被过滤for i in range(len(track_instances)):# 当前 query 没有目标 且 得分超过一定值if track_instances.obj_idxes[i] == -1 and track_instances.scores[i] >= self.score_thresh:# 为这个query 设定一个ID ,说明开始跟踪一个目标track_instances.obj_idxes[i] = self.max_obj_idself.max_obj_id += 1 # 最大ID+1# object_query存在目标,但是阈值小于一定值elif track_instances.obj_idxes[i] >= 0 and track_instances.scores[i] < self.filter_score_thresh:track_instances.disappear_time[i] += 1 # 目标消失的次数+1# 当消失的次数大于阈值,则将query设为无目标if track_instances.disappear_time[i] >= self.miss_tolerance:track_instances.obj_idxes[i] = -1

3.4、models/deformable_transformer_plus.py

class DeformableTransformer(nn.Module):def __init__(...):...def forward(...):# encoder 提取图片特征 memory:[1, 23674, 256]memory = self.encoder(...)if self.two_stage:...else:query_embed, tgt = torch.split(query_embed, c, dim=1)  # query_embed:[300 + n, 256]  tgt:[300 + n, 256]if ref_pts is None:...else:reference_points = ref_pts.unsqueeze(0).repeat(bs, 1, 1).sigmoid()  # 归一化坐标# decoder len(hs)==6,每个decode的中间结果,inter_references [6, 1, 300+n, 4] 每个decode的每个object_queries的bboxhs, inter_references = self.decoder(...)return hs, init_reference_out, inter_references_out, None, None  # ->  models/motr.pyclass DeformableTransformerDecoder(nn.Module):def __init__:...def forward(...):"""tgt: [1,300 + n , 256]reference_points:[1,300+n,2]src:[1, 23674, 256]"""for lid, layer in enumerate(self.layers):if reference_points.shape[-1] == 4:...else:reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]  # [1,300,4] 每个object_queries 在每个特征层上找一个对应点output = layer(...)  #  [1,300+n,256]if self.return_intermediate:return torch.stack(intermediate), torch.stack(intermediate_reference_points) #class DeformableTransformerDecoderLayer(...):def __init__(...):...def forward(...):if self.self_cross:return self._forward_self_cross(...)def _forward_self_cross(...):# self attentiontgt = self._forward_self_attn(tgt, query_pos, attn_mask)  # [1,300+n,256]# cross attention# detect/track querise 与 图片特征图进行cross_attn 交叉注意力tgt2 = self.cross_attn(...)   # [1,300+n,256]tgt = self.forward_ffn(tgt) # [1,300+n,256]return tgt def _forward_self_attn(...):if self.extra_track_attn:# tgt:[1,300+n,256]  track queries 内部做self_attntgt = self._forward_track_attn(tgt, query_pos) q = k = self.with_pos_embed(tgt, query_pos)if attn_mask is not None:...else:# detect queries 与 track queries 一起做self-attntgt2 = self.self_attn(...)return self.norm2(tgt)def _forward_track_attn(...):q = k = self.with_pos_embed(tgt, query_pos)  # [1,300+n,256]if q.shape[1] > 300:tgt2 = self.update_attn(...)  # 即有目标的track_queries做self-attn,然后再将track queries与 detect queriestgt = torch.cat([tgt[:, :300],self.norm4(tgt[:, 300:]+self.dropout5(tgt2))], dim=1)return tgt

3.5、models/qim.py

class QueryInteractionModule(...):def __init__(...):...def forward(...):# data = track_instances# data:[300 + n] 在进入 QIM网络之前,已经通过RuntimeTrackerBase更新了data中每个queries的属性active_track_instances = self._select_active_tracks(data) # 挑选出 data里面跟踪到目标的queries  [n]active_track_instances = self._update_track_embedding(active_track_instances) # track queries 之间进行self_attnmerged_track_instances = Instances.cat([init_track_instances, active_track_instances])  # 将300个空的detetion_query  和当前帧最终的track query合并return merged_track_instancesdef _update_track_embedding(...):# 当前帧跟踪到的 track querytgt2 = self.self_attn(q[:, None], k[:, None], value=tgt[:, None])[0][:, 0]if self.update_query_pos:...track_instances.query_pos[:, :dim // 2] = query_postrack_instances.query_pos[:, dim//2:] = query_feat

4、学习小结

由于数据集的问题,本文只说明了推理过程的部分,按照这个,训练部分的代码可参考推理部分的代码,只是增加了最后计算loss的部分。通过本文,学习了如何利用transformer进行端到端的多目标跟踪,可以看出transformer已经在跟多的领域发挥作用,有机会分享更多的论文学习记录。

【MOT】多目标追踪学习笔记之MOTR相关推荐

  1. Jarry的目标跟踪学习笔记一

    Jarry的目标跟踪学习笔记一 目标跟踪是计算机视觉中的一个重要方向,已经由来已久,并且有着广泛的应用,如:视频监控,人机交互, 无人驾驶等.在我的想象中,自己研究的内容就是,将来钢铁侠头盔里追踪敌人 ...

  2. 《南溪的目标检测学习笔记》——模型预处理的学习笔记

    1 介绍 在目标检测任务中,模型预处理分为两个步骤: 图像预处理:基于图像处理算法 数值预处理:基于机器学习理论 关于图像预处理,请参考<南溪的目标检测学习笔记>--图像预处理的学习笔记 ...

  3. 《南溪的目标检测学习笔记》——COCO数据集的学习笔记

    1 COCO数据集 COCO数据集下载链接:COCO_download 1.1 数据概览 数据集大小 train: 118287张 train+val: 123287张 val: 5000张 目标数量 ...

  4. 《南溪的目标检测学习笔记》的笔记目录

    1 前言 这是<南溪的目标检测学习笔记>的目录~ 2 学习目标检测的思路--"总纲" <南溪的目标检测学习笔记>--目标检测的学习笔记 我在这篇文章中介绍了 ...

  5. 《南溪的目标检测学习笔记》——目标检测模型的设计笔记

    1 南溪学习的目标检测模型--DETR 南溪最赞赏的目标检测模型是DETR, 论文名称:End-to-End Object Detection with Transformers 1.2 decode ...

  6. 《南溪的目标检测学习笔记》——夏侯南溪的CNN调参笔记,加油

    1 致谢 感谢赵老师的教导! 感谢张老师的指导! 2 调参目标 在COCO数据集上获得mAP>=10.0的模型,现在PaddleDetection上的Anchor-Free模型[TTFNet]的 ...

  7. [初窥目标检测]——《目标检测学习笔记(2):浅析Selective Search论文——“Selective Search for object recognition”》

    [初窥目标检测]--<目标检测学习笔记(2):浅析Selective Search论文--Selective Search for object recognition> 本文介绍 前文我 ...

  8. 9月1日目标检测学习笔记——文本检测

    文章目录 前言 一.类型 1.Top-Down 2.Bottom-up 二.基于深度学习的文本检测模型 1.CTPN 2.RRPN 3.FTSN 4.DMPNet 5.EAST 6.SegLink 7 ...

  9. 3D目标检测学习笔记

    博主初学3D目标检测,此前没有相关学习背景,小白一枚-现阶段的学习重点是点云相关的3D检测. 本文是阅读文章:3D Object Detection for Autonomous Driving: A ...

最新文章

  1. 深入理解Linux中的文件权限
  2. JavaScript this指向相关内容
  3. Python数据类型判断常遇到的坑
  4. 开源纯C#工控网关+组态软件(四)上下位机通讯原理
  5. Linux下启动启动tomcat 服务器报错 The file is absent or does not have execute permission
  6. SWT实现Text输入自动提示
  7. 2017.10.8 志愿者招募 失败总结
  8. 各种排序算法总结及C#代码实现
  9. 张掖市职教中心计算机专业,张掖市职教中心参加2021年全市中等职业学校学生教师技能大赛成绩喜人...
  10. Caffe学习-手写数字识别
  11. 禁用word公式编辑器
  12. MTK6589双卡卡1或是卡2拨出电话
  13. 文本特征提取:词袋模型/词集模型,TF-IDF
  14. 华为al00的计算机在哪,(详细)华为畅享8 LDN-AL00的USB调试模式在哪里开启的流程...
  15. 键盘定位板图纸_DIY如何自制专属GH60机械键盘教程【步骤详解】
  16. 发现美,创造美,拥有美^_^.
  17. js-array数组-slice-splice
  18. DNA计算 与 肽展公式 推导 AOPM-A 变胸腺苷, AOPM-O尿胞变腺苷, AOPM-P尿胞变鸟苷, AOPM-M鸟腺苷的 S形螺旋纹 血氧峰 触发器分子式 严谨完整过程
  19. USRP工作流程及各部分功能
  20. 2021计算机技术考研非全日制,2021考研考非全日制还是全日制?盘点你不懂的非全日制深层含义~...

热门文章

  1. egg设置cookie
  2. 使用xgboost进行文本分类
  3. Glid 加载图片不显示(Android9.0无法加载图片)
  4. 移动互联网时代我们如何引爆社群?
  5. 统计一串数字的不重复数字个数
  6. Linux-Squid代理服务器搭建
  7. 什么是双亲委派?如何打破双亲委派?
  8. 有哪些好的上报crash工具:推荐crashlytics
  9. 机械CAD中如何快速绘制相贯线?
  10. appium---TouchAction