机器视觉库之detectron2安装及使用详解

本文转自:https://blog.csdn.net/qq_18560985/article/details/124539628

文章目录

  • 前言
  • 一、Detectron2的安装
  • 二、简单的运行案例
    • 1.利用已有的模型进行各种测试
    • 2.训练自己的模型
  • 总结

前言

  detectron2是Facebook的一个机器视觉相关的库,建立在Detectron和maskrcnn-benchmark基础之上,可以进行目标检测、语义分割、全景分割,以及人体体姿骨干的识别。这个项目在GitHub上已经有超过了20k的星星。


一、Detectron2的安装

  在Detectron2的官网上已经给出了Linux平台的安装方法。在这里介绍另外一个常用的平台即windows上的安装方法。

  • 首先,需要保证你的电脑上有3.6版本以上的python环境以及c++相关的生成工具。
  • 其次,如果你的电脑正好配有Nivdia的显卡,且支持cuda。你可以在Nvidia的官网上下载最新的pytorch对应的cudatoolkit版本和cudatoolkit对应的cudnn,在下载cudnn的时候可能需要你注册成为Nvidia的会员(免费)并填写一个意向问卷。如果没有相应的显卡,可以直接下载cpu版的pytorch
  • 然后,你需要保证你的python环境有一下的这些包,pillow、Cython、opencv-python、numpy、matplotlib,并按照pytorch官网的提示安装pytorch
  • 最后你需要在Detectron2克隆下完整的代码,并解压,在Detectron2-main(从GitHub上克隆下来的项目)的目录下执行同样的命令即 pip install -e. 在执行命令的时候请务必保证你所在的目录有setup.py,如果你的电脑没有右键在终端打开,可以在cmd下找到项目的目录运行同样的命令。

二、简单的运行案例

1.利用已有的模型进行各种测试

  首先,不管是训练模型还是预测模型都需要有数据集,在这里作者要介绍的便是大名鼎鼎的coco数据集(Common Objects in Context) 它由世界各大商业巨头如谷歌、微软、脸书 赞助的。cocodataset包含常见类别80类,大约330K的图片(超200K以标注)。coco数据集每年都在更新,在这里我们以val2017数据集为例做预测。
小提示: 如果你感觉在coco数据集的官网下载速度较慢的话可以复制下载链接放到百度云的新建下载中进行中转加速下载。第一次运行以下代码建议在非自带ide中或者直接双击py文件运行。

#detectron2相关的库
import detectron2
import cv2
from detectron2 import model_zoo  #前人训好的模型
from detectron2.engine import DefaultPredictor #默认预测器
from detectron2.config import get_cfg  #配置函数
from detectron2.utils.visualizer import Visualizer  #可视化检测出来的框框函数
from detectron2.data import MetadataCatalog #detectron2对数据集预留的标签
from matplotlib import pyplot as plt  #画画函数
cfg=get_cfg()
#cfg.MODEL.DEVICE='cpu'  #如果你的电脑没有Nvidia的显卡或者你下载的是cpu版本的pytorch就将注释打开
im=cv2.imread("val2017\\000000009590.jpg")  #放入图片的地址
#plt.figure(figsize=(20,10))
#plt.imshow(im[:,:,::-1])
#plt.show()

进行目标检测

##物件辨识
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))  #设定参数档/迁移
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST=0.7   #阈值
cfg.MODEL.WEIGHTS=model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")   #真正参数档
predictor=DefaultPredictor(cfg)
outputs=predictor(im)
print(outputs['instances'].pred_classes)
print(outputs['instances'].pred_boxes)  #方框标记值
all_things=MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes   #中继资料/标记值
preds=[all_things[x] for x in outputs['instances'].pred_classes]
v=Visualizer(im[:,:,::-1],MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),scale=1.2)
v=v.draw_instance_predictions(outputs['instances'].to("cpu"))
plt.figure(figsize=(20,10))
plt.imshow(v.get_image())
plt.show()

效果图:

进行语义分割

##语义分割
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST=0.5
cfg.MODEL.WEIGHTS=model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor=DefaultPredictor(cfg)
outputs=predictor(im)
v=Visualizer(im[:,:,::-1],MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),scale=1.2)
v=v.draw_instance_predictions(outputs['instances'].to("cpu"))
plt.figure(figsize=(20,10))
plt.imshow(v.get_image())
plt.show()

效果图:

进行体姿骨干检测

##体姿骨干
cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST=0.7
cfg.MODEL.WEIGHTS=model_zoo.get_checkpoint_url("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
predictor=DefaultPredictor(cfg)
outputs=predictor(im)
v=Visualizer(im[:,:,::-1],MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),scale=1.2)
v=v.draw_instance_predictions(outputs['instances'].to("cpu"))
plt.figure(figsize=(20,10))
plt.imshow(v.get_image())
plt.show()

效果图:

进行全景分割

##全景分割
cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST=0.7
cfg.MODEL.WEIGHTS=model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
predictor=DefaultPredictor(cfg)
outputs=predictor(im)
v=Visualizer(im[:,:,::-1],MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),scale=1.2)
v=v.draw_instance_predictions(outputs['instances'].to("cpu"))
plt.figure(figsize=(20,10))
plt.imshow(v.get_image())
plt.show()

效果图:

2.训练自己的模型

  要想训练自己的模型,那就得有自己的数据集或者从某些地方下载下来的数据集,在这里我们使用coco数据集的气球数据集,免去了自己标记的繁琐行为,当然你也可以自己标记,在这里作者推荐使用labelme来实现图像的标记,在其项目地址有具体的安装方式,当然如果你的平台是windows,也可以直接下载可执行的exe文件 。
  在detectron2的官方文档中详细介绍了如何注册自己的数据集,包含文件路径、标记值、框的高度和宽度、框的样式、图片的id、类别的id、是否只包含单独一个类的标记等,另外的一个Metadata可注册也可不注册,如果不注册则标签视为默认的0,1,2,3,…,detectron2给出了标准的字典(dicts),你需要对相应的json文件进行一定的处理使之符合detectron2的标准。具体的标准形式长什么样子,可以尝试使用labelme标记一个图片并查看其生成的json文件。
  在这个案例中笔者用balloon数据集来尝试进行物体识别。数据集采用balloon数据集 。

加载必要的库:

import numpy as np
import json
import cv2
import os
import random
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog  #注册Metadata
from detectron2.data import DatasetCatalog   #注册资料集
from detectron2.engine import DefaultTrainer
from detectron2.structures import BoxMode  #标记方式
from matplotlib import pyplot as plt

将数据集中已经标好的json文件做成符合detectron2的字典格式:

def get_balloon_dicts(img_dir):json_file=os.path.join(img_dir,'via_region_data.json')with open(json_file) as f:imgs_anns=json.load(f)dataset_dicts=[]for idx,v in enumerate(imgs_anns.values()):record={}  #标准字典档filename=os.path.join(img_dir,v['filename'])height,width=cv2.imread(filename).shape[:2]  #获取尺寸record['file_name']=filename       record['image_id']=idxrecord['height']=heightrecord['width']=widthannos=v['regions']  #范围objs=[]for _,anno in annos.items():assert not anno['region_attributes']anno=anno['shape_attributes']px=anno['all_points_x']py=anno['all_points_y']poly=[(x+0.5,y+0.5) for x,y in zip(px,py)] #标记框框poly=[p for x in poly for p in x]obj={'bbox':[np.min(px),np.min(py),np.max(px),np.max(py)], #左上角坐标和右下角坐标'bbox_mode':BoxMode.XYXY_ABS,'segmentation':[poly],'category_id':0, #类别id'iscrowd':0    #标注对象是否发生堆叠}objs.append(obj)record['annotations']=objsdataset_dicts.append(record)return dataset_dicts

注册数据集:

for d in ['train','val']:  #注册数据集DatasetCatalog.register('balloon_'+d,lambda d=d: get_balloon_dicts('./balloon/'+d))MetadataCatalog.get('balloon_'+d).set(thing_classes=['balloon'])

配置并开始训练&测试:

cfg=get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")) #预设档,参数
cfg.DATASETS.TRAIN=('balloon_train',)  #训练集
cfg.DATASETS.TEST=('balloon_val',)  #测试集
cfg.DATALOADER.NUM_WORKERS=0   #执行序,0是cpu
cfg.SOLVER.IMS_PER_BATCH=1  #每批次改变的大小
cfg.SOLVER.BASE_LR=0.01  #学习率
cfg.SOLVER.MAX_ITER=100  #最大迭代次数
cfg.MODEL.WEIGHTS=model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")  #迁移基础
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE=128  #default:512 批次大小
cfg.MODEL.ROI_HEADS.NUM_CLASSES=1  #一类
cfg.MODEL.DEVICE='cpu'  #注释掉此项,系统默认使用NVidia的显卡
cfg.OUTPUT_DIR = 'D:/temp_model'
os.makedirs(cfg.OUTPUT_DIR,exist_ok=True)
trainer=DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
cfg.MODEL.WEIGHTS=os.path.join(cfg.OUTPUT_DIR,'model_final.pth')
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST=0.7
predictor=DefaultPredictor(cfg)
val_dicts=DatasetCatalog.get('balloon_val')
balloon_metadata=MetadataCatalog.get('balloon_val')
s1,s2=0,0
for d in val_dicts:im=cv2.imread(d['file_name'])outputs=predictor(im)s1+=len(outputs['instances'].get("pred_classes"))
with open('./balloon/val/via_region_data.json') as f:im_js=json.load(f)
for i in im_js.keys():s2+=len(im_js[i]['regions'])
print(s1/s2)

由于笔者的电脑显卡不够所以用cpu来跑:

  由于最大迭代次数只有100,所以在50个标注的气球框里只找出了21个气球,准确率高达42%…,不过这个准确率也可能大于1,很多时候同一个气球可能会被标上好几个框,特别是多个气球重合的情况下。

随机测试3张图看一下效果:

for d in random.sample(val_dicts,3):im=cv2.imread(d['file_name'])outputs=predictor(im)v=Visualizer(im[:,:,::-1],metadata=balloon_metadata,scale=0.8)v=v.draw_instance_predictions(outputs['instances'].to('cpu'))plt.figure(figsize=(20,10))plt.imshow(v.get_image()[::,::-1])plt.show()



总结

  detectron2是目前非常流行的一个目标检测的库,虽然当下有很多库都在宣称其效果已经超过了detectron,但是这不影响detectron的受欢迎程度。高强度的“炼丹”需要强大的“丹炉”,建议大家在训练数据集的时候最好使用配有独立显卡的计算机,如果条件不允许也可以使用Google的Colab或者百度的AI studio来完成你的训练和测试。

本篇参考:https://www.youtube.com/channel/UC_TZfmL1ob6M5yEL00NZu1A

机器视觉库之detectron2安装及使用详解相关推荐

  1. mysql 5.6 安装库_MySQL5.6安装步骤图文详解

    MySQL是一个开放源码的小型关系型数据库管理系统,目前MySQL被广泛地应用在Internet上的中小型网站中.由于其体积小.速度快.总体拥有成本低,尤其是开放源码这一特点,许多中小型网站为了降低网 ...

  2. mysql data文件夹恢复_【专注】Zabbix源码安装教程—步骤详解(2)安装并配置mysql...

    四.安装并配置mysql(1) 解压mysql-5.7.26.tar.gz与boost_1_59_0.tar.gz #tar -xvf mysql-5.7.26.tar.gz #tar -xvf bo ...

  3. php多线程安装pthreads步骤详解

    摘要: 本文讲的是php多线程安装pthreads步骤详解, PHP扩展下载:https://github.com/krakjoe/pthreads PHP手册文档:http://php.net/ma ...

  4. Nagios远程监控软件的安装与配置详解

    Nagios远程监控软件的安装与配置详解 作者:redhat_hu Nagios是一款功能强大的网络监视工具,它可以有效的监控windows.linux.unix主机状态以及路由器交换机的网络设置,打 ...

  5. Linux kail环境下安装pyrit 问题详解

    Linux kail2021环境下手动安装pyrit问题详解 周末花了两天搭建环境,利用pyrit做无线安全实验.在网上转了一圈,发现没有完整能解决手动安装pyrit问题的文章.所以自己安装完后做了一 ...

  6. linux安装Openssl步骤详解_问题:OpenSSL: error:100AE081:elliptic curve routines:EC_GROUP_new_by_curve_name:un

    linux上安装Openssl步骤详解 问题: OpenSSL: error:100AE081:elliptic curve routines:EC_GROUP_new_by_curve_name:u ...

  7. linux上安装Openssl步骤详解

    linux上安装Openssl步骤详解     1,查看原有系统是否已安装Openssl openssl version -a 系统已经安装了openssl,我们先安装新的版本,然后将环境配置成最新的 ...

  8. 一加一 安装 Kali NetHunter 详解

    一加一 安装 Kali NetHunter 详解# 2018年4月20日13:02:44 手机:Oneplus one 软件: Kali NetHunter 工具:一加工具包 解读:手机安装 kali ...

  9. Windows系统Git安装教程(详解Git安装过程)

    Windows系统Git安装教程(详解Git安装过程)   今天更换电脑系统,需要重新安装Git,正好做个记录,希望对第一次使用的博友能有所帮助! 获取Git安装程序   到Git官网下载,网站地址: ...

最新文章

  1. 内核方式挂载cephfs
  2. 告别手敲 SQL ?GPT-3 自动帮你写
  3. 利用css对input[type=file] 样式进行美化,input上传按钮美化
  4. R语言合并两个或多个有序数dataframe实战(dataframe的纵向合并):使用R原生方法、data.table、dplyr等方案
  5. 逸鹏说道:漫漫人生路
  6. VTK:图表之CreateTree
  7. python---Socket编程
  8. 12、OpenCV Python 图像梯度
  9. hadoop概念介绍
  10. [转载]无线通信系统中的调制解调基础(一):AM和FM
  11. 关于使用,NI采集卡+labview信号采集,问题交流【第二贴】
  12. Ubuntu 16.04安装sogou 拼音输入法
  13. C# 调用腾讯云的短信发送服务API
  14. 微信小程序显示空格符
  15. B. Partial Replacement
  16. 解决谷歌翻译不能使用的问题(2023.01.14)
  17. 判断IE浏览器的文档模式以及浏览器模式
  18. P1462 通往奥格瑞玛的道路
  19. 120W快充!Redmi Note 11系列1199元起
  20. 牛刀小试:利用Python分析豆瓣电影Top250(一)

热门文章

  1. CTF---basecrack---Base编码分析工具安装详解
  2. PAT(乙级)1090.危险品装箱(25)
  3. PTA寒假基础题训练(含解题思路)(中)
  4. VL综述:视觉-语言智能:任务、表征学习、大模型
  5. 职称计算机Word2003是考什么,职称计算机考试:word2003考点
  6. 多进程pool.join函数的含义
  7. 【FPGA实例】基于FPGA的DDS信号发生器设计
  8. IOS开发之蘑菇街框架
  9. Java语言brea使用方法
  10. C++系列8:常用库