一、Hook函数概念

Hook函数机制:不改变模型主体,实现额外功能,像一个挂件或挂钩等。

为什么需要这个函数呢?这与Pytorch的动态图计算机制有关,在动态图的计算过程中,一些中间变量会释放掉,比如特征图、非叶子节点的梯度,在模型前向传播、反向传播的时候添加hook这个额外函数,提取一些释放掉而后面又需要用到的变量,也可以用hook函数来改变中间变量的梯度。

Pytorch中提供四种hook函数:
1、torch.Tensor.register_hook(hook): 针对tensor
2、torch.nn.Module.register_forward_hook:后面这三个针对Module
3、torch.nn.Module.register_forward_pre_hook
4、torch.nn.Module.register_backward_hook

二、Hook函数与特征提取

1、torch.Tensor.register_hook()

功能:这是一个针对张量的hook函数,作用是注册一个反向传播的hook函数,为什么是在反向传播呢?因为只有在反向传播过程中非叶子的梯度会释放掉,用hook函数来保存这些中间变量的信息。

hook(grad) -> Tensor or None

hook函数仅有一个输入参数为张量的梯度,返回值是tensor或者none
例如:
下图是pytorch中一个简单的计算图与梯度求导

在上面计算图反响传播过程中,非叶子节点a和b的梯度会释放掉,在前面的学习中可知retain_grad()可保留参数的梯度,也可用hook函数来保留梯度,如下所示:

# 构建计算图,在反向传播中用hook来保存a的梯度
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)# 构建一个list用来存储a的梯度
a_grad = list()# 自定义hook函数,存放a的梯度,然后将a的梯度存放到前面构建的list中
def grad_hook(grad):a_grad.append(grad)# 接受一个hook函数的钩子,相当于把hook函数挂到计算图上,这样在反向传播时可以保存a的梯度
handle = a.register_hook(grad_hook)y.backward()# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()

输出结果:

可看出在反向传播结束后是将a和b的梯度释放掉了,而hook函数则是保留了a的梯度,这样可以方便后续的使用。另外hook函数可以在反向传播中改变节点的梯度值,如下:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)a_grad = list()# 改变节点的梯度值,在hook里可以实现具体的改变方式,并用return返回
def grad_hook(grad):grad *= 2return grad*3handle = w.register_hook(grad_hook)y.backward()# 查看梯度
print("w.grad: ", w.grad)
handle.remove()

输出结果:

通过hook函数的变化之后w的梯度变为原来的6倍。

2、Module.register_forward_hook

hook(module, input, output) -> None

功能:注册module前向传播的hook函数
model:当前的网络层
input:当前网络层输入的数据
output:当前网络层输出数据

3、Module.register_forward_pre_hook

hook(module, input) -> None

功能:注册module前向传播的hook函数
module:当前的网络层
input:当前网络层的输入数据
因为这个hook函数是用在前向传播前的函数,所以这里接受参数之后就没有返回值,这个功能可以查看网络之前的数据。

4、Module.register_backward_hook

hook(module, grad_input, grad_output) -> Tensor or None

功能:注册module反向传播的hook函数
module:当前网络层
grad_input:当前网络层的输入梯度数据
grad_output:当前网络层的输出梯度数据

以上就是Pytorch中的hook函数,第一个是针对tensor,后三个是针对module,根据hook函数的使用位置可分为前向传播前,前向传播,反向传播。下面通过具体的示例来了解一下:

假设输入是44的图像经过33的卷积之后得到2*2的feature map,然后经过池化得到后面的输出值,下面就用hook函数来获取中间的feature map层

#根据上图的示例,构建一个网络,只有卷积和池化两个操作
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 2, 3)self.pool1 = nn.MaxPool2d(2, 2)def forward(self, x):x = self.conv1(x)x = self.pool1(x)return x#定义前向传播的hook函数
def forward_hook(module, data_input, data_output):fmap_block.append(data_output)input_block.append(data_input)#定义前向传播前的hook函数
def forward_pre_hook(module, data_input):print("forward_pre_hook input:{}".format(data_input))#定义反向传播的hook函数
def backward_hook(module, grad_input, grad_output):print("backward hook input:{}".format(grad_input))print("backward hook output:{}".format(grad_output))# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()# 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)# inference
fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
output = net(fake_img)loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()# 观察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

输出结果:

在output = net(fake_img)处打上断点,查看一下上面的三个hook函数是如何实现的。debug进入到module.py中的_call_impl函数中,在这里会调用前向传播函数:

进一步debug,会进入到构建的网络前向传播函数中

这是网络的第一个字模块,也就是卷积模块,这里也是定义钩子的地方,进一步debug,可看到又进入到了module.py文件中的_call_impl函数中,仔细观察_call_impl函数可看到主要有四个模块

在上面的示例中设置了三个钩子,分别是前向传播之前,前向传播,反向传播,在不同的过程中会调用_call_impl函数的对应的模块,比如forward_pre_hook钩子会对应上面的第一个模块,然后在result=hook(self, input)会跳到自定义的钩子函数中。继续debug可看到其他钩子也是如此。

【总结】
上面的hook函数的运行机制,都是在module中的_call_impl函数中实现,这个函数完成了4部分的工作,前向传播之前的hook函数(这里钩子主要是查看输入数据的信息),前向传播,forward hook函数(这里的钩子接受参数的输入和输出,存储中间特征图的信息),backward hook函数(这里的钩子常是查看参数的梯度信息)。总体来说hook机制就是在计算图上挂一些钩子,然后钩子上定义一些函数,在不改变模型或者计算图主体的情况下,提供了一些实现别的额外功能的接口。

三、CAM可视化

CAM:类激活图, class activation map。主要功能就是分析卷积神经网络,图像通过卷积神经网络得到了输出之后,可以分析网络是关注图像的哪些部分而得到的这个结果。通过这个可以分析出网络是否学习到了图片中物体本身的特征信息, 如下所示的过程图:

论文:《Learning Deep Features for Discriminative Localization》

上面网络最后的输出是澳大利亚犬种。那么网络从图像中看到了什么东西才确定是这一个类呢?这里通过CAM算法进行一个可视化,结果就如图中所示。红色的就是网络重点关注的, 在这个结果中看以发现,这个网络重点关注了狗的头部,最后判定是一个这样的犬种。

CAM的基本思想:它会对网络的最后一个特征图进行加权求和,就可以得到一个注意力机制,就是卷积神经网络更关注于什么地方。那如何得到这些特征图的权值呢?对每一个feature map进行golbal average pooling就得到其对应的权值,再通过加权求和最后得到 class activation map。

缺点:CAM是通过golbal average pooling得到权值的,如果输入值改变就得重新训练网络得到权重值,所以就有了如下的改进算法

Grad-CAM:CAM改进版,利用梯度作为特征图权重

具体思想:根据最后网络输出的向量值进行backward,求出feature map中每一个像素值对应的梯度值,将feature map每一个像素值对应的梯度值进行平均,将梯度的平均值作为此feature map的权重值,然后进行加权求和得到 CAM.
论文:Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

pytorch学习笔记十五:Hook函数与CAM可视化相关推荐

  1. python复制指定字符串_python3.4学习笔记(十五) 字符串操作(string替换、删除、截取、复制、连接、比较、查找、包含、大小写转换、分割等)...

    python3.4学习笔记(十五) 字符串操作(string替换.删除.截取.复制.连接.比较.查找.包含.大小写转换.分割等) python print 不换行(在后面加上,end=''),prin ...

  2. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  3. windows内核开发学习笔记十五:IRP结构

    windows内核开发学习笔记十五:IRP结构   IRP(I/O Request Package)在windows内核中,有一种系统组件--IRP,即输入输出请求包.当上层应用程序需要访问底层输入输 ...

  4. 《Go语言圣经》学习笔记 第五章函数

    <Go语言圣经>学习笔记 第五章 函数 目录 函数声明 递归 多返回值 匿名函数 可变参数 Deferred函数 Panic异常 Recover捕获异常 注:学习<Go语言圣经> ...

  5. Polyworks脚本开发学习笔记(十五)-用Python连接Polyworks的COM组件

    Polyworks脚本开发学习笔记(十五)-用Python连接Polyworks的COM组件 用Polyworks脚本开发,没有高级语言的支持,功能难免单一,一些比较复杂的交互实现不了,界面和报告也很 ...

  6. JS学习笔记(五)函数类型、箭头函数、arguments参数、标签函数

    JS学习笔记(五) 本系列更多文章,可以查看专栏 JS学习笔记 文章目录 JS学习笔记(五) 一.函数 1. 函数定义 2. 方法( 对象 + 函数 ) 二.函数参数及返回值 1. 传递原始类型参数 ...

  7. IOS之学习笔记十五(协议和委托的使用)

    1.协议和委托的使用 1).协议可以看下我的这篇博客 IOS之学习笔记十四(协议的定义和实现) https://blog.csdn.net/u011068702/article/details/809 ...

  8. Mr.J-- jQuery学习笔记(十五)--实现页面的对联广告

    请看之前的:Mr.J-- jQuery学习笔记(十四)--动画显示隐藏 话不多说,直接上demo <!DOCTYPE html> <html lang="en"& ...

  9. 世界是有生命的(通向财富自由之路学习笔记十五)

    最近因为工作调度的事情,有了一段空闲的日子,有比较多的时间来回望自己走过的路以及如何走好以后的路.之前忙得很少时间来写博文,很少时间来写读书笔记,逐渐将自己一些很好的习惯丢弃了.从今天起将重拾写博文的 ...

最新文章

  1. 再论JavaScript原型继承和对象继承
  2. SAP CRM和Cloud for Customer中的Event handler(事件处理器)
  3. 1.2 xss原理分析与剖析(3)
  4. sql子查询示例_学习SQL:SQL查询示例
  5. url安全处理函数+php,php常用的url处理函数汇总
  6. 数据分析第一步 | 做好数据埋点
  7. 【解题报告】【HODJ1231】【最大子序列和】最大连续子序列
  8. 李子奈计量经济学笔记和课后习题答案
  9. Gym - 101808K Another Shortest Path Problem (Damascus University Collegiate)【并查集+LCA】
  10. python-docx 复制一页_python 怎么用docx读取word的某一页然后放到新的word文档中?...
  11. 统计 | 几种特殊随机变量的分布
  12. 算法注册机编写扫盲---第二课
  13. 手机照片局部放大镜_怎样发照片才能惊艳朋友圈?
  14. win10安装xshell免费版
  15. 请19级的童鞋们接收一下
  16. 《创新思维设计》自学报告#2 | 设计思维的特征
  17. npm 安装 node-sass 失败问题分析及解决方案
  18. 承香墨影的行业周报-0x0010
  19. ORAN专题系列-18:5G O-RAN FrontHaul前传接口互操作性测试规范IOT概述与总体架构
  20. spring-boot整合redies、mybatis、thymeleaf

热门文章

  1. Jenkins流水线整合钉钉
  2. selenium 教程 汇总
  3. android 微票效果,Android ShimmerLayout实现微光效果解析
  4. . NET6 Core 日志组件Log4net和Nlog
  5. html设置表单透明度,css利用transparent属性设置透明度的方法
  6. 九九加发表和九九乘法表。
  7. HTML水平垂直居中的四种方式
  8. Spring初级入门(一)--易百教程
  9. iis高并发 大量数据并发设置
  10. 车用摄像头的一个应用(想法)