点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达本文转自|计算机视觉联盟

最近做了一段时间的目标检测,不得不说检测这块还是相对比较复杂的,在熟悉项目的同时也确实学习到了很多有用的东西。MMdetetion是现在最著名、算法包最多并且使用人数最多的训练框架,其中的源码非常值得学习,今天总结下我对其中HOOK(钩子)机制的理解。

MMdetection最近更新很多,我以2.4.0版本的代码进行解读,分享自己的理解,也吸纳观众的点评。HOOK、Runer的定义在MMCV当中,MMdetection和MMCV是版本匹配的,我这里使用的是MMCV 1.1.2的代码。(HOOK相关的定义主要在MMCV中,下面用的代码都是摘自于MMCV)。

1.HOOK机制的作用

MMdetection中的HOOK可以理解为一种触发器,也可以理解为一种训练框架的架构规范,它规定了在算法训练过程中的种种操作,并且我们可以通过继承HOOK类,然后注册HOOK自定义我们想要的操作。

首先看一下HOOK的基类定义

# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.utils import RegistryHOOKS = Registry('hook')class Hook:def before_run(self, runner):passdef after_run(self, runner):passdef before_epoch(self, runner):passdef after_epoch(self, runner):passdef before_iter(self, runner):passdef after_iter(self, runner):passdef before_train_epoch(self, runner):self.before_epoch(runner)def before_val_epoch(self, runner):self.before_epoch(runner)def after_train_epoch(self, runner):self.after_epoch(runner)def after_val_epoch(self, runner):self.after_epoch(runner)def before_train_iter(self, runner):self.before_iter(runner)def before_val_iter(self, runner):self.before_iter(runner)def after_train_iter(self, runner):self.after_iter(runner)def after_val_iter(self, runner):self.after_iter(runner)def every_n_epochs(self, runner, n):return (runner.epoch + 1) % n == 0 if n > 0 else Falsedef every_n_inner_iters(self, runner, n):return (runner.inner_iter + 1) % n == 0 if n > 0 else Falsedef every_n_iters(self, runner, n):return (runner.iter + 1) % n == 0 if n > 0 else Falsedef end_of_epoch(self, runner):return runner.inner_iter + 1 == len(runner.data_loader)

可以说基类函数中定义了许多我们在模型训练中需要用到的一些功能,如果想定义一些操作我们就可以继承这个类并定制化我们的功能,可以看到HOOK中每一个参数都是有runner作为参数传入的。关于Runner的作用下一篇文章接着说,简而言之,Runner是一个模型训练的工厂,在其中我们可以加载数据、训练、验证以及梯度backward等等全套流程。MMdetection在设计的时候也为runner传入丰富的参数,定义了一个非常好的训练范式。在你的每一个hook函数中,都可以对runner进行你想要的操作。

而HOOK是怎么嵌套进runner中的呢?其实是在Runner中定义了一个hook的list,list中的每一个元素就是一个实例化的HOOK对象。其中提供了两种注册hook的方法,register_hook是传入一个实例化的HOOK对象,并将它插入到一个列表中,register_hook_from_cfg是传入一个配置项,根据配置项来实例化HOOK对象并插入到列表中。当然第二种方法又是MMLab的开源生态中定义的一种基础方法mmcv.build_from_cfg了,无论在MMdetection还是其他MMLab开源的算法框架中,都遵循着MMCV的这套基于配置项实例化对象的方法。毕竟MMCV是提供了一个基础的功能,服务于各个算法框架,这也是为什么MMLab的代码高质量的原因。不仅仅是算法的复现,更是架构、编程范式的一种体现,真·代码如诗

def register_hook(self, hook, priority='NORMAL'):"""Register a hook into the hook list.The hook will be inserted into a priority queue, with the specifiedpriority (See :class:`Priority` for details of priorities).For hooks with the same priority, they will be triggered in the sameorder as they are registered.Args:hook (:obj:`Hook`): The hook to be registered.priority (int or str or :obj:`Priority`): Hook priority.Lower value means higher priority."""assert isinstance(hook, Hook)if hasattr(hook, 'priority'):raise ValueError('"priority" is a reserved attribute for hooks')priority = get_priority(priority)hook.priority = priority# insert the hook to a sorted listinserted = False# hook是分优先级插入到list中的,在MMdetection中不同的HOOK是有优先级的,为什么呢?稍后在hook的调用中解释哈for i in range(len(self._hooks) - 1, -1, -1):if priority >= self._hooks[i].priority:self._hooks.insert(i + 1, hook)inserted = Truebreakif not inserted:self._hooks.insert(0, hook)def register_hook_from_cfg(self, hook_cfg):"""Register a hook from its cfg.Args:hook_cfg (dict): Hook config. It should have at least keys 'type'and 'priority' indicating its type and priority.Notes:The specific hook class to register should not use 'type' and'priority' arguments during initialization."""hook_cfg = hook_cfg.copy()priority = hook_cfg.pop('priority', 'NORMAL')hook = mmcv.build_from_cfg(hook_cfg, HOOKS)self.register_hook(hook, priority=priority)

调用HOOK函数

def call_hook(self, fn_name):"""Call all hooks.Args:fn_name (str): The function name in each hook to be called, such as"before_train_epoch"."""for hook in self._hooks:getattr(hook, fn_name)(self)

可以看到HOOK是调用的时候是遍历List,然后根据HOOK的名字来调用。这也是为什么要区分优先级的原因,优先级越高的放在List的前面,这样就能更快地被调用。当你想用_before_run_epoch_来做A和B两件事情的时候,在runner里面就是调用一次self.before_run_epoch,但是先做A还是先做B,就是通过不同的HOOK的优先级来决定了。比如在evaluation的时候对需要做测试,但是测试前对参数做滑动平均。比如emaHOOK中的72行,也写明了要在测试之前做指数滑动平均。

def after_train_epoch(self, runner):"""We load parameter values from ema backup to model before theEvalHook."""self._swap_ema_parameters()

checkpoint.py的HOOK中,同样也定义了after_train_epoch函数如下:

@master_onlydef after_train_epoch(self, runner):if not self.by_epoch or not self.every_n_epochs(runner, self.interval):returnrunner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')if not self.out_dir:self.out_dir = runner.work_dirrunner.save_checkpoint(self.out_dir, save_optimizer=self.save_optimizer, **self.args)# remove other checkpointsif self.max_keep_ckpts > 0:filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth')current_epoch = runner.epoch + 1for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1):ckpt_path = os.path.join(self.out_dir,filename_tmpl.format(epoch))if os.path.exists(ckpt_path):os.remove(ckpt_path)else:break

从测试代码中可以看到不同的HOOK虽然都是重写了after_train_epoch函数,但是调用的顺序还是先调用ema.py中的,然后再调用checkpoint.py中的after_train_epoch

resume_ema_hook = EMAHook(momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')runner = _build_demo_runner()runner.model = demo_model# 设置了HIGHREST的优先级runner.register_hook(resume_ema_hook, priority='HIGHEST')checkpointhook = CheckpointHook(interval=1, by_epoch=True)runner.register_hook(checkpointhook)runner.run([loader, loader], [('train', 1), ('val', 1)], 2)

具体的优先级定义有以下7种,作为HOOK的类成员属性。具体定义在链接中。

+------------+------------+| Level      | Value      |+============+============+| HIGHEST    | 0          |+------------+------------+| VERY_HIGH  | 10         |+------------+------------+| HIGH       | 30         |+------------+------------+| NORMAL     | 50         |+------------+------------+| LOW        | 70         |+------------+------------+| VERY_LOW   | 90         |+------------+------------+| LOWEST     | 100        |+------------+------------+

2.举一个简单的例子

最近打算好好锻炼身体,健康生活,努力工作,我打算让自己变得更加自律。我给自己定下了几个条例,每天吃早饭之前得晨练30分钟,运动完之后才会感觉充满活力。每天吃午饭之前我得跑上一个实验,吃完饭之后回来刚好可以看下中间结果,吃完午饭之后我感觉结果没问题我需要午休30分钟, 晚上下班前我如果没什么事再锻炼30分钟。秉承着这样的原则我给自己定义一个HOOK来规范我的生活。

  • 定义我的HOOK

import sys
class HOOK:def before_breakfast(self, runner):print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))def after_breakfast(self, runner):print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))def before_lunch(self, runner):print('{}:吃午饭之前跑上实验'.format(sys._getframe().f_code.co_name))def after_lunch(self, runner):print('{}:吃完午饭午休30分钟'.format(sys._getframe().f_code.co_name))def before_dinner(self, runner):print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))def after_dinner(self, runner):print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))def after_finish_work(self, runner, are_you_busy=False):if are_you_busy:print('{}:今天事贼多,还是加班吧'.format(sys._getframe().f_code.co_name))else:print('{}:今天没啥事,去锻炼30分钟'.format(sys._getframe().f_code.co_name))
  • 定义我的Runner

class Runner(object):def __init__(self, ):passself._hooks = []def register_hook(self, hook):# 这里不做优先级判断,直接在头部插入HOOKself._hooks.insert(0, hook)def call_hook(self, hook_name):for hook in self._hooks:getattr(hook, hook_name)(self)def run(self):print('开始启动我的一天')self.call_hook('before_breakfast')self.call_hook('after_breakfast')self.call_hook('before_lunch')self.call_hook('after_lunch')self.call_hook('before_dinner')self.call_hook('after_dinner')self.call_hook('after_finish_work')print('~~睡觉~~')
  • 运行main函数,注册HOOK并且调用Runner.run()开启我的一天

from MyHook import HOOK
from MyRunner import Runner
runner = Runner()
hook = HOOK()
runner.register_hook(hook)
runner.run()
  • 得到的输出结果如下:

开始启动我的一天
before_breakfast:吃早饭之前晨练30分钟
after_breakfast:吃早饭之前晨练30分钟
before_lunch:吃午饭之前跑上实验
after_lunch:吃完午饭午休30分钟
before_dinner: 没想好做什么
after_dinner: 没想好做什么
after_finish_work:今天没啥事,去锻炼30分钟
~~睡觉~~

3.总结

MMdetection中的HOOK设计巧妙,很好地对算法训练、测试进行了抽象和解耦。每一个做上层算法模型的,都值得一看。感谢MMLab贡献这么优质的代码,让我等凡夫俗子醍醐灌顶。

除了HOOK之外,这个代码中还有很多优质的思想。比如Runner是怎么做到包办一切的?注册器这个中枢管理系统是怎么工作的?多卡训练的一些坑是怎么解决的?等等等等,我也在持续地学习和消化。路漫漫其修远兮,吾将上下而求索。

一个小题目:我的代码中每个函数输出的时候都会打印出这个函数名,这个可以用_装饰器_很方便地解决奥。装饰器这个东西在MMLab的系列项目中有大量的应用。其中对fp16的支持让大家赞不绝口。接下来有时间,对Runner、Register、装饰器这些东西好好盘一盘。

end

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

深度理解目标检测(MMdetection)-HOOK机制相关推荐

  1. 深度学习目标检测工具箱mmdetection,训练自己的数据

    文章目录 一.简介 二.安装教程 1. 使用conda创建Python虚拟环境(可选) 2. 安装PyTorch 1.1 3. 安装Cython ~~4. 安装mmcv~~ 5. 安装mmdetect ...

  2. 深度学习目标检测详细解析以及Mask R-CNN示例

    深度学习目标检测详细解析以及Mask R-CNN示例 本文详细介绍了R-CNN走到端到端模型的Faster R-CNN的进化流程,以及典型的示例算法Mask R-CNN模型.算法如何变得更快,更强! ...

  3. 值得收藏!基于激光雷达数据的深度学习目标检测方法大合集(下)

    作者 | 黄浴 来源 | 转载自知乎专栏自动驾驶的挑战和发展 [导读]在近日发布的<值得收藏!基于激光雷达数据的深度学习目标检测方法大合集(上)>一文中,作者介绍了一部分各大公司和机构基于 ...

  4. 自动驾驶深度多模态目标检测和语义分割:数据集、方法和挑战

    自动驾驶深度多模态目标检测和语义分割:数据集.方法和挑战 原文地址:https://arxiv.org/pdf/1902.07830.pdf Deep Multi-Modal Object Detec ...

  5. 深度篇——目标检测史(八) 细说 CornerNet-Lite 目标检测

    返回主目录 返回 目标检测史 目录 上一章:深度篇--目标检测史(七) 细说 YOLO-V3目标检测 之 代码详解 论文地址:https://arxiv.org/pdf/1904.08900.pdf ...

  6. 深度学习目标检测指南:如何过滤不感兴趣的分类及添加新分类?

    编译 | 庞佳 责编 | Leo 出品 | AI 科技大本营(公众号ID:rgznai100) AI 科技大本营按:本文编译自 Adrian Rosebrock 发表在 PyImageSearch 上 ...

  7. 深度学习目标检测模型全面综述:Faster R-CNN、R-FCN和SSD

    为什么80%的码农都做不了架构师?>>>    Faster R-CNN.R-FCN 和 SSD 是三种目前最优且应用最广泛的目标检测模型,其他流行的模型通常与这三者类似.本文介绍了 ...

  8. 深度剖析目标检测算法YOLOV4

    深度剖析目标检测算法YOLOV4 目录 简述 yolo 的发展历程 介绍 yolov3 算法原理 介绍 yolov4 算法原理(相比于 yolov3,有哪些改进点) YOLOV4 源代码日志解读 yo ...

  9. 深度学习目标检测方法汇总

    目标检测简介   目标检测是计算机视觉的一个重要研究方向,是指从一个场景(或图片)中找到感兴趣的目标.任务大致分为三个流程: 从场景中提取候选区 从候选区提取特征 识别候选区的类别并对有效的候选框进行 ...

最新文章

  1. 《精通 ASP.NET MVC 3 框架(第三版)》----第2章 准备工作 2.1 准备工作站
  2. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码
  3. python爬百度翻译-爬虫 python爬取百度翻译接口 超详细附源码
  4. 【PAT乙级】1003 我要通过! (20 分)详解
  5. 深入理解 MySQL ——锁、事务与并发控制
  6. centos arm-linux-gcc,CentOS 6.4配置arm-linux-gcc交叉环境
  7. Ubuntu通过可视化界面配置 查找IP地址不存在的解决办法
  8. linux进程映像由哪些构成,Linux编程开发进程映像类型分析
  9. 2名数学家或发现史上最快超大乘法运算法,欲破解困扰人类近半个世纪的问题...
  10. mac 源码编译yar遇见的坑
  11. I2S接口以及Verilog实现数据接收
  12. 前期交互流程(PTES的第一步)
  13. 如何理解冲突域和广播域?(转)
  14. ps切图怎么做成html,PS切图怎么导出网页 PS切图怎么生成源代码
  15. 周鸿祎:创业者需要有点阿Q精神
  16. java复数类实部_Java编写一个复数类Complex,具有实部、虚部成员变量,可以完成加、减、乘、除和获得实部和虚部的方法...
  17. 图片怎样调整分辨率?如何在线修改分辨率?
  18. 「基因组学」使用CAFE进行基因家族扩张收缩分析
  19. Java 抛出异常【throw】
  20. 《A Mixed-Initiative Interface for Animating Static Pictures》翻译

热门文章

  1. Python十大装腔语法
  2. COCO 2019挑战赛,旷视研究院拿下三项计算机识别冠军 | ICCV 2019
  3. 一周焦点 | 最强AI芯片麒麟980发布;前端开发者将被取代?
  4. 重磅!阿里开源AI核心技术,95%算法工程师受用
  5. 推荐好用 Spring Boot 内置工具类
  6. PageHelper 在 Spring Boot + MyBatis 中合理且规范的使用方法
  7. 序列化/反序列化,我忍你很久了,淦!
  8. 我把SpringBoot项目从18.18M瘦身到0.18M,部署起来真省事!
  9. Datawhale数据分析教程来了!
  10. 人工神经网络背后的数学原理!