pytorch中间层输出方法
大纲
- 引言
- 一、钩子截流
- 附:钩子函数
- 二、视作输出
- 参考
引言
本文想要解决的是pytorch中间层的输出问题,有时我们训练神经网络时会设定回归或者分类作为目标,但在测试阶段实际需要的只是用神经网络提取输入的表征,因此需要获得网络的中间层输出。总结起来有两种方法:
- 一是在前向传播过程中通过"钩子"截取;
- 二是将中间层输出也视作输出,在最后获取
一、钩子截流
这种方式是在前向传播进行中还没得到最终输出时,将所需要的中间层输出从前向数据流中提取出来,利用到了pytorch中的register_hook()函数。这一函数可以为模型中的某个module设置一个回调函数,形如:
hook(module, input, output) -> None or modified output
函数的输入值为module的名字、module的输入和输出。通过前置定义一个数组,在hook()函数中将对应module的输入或输出加入该数组以实现中间层提取。给出代码如下:
import torch
from torch import nnclass test_model(nn.Module):def __init__(self):super(test_model, self).__init__()self.conv_16 = nn.Sequential(nn.Conv2d(1,16,(3,3),(1,1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2,2)))self.conv_32 = nn.Sequential(nn.Conv2d(16,32,(3,3),(1,1)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))self.linear_1 = nn.Sequential(nn.Linear(32,64),nn.ReLU())self.linear_class = nn.Sequential(nn.Linear(64,5),nn.ReLU())def forward(self, x):x = self.conv_16(x)x = self.conv_32(x)x = x.view(x.shape[0],x.shape[1])x = self.linear_1(x)return self.linear_class(x)features = []def hook(module, input, output):features.append(input)return Nonenet = test_model()# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():print(name)
# 设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))
net(a)
print(features)
实际过程中建议先打印所有层的名字以做到精确提取。
附:钩子函数
值得注意的是这个函数的用途并不止于提取中间层的输出,它也可以用于对module的输出值进行修改,查看该函数的源码注释
r"""Registers a forward hook on the module.
The hook will be called every time after :func:
forward
has
computed an output. It should have the following signature::hook(module, input, output) -> None or modified output
The input contains only the positional arguments given to the
module. Keyword arguments won’t be passed to the hooks and only to
theforward
. The hook can modify the output. It can modify the
input inplace but it will not have effect on forward since this is
called after :func:forward
is called.
我们可以分析得到:
register_hook()函数在对应module前向传播产生输出后自动执行,回调函数的输入只包括了module的位置参数不包括关键字参数。回调函数可以通过return 修改过的输出值来对module的最终输出进行修改,同样在回调函数内部我们也可以对输入进行inplace修改,但并不会对module的输出值造成影响因为register_hook()是在对应module前向传播产生输出后之执行,输入值已经被计算过了。(这里本人对最后一句话的理解与参考中的不一样,但未经验证过。)
二、视作输出
通过返回值提取中间层输出比较简单,同样有两种方法来实现:
- 一是将中间层的返回值作为模型的属性,在初始化时定义好;
- 二是在forward函数将中间层返回值一并输出;
代码如下:
import torch
from torch import nnclass test_model(nn.Module):def __init__(self):super(test_model, self).__init__()self.conv_16 = nn.Sequential(nn.Conv2d(1,16,(3,3),(1,1)),nn.ReLU(),nn.MaxPool2d(kernel_size=(2,2)))self.conv_32 = nn.Sequential(nn.Conv2d(16,32,(3,3),(1,1)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))self.linear_1 = nn.Sequential(nn.Linear(32,64),nn.ReLU())self.linear_class = nn.Sequential(nn.Linear(64,5),nn.ReLU())self.feature=[]def forward(self, x):x = self.conv_16(x)x = self.conv_32(x)x = x.view(x.shape[0],x.shape[1])x = self.linear_1(x)self.feature.append(x.detach())feature = x.detach()return self.linear_class(x),featurefeatures = []def hook(module, input, output):features.append(input)return Nonenet = test_model()# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():print(name)
# 设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))_,final_out=net(a)
hook_out=features
att_out=net.feature
对比可以发现三者输出是一致的。
参考
Pytorch获取中间层输出的几种方法
pytorch的hook机制之register_forward_hook
pytorch中间层输出方法相关推荐
- 提取CNN模型中间层输出方法
前言 针对结构中定义了多个nn.sequential的网络模型,无法直接获取其内部某一中间层的输出,本文将给出两个方法进行解决. 方法 1 逐层进行forward 创建自定义函数,实现按照执行顺序逐层 ...
- pytorch获取模型的中间层输出结果
在inference阶段,整个模型会load到GPU上,进行端到端的计算,通常只会给你输出一个最终结果. 如果想要获取模型的中间层输出,则需要在计算前标定目标层位置(通过forward返回),或者把模 ...
- java单链表输出_数据结构基础------1.线性表之单链表的创建与输出方法(Java版)...
基础知识: 线性表(linear list),是其组成元素间具有线性关系的一种线性结构. 线性表有 ①顺序存储结构(sequential storage structure) 顺序存储结构可以简单的理 ...
- sass学习笔记(二):sass的不同样式风格的输出方法
sass的不同样式风格的输出方法 1.嵌套式nested Sass 提供了一种嵌套显示 CSS 文件的方式.例如 nav {ul {margin: 0;padding: 0;list-style: n ...
- python语言格式化输出_Python format()格式化输出方法详解
原标题:Python format()格式化输出方法详解 format() 方法的语法格式如下: str.format(args) 此方法中,str 用于指定字符串的显示样式:args 用于指定要进行 ...
- Pytorch gpu加速方法
Pytorch gpu加速方法 原文: https://www.zhihu.com/question/274635237 relu 用 inplace=True 用 eval() 和 with tor ...
- android简化log输出方法
android简化版log输出 希望实现的效果: 只需要将类实现ILog接口(不需要进行任何额外的操作),然后就能直接通过printLog进行日志的输出 如下: class A implements ...
- 获取keras中间层输出、模型保存与加载
1. 获取keras中间层输出 # model summary and plot import keras from keras.models import Model from keras.util ...
- 怎样检查python环境是否安装好_如何搭建pytorch环境的方法步骤
1.conda创建虚拟环境pytorch_gpu conda create -n pytorch_gpu python=3.6 创建虚拟环境还是相对较快的,它会自动为本环境安装一些基本的库,等待时间无 ...
最新文章
- 2022-2028年中国丁基橡胶行业市场深度分析及投资前景展望报告
- python decode hex_在python2.7中使用b64decode()将base64转换为hex
- python 后台执行
- Java面向对象之构造方法
- 十大经典排序算法之希尔排序及其优化
- java使用validator进行校验
- Eclipse中Maven插件的使用技巧及原理
- Spring事务总结(一) 内部调用事务失效、异常回滚
- 在工作之余,你是怎么提升自己的?
- 程序员,Linux 下如何避免从删库到跑路的悲剧?
- 信号与系统考研复习例题详解_小语种日语日本文学复习考研资料加藤周一《日本文学史序说(上)》笔记和考研真题详解...
- JAVA 内存泄露的理解
- sqoop导出到mysql中文乱码问题总结、utf8、gbk
- python爬虫之如何建立一个自己的代理IP池
- 修改dns服务器转发器,域服务器dns设置转发器
- html网页设计作品文字,40个以大文字排版的网页设计作品
- 社交网络叠加直播功能,会产生什么化学反应?
- 《Python语言程序设计》刘卫国主编字符串与正则表达式习题5详解(选择)
- 教程:使用EXCEL制作均值曲线图表
- 简单使用Search()函数
热门文章
- Android之粗仿微信6.0——微信分界面
- facebook-javascript-sdk
- The job failed. Unable to determine if the owner (SINOOCEANLAND\v-baidd) of job sendmail has server
- ExBPA工具的使用方法
- 16宫格的翻牌消除游戏、纯前端实现16宫格的翻牌消除游戏
- IoT时代:Wi-Fi“配网”技术剖析总结
- web应用用户头像处理
- 银行卡和预付卡的区别
- 百度地图api,第一次定位成功,后面505错误
- 软件自动更新解决方案及QT实现(源码已上传)