PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call
PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解
在看pytorch官方文档的时候,发现在nn.Module部分和Variable部分均有hook的身影。感到很神奇,因为在使用tensorflow的时候没有碰到过这个词。所以打算一探究竟。
文章目录 [隐藏]
- 1 Variable 的 hook
- 1.1 register_hook(hook)
- 2 nn.Module的hook
- 2.1 register_forward_hook(hook)
- 3 register_backward_hook
Variable 的 hook
register_hook(hook)
注册一个backward钩子。
每次gradients被计算的时候,这个hook都被调用。hook应该拥有以下签名:
1
|
hook(grad) -> Variable or None
|
hook不应该修改它的输入,但是它可以返回一个替代当前梯度的新梯度。
这个函数返回一个 句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hook从module移除。
例子:
1
2
3
4
5
6
|
v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
h = v.register_hook(lambda grad: grad * 2)# double the gradient
v.backward(torch.Tensor([1, 1, 1]))
#先计算原始梯度,再进hook,获得一个新梯度。
print(v.grad.data)
h.remove()# removes the hook
|
输出:
1
2
3
4
|
2
2
2
[torch.FloatTensor of size 3]
|
nn.Module的hook
register_forward_hook(hook)
在module上注册一个forward hook。
这里要注意的是,hook 只能注册到 Module 上,即,仅仅是简单的 op 包装的 Module,而不是我们继承 Module时写的那个类,我们继承 Module写的类叫做 Container。
每次调用forward()计算输出的时候,这个hook就会被调用。它应该拥有以下签名:
1
|
hook(module, input, output) -> None
|
hook不应该修改 input和output的值。 这个函数返回一个 句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hook从module移除。
看这个解释可能有点蒙逼,但是如果要看一下nn.Module的源码怎么使用hook的话,那就乌云尽散了。
先看 register_forward_hook
handle = hooks.RemovableHandle(self._forward_hooks)self._forward_hooks[handle.id] = hookreturn handle</textarea></div><div class="crayon-main" style="position: relative; z-index: 1; overflow: hidden;"><table class="crayon-table" style=""><tbody><tr class="crayon-row"><td class="crayon-nums " data-settings="show"><div class="crayon-nums-content" style="font-size: 13px !important; line-height: 18px !important;"><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-1">1</div><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-2">2</div><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-3">3</div><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-4">4</div><div class="crayon-num" data-line="crayon-5dd4e6587a059500983678-5">5</div></div></td><td class="crayon-code"><div class="crayon-pre" style="font-size: 13px !important; line-height: 18px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;"><div class="crayon-line" id="crayon-5dd4e6587a059500983678-1"><span class="crayon-r">def</span><span class="crayon-h"> </span><span class="crayon-e">register_forward_hook</span><span class="crayon-sy">(</span><span class="crayon-r">self</span><span class="crayon-sy">,</span><span class="crayon-h"> </span><span class="crayon-v">hook</span><span class="crayon-sy">)</span><span class="crayon-o">:</span></div><div class="crayon-line" id="crayon-5dd4e6587a059500983678-2"> </div><div class="crayon-line" id="crayon-5dd4e6587a059500983678-3"><span class="crayon-h"> </span><span class="crayon-v">handle</span><span class="crayon-h"> </span><span class="crayon-o">=</span><span class="crayon-h"> </span><span class="crayon-v">hooks</span><span class="crayon-sy">.</span><span class="crayon-e">RemovableHandle</span><span class="crayon-sy">(</span><span class="crayon-r">self</span><span class="crayon-sy">.</span><span class="crayon-v">_forward_hooks</span><span class="crayon-sy">)</span></div><div class="crayon-line" id="crayon-5dd4e6587a059500983678-4"><span class="crayon-h"> </span><span class="crayon-r">self</span><span class="crayon-sy">.</span><span class="crayon-v">_forward_hooks</span><span class="crayon-sy">[</span><span class="crayon-v">handle</span><span class="crayon-sy">.</span><span class="crayon-k ">id</span><span class="crayon-sy">]</span><span class="crayon-h"> </span><span class="crayon-o">=</span><span class="crayon-h"> </span><span class="crayon-e">hook</span></div><div class="crayon-line" id="crayon-5dd4e6587a059500983678-5"><span class="crayon-e"> </span><span class="crayon-st">return</span><span class="crayon-h"> </span><span class="crayon-v">handle</span></div></div></td></tr></tbody></table></div></div><p>这个方法的作用是在此module上注册一个hook,函数中第一句就没必要在意了,主要看第二句,是把注册的hook保存在_forward_hooks字典里。</p><p>再看 nn.Module 的__call__方法(被阉割了,只留下需要关注的部分):</p><div id="crayon-5dd4e6587a05a173089609" class="crayon-syntax crayon-theme-github crayon-font-monaco crayon-os-pc print-yes notranslate" data-settings=" minimize scroll-mouseover" style="margin-top: 15px; margin-bottom: 15px; font-size: 13px !important; line-height: 18px !important; height: auto;"><div class="crayon-plain-wrap"><textarea wrap="soft" class="crayon-plain print-no" data-settings="dblclick" readonly="" style="tab-size: 4; font-size: 13px !important; line-height: 18px !important; z-index: 0; opacity: 0; overflow: hidden;">def __call__(self, *input, **kwargs):
相关文章:
- Python __dict__属性详解
- Pytorch nn.init 参数初始化方法
- Bert代码详解(一)重点详细
- Bert代码详解(二)重点
- 一本读懂BERT(实践篇)重点
- pytorch学习 中 torch.squeeze() 和torch.unsqueeze()的用法
- 关于pytorch--embedding的问题
- pytorch中的transpose()
- view(*args)改变张量的大小和形状_pytorch reshape numpy
- gelu
- 【PyTorch学习笔记】4:在Tensor上的索引和切片
- Ramsey定理数学
- 如何将模糊的扫描版pdf转为清晰的pdf或word_pdf问题小结
- Python怎么利用多核cpu
- 使用Pycharm给Python程序传递参数
- 获取准确路径目录
- 打标遗留的问题
- pytorch 与 numpy 的数组广播机制
- pytorch numpy 数据类型转换
- numpy数组方法
- pytorch.range() 和 pytorch.arange() 的区别
- python的print格式化输出,以及使用format来控制。
- 查错bug
- 使virtualenv从您的全局站点包继承特定的包
- python pycharm 包 安装问题
- 查看分析网络层次
- PyTorch的torch.cat
- pycharm连接远程服务器并进行代码上传+远程调试
- 创 keras_contrib 安装
- tf.concat()详解
PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call相关推荐
- PyTorch 学习笔记(一):让PyTorch读取你的数据集
本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial 文章目录 Dataset类 ...
- PyTorch学习笔记(六)——Sequential类、参数管理与GPU
系列文章\text{\bf 系列文章}系列文章 PyTorch学习笔记(一)--Tensor的基础语法 PyTorch学习笔记(二)--自动微分 PyTorch学习笔记(三)--Dataset和Dat ...
- PyTorch学习笔记(21) ——损失函数
0. 前言 本博客内容翻译自纽约大学数据科学中心在2020发布的<Deep Learning>课程的Activation Functions and Loss Functions 部分. ...
- 1.pytorch 学习笔记--Getting stared
pytorch 学习笔记–Getting stared 1.什么是pytorch Pytorch 是一个基于Python的科学计算包,主要面向以下人群: 替代numpy以使用GPU做计算加速 一个深度 ...
- PyTorch学习笔记(六):PyTorch进阶训练技巧
PyTorch实战:PyTorch进阶训练技巧 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: P ...
- PyTorch学习笔记(七):PyTorch可视化
PyTorch可视化 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一) ...
- Pytorch学习笔记总结
往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...
- PyTorch学习笔记(五):模型定义、修改、保存
往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...
- PyTorch学习笔记(四):PyTorch基础实战
PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...
最新文章
- ArchSummit2018深圳站筹备中,18大专题征集演讲嘉宾
- 穿上这件全球首款「隐形衣」,做这条街最「无脸」的仔;阿里给钱给资源,求解AI安全难题...
- 产品经理必备知识之网页设计系列(三)-移动端适配无障碍设计及测试
- 单片机学校实训老师上课需要的工具以及源码分享
- Atitit nodejs js 获取图像分辨率 尺寸 大小 宽度 高度
- python如何调用tess_python下以api形式调用tesseract识别图片验证码
- 常见数通设备镜像制作模板
- Hibernate逍遥游记-第5章映射一对多-02双向(set、key、one-to-many、inverse、cascade=all-delete-orphan)...
- 五子棋c语言策划书活动内容,五子棋比赛活动的策划案
- 小Z解读:企业证书利用itms-services协议分发应用在蜂窝网络下的限制
- 江苏省发布我国首个公路行业BIM省地方标准
- Linux服务器使用网络代理
- Android动画之Interpolator插入器
- 亚马逊云科技软件开发工程师团队
- 使用pca进行坐标系转换、降维
- 2017年商汤科技前端面试题
- SEM是什么,基础知识讲解
- 数据仓库ODS层的作用
- 计算机视觉算法竞聘者的职业技能需求
- 听云缓存报错:java.lang.NoClassDefFoundError: com.networkbench.agent.impl.instrumentation.NBSEventTraceEngi