思路和代码参考这位小哥的,他使用的是百度数据集,但已经找不到了,所以我使用了bosch的数据集。

https://zhuanlan.zhihu.com/p/89877517

数据集是bosch的,

http://link.zhihu.com/?target=https%3A//hci.iwr.uni-heidelberg.de/node/6132

数据集介绍

https://hci.iwr.uni-heidelberg.de/content/bosch-small-traffic-lights-dataset

Data description

This dataset contains 13427 camera images at a resolution of 1280x720 pixels and contains about 24000 annotated traffic lights. The annotations include bounding boxes of traffic lights as well as the current state (active light) of each traffic light.
The camera images are provided as raw 12bit HDR images taken with a red-clear-clear-blue filter and as reconstructed 8-bit RGB color images. The RGB images are provided for debugging and can also be used for training. However, the RGB conversion process has some drawbacks. Some of the converted images may contain artifacts and the color distribution may seem unusual.

Dataset specifications:

Training set:

  • 5093 images
  • Annotated about every 2 seconds
  • 10756 annotated traffic lights
  • Median traffic lights width: ~8.6 pixels
  • 15 different labels
  • 170 lights are partially occluded

Test set:

  • 8334 consecutive images
  • Annotated at about 15 fps
  • 13486 annotated traffic lights
  • Median traffic light width: 8.5 pixels
  • 4 labels (red, yellow, green, off)
  • 2088 lights are partially occluded

Bosch有自己的模型脚本,这是基于YOLO1模型实现的。

https://github.com/bosch-ros-pkg/bstld

Detectron2自定义数据训练模型的基本流程还可以参考

https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5#scrollTo=b2bjrfb2LDeo

对于detection2的自定义数据格式要求见官网

https://detectron2.readthedocs.io/tutorials/datasets.html

实现源码如下(在jupyter notebook下执行):

#1 cell1

import torch
import torchvision

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import numpy as np
import cv2
from matplotlib import pyplot as plt

# import some common detectron2 utilities
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

def cv_imshow(img):
    im = img[:,:,::-1]
    fig_h = 12
    plt.figure(figsize=(fig_h, int(1.0 * fig_h * im.shape[0] / im.shape[1])))
    plt.axis('off')
    plt.imshow(im, aspect='auto')

# https://github.com/facebookresearch/detectron2
DETECTRON2_REPO_PATH = './detectron2/'

#2 cell2

# register the traffic light dataset
import os
import numpy as np
import json
import yaml
from detectron2.structures import BoxMode
import itertools
#from tl_dataset import parse_label_file

#dataset_path = "/local/mnt/workspace/myname/dataset/bosch_traffic/rgb/"
dataset_path = "/local/mnt/workspace/myname/dataset/bosch-traffic/"

def get_tl_dicts(data_dir):
    dataset_dicts = []

yaml_path = ''
    '''  data_dir only for check  '''
    if('train' in data_dir):
         yaml_path = os.path.join(data_dir, "train.yaml")
         is_train = True
    elif('test' in data_dir):
         yaml_path = os.path.join(data_dir, "test.yaml")
         is_train = False
    else:
        print("***path error***")
        return;

if is_train:
        print("***path train***")
        yaml_path = os.path.join(dataset_path, "train.yaml")
    else:
        yaml_path = os.path.join(dataset_path, "test.yaml")

print("***??yaml????***")
    file = open(yaml_path, 'r', encoding="utf-8")
    file_data = file.read()
    file.close()

#print("file_data=", file_data)
    #print("file_data type=", type(file_data))

#print("***??yaml????????***")
    data = yaml.load(file_data)

for i in range(len(data)):
        image_path = os.path.abspath(os.path.join(dataset_path, data[i]['path']))
        
        print('image_path=',image_path)
        record = {}
        height, width = cv2.imread(image_path).shape[:2]
        record["file_name"] = image_path
        record["image_id"] = i
        record["height"] = height
        record["width"] = width
        print('width*height=',width,height)
        objs = []

for box in data[i]['boxes']:
            obj = {
                "bbox": [box['x_min'], box['y_min'], box['x_max'], box['y_max']],
                "bbox_mode": BoxMode.XYXY_ABS,
                "category_id": 0,
                "iscrowd": 0
            }
            print('x_min=',box['x_min'])
            '''
            if(box['label'] == 'RedLeft'):
                obj['category_id'] = 1
            if (box['label'] == 'RedRight'):
                obj['category_id'] = 2
            elif(box['label'] == 'Yellow'):
                obj['category_id'] = 10
            elif(box['label'] == 'Green'):
                obj['category_id'] = 20
            elif(box['label'] == 'GreenLeft'):
                obj['category_id'] = 21
            elif(box['label'] == 'GreenRight'):
                obj['category_id'] = 22
            else:
                obj['category_id'] = 30
            '''
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)

return dataset_dicts

#3 cell3

from detectron2.data import DatasetCatalog, MetadataCatalog
for d in ["train", "test"]:
    DatasetCatalog.register("/local/mnt/workspace/myname/dataset/bosch-traffic/rgb/" + d, lambda d=d: get_tl_dicts("/local/mnt/workspace/myname/dataset/bosch-traffic/rgb/" + d))
    MetadataCatalog.get(dataset_path + d).set(thing_classes=["traffic_light"])
tl_metadata = MetadataCatalog.get(dataset_path+'train')

#4 cell4

# show samples from dataset
import random
from google.colab.patches import cv2_imshow

dataset_dicts = get_tl_dicts(dataset_path+"train")
for d in random.sample(dataset_dicts, 3):
    print('file_name=', d["file_name"])
    #img_path = os.path.join(dataset_path, d["file_name"])
    img_path = d["file_name"]
    img = cv2.imread(img_path)
    visualizer = Visualizer(img[:, :, ::-1], metadata=tl_metadata, scale=0.5)
    vis = visualizer.draw_dataset_dict(d)
    cv2_imshow(vis.get_image()[:, :, ::-1])

#5 cell5

# Train
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg

cfg = get_cfg()
cfg.merge_from_file(DETECTRON2_REPO_PATH + "./configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
cfg.DATASETS.TRAIN = (dataset_path+'rgb/train',)
cfg.DATASETS.TEST = ()   # no metrics implemented for this dataset
cfg.DATALOADER.NUM_WORKERS = 2
# initialize from model zoo
cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl"
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.01
cfg.SOLVER.MAX_ITER = 300    # 300 iterations seems good enough, but you can certainly train longer
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (traffic light)

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
print('outdir=',cfg.OUTPUT_DIR)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

#6 cell6

# #
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
print('output=',cfg.OUTPUT_DIR)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set the testing threshold for this model
cfg.DATASETS.TEST = (dataset_path+'rgb/test', )
predictor = DefaultPredictor(cfg)

#7 cell7

from detectron2.utils.visualizer import ColorMode
from google.colab.patches import cv2_imshow

# testsets contains no label
# dataset_dicts = get_tl_dicts("apollo_tl_demo_data/testsets")
dataset_dicts = get_tl_dicts(dataset_path+'train')
for d in random.sample(dataset_dicts, 3):
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1],
                   metadata=tl_metadata,
                   scale=0.8,
    )
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    cv2_imshow(v.get_image()[:, :, ::-1])

几点需要注意的地方:

#1, bosch的标注数据格式,是以下面的数据形式组成的数组,每个图片有一个path和一个boxes,boxes包含若干字典,每个字典描述一个交通灯的box信息和相关label等

- boxes:

- {label: RedLeft, occluded: false, x_max: 613.625, x_min: 608.875, y_max: 364.75,

y_min: 354.0}

- {label: Red, occluded: false, x_max: 638.0, x_min: 633.125, y_max: 353.875, y_min: 343.375}

- {label: Red, occluded: false, x_max: 656.875, x_min: 652.875, y_max: 363.5, y_min: 355.375}

path: ./rgb/train/2017-02-03-11-44-56_los_altos_mountain_view_traffic_lights_bag/207458.png

#2 如碰到cv2没有安装的错误,可在terminal下安装pip install -U opencv-python

#3,如脚本使用cv2.imshow不能显示图片,可考虑用cv2_imshow来显示,需要导入

from google.colab.patches import cv2_imshow

Detectron2-基于bosch交通灯数据集训练交通灯检测模型相关推荐

  1. 【小白学习PyTorch教程】十、基于大型电影评论数据集训练第一个LSTM模型

    「@Author:Runsen」 本博客对原始IMDB数据集进行预处理,建立一个简单的深层神经网络模型,对给定数据进行情感分析. 数据集下载 here. 原始数据集,没有进行处理here. impor ...

  2. 基于Keras搭建cifar10数据集训练预测Pipeline

    基于Keras搭建cifar10数据集训练预测Pipeline 钢笔先生关注 0.5412019.01.17 22:52:05字数 227阅读 500 Pipeline 本次训练模型的数据直接使用Ke ...

  3. TF之pix2pix:基于TF利用Facades数据集训练pix2pix模型、测试并进行生成过程全记录

    TF之pix2pix:基于TF利用Facades数据集训练pix2pix模型.测试并进行生成过程全记录 目录 TB监控 1.SCALARS 2.IMAGES 3.GRAPHS 4.DISTRIBUTI ...

  4. 基于TensorFlow Object Detection API训练自己的目标识别模型

    基于TensorFlow Object Detection API训练自己的目标识别模型 环境 Windows10 CUDA_9 Cudnn_9.0 Anaconda3-5.2.0 Tensorflo ...

  5. Faster RCNN 训练自己的检测模型

    Faster RCNN 训练自己的检测模型 一.准备自己的训练数据 根据pascal VOC 2007的训练数据集基本架构,第一步,当然是要准备自己的训练图片集,本文直接将自己的准备的图片集(.jpg ...

  6. 民谣女神唱流行,基于AI人工智能so-vits库训练自己的音色模型(叶蓓/Python3.10)

    流行天后孙燕姿的音色固然是极好的,但是目前全网都是她的声音复刻,听多了难免会有些审美疲劳,在网络上检索了一圈,还没有发现民谣歌手的音色模型,人就是这样,得不到的永远在骚动,本次我们自己构建训练集,来打 ...

  7. 如何使用 PyTorch 训练自定义关键点检测模型

    默认情况下,PyTorch 提供了一个 Keypoint RCNN 模型,该模型经过预训练以检测人体的 17 个关键点(鼻子.眼睛.耳朵.肩膀.肘部.手腕.臀部.膝盖和脚踝). 这张图片上的关键点是由 ...

  8. 模型训练平台的构建_用5行代码构建自定义训练的对象检测模型

    模型训练平台的构建 如今,机器学习和计算机视觉已成为一种热潮. 我们都已经看到了有关自动驾驶汽车和面部识别的新闻,并且可能想象到建立我们自己的计算机视觉模型将会多么酷. 但是,进入该领域并不总是那么容 ...

  9. 用5行代码构建自定义训练的对象检测模型

    如今,机器学习和计算机视觉已成为一种热潮. 我们都已经看到了有关自动驾驶汽车和面部识别的新闻,并且可能想象到建立我们自己的计算机视觉模型将会多么酷. 但是,进入该领域并不总是那么容易,尤其是在没有扎实 ...

最新文章

  1. 判断均匀平面波的极化形式_化学选修3丨分子极性如何判断?四步就能搞定!...
  2. 时间序列相关算法与分析步骤
  3. 机器人学习--F1TENTH弗吉尼亚大学无人驾驶课程
  4. ubuntu10.04下audacious2.4源码编译过程(解决2.3cue的bug)
  5. lokijs可以用mysql_JavaScript实现的内存数据库LokiJS介绍和入门实例_javascript技巧
  6. NHibernate自定义集合类型(上):基本实现方式
  7. 数据库查询:列出各个部门中工资高于本部门平均工资的员工信息,并按部门号排序。
  8. 重读《从菜鸟到测试架构师》--黑色的盒子里有什么(中)
  9. python工厂模式 简书_[Python设计模式] 01 - 简单工厂模式
  10. python 血缘进程共享数据
  11. html 自动滚动到底部,Javascript实现DIV滚动自动滚动到底部的代码
  12. 误删D盘数据怎么办?推荐使用数据恢复软件EasyRecovery
  13. 信息系统分析与设计(第四版)期末复习提纲
  14. 时间序列分析 23 DTW (时序相似度度量算法) 上
  15. 25年后的晶体管会是什么样?
  16. 2D-X光图像重建3D-CT图像项目总结—后续补充
  17. NetLogo基础代码
  18. 架构师的工作都干些什么?!想做架构师必看
  19. 微信红包c语言程序,C语言 微信红包
  20. 超级IP哈利波特改编,网易这款刷爆朋友圈的手游究竟怎么样?

热门文章

  1. HSQLDb导出数据到Mysql
  2. 口语中使用频率相当高的俚语
  3. VSCode插件打包迁移与指定位置
  4. 【毕业生】资历与跳槽
  5. 高频的25个OSPF问答题,能全部答出来,我把路由器吃了~
  6. FCN网络训练训练——从零开始
  7. Seata之TCC模式
  8. vue实战-加入购物车一系列操作
  9. QT、C++电子相册
  10. URI、URL和URN的区别和联系