何为叶子节点和非叶子节点

在理解register_hook之前,首先得搞懂什么叶子节点和非叶子节。简单来说叶子节点是有梯度且独立得张量,例如a = torch.tensor(2.0,requires_grad=True),b= torch.tensor(3.0,requires_grad=True),非叶子节点是依赖其他张量而得到得张量如c = a+b
判断是叶子节点还是非叶子节点可以使用 is_leaf来判断一个张量是叶子节点还是非叶子节点。

import torch
a = torch.tensor(2.0,requires_grad=True)
b = torch.tensor(3.0,requires_grad=True)
print(a.is_leaf)
print(b.is_leaf)
c = a +b
print(c.is_leaf)>>> True
>>> True
>>> False

中间张量 c 作为非叶子节点是没有梯度信息得。pytorch默认在梯度反向传播过程中不会记录中间变量梯度信息。而且叶子节点的梯度信息在反向传播流过程中是不允许我们修改的。只能通过print(a.grad)查看张量的梯度信息。
那么,如果我们想查看中间变量 c 以及想改变叶子节点反向传播过程中的梯度值,应该怎么办呢。这时候就要使用register_hook这个钩子函数了。通过一下两段代码看一下钩子函数的主要作用。

register_hook

a = torch.tensor(2.0,requires_grad=True)
b = torch.tensor(3.0,requires_grad=True)
print(a.grad)
print(b.grad)
c = a*b
print(c.grad)  # 由于c是叶子节点,所以他是不记录梯度信息得。前后打印梯度信息都为Noned = torch.tensor(4.0,requires_grad=True)
e = c * d
e.backward()
print(a.grad)
print(b.grad)
print(c.grad)>>>输出
None
None
None
tensor(12.)
tensor(8.)
None

通过上面代码可以看出,c作为中间变量在反向传播过程中不记录梯度信息。c=a*b其中a的梯度就为b的值,b的梯度就是a的值。接下来对中间变量c 使用register_hook,这个函数传入的参数得是一个函数。

import torcha = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)c = a * bdef c_hook(grad):print("c_hook",grad)return grad + 2    # 什么也不返回的话用的是和之前一样的梯度,不对其进行变化。# 在c中,钩子按照有序字典的方式存储,按照存储的前后一次调用
c.register_hook(c_hook)
c.register_hook(lambda grad: print("hello my grad is",grad))
c.retain_grad()   # 存储中间变量的梯度print(a.grad)
print(b.grad)
print(c.grad)c.backward()print(a.grad)
print(b.grad)
print(c.grad)>>>
None
None
None
c_hook tensor(1.)
hello my grad is tensor(3.)
tensor(9.)
tensor(6.)
tensor(3.)

为什么输出会是这样的结果呢,一个张量可以注册多个钩子函数,反向传播过程中按照注册的顺序依次运行。 c.register_hook(c_hook) c.register_hook(lambda grad:)
,这两个函数可以重写c的梯度,第一个函数传入的参数是c的梯度,自身对自身的梯度pytorch中默认为1。所以此时c_hook中传入的grad=1,这个函数返回值为grad+2=3,此时会重写中间变量c的梯度信息。第二个钩子函数传入的函数为匿名函数,这个匿名函数对c的梯度没有进行重写,使用的还是上一个钩子函数重写的值,此使打印信息就为3。最后通过c.retain_grad()记c的梯度信息。通过这个例子,我稍微懂了点register_hook这个钩子函数的作用,是不是本来不可修改的梯度信息值,通过这个函数修改了呢。

通过一下这个例子比较再来看一下registe_hook函数的作用。

import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)c = a * bdef c_hook(grad):print("c_hook",grad)return grad + 2    # 什么也不返回的话用的是和之前一样的梯度,不对其进行变化。# 在c中,钩子按照有序字典的方式存储,按照存储的前后一次调用
c.register_hook(c_hook)
c.register_hook(lambda grad: print("hello my grad is",grad))
c.retain_grad()   # 存储中间变量的梯度d = torch.tensor(4.0, requires_grad=True)
d.register_hook(lambda grad: grad + 100)  # 将使用100+grad代替本来返回得梯度值e = c * dprint(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(e.grad)# e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad()e.backward()print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(e.grad)>>>输出
None
None
None
None
None
c_hook tensor(8.)
hello my grad is tensor(10.)
tensor(30.)
tensor(20.)
tensor(10.)
tensor(112.)
tensor(2.)

这段代码前部分和前面的代码保持一致,后面添加了e = c * d,在反向传播前,毋庸置疑a,b,,c,d,e的梯度都为None。反向传播过程中首先看 e,自身对自身的倒数默认为1,但是e注册的钩子将对原本的梯度 * 2 ,来替代原先的梯度信息,所以打印出的e的梯度信息为2。相应的,e 对 c的梯度信息相应的就变为 2d=8,e对d的梯度信息就变为 2c=12,案例说此使d的梯度信息为12,为什么是112呢,可以看出d注册了一个钩子函数,这个钩子给d原本的梯度信息加了100,来代替旧的梯度信息,所以d的梯度信息为112。由于c注册的钩子函数给他加了2,所以c的梯度信息为10。相应的a b 的梯度就都要乘以c 的梯度信息了。 同样,原本不变的梯度信息值在这里都根据register_hook这个函数相应的被重写。

以上就是我根据视频链接对register_hook的理解。

register_forward_hook

register_forward_hook register_forward_pre_hook这个函数主要使用在nn.Module网络中。
第一个函数看名称是用在网络forward之前,第二个是运行在forward之后,举例:

import torch
import torch.nn as nnclass SumNet(nn.Module):def __init__(self):super(SumNet, self).__init__()@staticmethoddef forward(a, b, c):d = a + b + cprint('forward():')print('    a:', a)print('    b:', b)print('    c:', c)print()print('    d:', d)print()return ddef forward_pre_hook(module, input_positional_args):a, b, c = input_positional_argsnew_input_positional_args = a + 10, b,c+10print('forward_pre_hook():')print('    module:', module)print('    input_positional_args:', input_positional_args)print()print('    new_input_positional_args:', new_input_positional_args)print()return new_input_positional_argsdef forward_hook(module, input_positional_args, output):new_output = output + 100print('forward_hook():')print('    module:', module)print('    input_positional_args:', input_positional_args)print('    output:', output)print()print('    new_output:', new_output)print()return new_outputdef main():sum_net = SumNet()sum_net.register_forward_pre_hook(forward_pre_hook)sum_net.register_forward_hook(forward_hook)a = torch.tensor(1.0, requires_grad=True)b = torch.tensor(2.0, requires_grad=True)c = torch.tensor(3.0, requires_grad=True)print('start')print()print('a:', a)print('b:', b)print('c:', c)print()print('before model')print()d = sum_net(a, b, c)   # 前向传播得时候钩子函数起作用了,先是forward_pre_hook,接下来是forward,接下来是forward_hook函数。print('after model')print()print('d:', d)if __name__ == '__main__':main()

输出信息:

starta: tensor(1., requires_grad=True)
b: tensor(2., requires_grad=True)
c: tensor(3., requires_grad=True)before modelforward_pre_hook():module: SumNet()input_positional_args: (tensor(1., requires_grad=True), tensor(2., requires_grad=True), tensor(3., requires_grad=True))new_input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True), tensor(13., grad_fn=<AddBackward0>))forward():a: tensor(11., grad_fn=<AddBackward0>)b: tensor(2., requires_grad=True)c: tensor(13., grad_fn=<AddBackward0>)d: tensor(26., grad_fn=<AddBackward0>)forward_hook():module: SumNet()input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True), tensor(13., grad_fn=<AddBackward0>))output: tensor(26., grad_fn=<AddBackward0>)new_output: tensor(126., grad_fn=<AddBackward0>)after modeld: tensor(126., grad_fn=<AddBackward0>)

分析以上为什么会输出这样的结果,前面提到register_forward_hook这个函数会在网络前向传播前运行,需要两个参数modul 和 input案例中输入为 tensor 1 2 3,经过这个函数给2 3 分别加了10,并且返回了一组新的值,这组值是要传入forward中,可以看出,forward函数打印的a b c 为传入的这组新值,而不是刚开始定义的1 2 3,forward函数运行过程中返回每层的输出会运行forward_hook函数。这个函数主要需要三个参数,module input output
以下从Lenet网络来使用这个函数:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):out = self.conv1(x)out = F.relu(out)     out = F.max_pool2d(out, 2)      out = self.conv2(out)out = F.relu(out)  out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out
model = LeNet()# 分别对model的第一个卷积层和最后一层使用了钩子函数,这样既可以取出对应层的输出。
def hook(model,input_,output):print("最后一层输出:",output.shape)def conv_hook(model,input_,output):print("conv1后",input_[0].shape,output.shape)model.register_forward_hook(hook)
model.conv1.register_forward_hook(conv_hook)img = torch.randn([1,3,32,32])
out_put = model(img)>>>
conv1后 torch.Size([1, 3, 32, 32]) torch.Size([1, 6, 28, 28])
最后一层输出: torch.Size([1, 10])

基于上可以看出给不同层使用钩子函数,可以提取出每一层的输出,并进行相应的处理。

以上就是pytorch中register_hookregister_forward_hook的基本理解。
如果有问题烦请指出加以改正。

pytorch中register_hook以及register_forward_hook相关推荐

  1. 关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

    关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题 Hook 是 PyTorch 中一个十分有用的特性.利用它,我们可以不必改变网络输入输出的结构, ...

  2. tensor torch 构造_详解Pytorch中的网络构造

    背景 在PyTroch框架中,如果要自定义一个Net(网络,或者model,在本文中,model和Net拥有同样的意思),通常需要继承自nn.Module然后实现自己的layer.比如,在下面的示例中 ...

  3. pytorch中使用TensorBoard进行可视化Loss及特征图

    pytorch中使用TensorBoard进行可视化Loss及特征图 安装导入TensorBoard 安装TensorBoard pip install tensorboard 导入TensorBoa ...

  4. Pytorch 中retain_graph的用法

    Pytorch 中retain_graph的用法 用法分析 在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么? #################### ...

  5. pytorch中调整学习率的lr_scheduler机制

    pytorch中调整学习率的lr_scheduler机制 </h1><div class="clear"></div><div class ...

  6. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  7. PyTorch中的MIT ADE20K数据集的语义分割

    PyTorch中的MIT ADE20K数据集的语义分割 代码地址:https://github.com/CSAILVision/semantic-segmentation-pytorch Semant ...

  8. PyTorch中nn.Module类中__call__方法介绍

    在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...

  9. 利用 AssemblyAI 在 PyTorch 中建立端到端的语音识别模型

    作者 | Comet 译者 | 天道酬勤,责编 | Carol 出品 | AI 科技大本营(ID:rgznai100) 这篇文章是由AssemblyAI的机器学习研究工程师Michael Nguyen ...

最新文章

  1. tcp connection setup的实现
  2. mysql5.6配置semi_sync
  3. [CommunityServer]看RBAC的一方景象
  4. Python+django网页设计入门(18):自定义模板过滤器
  5. 小白设计模式:责任链模式
  6. 【公众号】微信第三方登录(静默授权和非静默授权)(具体代码:U盘 新浪云SAE)...
  7. [转载]直接保存Matlab图像到PPT文件
  8. xss绕过尖括号和双括号_【Web安全入门】三个技巧教你玩转XSS漏洞
  9. 有关设计网站的收藏集合
  10. 国际直拨电话号码格式
  11. 第九届GIS应用技能大赛上午(试题及答案含数据)
  12. 测试人生 | 转行测试开发,4年4“跳”年薪涨3倍,我的目标是星辰大海(附大厂面经)!
  13. 图解GC(垃圾回收)复制算法加强版(1)Cheney的复制算法
  14. 加入共享宽带,让你的闲置宽带循环利用再变现
  15. python 常见日期转换、excel时间转化、日期加N天、减N天等操作
  16. 人口会一直增长下去吗_现在世界人口约多少亿 世界人口会一直增加吗还是越来越少...
  17. QMessageBox 中的 OK 按钮改为中文“确定”
  18. netstat 命令用法详解
  19. eclipse安装springboot插件
  20. Verilog 每日一题 (VL5 信号发生器)

热门文章

  1. GitHub 优秀的开源项目学习
  2. 软考论文分享--论项目的沟通管理
  3. java基于springboot+vue的协同过滤算法的图书推荐系统 nodejs
  4. P2770【USACO 2014 January Gold】难度系数
  5. 非对称加密和对称加密
  6. android telephonymanager 电话状态,TelephonyManager类:Android手机及Sim卡状态的获取
  7. 怎样设计完整的交易系统(主观交易和程序化交易均可借鉴)
  8. POJ 1849 Two(树的直径+思维)
  9. 一文了解各大图数据库查询语言(Gremlin vs Cypher vs nGQL)| 操作入门篇
  10. 什么样的企业需要舆情优化系统?什么样的企业需要手工监测?