为了节省显存(内存),pytorch在计算过程中不保存中间变量,包括中间层的特征图和非叶子张量的梯度等。有时对网络进行分析时需要查看或修改这些中间变量,此时就需要注册一个钩子(hook)来导出需要的中间变量。网上介绍这个的有不少,但我看了一圈,多少都有不准确或不易懂的地方,我这里再总结一下,给出实际用法和注意点。
hook方法有四种:
torch.Tensor.register_hook()
torch.nn.Module.register_forward_hook()
torch.nn.Module.register_backward_hook()
torch.nn.Module.register_forward_pre_hook().

1, torch.Tensor.register_hook(hook)

  用来导出指定张量的梯度,或修改这个梯度值。

import torch
def grad_hook(grad):grad *= 2
x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
h = x.register_hook(grad_hook)
z.backward()
print(x.grad)
h.remove()    # removes the hook
>>> tensor([2., 2., 2., 2.])

注意:(1)上述代码是有效的,但如果写成 grad = grad * 2就失效了,因为此时没有对grad进行本地操作,新的grad值没有传递给指定的梯度。保险起见,最好在def语句中写明return grad。即:

def grad_hook(grad):grad = grad * 2return grad

(2)可以用remove()方法取消hook。注意remove()必须在backward()之后,因为只有在执行backward()语句时,pytorch才开始计算梯度,而在x.register_hook(grad_hook)时它仅仅是"注册"了一个grad的钩子,此时并没有计算,而执行remove就取消了这个钩子,然后再backward()时钩子就不起作用了。
(3)如果在类中定义钩子函数,输入参数必须先加上self,即

def grad_hook(self, grad):...

2, torch.nn.Module.register_forward_hook(module, in, out)

  用来导出指定子模块(可以是层、模块等nn.Module类型)的输入输出张量,但只可修改输出,常用来导出或修改卷积特征图。

inps, outs = [],[]
def layer_hook(module, inp, out):inps.append(inp[0].data.cpu().numpy())outs.append(out.data.cpu().numpy())hook = net.layer1.register_forward_hook(layer_hook)
output = net(input)
hook.remove()

注意:(1)因为模块可以是多输入的,所以输入是tuple型的,需要先提取其中的Tensor再操作;输出是Tensor型的可直接用。
   (2)导出后不要放到显存上,除非你有A100。
   (3)只能修改输出out的值,不能修改输入inp的值(不能返回,本地修改也无效),修改时最好用return形式返回,如:

def layer_hook(self, module, inp, out):out = self.lam * out + (1 - self.lam) * out[self.indices]return out

  这段代码用在manifold mixup中,用来对中间层特征进行混合来实现数据增强,其中self.lam是一个[0,1]概率值,self.indices是shuffle后的序号。

3, torch.nn.Module.register_forward_pre_hook(module, in)

  用来导出或修改指定子模块的输入张量。

def pre_hook(module, inp):inp0 = inp[0]inp0 = inp0 * 2inp = tuple([inp0])return inphook = net.layer1.register_forward_pre_hook(pre_hook)
output = net(input)
hook.remove()

注意:(1)inp值是个tuple类型,所以需要先把其中的张量提取出来,再做其他操作,然后还要再转化为tuple返回。
(2)在执行output = net(input)时才会调用此句,remove()可放在调用后用来取消钩子。

4, torch.nn.Module.register_backward_hook(module, grad_in, grad_out)

  用来导出指定子模块的输入输出张量的梯度,但只可修改输入张量的梯度(即只能返回gin),输出张量梯度不可修改。

gouts = []
def backward_hook(module, gin, gout):print(len(gin),len(gout))gouts.append(gout[0].data.cpu().numpy())gin0,gin1,gin2 = gingin1 = gin1*2gin2 = gin2*3gin = tuple([gin0,gin1,gin2])return ginhook = net.layer1.register_backward_hook(backward_hook)
loss.backward()
hook.remove()

注意:
(1)其中的grad_in和grad_out都是tuple,必须要先解开,修改时执行操作后再重新放回tuple返回。
(2)这个钩子函数在backward()语句中被调用,所以remove()要放在backward()之后用来取消钩子。

【pytorch学习】四种钩子方法(register_forward_hook等)的用法和注意点相关推荐

  1. python中如何创建一个空列表_Python学习笔记(1):列表的四种创建方法

    我的电脑安装的是Anaconda 3开源的Python发行版本,其中是集合3.6版本的Python与可视化编程工具采用的是Spyder. 打开Spyder可视化工具,新建一个空白文件,做好备注为&qu ...

  2. 流形学习的四种降维方法

    文章目录 流形学习 主成分分析(PCA) 原理 实现 手写版 调库版 缺点 奇异值分解(SVD) 原理 实现 线性判别分析(LDA) 原理 手写版 调库版 PCA与LDA 局部线性嵌入(LLE) 原理 ...

  3. 计算机无法连接无线信号,win7系统连接无线信号时提示Windows无法连接到路由器名称的四种解决方法...

    现如今网络发展速度非常快,无线网络已经普及了,使用率高了遇到的问题也就多了.比如有时候笔记本win7系统连接无线信号时出现"Windows无法连接到路由器名称"(如下图所示),该如 ...

  4. 大数据可视化python_大数据分析之Python数据可视化的四种简易方法

    本篇文章探讨了大数据分析之Python数据可视化的四种简易方法,希望阅读本篇文章以后大家有所收获,帮助大家对相关内容的理解更加深入. < 数据可视化是任何数据科学或机器学习项目的一个重要组成部分 ...

  5. 怎么将file转换为html,怎么将PDF文件转换为HTML?分享四种实用方法!

    原标题:怎么将PDF文件转换为HTML?分享四种实用方法! 在我们日常学习和日常工作中,如果想要将PDF文件转换为HTML文件要怎么办呢?随着需求的增加,我们需要会的技能也要增加了.不止要将PDF文件 ...

  6. 计算机桌面都有说明,电脑桌面上所有图标都消失了的四种处理方法

    有些小伙伴们还不会处理电脑桌面上所有图标都消失了的问题,今天小编就带来了关于电脑桌面上所有图标都消失了的四种处理方法.快来学习吧! 电脑桌面上所有图标都消失了的四种处理方法 方法一:首先我们要看桌面上 ...

  7. C++/python描述 898. 数字三角形 (四种实现方法)

    C++/python描述 898. 数字三角形 (四种实现方法)   大家好,我叫亓官劼(qí guān jié ),在CSDN中记录学习的点滴历程,时光荏苒,未来可期,加油~博主目前仅在CSDN中写 ...

  8. python 财务分析可视化方法_Python数据可视化的四种简易方法

    Python数据可视化的四种简易方法 作者:PHPYuan 时间:2018-11-28 03:40:43 摘要: 本文讲述了热图.二维密度图.蜘蛛图.树形图这四种Python数据可视化方法. 数据可视 ...

  9. 产品设计中多见的四种倒角方法

    在工业设计中,对产品外观设计特别是关键点的把握,基本上离不开一个专业术语--倒角.无论是手绘画外观设计或是三维外观,都需要把握倒角的应用. 1.倒角定义 在机械设备制造中,倒角就是指将铸件的边角切割成 ...

最新文章

  1. React router 的 Route 中 component 和 render 属性理解
  2. Perforce使用之创建DEPOT流程
  3. linux 第十五章 shell 脚本习题
  4. OpenCV 使用方向梯度直方图估计图像旋转角度
  5. 做为 iOS 开发者 现在对未来迷茫怎么办?
  6. [攻防世界 pwn]——welpwn
  7. 现实版“奇异博士”?原来是这款神秘的“数学黑盒”
  8. jsonview浏览器插件 查看格式化json数据
  9. javaweb学习总结—jsp简单标签标签库开发
  10. 使用cacti监控CISCO交换机
  11. Atitit.web 视频播放器classid clsid 大总结quicktime,vlc 1. Classid的用处。用来指定播放器 1 2. object 标签用于包含对象,比如图像、音
  12. 多元线性回归实现代码
  13. oracle job定时报错,Oracle定时任务Job笔记
  14. linux51单片机烧录程序,单片机成长之路(51基础篇) - 006 在Linux下搭建51单片机的开发烧写环境...
  15. 模电——电源与地之间串联电容的作用
  16. 从PCC到MIC(2)
  17. uniapp 跳转外部链接
  18. 漫画:位运算技巧助你俘获offer
  19. 男友是程序员,看着他压力大我难受。有哪些缓解压力的好方法?
  20. win10出现打印机无法打印,而其他显示正常,重启没反应

热门文章

  1. P2024 食物链 (补集)
  2. Obj文件和Bin文件
  3. pytoch word_language_model 代码阅读
  4. 数据结构——HDU1312:Red and Black(DFS)
  5. Thread如何中断
  6. 去除表单元素的默认样式
  7. 今週木曜日までの日程表
  8. mysql 数据库编程_MySQL数据库编程(C++语言)
  9. 腾讯云Centos升级python2到python3
  10. oracle 更新记录语句,Oracle语句自动判断是要更新记录还是要插入记录