简单来说就是在不改动网络结构的情况下获取网络中间层输出
比如有个LeNet:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):out = self.conv1(x)out = F.relu(out)     out = F.max_pool2d(out, 2)      out = self.conv2(out)out = F.relu(out)  out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out

如果我们要获取conv2的输出,一种最直观的思路是这样:

def forward(self, x):out = self.conv1(x)out = F.relu(out)     out = F.max_pool2d(out, 2)      out = self.conv2(out)out_conv2 = outout = F.relu(out)out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out, out_conv2

直接修改forward部分的代码,将conv2的中间结果return即可。

但很多时候,我们并没有办法去直接修改网络的源代码,比如在pytorch中已经封装好的网络,那么这个时候就可以利用hook从外部获取Module的中间输出结果了。即:

features = []
def hook(module, input, output): features.append(output.clone().detach())net = LeNet()
x = torch.randn(2, 3, 32, 32)
handle = net.conv2.register_forward_hook(hook)
y = net(x)
print(features[0])
handle.remove()

取出网络的相应层后,对该层调用register_forward_hook方法。这个方法需要传入一个hook方法:

hook(module, input, output) -> None or modified output
  • module:表示该层网络
  • input:该层网络的输入
  • output:该层网络的输出

从这里可以发现hook甚至可以更改输入输出(不过并不会影响网络forward的实际结果),不过在这里我们只是简单地将output给保存下来。
需要注意的是hook函数在使用后应及时删除,以避免每次都运行增加运行负载。

参考:
https://blog.csdn.net/winycg/article/details/100695373
https://blog.csdn.net/foneone/article/details/107099060

Pytorch register_forward_hook()简单用法相关推荐

  1. PyTorch搭建简单神经网络实现回归和分类

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 安装 PyTorch 会安装两个模块,一个是torch,一个 torchvision, tor ...

  2. 反编译工具jad简单用法

    反编译工具jad简单用法 下载地址: [url]http://58.251.57.206/down1?cid=B99584EFA6154A13E5C0B273C3876BD4CC8CE672& ...

  3. QCustomPlot的简单用法总结

    QCustomPlot的简单用法总结 第一部分:QCustomPlot的下载与安装 第二部分:QCustomPlot在VS2013+QT下的使用 QCustomPlot的简单用法总结    写在前面, ...

  4. python matplotlib 简单用法

    python matplotlib 简单用法 具体内容请参考官网 代码 import matplotlib.pyplot as plt import numpy as np # 支持中文 plt.rc ...

  5. Windump网络命令的简单用法

    Windump网络命令的简单用法 大家都知道,unix系统下有个tcpdump的抓包工具,非常好用,是做troubleshooting的好帮手.其实在windows下也有一个类似的工作,叫windum ...

  6. Android TabLayout(选项卡布局)简单用法实例分析

    本文实例讲述了Android TabLayout(选项卡布局)简单用法.分享给大家供大家参考,具体如下: 我们在应用viewpager的时候,经常会使用TabPageIndicator来与其配合.达到 ...

  7. shell expect的简单用法

    为什么需要expect?     我们通过Shell可以实现简单的控制流功能,如:循环.判断等.但是对于需要交互的场合则必须通过人工来干预,有时候我们可能会需要实现和交互程序如 telnet服务器等进 ...

  8. Shellz中awk的简单用法

    其实shell脚本的功能常常被低估.在实际应用中awk sed 等用法可以为shell提供更为强大的功能.下面我们将一下awk调用的简单方法进行了总结.方便同学们学习: awk的简单用法: 第一种调用 ...

  9. python装饰器实例-Python装饰器原理与简单用法实例分析

    本文实例讲述了Python装饰器原理与简单用法.分享给大家供大家参考,具体如下: 今天整理装饰器,内嵌的装饰器.让装饰器带参数等多种形式,非常复杂,让人头疼不已.但是突然间发现了装饰器的奥秘,原来如此 ...

最新文章

  1. c语言常见50题 及答案(递归 循环 以及常见题目)
  2. 图解VC++ opengl环境配置和几个入门例子
  3. 实现单台测试机6万websocket长连接
  4. Remoting系列专题---构建Remoting“防火墙”
  5. 三步拆解一个数据分析体系
  6. R-Sys.time计算程序运行时间
  7. 应用华云对象存储服务实现网站存储的平滑迁移实践
  8. DRL前沿之:Benchmarking Deep Reinforcement Learning for Continuous Control
  9. 电脑机器人_磨小分校参加成都市“青少年电脑机器人创新实践活动”巡航者决赛...
  10. C语言递归方法求解背包问题
  11. 云忧cms搭建在宝塔nginx服务器,登录报错
  12. 软件工程 学习笔记 知识梳理
  13. origin做相关性分析图_相关性分析的可视化_相关系数图的绘制过程
  14. android o 开发者大会,谷歌开发者大会刚结束Android O又要来了?
  15. vue项目的导出功能
  16. 禁U盘不禁USB设备
  17. HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions
  18. 阿里软件开发工程师面经
  19. 51单片机的PID水温控制器设计
  20. 华为云OBS深度体验之迁移

热门文章

  1. 数学分析笔记—python基础语法
  2. 5不能另存为dwg_5.建立数模
  3. ft2232驱动安装方法_教你win10系统显卡驱动安装失败的解决方法「系统天地」
  4. 台式linux桌面远程链接华为云windows服务器桌面
  5. python多态(一分钟读懂)
  6. php 使用json 教程,PHP使用JSON 教程
  7. 查看python版本命令_Anaconda常用命令小结
  8. mysql库表的触发器表名_MySQL 触发器,实现不同数据库,不同表名,表结构不同,数据实时同步...
  9. [转]awsome-python
  10. keras中的loss、optimizer、metrics