文章目录

  • hook简介
  • PyTorch的四个hook
    • 1. torch.Tensor.register_hook(hook)
    • 2. torch.nn.Module.register_forward_hook
    • 3. torch.nn.Module.register_forward_pre_hook
    • 4.torch.nn.Module.register_backward_hook

本博文由TensorSense发表于PyTorch的hook及其在Grad-CAM中的应用,转载请注明出处。

hook简介

pytorch中的hook是一个非常有意思的概念,hook意为钩、挂钩、鱼钩。
引用知乎用户“马索萌”对hook的解释:“(hook)相当于插件。可以实现一些额外的功能,而又不用修改主体代码。把这些额外功能实现了挂在主代码上,所以叫钩子,很形象。”
简单讲,就是不修改主体,而实现额外功能。对应到在pytorch中,主体就是forward和backward,而额外的功能就是对模型的变量进行操作,如“提取”特征图,“提取”非叶子张量的梯度,修改张量梯度等等。

hook的出现与pytorch运算机制有关,pytorch在每一次运算结束后,会将中间变量释放,以节省内存空间,这些会被释放的变量包括非叶子张量的梯度,中间层的特征图等。但有时候,我们想可视化中间层的特征图,又不能改动模型主体代码,该怎么办呢?这时候就要用到hook了。
举个例子演示hook提取非叶子张量的梯度:

import torch
def grad_hook(grad):y_grad.append(grad)
y_grad = list()
x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
y = x+1
y.register_hook(grad_hook)
z = torch.mean(y*y)
z.backward()
print("type(y): ", type(y))
print("y.grad: ", y.grad)
print("y_grad[0]: ", y_grad[0])>>> ('type(y): ', <class 'torch.Tensor'>)
>>> ('y.grad: ', None)
>>> ('y_grad[0]: ', tensor([[1.0000, 1.5000],[2.0000, 2.5000]]))

可以看到y.grad的值为None,这是因为y是非叶子结点张量,在z.backward()完成之后,y的梯度被释放掉以节省内存,但可以通过torch.Tensor的类方法register_hook将y的梯度提取出来。

PyTorch的四个hook

PyTorch(1.1.0版)有如下4个hook:
torch.Tensor.register_hook (Python method, in torch.Tensor)
torch.nn.Module.register_forward_hook (Python method, in torch.nn)
torch.nn.Module.register_backward_hook (Python method, in torch.nn)
torch.nn.Module.register_forward_pre_hook (Python method, in torch.nn)

这4个hook中有一个是应用于tensor的,另外3个是针对nn.Module的。

1. torch.Tensor.register_hook(hook)

功能:注册一个反向传播hook函数,这个函数是Tensor类里的,当计算tensor的梯度时自动执行。
为什么是backward?因为这个hook是针对tensor的,tensor中的什么东西会在计算结束后释放呢?
只有gradient嘛,所以是 backward hook.

形式: hook(grad) -> Tensor or None ,其中grad就是这个tensor的梯度。

返回值:a handle that can be used to remove the added hook by calling handle.remove()

应用场景举例:在hook函数中可对梯度grad进行in-place操作,即可修改tensor的grad值。
这是一个很酷的功能,例如当浅层的梯度消失时,可以对浅层的梯度乘以一定的倍数,用来增大梯度;
还可以对梯度做截断,限制梯度在某一区间,防止过大的梯度对权值参数进行修改。
下面举两个例子,例1是如何获取中间变量y的梯度,例2是利用hook函数将变量x的梯度扩大2倍。

例1:

import torch
y_grad = list()
def grad_hook(grad):y_grad.append(grad)
x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
h = y.register_hook(grad_hook)
z.backward()
print("y.grad: ", y.grad)
print("y_grad[0]: ", y_grad[0])
h.remove()    # removes the hook>>> ('y.grad: ', None)
>>> ('y_grad[0]: ', tensor([0.2500, 0.2500, 0.2500, 0.2500]))

可以看到当z.backward()结束后,张量y中的grad为None,因为y是非叶子节点张量,在梯度反传结束之后,被释放。
在对张量y的hook函数(grad_hook)中,将y的梯度保存到了y_grad列表中,因此可以在z.backward()结束后,仍旧可以在y_grad[0]中读到y的梯度为tensor([0.2500, 0.2500, 0.2500, 0.2500])

例2:

import torch
def grad_hook(grad):grad *= 2
x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
h = x.register_hook(grad_hook)
z.backward()
print(x.grad)
h.remove()    # removes the hook>>> tensor([2., 2., 2., 2.])

原x的梯度为tensor([1., 1., 1., 1.]),经grad_hook操作后,梯度为tensor([2., 2., 2., 2.])。

2. torch.nn.Module.register_forward_hook

功能:Module前向传播中的hook,module在前向传播后,自动调用hook函数。
形式:hook(module, input, output) -> None。注意不能修改input和output
返回值:a handle that can be used to remove the added hook by calling handle.remove()
应用场景举例:用于提取特征图
举例:假设网络由卷积层conv1和池化层pool1构成,输入一张4*4的图片,现采用forward_hook获取module——conv1之后的feature maps,示意图如下:

import torch
import torch.nn as nn
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 2, 3)self.pool1 = nn.MaxPool2d(2, 2)def forward(self, x):x = self.conv1(x)x = self.pool1(x)return x
def farward_hook(module, data_input, data_output):fmap_block.append(data_output)input_block.append(data_input)
if __name__ == "__main__":# 初始化网络net = Net()net.conv1.weight[0].fill_(1)net.conv1.weight[1].fill_(2)net.conv1.bias.data.zero_()# 注册hookfmap_block = list()input_block = list()net.conv1.register_forward_hook(farward_hook)# inferencefake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * Woutput = net(fake_img)# 观察print("output shape: {}\noutput value: {}\n".format(output.shape, output))print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

首先初始化一个网络,卷积层有两个卷积核,权值分别为全1和全2,bias设置为0,池化层采用2*2的最大池化。
在进行forward之前对module——conv1注册了forward_hook函数,然后执行前向传播(output=net(fake_img)),当前向传播完成后,
fmap_block列表中的第一个元素就是conv1层输出的特征图了。
这里注意观察farward_hook函数有data_input和data_output两个变量,特征图是data_output这个变量,而data_input是conv1层的输入数据,
conv1层的输入是一个tuple的形式。

下面剖析一下module是怎么样调用hook函数的呢

  1. output = net(fake_img)
    net是一个module类,对module执行 module(input)是会调用module.call
  2. module.call
    在module.__call__中执行流程如下:
def __call__(self, *input, **kwargs):for hook in self._forward_pre_hooks.values():hook(self, input)if torch._C._get_tracing_state():result = self._slow_forward(*input, **kwargs)else:result = self.forward(*input, **kwargs)for hook in self._forward_hooks.values():hook_result = hook(self, input, result)if hook_result is not None:raise RuntimeError("forward hooks should never return any values, but '{}'""didn't return None".format(hook))...省略

首先判断module(这里是net)是否有forward_pre_hook,即在执行forward之前的hook;
然后执行forward;
forward结束之后才到forward_hook。
但是这里主要了,现在执行的是net.call,我们组成的hook是在module——net.conv1中,
所以第2个跳转是在net.__call__的 result = self.forward(*input, **kwargs)
3. net.forward

def forward(self, x):x = self.conv1(x)x = self.pool1(x)return x

在net.forward中,首先执行self.conv1(x), 而 conv1是一个nn.Conv2d(也是一个module类)。
在2中有说到,对module执行 module(input)是会调用module.call,因此第四步
4. nn.Conv2d.call
在nn.Conv2d.__call__中与2中说到的流程是一样的,再看一遍代码:

def __call__(self, *input, **kwargs):for hook in self._forward_pre_hooks.values():hook(self, input)if torch._C._get_tracing_state():result = self._slow_forward(*input, **kwargs)else:result = self.forward(*input, **kwargs)for hook in self._forward_hooks.values():hook_result = hook(self, input, result)if hook_result is not None:raise RuntimeError("forward hooks should never return any values, but '{}'""didn't return None".format(hook))

在这里终于要执行我们注册的forward_hook函数了,就在hook_result = hook(self, input, result)这里!
看到这里我们需要注意两点:

  1. hook_result = hook(self, input, result)中的input和result不可以修改!
    这里的input对应forward_hook函数中的data_input,result对应forward_hook函数中的data_output,在conv1中,input就是该层的输入数据,result就是经过conv1层操作之后的输出特征图。虽然可以通过hook来对这些数据操作,但是不能修改这些值,否则会破坏模型的计算。
  2. 注册的hook函数是不能带返回值的,否则抛出异常,这个可以从代码中看到
    if hook_result is not None:
    raise RuntimeError

总结一下调用流程:
net(fake_img) --> net.call : result = self.forward(*input, **kwargs) -->
net.forward: x = self.conv1(x) --> conv1.call:hook_result = hook(self, input, result)
hook就是我们注册的forward_hook函数了。

3. torch.nn.Module.register_forward_pre_hook

功能:执行forward()之前调用hook函数。
形式:hook(module, input) -> None
应用场景举例:暂时没碰到过,希望读者朋友补充register_forward_pre_hook相关应用场景。
register_forward_pre_hook与forward_hook一样,是在module.__call__中注册的,与forward_hook不同的是,其在module执行forward之前就运行了,具体可看module.__call__中的代码,第一行就是执行forward_pre_hook的相关操作。

4.torch.nn.Module.register_backward_hook

功能:Module反向传播中的hook,每次计算module的梯度后,自动调用hook函数。
形式:hook(module, grad_input, grad_output) -> Tensor or None
注意事项:当module有多个输入或输出时,grad_input和grad_output是一个tuple。
返回值:a handle that can be used to remove the added hook by calling handle.remove()
应用场景举例:例如提取特征图的梯度
举例:采用register_backward_hook实现特征图梯度的提取,并结合Grad-CAM(基于类梯度的类激活图可视化)方法对卷积神经网络的学习模式进行可视化。

关于Grad-CAM请看论文:《Grad-CAM Visual Explanations from Deep Networks via Gradient-based Localization》
简单介绍Grad-CAM的操作,Grad-CAM通过对最后一层特征图进行加权求和得到heatmap,整个CAM系列的主要研究就在于这个加权求和中的权值从那里来。

Grad-CAM是对特征图进行求梯度,将每一张特征图上的梯度求平均得到权值(特征图的梯度是element-wise的)。求梯度时并不采用网络的输出,而是采用类向量,即one-hot向量。
下图是ResNet的Grad-CAM示意图,上图类向量采用的是猫的标签,下图采用的是狗的标签,可以看到在上图模型更关注猫(红色部分),下图判别为狗的主要依据是狗的头部。

下面采用一个LeNet-5演示backward_hook在Grad-CAM中的应用。
简述代码过程:

  1. 创建网络net;
  2. 注册forward_hook函数用于提取最后一层特征图;
  3. 注册backward_hook函数用于提取类向量(one-hot)关于特征图的梯度;
  4. 对特征图的梯度进行求均值,并对特征图进行加权;
  5. 可视化heatmap。

代码位于PyTorch_Tutorial

需要注意的是在backward_hook函数中,grad_out是一个tuple类型的,要取得特征图的梯度需要这样grad_block.append(grad_out[0].detach())

这里对3张飞机的图片进行观察heatmap,如下图所示,第一行是原图,第二行是叠加了heatmap的图片。
这里发现一个有意思的现象,模型将图片判为飞机的依据是蓝天,而不是飞机(图1-3)。
那么我们喂给模型一张纯天蓝色的图片,模型会判为什么呢?如图4所示,发现模型判为了飞机

从这里发现,虽然能将飞机正确分类,但是它学到的却不是飞机的特征!
这导致模型的泛化性能大打折扣,从这里我们可以考虑采用trick让模型强制的学习到飞机而不是常与飞机一同出现的蓝天,或者是调整数据。

对于图4疑问:heatmap蓝色区域是否对图像完全不起作用呢?是否仅仅通过红色区域就可以对图像进行判别呢?
接下来将一辆正确分类的汽车图片(图5)叠加到图4蓝色响应区域(即模型并不关注的区域),结果如图6所示,汽车部分的响应值很小,模型仍通过天蓝色区域将图片判为了飞机。
接着又将汽车叠加到图4红色响应区域(图的右下角),结果如图7所示,仍将图片判为了飞机。
有意思的是将汽车叠加到图7的红色响应区域,模型把图片判为了船,而且红色响应区域是蓝色区域的下部分,这个与船在大海中的位置很接近!

通过以上代码学习backward_hook的使用及其在Grad-CAM中的应用,并通过Grad-CAM能诊断模型是否学习到了关键特征。
关于CAM( class activation maping,类激活响应图)是一个很有趣的研究,有兴趣的朋友可以对CAM、Grad-CAM和Grad-CAM++进行研究。

本博文由TensorSense发表于PyTorch的hook及其在Grad-CAM中的应用,转载请注明出处。

PyTorch的hook及其在Grad-CAM中的应用相关推荐

  1. Pytorch:NLP 迁移学习、NLP中的标准数据集、NLP中的常用预训练模型、加载和使用预训练模型、huggingface的transfomers微调脚本文件

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) run_glue.py微调脚本代码 python命令执行run ...

  2. PyTorch学习笔记(11)——论nn.Conv2d中的反向传播实现过程

    0. 前言 众所周知,反向传播(back propagation)算法 (Rumelhart et al., 1986c),经常简称为backprop,它允许来自代价函数的信息通过网络向后流动,以便计 ...

  3. 微型计算机在cad和cam中,重庆大学网络教育学院2013年9月份考试机械CAD/CAM第一次作业及答...

    时间:2013-09-05 09:10 来源:未知 作者:admin 点击: 次 2013年9月份考试机械CAD/CAM第一次作业 一.单项选择题(本大题共50分,共 25 小题,每小题 2 分) 1 ...

  4. 安装pytorch时,anaconda的Jupyter Notebook中出现实心圆,并且代码失效的解决办法

    最近我开始进行深度学习(Pytorch),需要用到anaconda中Jupyter Notebook的torch模块,因为之前闲暇时下载过anaconda,以为可以直接加载torch模块,是我太天真了 ...

  5. Pytorch获取中间变量的梯度grad

    为了节约显存,pytorch在反向传播的过程中只保留了计算图中的叶子结点的梯度值,而未保留中间节点的梯度 import torchx = torch.tensor(3., requires_grad= ...

  6. 【Pytorch神经网络理论篇】 20 神经网络中的注意力机制

    注意力机制可以使神经网络忽略不重要的特征向量,而重点计算有用的特征向量.在抛去无用特征对拟合结果于扰的同时,又提升了运算速度. 1 注意力机制 所谓Attention机制,便是聚焦于局部信息的机制,比 ...

  7. 微型计算机在cad和cam中,CAM CAD考试题

    机械CAD/CAM习题 第一章 CAD/CAM技术概述 选择题 1.下述CAD/CAM过程的操作中,属于CAD范畴的为( A ).CAD范畴几何 造型工程分析仿真模拟图形处理 A.模拟仿真 B.CAP ...

  8. Hook技术在APP测试中的应用

    在对APP进行安全检测和渗透测试的过程中,常会遇到APP采用一些安全防护措施,测试人员需要绕过这些安全防护措施才能开展后续的测试工作,例如环境安全检测.传输数据加密等.Hook技术可用来改变程序的执行 ...

  9. 封装一个hook,在Vue3 setup中使用Vuex中的mapState,mapGetters

    在Vue3中没有很好的方法使用Vuex中的映射函数到setup中使用,一般就使用一种效率低一些的方法 setup(props, context) {const store = useStore();c ...

最新文章

  1. 最佳学习方法(3)听课--听一反三
  2. css学习入门篇(1)
  3. 案例39-后台查询订单详情代码实现
  4. boost::callable_traits添加const成员的测试程序
  5. 线性及非线性方程组的解法
  6. django模板系统(下)
  7. 【OpenCV 例程200篇】84. 由低通滤波器得到高通滤波器
  8. 程序员,你还要迷茫多久?
  9. Blazeface 人脸检测器
  10. 设计模式-创建型模式-模板方法
  11. android java程序中调用shell命令
  12. TMOD、SCON、PCON寄存器的配置
  13. 第10课:图片管理模块
  14. linux桌面 英文,Linux桌面最好看的40+种英文Sans字体(2019版)
  15. MySQL数据库知识的总结
  16. PHP实现密钥分发中心,密钥分发中心(KDC)
  17. kubernetes入门之Downward API
  18. Windows 10x64 Pro Modified By Michael
  19. Android back按键基础开发
  20. 【栈和队列】栈的push、pop序列

热门文章

  1. 常用代码块:java使用系统浏览器打开url
  2. eclipse 快捷键收藏
  3. MyEclipse提示Errors occurred during the build
  4. fltk在UbuntuLinux中搭建和测试-《C++程序设计原理与实践》Chapter12:显示模型 环境构建...
  5. WCF,Net remoting,Web service
  6. 计算机求百钱买百鸡采用的算法,多种解法求百钱百鸡问题.doc
  7. hive高级数据类型
  8. String类的流程控制
  9. 编译fdk-aac for ios
  10. 老李推荐:第14章8节《MonkeyRunner源码剖析》 HierarchyViewer实现原理-获取控件列表并建立控件树 1...