介绍

该论文在YOLOv3的基础上增加embedding学习,通过同时输出box、class和embedding实现One-stage跟踪,提高速度。但获得embedding后仍需要匹配算法进行跟踪,严格来说依然是两阶段的。因与YOLOv3模型相似,本文主要介绍跟踪部分内容和代码。

论文:Towards Real-Time Multi-Object Tracking
代码:Zhongdao/Towards-Realtime-MOT

模型

论文针对SED跟踪弊端,提出One-stage 跟踪JED范式。

采用YOLOv3模型,网络结构基本相同,只是在prediction head中多出一个分支用于学习embedding,每个预测头都被建模为多任务学习问题,Prediction head输出大小为(6A+D)HW,其中A是该比例的anchor数量,D是embedding维度,分配如下

embedding学习

跟踪轨迹中的每一个物体都有唯一的track ID,为得到分类信息,在embedding后引入全连接层,借鉴物体分类思路(将每个track ID当做一个类别),将embedding信息转化为track ID的分类信息,论文给的图像中没有说明,补全效果如下

经mask处理后获得targets处的embedding信息,再经过全连接分类,获得nID个输出(nID是数据集中ID的个数),训练时再进行loss计算。

Loss

loss分为三项:
前后景分类:交叉熵
检测框:smooth-L1
embedding:论文中讲用定义的LEC损失,与交叉熵相似,代码中使用的交叉熵

代码在models.py中,如下

class YOLOLayer(nn.Module):def __init__(self, anchors, nC, nID, nE, img_size, yolo_layer):super(YOLOLayer, self).__init__()self.layer = yolo_layernA = len(anchors)self.anchors = torch.FloatTensor(anchors)self.nA = nA  # number of anchors (3)   注意这里代码和论文都是4self.nC = nC  # number of classes (80)self.nID = nID # number of identitiesself.img_size = 0self.emb_dim = nE self.shift = [1, 3, 5]self.SmoothL1Loss  = nn.SmoothL1Loss()self.SoftmaxLoss = nn.CrossEntropyLoss(ignore_index=-1)self.CrossEntropyLoss = nn.CrossEntropyLoss()self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)    # 此处为ID loss的损失函数self.s_c = nn.Parameter(-4.15*torch.ones(1))  # -4.15self.s_r = nn.Parameter(-4.85*torch.ones(1))  # -4.85self.s_id = nn.Parameter(-2.3*torch.ones(1))  # -2.3self.emb_scale = math.sqrt(2) * math.log(self.nID-1) if self.nID>1 else 1

这里定义了一些参数和损失函数,其中nA是当前模板下的anchor数量,注释给的3,论文和代码里都是4。

    def forward(self, p_cat,  img_size, targets=None, classifier=None, test_emb=False):p, p_emb = p_cat[:, :24, ...], p_cat[:, 24:, ...]   # 前24维用于box和confnB, nGh, nGw = p.shape[0], p.shape[-2], p.shape[-1]     # 个数、sizeif self.img_size != img_size:create_grids(self, img_size, nGh, nGw)if p.is_cuda:self.grid_xy = self.grid_xy.cuda()self.anchor_wh = self.anchor_wh.cuda()p = p.view(nB, self.nA, self.nC + 5, nGh, nGw).permute(0, 1, 3, 4, 2).contiguous()  # prediction  将24维信息分为4个anchor对应的[[box],[conf]],box4位,conf两位p_emb = p_emb.permute(0,2,3,1).contiguous()    # embedding 信息 shape(nB,[size],512)p_box = p[..., :4]                             # boxs 信息 shape(nB,4,[size],4)p_conf = p[..., 4:6].permute(0, 4, 1, 2, 3)  # Conf shape(nB,2,4,[size]) 4为anchor数量

这里是按照模型分配特征信息,不同信息代码中都标注出来了

       # Trainingif targets is not None:if test_emb:tconf, tbox, tids = build_targets_max(targets, self.anchor_vec.cuda(), self.nA, self.nC, nGh, nGw)else:tconf, tbox, tids = build_targets_thres(targets, self.anchor_vec.cuda(), self.nA, self.nC, nGh, nGw)tconf, tbox, tids = tconf.cuda(), tbox.cuda(), tids.cuda() # 置信度 检测框 id,每种对应四个anchor有相同的四层mask = tconf > 0

这里是获取标签信息

           # Compute lossesnT = sum([len(x) for x in targets])  # number of targetsnM = mask.sum().float()  # number of anchors (assigned to targets)nP = torch.ones_like(mask).sum().float()if nM > 0:lbox = self.SmoothL1Loss(p_box[mask], tbox[mask])else:FT = torch.cuda.FloatTensor if p_conf.is_cuda else torch.FloatTensorlbox, lconf =  FT([0]), FT([0])lconf =  self.SoftmaxLoss(p_conf, tconf)lid = torch.Tensor(1).fill_(0).squeeze().cuda()emb_mask,_ = mask.max(1)

计算conf和box的loss,设置mask,只考虑有targets处的,即包含目标处的损失

           # For convenience we use max(1) to decide the id, TODO: more reseanable strategytids,_ = tids.max(1) tids = tids[emb_mask]embedding = p_emb[emb_mask].contiguous()embedding = self.emb_scale * F.normalize(embedding)nI = emb_mask.sum().float()if  test_emb:if np.prod(embedding.shape)==0  or np.prod(tids.shape) == 0:return torch.zeros(0, self.emb_dim+1).cuda()emb_and_gt = torch.cat([embedding, tids.float()], dim=1)return emb_and_gtif len(embedding) > 1:logits = classifier(embedding).contiguous()   # 全连接分类lid =  self.IDLoss(logits, tids.squeeze())

这里是ID的loss计算,结合上边补全的网络图像,embedding经mask提取,Linear线性分类,再计算交叉熵损失。

 # Sum loss componentsloss = torch.exp(-self.s_r)*lbox + torch.exp(-self.s_c)*lconf + torch.exp(-self.s_id)*lid + \(self.s_r + self.s_c + self.s_id)loss *= 0.5return loss, loss.item(), lbox.item(), lconf.item(), lid.item(), nT

最后计算总损失并返回。整个网络的运行过程如下:

class Darknet(nn.Module):"""YOLOv3 object detection model"""def __init__(self, cfg_dict, nID=0, test_emb=False):super(Darknet, self).__init__()if isinstance(cfg_dict, str):cfg_dict = parse_model_cfg(cfg_dict)self.module_defs = cfg_dict self.module_defs[0]['nID'] = nIDself.img_size = [int(self.module_defs[0]['width']), int(self.module_defs[0]['height'])]self.emb_dim = int(self.module_defs[0]['embedding_dim'])self.hyperparams, self.module_list = create_modules(self.module_defs)self.loss_names = ['loss', 'box', 'conf', 'id', 'nT']self.losses = OrderedDict()for ln in self.loss_names:self.losses[ln] = 0self.test_emb = test_embself.classifier = nn.Linear(self.emb_dim, nID) if nID>0 else None# 这里定义的就是上边embedding分类网络,输出大小为nIDdef forward(self, x, targets=None, targets_len=None):self.losses = OrderedDict()for ln in self.loss_names:self.losses[ln] = 0is_training = (targets is not None) and (not self.test_emb)#img_size = x.shape[-1]layer_outputs = []output = []for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):mtype = module_def['type']if mtype in ['convolutional', 'upsample', 'maxpool']:x = module(x)elif mtype == 'route':layer_i = [int(x) for x in module_def['layers'].split(',')]if len(layer_i) == 1:x = layer_outputs[layer_i[0]]else:x = torch.cat([layer_outputs[i] for i in layer_i], 1)elif mtype == 'shortcut':layer_i = int(module_def['from'])x = layer_outputs[-1] + layer_outputs[layer_i]elif mtype == 'yolo':if is_training:  # get losstargets = [targets[i][:int(l)] for i,l in enumerate(targets_len)]x, *losses = module[0](x, self.img_size, targets, self.classifier)for name, loss in zip(self.loss_names, losses):self.losses[name] += losselif self.test_emb:if targets is not None:targets = [targets[i][:int(l)] for i,l in enumerate(targets_len)]x = module[0](x, self.img_size, targets, self.classifier, self.test_emb)else:  # get detectionsx = module[0](x, self.img_size)output.append(x)layer_outputs.append(x)if is_training:self.losses['nT'] /= 3 output = [o.squeeze() for o in output]return sum(output), torch.Tensor(list(self.losses.values())).cuda()elif self.test_emb:return torch.cat(output, 0)return torch.cat(output, 1)

track

跟踪过程主要在multitracker.py文件

class JDETracker(object):def __init__(self, opt, frame_rate=30):self.opt = optself.model = Darknet(opt.cfg, nID=14455)# load_darknet_weights(self.model, opt.weights)self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)self.model.cuda().eval()self.tracked_stracks = []  # type: list[STrack]self.lost_stracks = []  # type: list[STrack]self.removed_stracks = []  # type: list[STrack]self.frame_id = 0self.det_thresh = opt.conf_thresself.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)self.max_time_lost = self.buffer_sizeself.kalman_filter = KalmanFilter()

可以看到代码中使用的nID为14455,这里定义了一些存储不同状态轨迹的容器(轨迹类也在这个文件中,class STrack),以及KF卡尔曼滤波器。跟踪过程如下

        with torch.no_grad():pred = self.model(im_blob)# pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddingspred = pred[pred[:, :, 4] > self.opt.conf_thres]     # [p_box, p_conf, p_cls, p_emb] 第四位为conf conf阈值判断# pred now has lesser number of proposals. Proposals rejected on basis of object confidence scoreif len(pred) > 0:dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0].cpu()     # 非极大值抑制# Final proposals are obtained in dets. Information of bounding box and embeddings also included# Next step changes the detection scalesscale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()'''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)'''# class_pred is the embeddings.detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]else:detections = []

首先获取预测的detection信息,注释很详细

    ''' Add newly detected tracklets to tracked_stracks'''unconfirmed = []tracked_stracks = []  # type: list[STrack]for track in self.tracked_stracks:if not track.is_activated:# previous tracks which are not active in the current frame are added in unconfirmed listunconfirmed.append(track)# print("Should not be here, in unconfirmed")else:# Active tracks are added to the local list 'tracked_stracks'tracked_stracks.append(track)

然后将容器中的轨迹取出包括正常的轨迹和一些未确定轨迹(主要是只包含起始帧,即detection只出现过一次的轨迹)

''' Step 2: First association, with embedding'''# Combining currently tracked_stracks and lost_stracksstrack_pool = joint_stracks(tracked_stracks, self.lost_stracks)# Predict the current location with KFSTrack.multi_predict(strack_pool, self.kalman_filter)dists = matching.embedding_distance(strack_pool, detections)# dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)# The dists is the list of distances of the detection with the tracks in strack_poolmatches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)# The matches is the array for corresponding matches of the detection with the corresponding strack_pool

这里计算embedding_distance使用余弦距离,函数如下

def embedding_distance(tracks, detections, metric='cosine'):""":param tracks: list[STrack]:param detections: list[BaseTrack]:param metric::return: cost_matrix np.ndarray"""cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)if cost_matrix.size == 0:return cost_matrixdet_features = np.asarray([track.curr_feat for track in detections], dtype=np.float)track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float)cost_matrix = np.maximum(0.0, cdist(track_features, det_features)) # Nomalized featuresreturn cost_matrix

接着使用KF滤波融合运动信息,如下

def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):if cost_matrix.size == 0:return cost_matrixgating_dim = 2 if only_position else 4gating_threshold = kalman_filter.chi2inv95[gating_dim]measurements = np.asarray([det.to_xyah() for det in detections])for row, track in enumerate(tracks):gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position, metric='maha')cost_matrix[row, gating_distance > gating_threshold] = np.infcost_matrix[row] = lambda_ * cost_matrix[row] + (1-lambda_)* gating_distancereturn cost_matrix

匹配使用JCV算法(线性任务分配,函数lap.lapjv,这个好像比匈牙利要快),计算最小distance/cost情况下detection与track的匹配,如下

def linear_assignment(cost_matrix, thresh):if cost_matrix.size == 0:return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))matches, unmatched_a, unmatched_b = [], [], []cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)    # 使用JVC匹配for ix, mx in enumerate(x):if mx >= 0:matches.append([ix, mx])unmatched_a = np.where(x < 0)[0]unmatched_b = np.where(y < 0)[0]matches = np.asarray(matches)return matches, unmatched_a, unmatched_b

将匹配结果保存,
再对未匹配的track和detection 使用IoU distance 匹配,
移除仍未匹配的未经确认的轨迹(通常只有起始帧有detection),
对新的检测创建轨迹,
移除长时间未匹配的轨迹,
最终进行轨迹更新
上述过程中的匹配过程均使用JVC算法,过程相似不再贴代码了

总结

网络结构简单,主要是在YOLO3预测头后面加了一个embedding学习。对比tracking by detection,该网络同时输出图像画面中的检测框位置和检测框内物体的embedding,从而加速MOT的速度。但JDE只是同时输出了检测框和embedding信息,后面还要匹配,其实还是两阶段的。看到大佬分析将detection和ReID结合在同一个网络中的做法仍然和分别去做存在一定的差距,因为detection本质上需要catagory特征,然而ReID需要identity特征,同一个网络中不能同时得到较好的class和identity特征。

代码运行比较容易,可能会遇到点小问题,安装cython-bbox,可以下源码编译;注意根据下载的数据和权重修改代码和cfg/ccmcpe.json中的信息,最好下载作者提供的数据,是已经按照代码格式处理过的,官网下载的不能直接用。

欢迎交流指正
参考
【MOT】对JDE的深度解析
One Shot Multi-Object Tracking Overview
多目标跟踪算法(JDE)Towards Real-Time Multi-Object Tracking训练方法

[论文笔记]JED:Towards Real-Time Multi-Object Tracking相关推荐

  1. 多目标跟踪综述、论文、数据集大汇总 Awesome Multiple object Tracking

    Awesome Multiple object Tracking(持续更新) 综述 论文 2022 2021 2020 2019 2018 2017 2016 数据集 综述 Multiple Obje ...

  2. 论文笔记之:Action-Decision Networks for Visual Tracking with Deep Reinforcement Learning

    论文笔记之:Action-Decision Networks for Visual Tracking with Deep Reinforcement Learning  2017-06-06  21: ...

  3. 论文笔记——Rich feature hierarchies for accurate object detection and semantic segmentation

    最近在看一些目标检测的论文,本文是经典的R-CNN(Regions with CNN features),随之产生的一系列目标检测算法:RCNN,Fast RCNN, Faster RCNN代表当下目 ...

  4. 【论文笔记】Feature Pyramid Networks for Object Detection

    文章目录 Abstract 1. Introduction 3. Feature Pyramid Networks 4. Applications 4.1. Feature Pyramid Netwo ...

  5. 【论文笔记】MV3D:Multi-View 3D Object Detection Network for Autonomous Driving

    摘要 本文针对自动驾驶场景中的高精度3D对象检测.我们提出了多视点三维网络(MV3D),这是一个以激光雷达点云和RGB图像为输入,预测定向三维边界框的传感器融合框架.我们用一个紧凑的多视图表示来编码稀 ...

  6. 【论文阅读】Graph Networks for Multiple Object Tracking

    1.引言 受文献[3]的启发,我们提出了一种基于两个端到端图网络的近在线MOT方法,该方法提供了一种新的基本图网络框架.我们根据MOT问题的特点,精心设计了自己的图网络.对于两个分别处理外观和运动的端 ...

  7. 论文笔记(二十)VisuoTactile 6D Pose Estimation of an In-Hand Object using Vision and Tactile Sensor Data

    VisuoTactile 6D Pose Estimation of an In-Hand Object using Vision and Tactile Sensor Data 文章概括 摘要 1. ...

  8. 论文笔记 SiamMask : Fast Online Object Tracking and Segmentation: A Unifying Approach

    论文连接:[1812.05050] Fast Online Object Tracking and Segmentation: A Unifying Approach 论文连接:[1812.05050 ...

  9. 论文笔记 《Selective Search for Object Recognition》

    论文笔记 <Selective Search for Object Recognition> 项目网址:http://koen.me/research/selectivesearch/ 一 ...

  10. 3d object是什么文件_[单目3D目标检测论文笔记] 3D Bounding Box Estimation

    本文是3D Bounding Box Estimation Using Deep Learning and Geometry的论文笔记及个人理解.这篇文章是单目图像3d目标检测的一个经典工作之一.其目 ...

最新文章

  1. C将十六进制数字字符串转成数字
  2. “4K云字库”基本框架图
  3. 标记一下 两个Google Chrome 源码研究的网站
  4. 计算机等级考试试题在线测试,计算机等级考试上机练习题.pdf
  5. 优先队列priority_queue自定义比较函数
  6. Java高并发编程详解系列-Future设计模式
  7. 【http】http https搜集的好文章
  8. ruby mysql dbi_Ruby/DBI-数据库访问接口
  9. 用Jersey构建RESTful服务简单示例
  10. lbp特征的matlab实现
  11. win10如何找计算机管理员密码,win10怎么修改administrator账户密码 win10修改管理员账户密码方法...
  12. 信号降噪方法——基于自适应神经模糊推理系统(ANFIS)的降噪处理
  13. 怎样选择合适的电流继电器
  14. 数据集:银行客户信息
  15. Windows11 安装教程(ultraiso制作启动盘)
  16. Google URL Shorter
  17. 通信算法之十一:QPSK/DQPSK/OQPSK/BPSK/DBPSK/16QAM调制解调仿真链路
  18. JAVASCRIPT遇到《九章算术》-《九章算术注》序
  19. 由华为裁员传闻引发的思考:年轻人如何避免中年危机?
  20. 什么是SoC(System-on-a-Chip)

热门文章

  1. kafka权威指南读书心得
  2. 基于qt开发的轻量级浏览器
  3. oracle的存储过程菜鸟教程,SQL菜鸟入门级教程之存储过程
  4. (已更新)视频app小程序模板源码
  5. 【CTF】关于md5总结
  6. 【Python的自学之路】(八):文字游戏分享
  7. 通过jsp实现省市区县四级联动菜单
  8. 商务统计_13 使用excel拟合曲趋势线
  9. 虚拟与增强现实——输入与输出设备
  10. Fiddler详解-Fiddler Classic