1. 什么是Hook

经常会听到钩子函数(hook function)这个概念,最近在看目标检测开源框架mmdetection,里面也出现大量Hook的编程方式,那到底什么是hook?hook的作用是什么?

  • what is hook ?钩子hook,顾名思义,可以理解是一个挂钩,作用是有需要的时候挂一个东西上去。具体的解释是:钩子函数是把我们自己实现的hook函数在某一时刻挂接到目标挂载点上。
  • hook函数的作用 举个例子,hook的概念在windows桌面软件开发很常见,特别是各种事件触发的机制; 比如C++的MFC程序中,要监听鼠标左键按下的时间,MFC提供了一个onLeftKeyDown的钩子函数。很显然,MFC框架并没有为我们实现onLeftKeyDown具体的操作,只是为我们提供一个钩子,当我们需要处理的时候,只要去重写这个函数,把我们需要操作挂载在这个钩子里,如果我们不挂载,MFC事件触发机制中执行的就是空的操作。

从上面可知

  • hook函数是程序中预定义好的函数,这个函数处于原有程序流程当中(暴露一个钩子出来)
  • 我们需要再在有流程中钩子定义的函数块中实现某个具体的细节,需要把我们的实现,挂接或者注册(register)到钩子里,使得hook函数对目标可用
  • hook 是一种编程机制,和具体的语言没有直接的关系
  • 如果从设计模式上看,hook模式是模板方法的扩展
  • 钩子只有注册的时候,才会使用,所以原有程序的流程中,没有注册或挂载时,执行的是空(即没有执行任何操作)

本文用python来解释hook的实现方式,并展示在开源项目中hook的应用案例。hook函数和我们常听到另外一个名称:回调函数(callback function)功能是类似的,可以按照同种模式来理解。

2. hook实现例子

据我所知,hook函数最常使用在某种流程处理当中。这个流程往往有很多步骤。hook函数常常挂载在这些步骤中,为增加额外的一些操作,提供灵活性。

下面举一个简单的例子,这个例子的目的是实现一个通用往队列中插入内容的功能。流程步骤有2个

  • 需要再插入队列前,对数据进行筛选 input_filter_fn
  • 插入队列 insert_queue
  1. class ContentStash(object):
  2. """
  3. content stash for online operation
  4. pipeline is
  5. 1. input_filter: filter some contents, no use to user
  6. 2. insert_queue(redis or other broker): insert useful content to queue
  7. """
  8. def __init__(self):
  9. self.input_filter_fn = None
  10. self.broker = []
  11. def register_input_filter_hook(self, input_filter_fn):
  12. """
  13. register input filter function, parameter is content dict
  14. Args:
  15. input_filter_fn: input filter function
  16. Returns:
  17. """
  18. self.input_filter_fn = input_filter_fn
  19. def insert_queue(self, content):
  20. """
  21. insert content to queue
  22. Args:
  23. content: dict
  24. Returns:
  25. """
  26. self.broker.append(content)
  27. def input_pipeline(self, content, use=False):
  28. """
  29. pipeline of input for content stash
  30. Args:
  31. use: is use, defaul False
  32. content: dict
  33. Returns:
  34. """
  35. if not use:
  36. return
  37. # input filter
  38. if self.input_filter_fn:
  39. _filter = self.input_filter_fn(content)
  40. # insert to queue
  41. if not _filter:
  42. self.insert_queue(content)
  43. # test
  44. ## 实现一个你所需要的钩子实现:比如如果content 包含time就过滤掉,否则插入队列
  45. def input_filter_hook(content):
  46. """
  47. test input filter hook
  48. Args:
  49. content: dict
  50. Returns: None or content
  51. """
  52. if content.get('time') is None:
  53. return
  54. else:
  55. return content
  56. # 原有程序
  57. content = {'filename': 'test.jpg', 'b64_file': "#test", 'data': {"result": "cat", "probility": 0.9}}
  58. content_stash = ContentStash('audit', work_dir='')
  59. # 挂上钩子函数, 可以有各种不同钩子函数的实现,但是要主要函数输入输出必须保持原有程序中一致,比如这里是content
  60. content_stash.register_input_filter_hook(input_filter_hook)
  61. # 执行流程
  62. content_stash.input_pipeline(content)

3. hook在开源框架中的应用

3.1 keras

在深度学习训练流程中,hook函数体现的淋漓尽致。

一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:

  • 开始训练
  • 训练一个epoch前
  • 训练一个batch前
  • 训练一个batch后
  • 训练一个epoch后
  • 评估验证集
  • 结束训练

这些步骤是穿插在训练一个batch数据的过程中,这些可以理解成是钩子函数,我们可能需要在这些钩子函数中实现一些定制化的东西,比如在训练一个epoch后我们要保存下训练的模型,在结束训练时用最好的模型执行下测试集的效果等等。

keras中是通过各种回调函数来实现钩子hook功能的。这里放一个callback的父类,定制时只要继承这个父类,实现你过关注的钩子就可以了。

  1. @keras_export('keras.callbacks.Callback')
  2. class Callback(object):
  3. """Abstract base class used to build new callbacks.
  4. Attributes:
  5. params: Dict. Training parameters
  6. (eg. verbosity, batch size, number of epochs...).
  7. model: Instance of `keras.models.Model`.
  8. Reference of the model being trained.
  9. The `logs` dictionary that callback methods
  10. take as argument will contain keys for quantities relevant to
  11. the current batch or epoch (see method-specific docstrings).
  12. """
  13. def __init__(self):
  14. self.validation_data = None # pylint: disable=g-missing-from-attributes
  15. self.model = None
  16. # Whether this Callback should only run on the chief worker in a
  17. # Multi-Worker setting.
  18. # TODO(omalleyt): Make this attr public once solution is stable.
  19. self._chief_worker_only = None
  20. self._supports_tf_logs = False
  21. def set_params(self, params):
  22. self.params = params
  23. def set_model(self, model):
  24. self.model = model
  25. @doc_controls.for_subclass_implementers
  26. @generic_utils.default
  27. def on_batch_begin(self, batch, logs=None):
  28. """A backwards compatibility alias for `on_train_batch_begin`."""
  29. @doc_controls.for_subclass_implementers
  30. @generic_utils.default
  31. def on_batch_end(self, batch, logs=None):
  32. """A backwards compatibility alias for `on_train_batch_end`."""
  33. @doc_controls.for_subclass_implementers
  34. def on_epoch_begin(self, epoch, logs=None):
  35. """Called at the start of an epoch.
  36. Subclasses should override for any actions to run. This function should only
  37. be called during TRAIN mode.
  38. Arguments:
  39. epoch: Integer, index of epoch.
  40. logs: Dict. Currently no data is passed to this argument for this method
  41. but that may change in the future.
  42. """
  43. @doc_controls.for_subclass_implementers
  44. def on_epoch_end(self, epoch, logs=None):
  45. """Called at the end of an epoch.
  46. Subclasses should override for any actions to run. This function should only
  47. be called during TRAIN mode.
  48. Arguments:
  49. epoch: Integer, index of epoch.
  50. logs: Dict, metric results for this training epoch, and for the
  51. validation epoch if validation is performed. Validation result keys
  52. are prefixed with `val_`.
  53. """
  54. @doc_controls.for_subclass_implementers
  55. @generic_utils.default
  56. def on_train_batch_begin(self, batch, logs=None):
  57. """Called at the beginning of a training batch in `fit` methods.
  58. Subclasses should override for any actions to run.
  59. Arguments:
  60. batch: Integer, index of batch within the current epoch.
  61. logs: Dict, contains the return value of `model.train_step`. Typically,
  62. the values of the `Model`'s metrics are returned. Example:
  63. `{'loss': 0.2, 'accuracy': 0.7}`.
  64. """
  65. # For backwards compatibility.
  66. self.on_batch_begin(batch, logs=logs)
  67. @doc_controls.for_subclass_implementers
  68. @generic_utils.default
  69. def on_train_batch_end(self, batch, logs=None):
  70. """Called at the end of a training batch in `fit` methods.
  71. Subclasses should override for any actions to run.
  72. Arguments:
  73. batch: Integer, index of batch within the current epoch.
  74. logs: Dict. Aggregated metric results up until this batch.
  75. """
  76. # For backwards compatibility.
  77. self.on_batch_end(batch, logs=logs)
  78. @doc_controls.for_subclass_implementers
  79. @generic_utils.default
  80. def on_test_batch_begin(self, batch, logs=None):
  81. """Called at the beginning of a batch in `evaluate` methods.
  82. Also called at the beginning of a validation batch in the `fit`
  83. methods, if validation data is provided.
  84. Subclasses should override for any actions to run.
  85. Arguments:
  86. batch: Integer, index of batch within the current epoch.
  87. logs: Dict, contains the return value of `model.test_step`. Typically,
  88. the values of the `Model`'s metrics are returned. Example:
  89. `{'loss': 0.2, 'accuracy': 0.7}`.
  90. """
  91. @doc_controls.for_subclass_implementers
  92. @generic_utils.default
  93. def on_test_batch_end(self, batch, logs=None):
  94. """Called at the end of a batch in `evaluate` methods.
  95. Also called at the end of a validation batch in the `fit`
  96. methods, if validation data is provided.
  97. Subclasses should override for any actions to run.
  98. Arguments:
  99. batch: Integer, index of batch within the current epoch.
  100. logs: Dict. Aggregated metric results up until this batch.
  101. """
  102. @doc_controls.for_subclass_implementers
  103. @generic_utils.default
  104. def on_predict_batch_begin(self, batch, logs=None):
  105. """Called at the beginning of a batch in `predict` methods.
  106. Subclasses should override for any actions to run.
  107. Arguments:
  108. batch: Integer, index of batch within the current epoch.
  109. logs: Dict, contains the return value of `model.predict_step`,
  110. it typically returns a dict with a key 'outputs' containing
  111. the model's outputs.
  112. """
  113. @doc_controls.for_subclass_implementers
  114. @generic_utils.default
  115. def on_predict_batch_end(self, batch, logs=None):
  116. """Called at the end of a batch in `predict` methods.
  117. Subclasses should override for any actions to run.
  118. Arguments:
  119. batch: Integer, index of batch within the current epoch.
  120. logs: Dict. Aggregated metric results up until this batch.
  121. """
  122. @doc_controls.for_subclass_implementers
  123. def on_train_begin(self, logs=None):
  124. """Called at the beginning of training.
  125. Subclasses should override for any actions to run.
  126. Arguments:
  127. logs: Dict. Currently no data is passed to this argument for this method
  128. but that may change in the future.
  129. """
  130. @doc_controls.for_subclass_implementers
  131. def on_train_end(self, logs=None):
  132. """Called at the end of training.
  133. Subclasses should override for any actions to run.
  134. Arguments:
  135. logs: Dict. Currently the output of the last call to `on_epoch_end()`
  136. is passed to this argument for this method but that may change in
  137. the future.
  138. """
  139. @doc_controls.for_subclass_implementers
  140. def on_test_begin(self, logs=None):
  141. """Called at the beginning of evaluation or validation.
  142. Subclasses should override for any actions to run.
  143. Arguments:
  144. logs: Dict. Currently no data is passed to this argument for this method
  145. but that may change in the future.
  146. """
  147. @doc_controls.for_subclass_implementers
  148. def on_test_end(self, logs=None):
  149. """Called at the end of evaluation or validation.
  150. Subclasses should override for any actions to run.
  151. Arguments:
  152. logs: Dict. Currently the output of the last call to
  153. `on_test_batch_end()` is passed to this argument for this method
  154. but that may change in the future.
  155. """
  156. @doc_controls.for_subclass_implementers
  157. def on_predict_begin(self, logs=None):
  158. """Called at the beginning of prediction.
  159. Subclasses should override for any actions to run.
  160. Arguments:
  161. logs: Dict. Currently no data is passed to this argument for this method
  162. but that may change in the future.
  163. """
  164. @doc_controls.for_subclass_implementers
  165. def on_predict_end(self, logs=None):
  166. """Called at the end of prediction.
  167. Subclasses should override for any actions to run.
  168. Arguments:
  169. logs: Dict. Currently no data is passed to this argument for this method
  170. but that may change in the future.
  171. """
  172. def _implements_train_batch_hooks(self):
  173. """Determines if this Callback should be called for each train batch."""
  174. return (not generic_utils.is_default(self.on_batch_begin) or
  175. not generic_utils.is_default(self.on_batch_end) or
  176. not generic_utils.is_default(self.on_train_batch_begin) or
  177. not generic_utils.is_default(self.on_train_batch_end))

这些钩子的原始程序是在模型训练流程中的

keras源码位置: tensorflowpythonkerasenginetraining.py

部分摘录如下(## I am hook):

  1. # Container that configures and calls `tf.keras.Callback`s.
  2. if not isinstance(callbacks, callbacks_module.CallbackList):
  3. callbacks = callbacks_module.CallbackList(
  4. callbacks,
  5. add_history=True,
  6. add_progbar=verbose != 0,
  7. model=self,
  8. verbose=verbose,
  9. epochs=epochs,
  10. steps=data_handler.inferred_steps)
  11. ## I am hook
  12. callbacks.on_train_begin()
  13. training_logs = None
  14. # Handle fault-tolerance for multi-worker.
  15. # TODO(omalleyt): Fix the ordering issues that mean this has to
  16. # happen after `callbacks.on_train_begin`.
  17. data_handler._initial_epoch = ( # pylint: disable=protected-access
  18. self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
  19. for epoch, iterator in data_handler.enumerate_epochs():
  20. self.reset_metrics()
  21. callbacks.on_epoch_begin(epoch)
  22. with data_handler.catch_stop_iteration():
  23. for step in data_handler.steps():
  24. with trace.Trace(
  25. 'TraceContext',
  26. graph_type='train',
  27. epoch_num=epoch,
  28. step_num=step,
  29. batch_size=batch_size):
  30. ## I am hook
  31. callbacks.on_train_batch_begin(step)
  32. tmp_logs = train_function(iterator)
  33. if data_handler.should_sync:
  34. context.async_wait()
  35. logs = tmp_logs # No error, now safe to assign to logs.
  36. end_step = step + data_handler.step_increment
  37. callbacks.on_train_batch_end(end_step, logs)
  38. epoch_logs = copy.copy(logs)
  39. # Run validation.
  40. ## I am hook
  41. callbacks.on_epoch_end(epoch, epoch_logs)

3.2 mmdetection

mmdetection是一个目标检测的开源框架,集成了许多不同的目标检测深度学习算法(pytorch版),如faster-rcnn, fpn, retianet等。里面也大量使用了hook,暴露给应用实现流程中具体部分。

详见https://github.com/open-mmlab/mmdetection

这里看一个训练的调用例子(摘录)(https://github.com/open-mmlab/mmdetection/blob/5d592154cca589c5113e8aadc8798bbc73630d98/mmdet/apis/train.py

  1. def train_detector(model,
  2. dataset,
  3. cfg,
  4. distributed=False,
  5. validate=False,
  6. timestamp=None,
  7. meta=None):
  8. logger = get_root_logger(cfg.log_level)
  9. # prepare data loaders
  10. # put model on gpus
  11. # build runner
  12. optimizer = build_optimizer(model, cfg.optimizer)
  13. runner = EpochBasedRunner(
  14. model,
  15. optimizer=optimizer,
  16. work_dir=cfg.work_dir,
  17. logger=logger,
  18. meta=meta)
  19. # an ugly workaround to make .log and .log.json filenames the same
  20. runner.timestamp = timestamp
  21. # fp16 setting
  22. # register hooks
  23. runner.register_training_hooks(cfg.lr_config, optimizer_config,
  24. cfg.checkpoint_config, cfg.log_config,
  25. cfg.get('momentum_config', None))
  26. if distributed:
  27. runner.register_hook(DistSamplerSeedHook())
  28. # register eval hooks
  29. if validate:
  30. # Support batch_size > 1 in validation
  31. eval_cfg = cfg.get('evaluation', {})
  32. eval_hook = DistEvalHook if distributed else EvalHook
  33. runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
  34. # user-defined hooks
  35. if cfg.get('custom_hooks', None):
  36. custom_hooks = cfg.custom_hooks
  37. assert isinstance(custom_hooks, list),
  38. f'custom_hooks expect list type, but got {type(custom_hooks)}'
  39. for hook_cfg in cfg.custom_hooks:
  40. assert isinstance(hook_cfg, dict),
  41. 'Each item in custom_hooks expects dict type, but got '
  42. f'{type(hook_cfg)}'
  43. hook_cfg = hook_cfg.copy()
  44. priority = hook_cfg.pop('priority', 'NORMAL')
  45. hook = build_from_cfg(hook_cfg, HOOKS)
  46. runner.register_hook(hook, priority=priority)

4. 总结

本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。

感谢阅读!!!

多说一句,很多人学Python过程中会遇到各种烦恼问题,没有人解答容易放弃。小编是一名python开发工程师,这里有我自己整理了一套最新的python系统学习教程,包括从基础的python脚本到web开发、爬虫、数据分析、数据可视化、机器学习等。想要这些资料的可以关注小编,并在后台私信小编:“01”即可领取。

mfc中嵌入python_Python 中的 Hook 钩子函数相关推荐

  1. java内嵌excel_如何在Excel中嵌入URL中的图像?

    我试图从URL中提取图像并将其嵌入Excel中 . 我的Excel表格很简单:它包含2列 . 第1列具有图像URL . 在第2列中,我想嵌入图像 . 我使用以下代码 . 它在第一行工作得非常好,我在本 ...

  2. pytest合集(10)— Hook钩子函数

    一.钩子函数 钩子函数这个称呼是很多开发语言中都会涉及到的一个东西. 1.理解钩子函数 如何理解钩子函数 - 知乎 2.pytest的钩子函数 Hooks钩子函数是pytest框架预留的函数,通过这些 ...

  3. vue子组件mounted不执行_vue中父子组件传值,解决钩子函数mounted只运行一次的问题...

    因为mounted函数只会在html和模板渲染之后会加载一次,但是在子组件中只有第一次的数据显示是正常的,所以需要再增加一个updated函数,在更新之后就可以重新进行取值加载,完成数据的正常显示. ...

  4. mysql 钩子函数_pod 生命周期hook钩子函数

    参考: 0.如果没有设置钩子,pod如何删除 给pod里的每个容器中pid为1的进程发送 kill -9 (SIGTERM) 信号, 1.postStart 这个钩子在创建容器之后立即执行.但是,并不 ...

  5. Vue3.x中自定义时钟钩子函数(TypeScript语法)

      钩子函数的运用能使我们的代码更加简洁且易于维护,那么在Vue3.x中钩子函数的编写方式是怎样的呢?   下面,我以一个点击获取当前时间的例子来记录钩子函数的编写过程. 创建hooks目录   一般 ...

  6. gorm time.Time 使用钩子函数解决反序列化问题

    问题描述: gorm中使用下面的CreatedAt 和UpdateAt,可以实现在记录创建和更新时自动更新下面两个字段.虽然使用默认的json解析,从json中到golang中,从golang中写入到 ...

  7. Vue钩子函数之钩子事件hookEvent,监听组件

    在Vue当中,hooks可以作为一种event,在Vue的源码当中,称之为hookEvent. 在Vue组件中,可以用过$on,$once去监听所有的生命周期钩子函数,如监听组件的updated钩子函 ...

  8. Python钩子函数

    文章目录 python hook 机制 一. 概念 1. hook概述 2. hook 二. 示例 python hook 机制 一. 概念 1. hook概述 hook,又称钩子,在C/C++中一般 ...

  9. python keyboard hook_键盘监控的实现Ⅰ——Keyboard Hook API函数

    在实际应用中,键盘监控是一种很常见的技术,它包括按键的记录.按键的过滤.按键的修改(映射)等.比方说,我们想统计用户的击键情况,这个就是按键的记录:我们想屏蔽某些系统键(例如Alt键.Win键),这个 ...

最新文章

  1. 宿松长铺程集高中2021年高考成绩查询,2017宿松程集中学录取分数线(附2017高考成绩喜报)...
  2. 手机zip模拟器_【教程】萌新手机krkr2模拟器运行教程
  3. 数据仓库分层和元数据管理
  4. Oracle_视图_索引_plsql_游标_存储过程_存储函数_触发器
  5. ThinkPHP开发博客系统笔记之二
  6. 一线城市的繁荣vs年轻人的梦想?
  7. jsp 引入java类库报错_myeclipse中运行Jsp项目调用java,运行不了,报错说不能解析jsp中的类型,资源文件无法使用,求解,...
  8. Java完全自学手册pdf,由浅入深,循序渐进(1)
  9. github 仓库中文名_github仓库的使用
  10. Unity3D 通过脚本设置PlayerSettings的属性(GPU Skinning,Auto Graphics APi[OpenGLES2])等
  11. 淘宝店铺层级每个月更新么?如何提高淘宝店铺层级?
  12. 伯恩半导体 - ESD 选型指南
  13. 微信外包公司—北京动点软件:微信公众平台案例介绍
  14. 计算机网络实验二静态路由基础
  15. Android USB HID整理
  16. 基于面向对象 来写一个简单的贪吃蛇小游戏(代码可直接用)
  17. 从零开始搭建自己的网站二十一:网站IP/PV统计功能设计
  18. Mac版Ps、AE、PR不能突然使用?
  19. vmware 虚拟机使用windows的 http/socks 代理
  20. 【MATLAB】matlab中clc,close,close all,clear,clear all作用区别

热门文章

  1. java mysql 死锁,java-Spring JPA MySQL和死锁
  2. redis 内存不足 排查_一文深入了解 Redis 内存模型,Redis 的快是有原因的!
  3. tp的echo输出字符串后换行
  4. 微信小程序之redirectTo、switchTab和navigateTo
  5. c语言 库 键盘,python 函数 map 、lambda
  6. docker安装redis(最新)
  7. 高德地图 街道范围_高德地图发布交通“评诊治”系统:针对各类交通拥堵场景“因地制宜”...
  8. angular项目打包_vue项目部署的最佳实践
  9. 量子计算机新科技未来,能够“预测多个未来”的量子计算机诞生
  10. Promise 上手