大纲

  • 引言
  • 一、钩子截流
    • 附:钩子函数
  • 二、视作输出
  • 参考

引言

 本文想要解决的是pytorch中间层的输出问题,有时我们训练神经网络时会设定回归或者分类作为目标,但在测试阶段实际需要的只是用神经网络提取输入的表征,因此需要获得网络的中间层输出。总结起来有两种方法:

  • 一是在前向传播过程中通过"钩子"截取;
  • 二是将中间层输出也视作输出,在最后获取

一、钩子截流

 这种方式是在前向传播进行中还没得到最终输出时,将所需要的中间层输出从前向数据流中提取出来,利用到了pytorch中的register_hook()函数。这一函数可以为模型中的某个module设置一个回调函数,形如:
hook(module, input, output) -> None or modified output
函数的输入值为module的名字、module的输入和输出。通过前置定义一个数组,在hook()函数中将对应module的输入或输出加入该数组以实现中间层提取。给出代码如下:

import torch
from torch import  nnclass test_model(nn.Module):def __init__(self):super(test_model, self).__init__()self.conv_16 = nn.Sequential(nn.Conv2d(1,16,(3,3),(1,1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2,2)))self.conv_32 = nn.Sequential(nn.Conv2d(16,32,(3,3),(1,1)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))self.linear_1 = nn.Sequential(nn.Linear(32,64),nn.ReLU())self.linear_class = nn.Sequential(nn.Linear(64,5),nn.ReLU())def forward(self, x):x = self.conv_16(x)x = self.conv_32(x)x = x.view(x.shape[0],x.shape[1])x = self.linear_1(x)return self.linear_class(x)features = []def hook(module, input, output):features.append(input)return Nonenet = test_model()# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():print(name)
#  设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))
net(a)
print(features)

实际过程中建议先打印所有层的名字以做到精确提取。

附:钩子函数

 值得注意的是这个函数的用途并不止于提取中间层的输出,它也可以用于对module的输出值进行修改,查看该函数的源码注释

r"""Registers a forward hook on the module.

The hook will be called every time after :func:forward has
computed an output. It should have the following signature::

   hook(module, input, output) -> None or modified output

The input contains only the positional arguments given to the
module. Keyword arguments won’t be passed to the hooks and only to
the forward. The hook can modify the output. It can modify the
input inplace but it will not have effect on forward since this is
called after :func:forward is called.

我们可以分析得到:
 register_hook()函数在对应module前向传播产生输出后自动执行,回调函数的输入只包括了module的位置参数不包括关键字参数。回调函数可以通过return 修改过的输出值来对module的最终输出进行修改,同样在回调函数内部我们也可以对输入进行inplace修改,但并不会对module的输出值造成影响因为register_hook()是在对应module前向传播产生输出后之执行,输入值已经被计算过了。(这里本人对最后一句话的理解与参考中的不一样,但未经验证过。)

二、视作输出

 通过返回值提取中间层输出比较简单,同样有两种方法来实现:

  • 一是将中间层的返回值作为模型的属性,在初始化时定义好;
  • 二是在forward函数将中间层返回值一并输出;

代码如下:

import torch
from torch import  nnclass test_model(nn.Module):def __init__(self):super(test_model, self).__init__()self.conv_16 = nn.Sequential(nn.Conv2d(1,16,(3,3),(1,1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2,2)))self.conv_32 = nn.Sequential(nn.Conv2d(16,32,(3,3),(1,1)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))self.linear_1 = nn.Sequential(nn.Linear(32,64),nn.ReLU())self.linear_class = nn.Sequential(nn.Linear(64,5),nn.ReLU())self.feature=[]def forward(self, x):x = self.conv_16(x)x = self.conv_32(x)x = x.view(x.shape[0],x.shape[1])x = self.linear_1(x)self.feature.append(x.detach())feature = x.detach()return self.linear_class(x),featurefeatures = []def hook(module, input, output):features.append(input)return Nonenet = test_model()# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():print(name)
#  设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))_,final_out=net(a)
hook_out=features
att_out=net.feature

对比可以发现三者输出是一致的。

参考

Pytorch获取中间层输出的几种方法
pytorch的hook机制之register_forward_hook

pytorch中间层输出方法相关推荐

  1. 提取CNN模型中间层输出方法

    前言 针对结构中定义了多个nn.sequential的网络模型,无法直接获取其内部某一中间层的输出,本文将给出两个方法进行解决. 方法 1 逐层进行forward 创建自定义函数,实现按照执行顺序逐层 ...

  2. pytorch获取模型的中间层输出结果

    在inference阶段,整个模型会load到GPU上,进行端到端的计算,通常只会给你输出一个最终结果. 如果想要获取模型的中间层输出,则需要在计算前标定目标层位置(通过forward返回),或者把模 ...

  3. java单链表输出_数据结构基础------1.线性表之单链表的创建与输出方法(Java版)...

    基础知识: 线性表(linear list),是其组成元素间具有线性关系的一种线性结构. 线性表有 ①顺序存储结构(sequential storage structure) 顺序存储结构可以简单的理 ...

  4. sass学习笔记(二):sass的不同样式风格的输出方法

    sass的不同样式风格的输出方法 1.嵌套式nested Sass 提供了一种嵌套显示 CSS 文件的方式.例如 nav {ul {margin: 0;padding: 0;list-style: n ...

  5. python语言格式化输出_Python format()格式化输出方法详解

    原标题:Python format()格式化输出方法详解 format() 方法的语法格式如下: str.format(args) 此方法中,str 用于指定字符串的显示样式:args 用于指定要进行 ...

  6. Pytorch gpu加速方法

    Pytorch gpu加速方法 原文: https://www.zhihu.com/question/274635237 relu 用 inplace=True 用 eval() 和 with tor ...

  7. android简化log输出方法

    android简化版log输出 希望实现的效果: 只需要将类实现ILog接口(不需要进行任何额外的操作),然后就能直接通过printLog进行日志的输出 如下: class A implements ...

  8. 获取keras中间层输出、模型保存与加载

    1. 获取keras中间层输出 # model summary and plot import keras from keras.models import Model from keras.util ...

  9. 怎样检查python环境是否安装好_如何搭建pytorch环境的方法步骤

    1.conda创建虚拟环境pytorch_gpu conda create -n pytorch_gpu python=3.6 创建虚拟环境还是相对较快的,它会自动为本环境安装一些基本的库,等待时间无 ...

最新文章

  1. 2022-2028年中国丁基橡胶行业市场深度分析及投资前景展望报告
  2. python decode hex_在python2.7中使用b64decode()将base64转换为hex
  3. python 后台执行
  4. Java面向对象之构造方法
  5. 十大经典排序算法之希尔排序及其优化
  6. java使用validator进行校验
  7. Eclipse中Maven插件的使用技巧及原理
  8. Spring事务总结(一) 内部调用事务失效、异常回滚
  9. 在工作之余,你是怎么提升自己的?
  10. 程序员,Linux 下如何避免从删库到跑路的悲剧?
  11. 信号与系统考研复习例题详解_小语种日语日本文学复习考研资料加藤周一《日本文学史序说(上)》笔记和考研真题详解...
  12. JAVA 内存泄露的理解
  13. sqoop导出到mysql中文乱码问题总结、utf8、gbk
  14. python爬虫之如何建立一个自己的代理IP池
  15. 修改dns服务器转发器,域服务器dns设置转发器
  16. html网页设计作品文字,40个以大文字排版的网页设计作品
  17. 社交网络叠加直播功能,会产生什么化学反应?
  18. 《Python语言程序设计》刘卫国主编字符串与正则表达式习题5详解(选择)
  19. 教程:使用EXCEL制作均值曲线图表
  20. 简单使用Search()函数

热门文章

  1. Android之粗仿微信6.0——微信分界面
  2. facebook-javascript-sdk
  3. The job failed. Unable to determine if the owner (SINOOCEANLAND\v-baidd) of job sendmail has server
  4. ExBPA工具的使用方法
  5. 16宫格的翻牌消除游戏、纯前端实现16宫格的翻牌消除游戏
  6. IoT时代:Wi-Fi“配网”技术剖析总结
  7. web应用用户头像处理
  8. 银行卡和预付卡的区别
  9. 百度地图api,第一次定位成功,后面505错误
  10. 软件自动更新解决方案及QT实现(源码已上传)