欢迎关注 “小白玩转Python”,发现更多 “有趣”

使用过深度学习的人都知道,有时候调试模型是非常困难的。张量的不匹配、梯度爆炸,以及其他无数的问题都会让你大吃一惊。解决这些问题需要细微的观察这些模型。最基本的方法包括在forward()方法中添加print语句或引入断点。但是这也相当麻烦,因为需要猜测哪里出现了问题。

现在有了一个解决方案:hooks。这些是特定的函数,可以附加到每一层,并在每次使用层时调用。它们基本上允许冻结特定模块的正向或反向传递的执行,并处理其输入和输出。

下面让我们详细介绍一下!

Hooks速成小课堂

钩子只是一个带有预定义签名的可调用对象,它可以注册到任何nn.Module。当在模块上使用触发器方法(即forward()或backward())时,模块本身及其输入和可能的输出将传递给钩子,在计算进行到下一个模块之前执行。

在 PyTorch 中,可以将钩子注册为:

· forward prehook(在前向传播之前执行)

· forward hook(在前向传播之后执行)

· backward hook(在后向传播之后执行)

乍一看可能很复杂,让我们来看一个具体的例子!

例:保存每个卷积层的输出

假设我们要查看 ResNet34框架中每个卷积层的输出。这项工作非常适合使用hooks。在下一部分中,我将向您展示如何执行这一操作。

我们的模型定义如下:

import torch
from torchvision.models import resnet34
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = resnet34(pretrained=True)
model = model.to(device)

创建一个hook来保存输出非常简单,对于我们的目的来说,一个基本的可调用对象就足够了。

class SaveOutput:def __init__(self):self.outputs = []  def __call__(self, module, module_in, module_out):self.outputs.append(module_out)def clear(self):self.outputs = []

SaveOutput的一个实例将只记录前向传播过程的输出张量并将其存储在一个列表中。

可以使用 register_forward_hook(hook)方法注册前向钩子。(对于其他类型的钩子,我们有 register_backward_hook 和 register_forward_pre_hook。)这些方法的返回值是hook句柄,可用于从模块中删除hook。

现在我们将hook注册到每个卷积层。

save_output = SaveOutput()
hook_handles = []
for layer in model.modules():if isinstance(layer, torch.nn.modules.conv.Conv2d):handle = layer.register_forward_hook(save_output)hook_handles.append(handle)

完成后,钩子将在每个卷积层的每个前向传递后被调用。为了测试它,我们将使用下面的图像。

前向传播:

from PIL import Image
from torchvision import transforms as T
image = Image.open('cat.jpg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)
out = model(X)

正如预期的那样,输出被正确存储。

>>> len(save_output.outputs)
36

通过查看此列表中的张量,我们可以将网络所看到的可视化。

出于好奇,我们可以看看网络接下来会发生什么。如果我们深入网络,学习到的特征会越来越高。例如,有一个过滤器似乎是负责检测眼睛。

超越

当然,这只是冰山一角。Hooks可以做的远不止简单地存储中间层的输出。例如,神经网络剪枝,这是一种策略,以减少参数的数目,也可以通过hooks实现。

总而言之,如果你想增强你的工作,学会使用hooks是一种非常有用的技巧。有了这些,你就能做得更多,做得更有效。

·  END  ·

HAPPY LIFE

你应该知道的一个PyTorch小技巧相关推荐

  1. 程序员的反击!每天一个离职小技巧

    作者 | 梦想橡皮擦 来源 | 非本科程序员(ID:htmlhttp) 写在前面 俗话说的好,代码写的少,离职少不了. 最近畅游互联网,发现一些离职小技巧,读后,内心被深深的打动了,但是细细的品过之后 ...

  2. 3分钟学会python_3分钟学会一个Python小技巧

    Python时间日期转换在开发中是非常高频的一个操作,你经常会遇到需要将字符串转换成 datetime 或者是反过来将 datetime 转换成字符串. datetime 分别提供了两个方法 strp ...

  3. pandas apply lambda_一分钟一个Pandas小技巧(二)

    " 在逛Kaggle的时候发现了一篇不错的Pandas技巧,我将挑选一些有用的并外加一些自己的想法分享给大家.本系列虽基础但带仍有一些奇怪操作,粗略扫一遍,您或将发现一些您需要的技巧.&qu ...

  4. vob转mp4,每天一个实用小技巧

    vob转mp4,vob的英文全称是Video Object,它是DVD视频媒体使用的容器格式,vob格式擅长将数字视频.音频.字幕.菜单等多个元素复用在流格式中.而且vob格式的文件可以被加密保护.经 ...

  5. 每天一个前端小技巧——生成gif动图下载

    每天一个前端小技巧--生成gif动图下载 动态热图的展现,分别展现某个时间段的热图时间变化,例如:最近一周七天内,每天的热图分布变化图:这个动态变化的图生成一个gif图提供下载是否可行? 实现方案: ...

  6. 每天一个脱发小技巧 | Eclipse环境下spotbugs的安装配置和详细使用方法

    每天一个脱发小技巧 | Eclipse环境下spotbugs的安装配置和详细使用方法 SpotBugs介绍 Eclipse环境下SpotBugs安装 SpotBugs的使用 其他 SpotBugs介绍 ...

  7. 每30秒学会一个Python小技巧,GitHub星数4600+

    (图片付费下载自视觉中国) 作者 | xiaoyu,数据爱好者 来源 | Python数据科学(ID:PyDataScience) 很多学习Python的朋友在项目实战中会遇到不少功能实现上的问题,有 ...

  8. 震惊了!每30秒学会一个Python小技巧,Github星数6000+

    点击上方"Python数据科学",星标公众号 重磅干货,第一时间送达 ☞500g+超全学习资源免费领取,干货来袭! 作者:xiaoyu,数据爱好者 Python数据科学出品 很多学 ...

  9. 30秒就能学会一个Python小技巧?

    作者:wLsq 来源:Python数据科学 大家好,很多学习Python的朋友在项目实战中会遇到不少功能实现上的问题,有些问题并不是很难的问题,或者已经有了很好的方法来解决.当然,孰能生巧,当我们代码 ...

最新文章

  1. ggplot2笔记2:图层的使用——基础、怎样加标签、注释
  2. Spring Boot 搭载属于你的网站框架(一)
  3. java instanceof 原理_java-在现代JVM实现中如何实现instanceof?
  4. python套接字socket的作用_【学习笔记】python实现的套接字socket
  5. python绘制正态分布曲线
  6. 线程运行程序c语言,理解线程1 C语言示例的程序
  7. Codeforces - tag::data structures 大合集 [占坑 25 / 0x3f3f3f3f]
  8. 新生命 · 人工智能 · 未来
  9. git指令快捷 idea_IDEA+Git+Gitlab使用详细教程
  10. curl 升级 php,将命令行cURL转换为PHP cURL
  11. 【公告】社区周刊即日起停刊
  12. [vue-cli]怎么使用vue-cli3创建一个项目?
  13. 云起智慧中心连接华为_【转发】华为智慧屏HiLink控制联动,操作指南来了!
  14. JavaScript变量声明+数据类型+数字格式+操作符+进制
  15. 不属于python循环结构的是( )_Python语句print(type(['a','1',2,3]))的输出结果是哪一项?_学小易找答案...
  16. Python collections的使用
  17. vue-count-to插件使用方法
  18. 判断四个点是否可以构成矩形(优雅的解法!!!)
  19. [java学习笔记]-注解和反射
  20. 链接、图像、列表、计数器

热门文章

  1. 揭秘TVS管是否能替代稳压二极管吗?
  2. Web前端:UI设计对提高用户参与度的重要性
  3. 《Web安全攻防 渗透测试实战指南》学习笔记(2) - Sqlmap
  4. 繁体系统下因输入法引起部分软件乱码
  5. python----->第二天,数据类型,三种基本结构,函数,文件操作,打包、导包
  6. android 获取通讯录全选反选_Android Recyclerview实现多选,单选,全选,反选,批量删除的功能...
  7. mysql查询不同老师所教不同课程_MySQL学生表、老师表、课程表和成绩表查询语句,全部亲测...
  8. 2. 查询教师编号、教师姓名、课程名称、平均成绩。
  9. Gallery2源码阅读图片编辑
  10. python读取raw图片文件_【IT专家】使用Python读取CR2(原始佳能图像)头。