环境:win10,cuda 10.1 , GTX1060

一、数据处理

1、数据集获取:

链接:https://pan.baidu.com/s/1K3rI9PvzHc1KqOJITNMdVg 
提取码:lox4

2、数据集格式

数据格式也不一定完全按照上面这种,但是必须得保证图片和标签的名字相同。以MOT17(JDE的Modelzoo中下载得到)为例:

文件夹结构:

labels_with_ids文件夹里面是用转换工具将gt.txt生成对应的JDE训练所需的标注文件,对应每一个视频序列的每一帧图片。

而这片博客是要在UA-DETRAC数据集上训练JDE,所以先看看UA-DETRAC原始数据集(所有图片分辨率为960x540)

标签是XML格式,且一个xml对应一个视频序列,每一个xml内容包含该视频序列中所有帧的标注信息:

其中一帧中包括多个车辆标注,标注信息包括:车辆ID,box坐标,以及一些属性:方向,速度,轨迹长度,遮挡率,车辆类别。要想在JDE中训练,需要进行转换,JDE要求的标注格式:

编写脚本,解析原始xml标注文件,生成上述的标注txt文件,因为FairMOT算法和JDE用的是同一个数据处理方式,甚至是完全相同的数据集,因此我直接在FairMOT的数据转换工具基础上做了修改,内容如下:

import os.path as osp
import os
import numpy as np
import shutil
import xml.dom.minidom as xml
import abcdef mkdirs(d):if not osp.exists(d):os.makedirs(d)seq_root = 'F:/dataset/MOT/UA-DETRAC/DETRAC-train-data/Insight-MVT_Annotation_Train'#图片
xml_root = 'F:/dataset/MOT/UA-DETRAC/DETRAC-Train-Annotations-XML'  #原始xml标注
label_root="F:/dataset/MOT/UA-DETRAC/DETRAC-Train-Annotations-track" #新生成的标签保存目录#mkdirs(label_root)
seqs = [s for s in os.listdir(seq_root)]'''
读取xml文件
'''class XmlReader(object):__metaclass__ = abc.ABCMetadef __init__(self):passdef read_content(self,filename):content = Noneif (False == os.path.exists(filename)):return contentfilehandle = Nonetry:filehandle = open(filename,'rb')except FileNotFoundError as e:print(e.strerror)try:content = filehandle.read()except IOError as e:print(e.strerror)if (None != filehandle):filehandle.close()if(None != content):return content.decode("utf-8","ignore")return content@abc.abstractmethoddef load(self,filename):passclass XmlTester(XmlReader):def __init__(self):XmlReader.__init__(self)def load(self, filename):filecontent = XmlReader.read_content(self,filename)#print(filecontent)seq_gt=[]if None != filecontent:dom = xml.parseString(filecontent)root = dom.getElementsByTagName('sequence')[0]if root.hasAttribute("name"):seq_name=root.getAttribute("name")print ("*"*20+"sequence: %s" %seq_name +"*"*20)#获取所有的frameframes = root.getElementsByTagName('frame')for frame in frames:if frame.hasAttribute("num"):frame_num=int(frame.getAttribute("num"))print ("-"*10+"frame_num: %s" %frame_num +"-"*10)target_list = frame.getElementsByTagName('target_list')[0]#获取一帧里面所有的targettargets = target_list.getElementsByTagName('target')targets_dic={}for target in targets:if target.hasAttribute("id"):tar_id=int(target.getAttribute("id"))#print ("id: %s" % tar_id)box = target.getElementsByTagName('box')[0]if box.hasAttribute("left"):left=box.getAttribute("left")#print ("  left: %s" % left)if box.hasAttribute("top"):top=box.getAttribute("top")#print ("  top: %s" %top )if box.hasAttribute("width"):width=box.getAttribute("width")#print ("  width: %s" % width)if box.hasAttribute("height"):height=box.getAttribute("height")#print ("  height: %s" %height )#中心坐标x=float(left)+float(width)/2y=float(top)+float(height)/2#宽高中心坐标归一化# x/=img_w# y/=img_h# width=float(width)/img_w# height=float(height)/img_hattribute = target.getElementsByTagName('attribute')[0]if attribute.hasAttribute("vehicle_type"):type=attribute.getAttribute("vehicle_type")if type=="car":type=0if type=="van":type=1if type=="bus":type=2if type=="others":type=3#anno_f.write(str(type)+" "+tar_id+" %.3f"%x+" %.3f"%y+" %.3f"%width+" %.3f"%height+"\n")seq_gt.append([frame_num,tar_id,x,y,float(width),float(height),type])         return seq_gttid_curr = 0
tid_last = -1  #用于在下一个视频序列时,ID数接着上一个视频序列最大值
for seq in seqs: #每一个视频序列print(seq)seq_width = 960seq_height = 540gt_xml = osp.join(xml_root, seq+'.xml')reader = XmlTester()gt=reader.load(gt_xml)#统计这个序列所有IDids=[]for line in gt:if not line[1] in ids:ids.append(line[1])print (ids)#根据ID将同一ID的不同帧标注放在一起final_gt=[]for id in ids:for line in gt:if line[1]==id:final_gt.append(line)print(len(final_gt))seq_label_root = osp.join(label_root, seq)if not os.path.exists(seq_label_root):mkdirs(seq_label_root)for fid, tid, x, y, w, h, label in final_gt:label=int(label)print(" ",fid,label)fid = int(fid)tid = int(tid)if not tid == tid_last:tid_curr += 1tid_last = tidlabel_fpath = osp.join(seq_label_root, 'img{:05d}.txt'.format(fid))label_str = '{:d} {:d} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format(int(label),tid_curr, float(x) / seq_width, float(y) / seq_height, float(w) / seq_width, float(h) / seq_height) #宽高中心坐标归一化with open(label_fpath, 'a') as f:f.write(label_str)

生成的标注:

这里需要注意一点,就是生成的标签中,相邻两个视频序列的目标ID是连续的,而不是每个序列的目标ID全部从1开始。可以这么理解,假如上一个视频中共60个目标,那么在下一个视频开始,新目标就应该是61,62....以此类推,这样在后续训练时加载数据集统计ID时才不会错,我之前就是脚本写的有问题,60个视频序列的车辆ID才343个,这显然是不对的。后面对比了MOT的数据,发现MOT的原始数据标注都是按同一ID的标注放在一起,比如目标1出现了10帧,那么前10行就是目标1的标注,目标2出现了15帧,那么接下来的15行就是目标2的标注。因此我在转换工具中加入了一段代码,用来处理这个“放在一起”的过程:

生成训练所需的xxxx.train文件,脚本如下:

import os
root_path="F:/dataset/MOT/UA-DETRAC"
label_flder="DETRAC-Train-Annotations-track"
img_folder="DETRAC-train-data/Insight-MVT_Annotation_Train"
seqs=os.listdir(root_path+"/"+label_flder)
train_f=open("UA-DETRAC.train","w")
count=0
for seq in seqs:print("seq:",seq)labels=os.listdir(root_path+"/"+label_flder+"/"+seq)for label in labels:img_name=label[:-4]+".jpg"save_str=root_path+"/"+img_folder+"/"+seq+'/'+img_name+"\n"print("img:",save_str)count+=1print(count)train_f.write(save_str)
train_f.close()

到此数据处理结束。

二、训练

1、训练相关代码修改

(1)因为前面数据集做了修改,所以要对应的修改dataset.py文件

因为图片和标签的文件夹层次结构不同,所以这里替换图片的路径中的部分来得到标签路径。

(2)修改网络定义配置cfg。JDE中使用的是YOLO v3,其中3个yolo层的anchor,尺寸都是针对行人比例大小特殊设置的,因为UA-DETRAC所有标注数据都是车辆,且车辆大多数都是近似1:1的框(没有像行人那么大的宽高比),因此我直接将三层yolo层的anchor都按照原始416x416大小的yolov3的cfg设置来修改,此外需要注意的是,类别个数,JDE中全部是行人,所以类别数为1,检测和分类分支的卷积通道数为24=4*(1+5),4表示每一个yolo层的anchor数,1表示类别数,5表示conf,x,y,w,h。现在UA-DETRAC数据集中车辆类别有4个:['car', 'van', 'bus','others'],每一个yolo层的anchor也改成了3,所以检测和分类分支的卷积通道数为27=3*(4+5)。

(3)需要在数据配置文件中,将训练数据修改成刚生成的xxxx.train文件:

2、训练

设置训练参数:

vscode中,ctrl+F5开始训练,或者命令行中python train.py开始训练(统计出总共5920个目标,训练集+测试集共8250个)。

训练中各项loss收敛正常(图为训练到第7个epoch)

不过训练中total loss出现负值,不知道为何,total loss会是负数, 这个问题还没弄清楚,有大佬若知道请不吝赐教。

#----------------------------------------------------------------------------------------------------------------------------------------------------

2020/0609更新

之在UA-DETRAC数据集上训练,使用了4个类别:['car', 'van', 'bus','others'],但是JDE默认是只有一类,也就是一个类别的多目标跟踪,例如行人多目标跟踪,车辆多目标跟踪。因此我把类别全部改成一类:car,对应的cfg文件就得修改:

这里18=3*(1+5),1表示只有1类。

这次使用darknet53预训练模型fineturn训练,修改参数中的weights-from参数,修改成darknet53.conv.74文件所在目录。

此外设置初始学习率:0.01,分辨率为[416,416]

 命令行中输入:python train.py ,开始训练,大概训练到26个epoch时的loss如下:

2020-06-04 14:19:15 [INFO]:    Epoch       Batch       box      conf        id     total  nTargets      time    cur_lr
2020-06-04 15:31:00 [INFO]:    26/29  6080/13459   0.00161  0.000657       5.9     -21.8      43.7     0.446    0.0001
2020-06-04 15:31:27 [INFO]:    26/29  6120/13459   0.00161  0.000657       5.9     -21.8      43.7     0.434    0.0001
2020-06-04 15:31:55 [INFO]:    26/29  6160/13459   0.00161  0.000657       5.9     -21.8      43.7     0.465    0.0001
2020-06-04 15:32:23 [INFO]:    26/29  6200/13459   0.00161  0.000657       5.9     -21.8      43.7     0.432    0.0001
2020-06-04 15:32:50 [INFO]:    26/29  6240/13459   0.00161  0.000657      5.89     -21.8      43.7     0.455    0.0001
2020-06-04 15:33:18 [INFO]:    26/29  6280/13459   0.00161  0.000656       5.9     -21.8      43.7     0.449    0.0001
2020-06-04 15:33:46 [INFO]:    26/29  6320/13459   0.00161  0.000656      5.89     -21.9      43.7     0.443    0.0001
2020-06-04 15:34:14 [INFO]:    26/29  6360/13459   0.00161  0.000656      5.89     -21.9      43.7     0.469    0.0001
2020-06-04 15:34:42 [INFO]:    26/29  6400/13459   0.00161  0.000655      5.89     -21.9      43.7     0.444    0.0001
2020-06-04 15:35:10 [INFO]:    26/29  6440/13459   0.00161  0.000655      5.89     -21.9      43.7      0.44    0.0001
2020-06-04 15:35:38 [INFO]:    26/29  6480/13459   0.00161  0.000654      5.89     -21.9      43.6     0.433    0.0001
2020-06-04 15:36:07 [INFO]:    26/29  6520/13459   0.00161  0.000654      5.89     -21.9      43.7      0.44    0.0001
2020-06-04 15:36:35 [INFO]:    26/29  6560/13459   0.00161  0.000655      5.89     -21.9      43.7      0.47    0.0001
2020-06-04 15:37:03 [INFO]:    26/29  6600/13459   0.00161  0.000655      5.89     -21.9      43.7      0.46    0.0001
2020-06-04 15:37:31 [INFO]:    26/29  6640/13459   0.00161  0.000655      5.89     -21.9      43.7     0.439    0.0001
2020-06-04 15:37:59 [INFO]:    26/29  6680/13459   0.00161  0.000659       5.9     -21.9      43.7     0.464    0.0001
2020-06-04 15:38:27 [INFO]:    26/29  6720/13459   0.00162  0.000669       5.9     -21.7      43.7     0.449    0.0001
2020-06-04 15:38:55 [INFO]:    26/29  6760/13459   0.00164  0.000684      5.92     -21.5      43.7     0.442    0.0001
2020-06-04 15:39:24 [INFO]:    26/29  6800/13459   0.00165  0.000691      5.93     -21.5      43.7     0.437    0.0001
2020-06-04 15:39:54 [INFO]:    26/29  6840/13459   0.00166  0.000705      5.94     -21.4      43.7     0.453    0.0001
2020-06-04 15:40:27 [INFO]:    26/29  6880/13459   0.00167  0.000715      5.95     -21.3      43.7     0.808    0.0001
2020-06-04 15:41:34 [INFO]:    26/29  6920/13459   0.00168  0.000719      5.96     -21.3      43.7     0.967    0.0001
2020-06-04 15:42:55 [INFO]:    26/29  6960/13459   0.00168  0.000724      5.97     -21.3      43.7      1.16    0.0001
2020-06-04 15:43:33 [INFO]:    26/29  7000/13459   0.00169  0.000726      5.97     -21.3      43.6     0.431    0.0001

我是训练到26个eopch结束,跑一下demo,修改下cfg文件和训练好的权重目录,以及测试图片所在文件夹的目录,如下所示:

原始的JDE只支持mp4格式的视频demo,参数是--input-vedio,我这里主要大多是h264的视频,为了测试还得去转成MP4格式,为了方便,我修改了这个参数为:--input-vedio-images,可以测视频,也可以测图片,修改下如下代码:

在detaset.py中,复制一份class LoadVideo类,改名为LoadImages,然后增加一个成员变量self.frame_rate=30,这个因为后面统一读取,默认是按视频格式,所以有帧率,这里也加上帧率这个参数,防止报错。

开始测试:输入 python demo.py 在results/frame文件夹下生成了每一帧的跟踪结果

然后在整个测试图片文件夹测试完后会将跟踪结果拼成一个mp4视频

视频截图如下:

三、采坑记录

在跑demo时遇到一个问题,就是有些尺寸比例的车辆显示没有跟踪到,如下图所示:

但是正面正对摄像头的车辆却效果很好,如下图所示:

所以我一度以为是anchor问题,自己也在ua-detrac数据集的训练集上聚类出了一组anchor专门训练。但是后来发现不管怎么训练还是跟踪不到(没有跟踪框),一直觉得检测没训练好,各种检查训练数据,换anchor,调学习率,换网络,但还是同样的问题。

所以我觉得,先不管跟踪,先看看检测效果怎么样,在检测结束,显示一下检测结果,代码修改如下:multitracker.py中的 def update(self, im_blob, img0):函数,增加显示代码:

结果显示出来是检测到了,有些没显示是因为置信度低于阈值,说明检测没问题。接下来就检查跟踪模块,发现跟踪过程也是完全正常,每次都有7、8个目标进行匹配,而且跟踪reid分支提取的特征,构成的距离矩阵也是正常的,相同车辆的距离最小。匹配结束也是有好多个目标被确认跟踪。所以跟踪模块也没问题。最后检查输出模块,问题就出在了这。。。

这里对输出的跟踪框做了过滤, 由于JDE原始是做行人跟踪,所以过滤掉了宽高比大于1.6的跟踪框,所以导致很多符合这种比例的车辆全部被过滤,显示不出来。好了到此问题查清楚了,注释掉过滤语句,重新跑demo,天下太平,一切正常了。

#---------------------------------------------------------------------------------------------------------------------

2020/06/10更新

JDE中训练时的数据增强:

1、原始图片

2、csv增强50%,添加忽略区域(黑色部分,只针对UA-DETRAC数据集):

3、Letterbox:resize+pad,就是将长边缩放到416,然后短边填充(127.5,127.5,127.5),这个值是0-255之间的中值。

4、仿射变换( random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.50, 1.20))):

旋转,平移

5、水平翻转(概率0.5):

6、随机裁剪(416x416,自己增加)

多目标跟踪算法JDE在 UA-DETRAC数据集上训练相关推荐

  1. 模拟数据集上训练神经网络,网络解决二分类问题练习

    #2018-06-24 395218 June Sunday the 25 week, the 175 day SZ ''' 模拟数据集上训练神经网络,网络解决二分类问题.'''import tens ...

  2. pascal行人voc_在一个很小的Pascal VOC数据集上训练一个实例分割模型

    只使用1349张图像训练Mask-RCNN,有代码. 代码:https://github.com/kayoyin/tiny-inst-segmentation 介绍 计算机视觉的进步带来了许多有前途的 ...

  3. DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本

    DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本 目录 输出结果 设计思路 实现部分代码 说明:所有图片文件丢失 输出结果 更新-- 设计思路 更新-- 实现部分代码 更 ...

  4. internetreadfile读取数据长度为0_YOLOV3的TensorFlow2.0实现,支持在自己的数据集上训练...

    GitHub链接: calmisential/YOLOv3_TensorFlow2​github.com 我主要参考了yolov3的一个keras实现版本: qqwweee/keras-yolo3​g ...

  5. domain gap(域间隙)是什么?==>在一个数据集上训练好的模型无法应用在另一个数据集上

    不同数据集之间存在domain gap,在一个数据集上训练模型,在另外一个数据集上进行预测性能下降很大 re-id(视频行人重识别问题) 现有公开的数据集与真实场景存在很大不同,不同re-id的数据集 ...

  6. pascal行人voc_在Pascal VOC 数据集上训练YOLOv3模型

    上节介绍了<从零开始在Windows10中编译安装YOLOv3>,本节介绍在Pascal VOC 数据集上训练YOLOv3. 第一步,下载并安装YOLOv3训练依赖项. a.下载Pasca ...

  7. 在MNIST数据集上训练一个手写数字识别模型

    使用Pytorch在MNIST数据集上训练一个手写数字识别模型, 代码和参数文件 可下载 1.1 数据下载 import torchvision as tvtraining_sets = tv.dat ...

  8. 在自己的数据集上训练CrowdDet过程记录

    论文链接:https://readpaper.com/pdf-annotate/note?noteId=656650387498369024&pdfId=542662939605901312 ...

  9. 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练

    文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...

最新文章

  1. Clojure世界:单元测试
  2. 如何在本地站点打开html,如何在本地运行的网页上创建指向本地文件的链接?...
  3. Convert .Net Program To Mono
  4. Python3.0 我的DailyReport 脚本(一) 使用COM操作Excel
  5. c语言 遍历.jpg图像,求指导,如何用c语言实现读取*.raw格式图像
  6. 360安全浏览器兼容模式怎么设置_360浏览器及安全卫士怎么减少广告弹出?
  7. C#工程添加了DLL编译运行时却提示”无法加载DLL“的解决方案
  8. nuxt SSR部署到iis7方案
  9. 数据分析和挖掘常用方法
  10. 经典论文-SqueezeNet论文及实践
  11. 谈谈架构师是何种生物
  12. 【脚本】Python+adb王者荣耀闯关自动刷金币
  13. 计算机音乐卡内基大学,卡耐基梅隆大学音乐暑期课程 年轻音乐家的成功之路...
  14. 【Day1.4】奢华的海滨酒店,打发半天时间不成问题
  15. 思科 Spanning Tree Protocol(STP)生成树
  16. 乐优商城架构介绍(一)
  17. 计步器(Pedometer)实现原理简介
  18. easyexcel实现导出
  19. jQuery选择器代码详解(一)——Sizzle方法
  20. 《趣学Python编程》——第1部分 学习编程 第1章 Python不是大蟒蛇 1.1 关于计算机语言...

热门文章

  1. MySQL 的read_only 只读属性说明
  2. 根据排队论阐述路由器和高速公路的拥堵以及拥堵缓解问题
  3. 【PowerApps 基础函数介绍】
  4. 模块结构篇:7.1)动力型塑料齿轮轮系设计步骤详解
  5. BZOJ3398 [Usaco2009 Feb]Bullcow 牡牛和牝牛
  6. DoubleListView效果
  7. debain10更换源和配置
  8. 一起安装多个depot文件
  9. 如何实现bilibili最新头部景深效果~炫酷
  10. Spring框架的自动装配