首先明确一点,hook是什么呢?翻译出来是"钩子",顾名思义,就是挂在某个东西上的一种即插即用的结构。

有哪些hook呢?
常用的主要有3个:
1. torch.autograd.Variable.register_hook (in Automatic differentiation package)
2. torch.nn.Module.register_backward_hook (in torch.nn.Module)
3. torch.nn.Module.register_forward_hook (in torch.nn.Module)
第一个是register_hook,是针对Variable对象的(可以理解为挂在Variable用来提取某些信息的一种结构); 后面的两个register_backward_hook和register_forward_hook是针对nn.Module(挂在Module上的)这个对象的。

其次,我们为何要用hook呢或者说hook有什么作用呢?

举个例子,比如有这么一个函数,

,你想通过梯度下降法来求得函数的极小值 (或者最小值)。这在pytorch里很容易实现:
import 

在pytorch的计算图中,只有叶子节点 (leaf node)是可以被追踪梯度的(即requires_grad=True),中间节点的梯度为了节省内存的原因而没有被保留 (对于中间变量,一旦它们完成了自身反传的使命,梯度就会被释放掉 ),因此当我们输出中间变量y的梯度的时候:

y.grad

系统会返回None。那怎么办呢?

因此,hook就派上用场了。简而言之,register_hook( )的作用是,当反传时,除了完成原有的反传,额外多完成一些任务。你可以定义一个中间变量的hook,将它的grad值打印出来,当然你也可以定义一个列表,将每次的grad值添加到里面去保留起来。

import torch
from torch.autograd import Variablegrad_list = []def print_grad(grad):grad_list.append(grad)x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data

需要注意的是,register_hook函数接收的是一个函数(函数名),这个函数有如下的形式:

hook(grad) -> Variable or None

PS:这个函数是可以改变被执行变量梯度的!

v = torch.tensor([1, 1, 1], dtype = torch.float32, requires_grad=True)
u = torch.pow(v, 2)
z = torch.mean(u)
# register hook for u
h = u.register_hook(lambda grad: print(2 * grad))  # double the gradient
z.backward()
h.remove()  # removes the hook

系统输出:tensor([0.6667, 0.6667, 0.6667])

这个函数返回一个句柄h, 它有一个方法h.remove( ), 可以使用这个方法将hook从变量u上"卸载"下来。

卸载pytorch_Pytorch中的hook的使用详解相关推荐

  1. eclipse配置python开发环境_Eclipse中配置python开发环境详解

    Eclipse中配置python开发环境详解 1.下载python安装包.python-2.6.6.msi.并安装. 默认python会安装在C:\Python26下,查看环境变量,如果没有在path ...

  2. php调用linux摄像头,Linux_Linux中开发USB摄像头驱动详解,USB摄像头以其良好的性能和低 - phpStudy...

    Linux中开发USB摄像头驱动详解 USB摄像头以其良好的性能和低廉的价格得到广泛应用.同时因其灵活.方便的特性,易于集成到嵌入式系统中.但是如果使用现有的符合Video for Linux标准的驱 ...

  3. Linux中/proc目录下文件详解

    Linux中/proc目录下文件详解(一) 声明:可以自由转载本文,但请务必保留本文的完整性. 作者:张子坚 email:zhangzijian@163.com 说明:本文所涉及示例均在fedora ...

  4. python创建列向量_关于Numpy中的行向量和列向量详解

    关于Numpy中的行向量和列向量详解 行向量 方式1 import numpy as np b=np.array([1,2,3]).reshape((1,-1)) print(b,b.shape) 结 ...

  5. jQuery中getJSON跨域原理详解

    详见:http://blog.yemou.net/article/query/info/tytfjhfascvhzxcytp28 jQuery中getJSON跨域原理详解 前几天我再开发一个叫 河蟹工 ...

  6. java mod %区别_Java中 % 与Math.floorMod() 区别详解

    %为取余(rem),Math.floorMod()为取模(mod) 取余取模有什么区别呢? 对于整型数a,b来说,取模运算或者取余运算的方法都是: 1.求 整数商: c = a/b; 2.计算模或者余 ...

  7. python的执行过程_在交互式环境中执行Python程序过程详解

    前言 相信接触过Python的伙伴们都知道运行Python脚本程序的方式有多种,目前主要的方式有:交互式环境运行.命令行窗口运行.开发工具上运行等,其中在不同的操作平台上还互不相同.今天,小编讲些Py ...

  8. python平方数迭代器_对python中的高效迭代器函数详解

    python中内置的库中有个itertools,可以满足我们在编程中绝大多数需要迭代的场合,当然也可以自己造轮子,但是有现成的好用的轮子不妨也学习一下,看哪个用的顺手~ 首先还是要先import一下: ...

  9. 对python 数据处理中的LabelEncoder 和 OneHotEncoder详解

    对python 数据处理中的LabelEncoder 和 OneHotEncoder详解_起飞的木木的博客-CSDN博客_labelencoder原理

最新文章

  1. 谈谈我对服务熔断、服务降级的理解 专题
  2. phd for engineering at industry
  3. jQuery 库 - 特性
  4. gcc对C语言的扩展:标签变量(Labels as Values)
  5. java监控gc线程_Java应用性能监控系统,使用JMX实现,实现了类加载监控、内存监控、线程监控、GC监控...
  6. 同时开多个独立窗口Visio 2003/2007版本的软件
  7. activiti6创建28张表
  8. 桌面应用软件开发语言调查(转)
  9. java我的世界填充方块,我的世界怎么快速填充方块-快速填充方块攻略
  10. 最全的阿里面试经验(一)
  11. jenkins-github上提交代码后构建job(十二)
  12. 企业微信 web 项目工业级蜕变
  13. SQL之to_date()
  14. 2022年度总结——2022我在CSDN的那些事暨2023我的目标展望:Pursue freedom Realize self-worth
  15. Spring04:自动装配
  16. 杭州 职称 计算机免试,浙职称评审政策调整外语计算机免考年限有变动
  17. 《黑天鹅》观感:成长的蜕变 --摘抄
  18. Jira、confluence和crowd安装文档
  19. Java SSM开发大众点评后端
  20. 蚁剑加密 WebShell 过杀软

热门文章

  1. get assigned pageset and my pages
  2. ABAP开发环境终于支持以驼峰命名法自动格式化ABAP变量名了
  3. My task - how is inline creation implemented
  4. Solution for Lead OPA test error ( add button clicked after cancel button )
  5. how is OData url select option implemented in the backend
  6. 通过配置文件避免硬编码的一个例子
  7. Use BAdI to link appointment to a given opportunity during creation
  8. note deletion case
  9. 2019年6月19日Jerry Wang的SAP SAP Cloud Connector练习
  10. Import project出现Select at least one project的解决方法