文章目录

  • 1、数据集格式及存放
  • 2、修改两处
  • 3、用训练命令生成配置文件
  • 4、正式训练开始
  • 5、报错记录
  • 6、模型评价测试(VOC指标mAP、COCO指标AP)
  • 7、绘制每个类别bbox 的结果曲线图并保存
  • 8、统计模型参数量和FLOPs
  • 9 计算混淆矩阵
  • 10 画PR曲线
  • 11 查看完整config配置文件
  • 12 核查数据增强的结果是否正确
  • 8、参考链接

1、数据集格式及存放

mmdet支持COCO格式和VOC格式,能用COCO格式,还是建议COCO的。网上有YOLO转COCO,VOC转COCO,可以自己转换。

在mmdetection代码的根目录下,创建 data/coco 文件夹,按照coco的格式排放好数据集。annotations下面是标签文件,train2017val2017test2017是图片。

2、修改两处

第一处: mmdet/core/evalution/class_names.py 代码下的 def coco_classes() 的 return 内容改为自己数据集的类别;

第二处:mmdet/datasets/coco.py 代码下的 class CocoDataset(CustomDataset) 的 CLASSES 改为自己数据集的类别;

注意:修改两处后,一定要在根目录下,输入命令:
python setup.py install build
重新编译代码,要不然类别会没有载入,还是原coco类别,训练异常。

3、用训练命令生成配置文件

python tools/train.py configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py --work-dir work_dirs

其中,work_dirs是自己在根目录新建的工作目录,训练文件存储在这里。

注意,此时运行命令之后,并不是直接训练就可以不管了!我们还有参数设置没改!这里输入训练命令,只是需要它生成一个配置文件,便于我们改参数!

打开配置文件 cascade_rcnn_r50_fpn_1x_coco.py :
(1)修改 num_classes ,将其改为自己数据类别(直接全局搜索,有3处,都要改);

(2)修改 data_root 路径和训练集、验证集、测试集的图片和标签路径,如下图:


(3)修改训练图片大小和学习率

修改下处代码,可以更改图片大小

img_scale = (1333, 800),

batch_size, mmdet默认的方式是由 GPU 数量与 samples_per_gpu 参数决定:
samples_per_gpu: 每个gpu读取的图像数量(意思不就是batch_size=2),该参数和训练时的gpu数量决定了训练时的batch_size。(为什么这么说呢?因为mmdet是8个GPU训练的,那么总的batch就是 8 *samples_per_gpu=16,即训练时是batch_size为16) 。
但我们通常是只有一个gpu, 该参数设置为 2, 意思就是我们训练的 batch_size为2;

workers_per_gpu: 读取数据时每个gpu分配的线程数 ,一般设置为 2即可;(我感觉既然用单个GPU,设置到8也无妨吧?我还没试)

学习率设置:
mmdet 默认的学习率是基于8个gpu,而且默认是1个GPU处理2个图像(就上面说的samples_per_gpu为2),可以这样理解:
8个GPU,每个GPU处理2张图片,那么真实训练总的一个batch就包括16张图片,学习率为0.02;
4个GPU,每个GPU处理2张图片,那么真实训练总的一个batch就包括8张图片,学习率为0.01;
1个GPU,每个GPU处理2张图片,那么真实训练总的一个batch就包括2张图片,学习率为0.0025;
1个GPU,每个GPU处理1张图片,那么真实训练总的一个batch就包括1张图片,学习率为0.00125;

(4)使用预训练模型
提前从github上下载预训练模型,新建一个checkpoints文件夹下,放到里面。(模型下载链接:https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md)
然后修改以下代码:

# 原本是 load_from = None ,修改为
load_from = 'checkpoints/fcascade_rcnn_r50_fpn_1x_coco_20200316-3dc56deb.pth’

(5)训练轮数,保存模型间隔,日志保存参数

4、正式训练开始

!!!看清楚路径!使用的是更改过的配置文件训练!!!

python tools/train.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py

5、报错记录

在第三步生成配置文件时,遇到以下报错:

AssertionError: The num_classes (10) in Shared2FCBBoxHead of
MMDataParallel does not matches the length of CLASSES 80) in
CocoDataset

即使在修改 coco.py 和 class_names.py 后运行 python setup.py install仍然无法解决;

解决方法:
根据报错信息,找到自己虚拟环境的/mmdet/datasets/coco.pymmdet/core/evaluation/class_names.py,再次修改
CocoDataset()coco_classes()l两处(跟第二步一样,其实打开,就能看到虚拟环境下的并没有修改成功)

参考链接:AssertionError: The num_classes (3) in Shared2FCBBoxHead of
MMDataParallel does not
matches

6、模型评价测试(VOC指标mAP、COCO指标AP)

(1)生成中间件

python tools/test.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_20.pth  --out results.pkl
  • work_dirs/cascade_rcnn_r50_fpn_1x_coco.py 模型配置文件(跟训练时的一样)
  • work_dirs/epoch_20.pth: 训练好的模型(我是训练了20epoch)
  • --out 指定 results.pkl 输出目录,可以自己指定输出目录

(2)使用COCO标准评估指标

python tools/analysis_tools/eval_metric.py  work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl  --eval=bbox
  • --eval,COCO数据集可选参数有:bbox 、segm、proposal ;对VOC数据集可选参数有:mAP

(3)使用VOC标准评估指标

# results.pkl 的顺序别放错,在中间。
python tools/voc_eval.py results.pkl work_dirs/cascade_rcnn_r50_fpn_1x_coco.py
  • voc_eval.py 文件 mmdetection 2.X 版本删除了,可以去老版本1.X 找找

7、绘制每个类别bbox 的结果曲线图并保存

(1)使用 test.py 生成 results.bbox.json 文件(在根目录下,路径可自己指定)

python tools/test.py  work_dirs/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_20.pth  --format-only  --options "jsonfile_prefix=./results"

(2)获得COCO bbox错误结果每个类别,保存分析结果图像到目录results/

python tools/analysis_tools/coco_error_analysis.py results.bbox.json results  --ann=data/coco/annotations/instances_val2017.json
  • results.bbox.json:上一步生成的文件
  • results: 结果曲线图的生成目录, 此处将生成到results/ 目录下
  • –ann=data/coco/annotations/instances_val2017.json: 数据集标注文件存放路径

8、统计模型参数量和FLOPs

python tools/analysis_tools/get_flops.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py --shape 640 640
  • --shape 参数指定输入图片尺寸

9 计算混淆矩阵

python tools/analysis_tools/confusion_matrix.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl coco_confusion_matrix/
  • 需要三个参数,配置文件、pkl文件、输出目录

10 画PR曲线

plot_pr_curve.py 代码来自:https://blog.csdn.net/weixin_44966641/article/details/124558532

import os
import sys
import mmcv
import numpy as np
import argparse
import matplotlib.pyplot as pltfrom pycocotools.coco import COCO
from pycocotools.cocoeval import COCOevalfrom mmcv import Config
from mmdet.datasets import build_datasetdef plot_pr_curve(config_file, result_file, out_pic, metric="bbox"):"""plot precison-recall curve based on testing results of pkl file.Args:config_file (list[list | tuple]): config file path.result_file (str): pkl file of testing results path.metric (str): Metrics to be evaluated. Options are'bbox', 'segm'."""cfg = Config.fromfile(config_file)# turn on test mode of datasetif isinstance(cfg.data.test, dict):cfg.data.test.test_mode = Trueelif isinstance(cfg.data.test, list):for ds_cfg in cfg.data.test:ds_cfg.test_mode = True# build datasetdataset = build_dataset(cfg.data.test)# load result file in pkl formatpkl_results = mmcv.load(result_file)# convert pkl file (list[list | tuple | ndarray]) to jsonjson_results, _ = dataset.format_results(pkl_results)# initialize COCO instancecoco = COCO(annotation_file=cfg.data.test.ann_file)coco_gt = cocococo_dt = coco_gt.loadRes(json_results[metric]) # initialize COCOeval instancecoco_eval = COCOeval(coco_gt, coco_dt, metric)coco_eval.evaluate()coco_eval.accumulate()coco_eval.summarize()# extract eval dataprecisions = coco_eval.eval["precision"]'''precisions[T, R, K, A, M]T: iou thresholds [0.5 : 0.05 : 0.95], idx from 0 to 9R: recall thresholds [0 : 0.01 : 1], idx from 0 to 100K: category, idx from 0 to ...A: area range, (all, small, medium, large), idx from 0 to 3M: max dets, (1, 10, 100), idx from 0 to 2'''pr_array1 = precisions[0, :, 0, 0, 2] pr_array2 = precisions[1, :, 0, 0, 2] pr_array3 = precisions[2, :, 0, 0, 2] pr_array4 = precisions[3, :, 0, 0, 2] pr_array5 = precisions[4, :, 0, 0, 2] pr_array6 = precisions[5, :, 0, 0, 2] pr_array7 = precisions[6, :, 0, 0, 2] pr_array8 = precisions[7, :, 0, 0, 2] pr_array9 = precisions[8, :, 0, 0, 2] pr_array10 = precisions[9, :, 0, 0, 2] x = np.arange(0.0, 1.01, 0.01)# plot PR curveplt.plot(x, pr_array1, label="iou=0.5")plt.plot(x, pr_array2, label="iou=0.55")plt.plot(x, pr_array3, label="iou=0.6")plt.plot(x, pr_array4, label="iou=0.65")plt.plot(x, pr_array5, label="iou=0.7")plt.plot(x, pr_array6, label="iou=0.75")plt.plot(x, pr_array7, label="iou=0.8")plt.plot(x, pr_array8, label="iou=0.85")plt.plot(x, pr_array9, label="iou=0.9")plt.plot(x, pr_array10, label="iou=0.95")plt.xlabel("recall")plt.ylabel("precison")plt.xlim(0, 1.0)plt.ylim(0, 1.01)plt.grid(True)plt.legend(loc="lower left")plt.savefig(out_pic)if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument('config', help='config file path')parser.add_argument('pkl_result_file', help='pkl result file path')parser.add_argument('--out', default='pr_curve.png')parser.add_argument('--eval', default='bbox')cfg = parser.parse_args()plot_pr_curve(config_file=cfg.config, result_file=cfg.pkl_result_file, out_pic=cfg.out, metric=cfg.eval)

输入命令:

python plot_pr_curve.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl

11 查看完整config配置文件

python tools/misc/print_config.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py

12 核查数据增强的结果是否正确

python tools/misc/browse_dataset.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py  --output-dir work_dirs/

8、参考链接

https://blog.csdn.net/qq_35077107/article/details/124768460?spm=1001.2014.3001.5502

https://blog.csdn.net/weixin_44966641/article/details/124558532

【MMDetection】——训练个人数据集相关推荐

  1. mmdetection训练voc数据集

  2. 一步一步教你在 docker 容器下使用 mmdetection 训练自己的数据集

    这里不再介绍 mmdetection 的安装和配置,使用 mmdetection 较简单的方法是使用已安装 mmdetection 的 docker 容器.这样直接省去了安装 mmdetection ...

  3. mmdetection训练自己的COCO数据集及常见问题

    训练自己的VOC数据集及常见问题见下文: mmdetection训练自己的VOC数据集及常见问题_不瘦8斤的妥球球饼的博客-CSDN博客_mmdetection训练voc 目录 一.环境安装 二.训练 ...

  4. mmdetection训练自己的VOC数据集及常见问题

    训练自己的COCO数据集及常见问题见下文: mmdetection训练自己的COCO数据集及常见问题_不瘦8斤的妥球球饼的博客-CSDN博客 目录 一.环境安装 二.训练测试步骤 三.常见问题 bat ...

  5. mmdetection实战,训练扑克牌数据集(VOC格式)并测试计算mAP

    mmdetection实战,训练扑克牌数据集(VOC格式)并测试计算mAP 一.数据集准备 二.mmdetection的安装 三.修改相关文件 1. 修改class_names.py文件 2. 修改v ...

  6. 建立自己的voc数据集_将自己数据集转化成voc数据集格式并用mmdetection训练

    一.准备自己的数据 拿nwpu数据集来举例,nwpu数据集文件夹中的内容是: images文件夹:存放数据图片 labelTxt文件夹:存放标注信息,images文件夹中每张图片都对应一个txt文件存 ...

  7. mmdetection训练出现:IndexError: list index out of range 错误

    mmdetection训练出现:IndexError: list index out of range 错误 文章目录: 1 问题分析 1.1 尝试解决错误:第一次 1.2 尝试解决错误:第二次 2 ...

  8. mmdetection训练、测试

    文章目录 前言 一.使用mmdetection训练测试Mask-Rcnn 1.数据集转化 2.准备配置文件 3.训练 4.测试 二.mmdetection补充知识 前言 用于记录mmdetection ...

  9. mmdetection训练自己的数据并评估mAP

    用mmdetection做目标检测的训练还是比较简单的,但是目前代码尚不稳定,其中也有很多的坑,下面简单讲解一下如何用mmdetetection在VOC的数据集上进行模型的训练,算是对mmdetect ...

  10. MMDetection实战:MMDetection训练与测试

    文章目录 摘要 配置文件参数详解 环境准备 训练 制作数据集 修改配置文件 修改数据集的类别 开始训练 测试 完整代码和数据集: 摘要 MMDetection是商汤和港中文大学针对目标检测任务推出的一 ...

最新文章

  1. Albert: A lite bert for self-supervised learning of language representations (Albert)
  2. linux——虚拟机的图形安装、管理以及快照
  3. vim--之初学轻松几步走
  4. Java如何跨语言调用Python/R训练的模型
  5. 竞价账户烧钱的七大病因和处理办法
  6. 【Bash】实现指定目录下的文件编码转换,以原文件名保存
  7. vim显示python嵌套级_在Vim中为Python突出显示语法
  8. JavaScript学习笔记:创建自定义对象
  9. dao层如何调用对象_以k8s集群管理为例,大牛教你如何设计优秀项目架构
  10. 【Jenkins】Jenkins : jenkins-2.121.1 安装 与 使用
  11. celery-分布式任务队列-原理
  12. android抽屉风格,Android开发实战之拥有Material Design风格的抽屉式布局
  13. openstack ha 部署
  14. vue之神奇的动态按钮
  15. docker-compose 学习:通过 image 指令指定镜像搭建一个简单LNMP
  16. oracle判断字符串以什么开头_sql语句判断字符串以什么什么开头
  17. C语言字符与数字的互转
  18. 立创EDA超详细的PCB设计流程
  19. 超市管理系统数据库设计
  20. 把小写金额转成大写金额 (Java经典编程案例)

热门文章

  1. 写作的“收益”超乎想象
  2. java表格界面_Java自学-图形界面 表格
  3. 爬取开眼app小视频
  4. 和小黄鸭交谈:全球公认的调试代码好方法!
  5. Mastering Sublime Text 下载
  6. 巧用模板方法模式,实现加载违禁词文件功能
  7. java从多少不免费了_Java要开始收费了,为什么使用了23年的Java不再免费?
  8. java线程cutdown_Java线程池实现原理与技术II
  9. 技巧1 以空格分隔 第一个最后一个数后无空格
  10. 【Python 数据科学】聚合apply和agg