从源代码开始 Detectron2学习笔记
`从零开始 Detectron2学习笔记(一)
- 框架简介
- 1.Detection2的安装
- 2. 用预训练模型进行检测
- 2.1官方demo示例
- 2. 2源代码解读
- 2.2.1 模型的配置和构建
- 2.2.2 模型检测与结果可视化
- 3.小结
框架简介
1.Detection2的安装
Detectron2的安装过程如下,原链接及安装的常见问题见INSTALL.md
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
detectron2 相当于python的一个扩展库(类似numpy),安装后即可通过import … 来进行使用,十分方便。需要注意的是detectron2 依赖于pytorch与torchvision 因此也要安装好这两个库。
此外,detectron2 的配置文件为yaml格式,因此我们也要安装pyyaml来对yaml文件进行读取和操作。相关库的安装过程如下:
# install dependencies:
pip install pyyaml==5.1
conda install pytorch torchvisioncudatoolkit=10.1 -c pytorch
2. 用预训练模型进行检测
2.1官方demo示例
话不多说,先上代码(Detectron2 Beginner’s Tutorial)
在使用detectron2框架前,我们需要在项目中导入detectron2中的一些常用的组件,以及其他的常用函数库,如opencv等,如下图所示
# set up detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()# import some common libraries
import numpy as np
import os, json, cv2, random
from google.colab.patches import cv2_imshow# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
之后,我们从coco数据集中下载一张示例图片,做为我们第一次使用detectron2来进行检测的对象:
!wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg
im = cv2.imread("./input.jpg")
cv2_imshow(im)
下载的图片如下图所示:
做好准备工作之后,我们就开始进行detectron2的最简单应用----直接使用预训练的检测模型对上文图片进行检测。
首先,我们要构建出要使用的检测模型。在detectron2中,模型的结构和超参数均由.yaml格式的configs文件决定。下面给出构建一个检测模型的简单代码:
# get config files of Mask-RCNN
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
# 设定测试时ROI_HEADS的正负样本判定的得分阈值
cfg.MODEL.ROI_HEADS_SCORE_THRESH_TEST = 0.5
# 加载预训练模型
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
# 使用上述的配置文件构建默认的检测器(DefaultPredictor)
predictor = DefaultPredictor(cfg)
# 对之前下载的图片进行检测
output = predictor(im)
# 查看检测结果的格式
for k,v in outputv.items():print(k,type(v))
获得的输出结果如下:
# predictor(im)输出的结果为detectron2中的Instances类的实例
instances <class 'detectron2.structures.instances.Instances'>
最后,我们对检测结果进行可视化。在detectron2中有集成的可视化工具,可以很方便地展示模型的检测结果。
# 前文代码中有导入“from detectron2.utils.visualizer import Visualizer”
# 此处im[:,:,::-1] 是因为cv2是用BGR格式读入图片,我们在可视化时需将其转为BGR
v = Visualizer(im[:,:,::-1],MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),scale=1.2)
out = v.draw_instance_predictions(output['instances'].to('cpu'))
# 同理,使用cv2的 imshow函数,我们就要把RGB再变为BGR
cv2_imshow(out.get_image()[:,:,::-1])
最终得到的检测及分割结果如下:
2. 2源代码解读
虽然框架是现成的,但我们不能止步于拿着现成的东西去用,至少要知道它是如何运作的,这样在以后写自己的代码的时候就能事半功倍。为此,我们从github上Facebook的源代码来逐句分析前文代码的作用方式。
2.2.1 模型的配置和构建
首先来看前两句:
# from detectron2.config import get_cfg
cfg = get_cfg()
# model_zoo.get_config_file()返回配置文件的地址名
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
上述代码创建了“cfg”这一实例,并用"merge_from_file"函数来获取Mask-RCNN的模型配置及相关参数。为了深入了解原理,我们来看到与config相关的源代码detectron2/detectron2/config/config.py
首先在第84行找到我们的“get_cfg()”函数:
# line 84
def get_cfg() -> CfgNode:"""Get a copy of the default config.Returns:a detectron2 CfgNode instance."""from .defaults import _Creturn _C.clone()
如注释所示,此函数会返回一个detectron2中的 CfgNode 实例
CfgNode类的定义,在源文件的开头便可找到,在这里不做完整的复制粘贴,感兴趣的同学可以自行查看。目前我们只要知道,在detectron2中,CfgNode类起着获取,改变模型配置文件的作用,可以看作是构建模型的起点,也是使用detectron2的起点。
我们还需注意到一个细节,那就是detectron2中的CfgNode类继承自fvcore中的CfgNode:
from fvcore.common.config import CfgNode as _CfgNode
class CfgNode(_CfgNode):"""The same as `fvcore.common.config.CfgNode`, but different in:1. Use unsafe yaml loading by default.Note that this may lead to arbitrary code execution: you must notload a config file from untrusted sources before manually inspectingthe content of the file.2. Support config versioning.When attempting to merge an old config, it will convert the old config automatically."""
因此代码中的使用的部分函数要到fvcore库中去阅读,比如我们接下来要看到的"merge_from_file()"函数。代码有些长,在这里只展示比较关键的部分:
# "load_yaml_with_base()"与merge_from_other_cfg()"函数均来自fvcore
# 函数主要接收的参数为"cfg_filename",即.yaml文件
def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None:assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!"# 读取由参数指定的.yaml文件到loaded_cfg中loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)loaded_cfg = type(self)(loaded_cfg)...(此处均为检查configs的版本之间是否匹配,在这略去)if loaded_ver == self.VERSION:# 由于实际上CfgNode类最终可以溯源到python最基本的字典类(dict),因此CfgNode的实例本身可以看成一个字典,而"merge_from_other_cfg()"函数就是将原本的字典变为目标字典self.merge_from_other_cfg(loaded_cfg)
现在,我们不妨打印一下cfg:
print(type(cfg))
for k,v in cfg.items():print(k)
得到的输出如下:
<class 'detectron2.config.config.CfgNode'>
VERSION
MODEL
INPUT
DATASETS
DATALOADER
SOLVER
TEST
OUTPUT_DIR
SEED
CUDNN_BENCHMARK
VIS_PERIOD
GLOBAL
可以看出,CfgNode的实例完全可以当成字典来访问和使用。
进一步,我们看一下cfg的键"MODEL"中都包含着哪些信息:
for k,v in cfg.MODEL.items():print(k)
# output as follows
LOAD_PROPOSALS
MASK_ON
KEYPOINT_ON
DEVICE
META_ARCHITECTURE
WEIGHTS
PIXEL_MEAN
PIXEL_STD
BACKBONE
FPN
PROPOSAL_GENERATOR
ANCHOR_GENERATOR
RPN
ROI_HEADS
ROI_BOX_HEAD
ROI_BOX_CASCADE_HEAD
ROI_MASK_HEAD
ROI_KEYPOINT_HEAD
SEM_SEG_HEAD
PANOPTIC_FPN
RETINANET
RESNETS
ROI_HEADS_SCORE_THRESH_TEST
可以看到,MODEL中包含着如BACKBONE,WEIGHTS,ROI_HEADS,等检测和分割模型的关键组件的配置信息。此外,从cfg.MODEL这一写法可以看到,除了用键值对的方式访问,cfg还可以直接用属性(attribute)的方式来访问和修改。
配置文件的部分先告一段落,接下来我们来看模型是如何根据配置文件构建的,相关代码非常简单,只有一行:
# from detectron2.engine import DefaultPredictor
predictor = DefaultPredictor(cfg)
我们来看一下关于DefaultPredictor的源代码(defaults.py):
class DefaultPredictor:def __init__(self, cfg):self.cfg = cfg.clone() # cfg can be modified by model# 用build_model()函数构建模型self.model = build_model(self.cfg)# prediction onlyself.model.eval()if len(cfg.DATASETS.TEST):self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])checkpointer = DetectionCheckpointer(self.model)# load weightscheckpointer.load(cfg.MODEL.WEIGHTS)self.aug = T.ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST)self.input_format = cfg.INPUT.FORMATassert self.input_format in ["RGB", "BGR"], self.input_format# call of predictionsdef __call__(self, original_image):"""Args:original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).Returns:predictions (dict):the output of the model for one image only.See :doc:`/tutorials/models` for details about the format."""with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258# Apply pre-processing to image.if self.input_format == "RGB":# whether the model expects BGR inputs or RGBoriginal_image = original_image[:, :, ::-1]height, width = original_image.shape[:2]image = self.aug.get_transform(original_image).apply_image(original_image)# (H,W,C) to (C,H,W)(pytorch-style input)image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))inputs = {"image": image, "height": height, "width": width}# self.model([inputs]) single image predictionpredictions = self.model([inputs])[0]return predictions
以上代码并不难理解,主要内容为模型的构建,权重的加载和prediction的执行。比较重要的几个地方有:
# from detectron2.modeling import build_model
1.self.model = build_model(self.cfg)
2.predictions = self.model([inputs])[0]
我们来看一下build_model()函数:
import torchfrom detectron2.utils.registry import RegistryMETA_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
META_ARCH_REGISTRY.__doc__ = """
Registry for meta-architectures, i.e. the whole model.
The registered object will be called with `obj(cfg)`
and expected to return a `nn.Module` object.
"""
def build_model(cfg):"""Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.Note that it does not load any weights from ``cfg``."""meta_arch = cfg.MODEL.META_ARCHITECTUREmodel = META_ARCH_REGISTRY.get(meta_arch)(cfg)model.to(torch.device(cfg.MODEL.DEVICE))return model
此函数的运行分为三步:
- 得到模型的元结构,即META_ARCHITECTRUE.
- 根据元结构来得到一个nn.Module的实例。
- 将模型加载到配置文件指定的gpu或cpu上。
可以看到,在此函数前,有一个模型结构注册(Registry)的步骤,我们来重点关注这一操作及其源代码(fvcore.common.registry):
我们先关注初始化函数:
# META_ARCH_REGISTRY = Registry("META_ARCH")
class Registry(Iterable[Tuple[str, Any]]):# 初始化一个注册表(name: "BACKBONE" for example)def __init__(self, name: str) -> None:"""Args:name (str): the name of this registry"""self._name: str = nameself._obj_map: Dict[str, Any] = {}'''可以看到,此类实例的初始化非常简单,仅有两个变量需要关注,分别为self._name,即此Registry的名称,以及self._obj_map,一个记录对应键值对信息的字典'''
接下来是此类最重要的注册功能:
# 这个函数的功能简洁明了,就是将键值对参数写入self._obj_map中
def _do_register(self, name: str, obj: Any) -> None:assert (name not in self._obj_map), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)self._obj_map[name] = objdef register(self, obj: Any = None) -> Any:"""Register the given object under the the name `obj.__name__`.Can be used as either a decorator or not. See docstring of this class for usage."""if obj is None:# used as a decorator# @BACKBONE_REGISTER.register()# class MyBackbone():def deco(func_or_class: Any) -> Any:name = func_or_class.__name__self._do_register(name, func_or_class)return func_or_classreturn deco# used as a function call#BACKBONE_REGISTER.register(MyBackbone)name = obj.__name__self._do_register(name, obj)
接下来是get()函数:
# 此函数的作用为,由给定的键返回self._obj_map中对应的值
# 在本节的预训练模型检测中,get函数的作用为 返回一个在注册表中已经注册好了的GeneralizedRCNNdef get(self, name: str) -> Any:ret = self._obj_map.get(name)if ret is None:raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name))return ret
现在,我们可以回到最初构建模型的函数上了:
meta_arch = cfg.MODEL.META_ARCHITECTUREmodel = META_ARCH_REGISTRY.get(meta_arch)(cfg)model.to(torch.device(cfg.MODEL.DEVICE))# 看一个基础的RCNN类配置文件:[BASE_RCNN_C4.yaml](https://github.com/facebookresearch/detectron2/blob/master/configs/Base-RCNN-C4.yaml)'''MODEL:META_ARCHITECTURE: "GeneralizedRCNN"RPN:PRE_NMS_TOPK_TEST: 6000POST_NMS_TOPK_TEST: 1000ROI_HEADS:NAME: "Res5ROIHeads"代码中的meta_arch即为此配置文件中的"GeneralizedRCNN",对应于detectron2/modeling/meta_arch/rcnn.py中的GeneralizedRCNN类。第二行代码的功能为,利用cfg中的参数来实例化一个GeneralizedRCNN,其中get函数的作用为,接收Config中的meta_arch的名称,以名称为key在META_ARCH注册表中找到对应的value(GeneralizedRCNN)。'''
2.2.2 模型检测与结果可视化
检测模型在测试时,只需要输入待检测的图片:
from detectron2.engine import DefaultPredictor
predictor = DefaultPredictor(cfg)
# Testing
output = predictor(im)
再看到DefaultPredictor 的调用函数__call__():
# DefaultPredictordef __call__(self, original_image):"""Args:original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).Returns:predictions (dict):the output of the model for one image only.See :doc:`/tutorials/models` for details about the format."""with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258# Apply pre-processing to image.if self.input_format == "RGB":# whether the model expects BGR inputs or RGBoriginal_image = original_image[:, :, ::-1]height, width = original_image.shape[:2]image = self.aug.get_transform(original_image).apply_image(original_image)image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))inputs = {"image": image, "height": height, "width": width}predictions = self.model([inputs])[0]return predictions
可以看到,当执行语句“output = predictor(im)”时,先是对图片进行一系列处理,包括Tensor的转换,长边短边的裁剪等,最后真正输入进model的为一个字典"inputs = {“image”: image, “height”: height, “width”: width}",包括图片Tensor以及图片的长宽。
再看到GeneralizedRCNN的inference函数:
def inference(self,batched_inputs: Tuple[Dict[str, torch.Tensor]],detected_instances: Optional[List[Instances]] = None,do_postprocess: bool = True,):
推理时要求的输入也是一个字典,正好与inputs字典相对应。至于 GeneralizedRCNN究竟是如何前向传播的,在后面深入阅读模型结构代码的时候再继续学习,此处先搁置。
得到outputs后,来看一下检测模型的outputs都包含哪些内容:
print(type(output))
for k,v in output.items():print(k)
print(type(output['instances']))
得到的输出如下:
<class 'dict'>
instances
<class 'detectron2.structures.instances.Instances'>
可以看到,outputs仅仅包含一个instances对象,而这个instances是detectron2中自己在structure中定义好的类别。instances类的详细代码不谈,这里只展示instances的几个关键属性:
# 对于检测而言,得出结果的关键属性只有两个,分别是目标的类别和bounding box
print(output["instances"].pred_classes)
print(output["instances"].pred_boxes)Results:
tensor([17, 0, 0, 0, 0, 0, 0, 0, 25, 0, 25, 25, 0, 0, 24],device='cuda:0')
Boxes(tensor([[126.6035, 244.8977, 459.8291, 480.0000],[251.1083, 157.8127, 338.9731, 413.6379],[114.8496, 268.6864, 148.2352, 398.8111],[ 0.8217, 281.0327, 78.6072, 478.4210],[ 49.3954, 274.1229, 80.1545, 342.9808],[561.2248, 271.5816, 596.2755, 385.2552],[385.9072, 270.3125, 413.7130, 304.0397],[515.9295, 278.3744, 562.2792, 389.3802],[335.2409, 251.9167, 414.7491, 275.9375],[350.9300, 269.2060, 386.0984, 297.9081],[331.6292, 230.9996, 393.2759, 257.2009],[510.7349, 263.2656, 570.9865, 295.9194],[409.0841, 271.8646, 460.5582, 356.8722],[506.8767, 283.3257, 529.9403, 324.0392],[594.5663, 283.4820, 609.0577, 311.4124]], device='cuda:0'))
可以看到,检测结果中给出了类别对应的标签和(left,top,right,bottom)结构的bounding box。
3.小结
本文介绍了detectron2的安装和用预训练模型对样例图片进行测试的方法,初步展示了detectron2的便捷性。同时,通过阅读源代码的方式,详细讲述了在detectron2中模型的构建,参数加载,得到图片测试结果的一般过程。下一步要尝试的是,使用detectron2构建自己的数据集并进行训练。
从源代码开始 Detectron2学习笔记相关推荐
- Detectron2学习笔记
文章目录 一.Detectron2 操作介绍 1.1 训练 1.2 测试 1.3 数据及格式要求 1.4 Load/Save model 1.5 模型输入形式 1.6 模型输出 1.7 config ...
- detectron2 学习笔记
目录 一.安装 二.项目详细介绍 训练 三.tools文件夹 四.换自己的数据集 制作数据集 换数据集 标准数据集字典 Metadata Dataloader 数据增强Data Augmentatio ...
- Ngrx Store实现源代码的MemoizedSelector学习笔记
定义一个类型AnyFn,代表任意的函数: export type AnyFn = (...args: any[]) => any; let a: AnyFn;a = (data) => c ...
- 树莓派学习笔记—— 源代码方式安装opencv
0.前言 本文介绍如何在树莓派中通过编译源代码的方式安装opencv,并通过一个简单的例子说明如何使用opencv. 更多内容请参考--[树莓派学习笔记--索引博文] 1.下载若干依赖项 在开始安装之 ...
- java存入光盘_java 这是 学习笔记(jdk7)书中的光盘里的源码,不知大家需要不,里面都是新手 的好 Develop 238万源代码下载- www.pudn.com...
文件名称: java下载 收藏√ [ 5 4 3 2 1 ] 开发工具: Java 文件大小: 1272 KB 上传时间: 2013-04-01 下载次数: 18 提 供 者: 孙鹏启 详细 ...
- eos 源代码学习笔记一
文章目录 eos 源代码学习笔记 1.eos 中的常见合约类型 2.语言环境局部( locale )变量的使用简介(目的是通过 gettext 软件包 来实现软件的全球化) 3.eos 源代码的一些优 ...
- 【OS学习笔记】二十五 保护模式七:任务和特权级保护对应的汇编源代码
本汇编代码是以下两篇文章讲解的内容的内核代码; [OS学习笔记]二十三 保护模式七:保护模式下任务的隔离与任务的特权级概念 [OS学习笔记]二十四 保护模式七:调用门与依从的代码段----特权级保护 ...
- CTFHUB http协议题目 学习笔记 详细步骤 请求方式 302跳转 cookie 基础认证 响应源代码
CTFHUB http协议题目 学习笔记 详细步骤 请求方式 302跳转 cookie 基础认证 响应源代码 WEB-HTTP协议 1-请求方式 2-302跳转 3.cookie 4.基础认证 5.响 ...
- PCA(主成分分析-principal components analysis)学习笔记以及源代码实战讲解
PCA(主成分分析-principal components analysis)学习笔记以及源代码实战讲解 文章目录 PCA(主成分分析-principal components analysis)学 ...
最新文章
- MVC系列1-MVC基础
- 2020五大技术趋势一览!超自动化、人类增强技术、无人驾驶发展、机器视觉崛起、区块链实用化...
- Two ways to assign values to member variables
- socket bufferedinputstream通信读取不到服务器返回的响应_TCP角度看socket通信过程,socket怎么表示三次握手,四次挥手...
- 智能优化算法改进算法 -附代码
- IEEE745浮点数格式
- sql server cross/outer apply 用法
- wps表格宏被禁用如何解禁_wps excel宏被禁用如何启用 - 卡饭网
- 网络型 PLC可编程控制器综合实训装置
- 横向扩展文件服务器,如何在 VMM 中创建横向扩展文件服务器
- 一些开源的项目 收藏
- 三峡学院计算机调剂,重庆三峡学院2019考研调剂信息公告
- 这几款超实用办公神器,让你的工作省心省时又省力!
- mysql 餐饮管理系统_Java Mysql 餐饮管理系统 过程心得记录
- 【unity插件】Rewired插件-unity3d实现主机、PC手柄震动Vibration
- linux shell 三元运算符,语法 - Bash中的三元运算符(?:)
- 安全合规/GDPR--24--研究:GDPR合规体系设立与执行
- eNSP上华为路由器开SNMP
- 向U盘中安装Linux系统的经验(不是制作安装盘)
- 用芝麻二维码生成器制作App下载二维码
热门文章
- 鲸鱼算法(WOA)优化支持向量机的数据回归预测,WOA-SVM回归预测,多输入单输出模型。
- PHP使用CURL详解
- 移动4g有信号无法连接服务器,在门口有4G的信号,但是进了房间里就没有了,上网都连不上。怎么避免这种情况?...
- 告别2017,码农翻身全年文章精华
- opencv 保存读取16位深度的图像
- JAVA编写学校超市选址问题_学校超市选址问题课程设计
- 如何使用wce进行hash注入
- 基于AutoJs实现的薅羊毛App专业版源码大分享---更新啦
- 计算机网络第一章学习
- Microsemi Libero使用技巧6——FPGA全局网络的设置