获取Pytorch中间某一层权重或者特征

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1.获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch

import pandas as pd

import numpy as np

import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}

for name,parameters in resnet18.named_parameters():

print(name,':',parameters.size())

parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])

bn1.weight : torch.Size([64])

bn1.bias : torch.Size([64])

layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])

layer1.0.bn1.weight : torch.Size([64])

layer1.0.bn1.bias : torch.Size([64])

layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])

layer1.0.bn2.weight : torch.Size([64])

layer1.0.bn2.bias : torch.Size([64])

layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])

layer1.1.bn1.weight : torch.Size([64])

layer1.1.bn1.bias : torch.Size([64])

layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])

layer1.1.bn2.weight : torch.Size([64])

layer1.1.bn2.bias : torch.Size([64])

layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])

layer2.0.bn1.weight : torch.Size([128])

layer2.0.bn1.bias : torch.Size([128])

layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])

layer2.0.bn2.weight : torch.Size([128])

layer2.0.bn2.bias : torch.Size([128])

layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])

layer2.0.downsample.1.weight : torch.Size([128])

layer2.0.downsample.1.bias : torch.Size([128])

layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])

layer2.1.bn1.weight : torch.Size([128])

layer2.1.bn1.bias : torch.Size([128])

layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])

layer2.1.bn2.weight : torch.Size([128])

layer2.1.bn2.bias : torch.Size([128])

layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])

layer3.0.bn1.weight : torch.Size([256])

layer3.0.bn1.bias : torch.Size([256])

layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])

layer3.0.bn2.weight : torch.Size([256])

layer3.0.bn2.bias : torch.Size([256])

layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])

layer3.0.downsample.1.weight : torch.Size([256])

layer3.0.downsample.1.bias : torch.Size([256])

layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])

layer3.1.bn1.weight : torch.Size([256])

layer3.1.bn1.bias : torch.Size([256])

layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])

layer3.1.bn2.weight : torch.Size([256])

layer3.1.bn2.bias : torch.Size([256])

layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])

layer4.0.bn1.weight : torch.Size([512])

layer4.0.bn1.bias : torch.Size([512])

layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])

layer4.0.bn2.weight : torch.Size([512])

layer4.0.bn2.bias : torch.Size([512])

layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])

layer4.0.downsample.1.weight : torch.Size([512])

layer4.0.downsample.1.bias : torch.Size([512])

layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])

layer4.1.bn1.weight : torch.Size([512])

layer4.1.bn1.bias : torch.Size([512])

layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])

layer4.1.bn2.weight : torch.Size([512])

layer4.1.bn2.bias : torch.Size([512])

fc.weight : torch.Size([1000, 512])

fc.bias : torch.Size([1000])

parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],

[-0.07455588, -0.799308 , -0.21283598],

[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):

with pd.ExcelWriter(excel_name) as writer:

[output_num,input_num,filter_size,_]=parm[key_name].size()

for i in range(output_num):

for j in range(input_num):

data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())

#print(data)

data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1

with pd.ExcelWriter('test1.xlsx') as writer:

for key in parm_resnet50.keys():

data=parm_resnet50[key].reshape(-1,1)

data=data[data>0.001]

data=pd.DataFrame(data,columns=[key])

data.to_excel(writer,index=False,startcol=counter)

counter+=1

2.获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):

x = net.conv1(input_data)

x = net.bn1(x)

x = F.relu(x)

x = net.layer1(x)

x = net.layer2(x)

x = net.layer3(x)

x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出

x=x.view(x.shape[0],-1)

return x

model = models.resnet18()

x = resnet_cifar(model,input_data)

原文:https://blog.csdn.net/happyday_d/article/details/88974361

pytorch 提取权重_获取Pytorch中间某一层权重或者特征相关推荐

  1. pytorch卷积可视化_使用Pytorch可视化卷积神经网络

    pytorch卷积可视化 Filter and Feature map Image by the author 筛选和特征图作者提供的图像 When dealing with image's and ...

  2. julia有 pytorch包吗_用 PyTorch 实现基于字符的循环神经网络 | Linux 中国

    导读:在过去的几周里,我花了很多时间用 PyTorch 实现了一个 char-rnn 的版本.我以前从未训练过神经网络,所以这可能是一个有趣的开始. 本文字数:7201,阅读时长大约: 9分钟 htt ...

  3. pytorch保存准确率_初学Pytorch:MNIST数据集训练详解

    前言 本文讲述了如何使用Pytorch(一种深度学习框架)构建一个简单的卷积神经网络,并使用MNIST数据集(28*28手写数字图片集)进行训练和测试.针对过程中的每个步骤都尽可能的给出了详尽的解释. ...

  4. bert pytorch源码_【PyTorch】梯度爆炸、loss在反向传播变为nan

    点击上方"MLNLP",选择"星标"公众号 重磅干货,第一时间送达 作者丨CV路上一名研究僧 知乎专栏丨深度图像与视频增强 地址丨https://zhuanla ...

  5. pytorch 矩阵相乘_深入浅出PyTorch(算子篇)

    Tensor 自从张量(Tensor)计算这个概念出现后,神经网络的算法就可以看作是一系列的张量计算.所谓的张量,它原本是个数学概念,表示各种向量或者数值之间的关系.PyTorch的张量(torch. ...

  6. pytorch 矩阵相乘_编译PyTorch静态库

    背景 众所周知,PyTorch项目作为一个C++工程,是基于CMake进行构建的.然而当你想基于CMake来构建PyTorch静态库时,你会发现: 静态编译相关的文档不全: CMake文件bug太多, ...

  7. pytorch卷积神经网络_使用Pytorch和Matplotlib可视化卷积神经网络的特征

    在处理图像和图像数据时,CNN是最常用的架构.卷积神经网络已经被证明在深度学习和计算机视觉领域提供了许多最先进的解决方案.没有CNN,图像识别.目标检测.自动驾驶汽车就不可能实现. 但当归结到CNN如 ...

  8. java 身份证地址提取籍贯_获取身份证信息中的籍贯、出生及性别信息

    前言:之前在项目开发中,经常需要用户录入身份证信息,同时还要提供籍贯等信息.那么,如何从身份证号码中解析出籍贯等信息,就是本篇博客索要解决的. 其实,身份证号码前6位就是用户的籍贯编码,直接解析该6位 ...

  9. java分词 词权重_分析牛:查询分词权重,巧妙布局网页关键词

    今天和大家分享一个纯干货,关键词的布局,也许很多人会说,这个还不容易,title出现一次,keywords出现一次,description在出现一次,然后正文的H标签里在出现一次,最后在每段的开头,末 ...

  10. python yolov5 脚本制作(第一部分:环境搭建、yolov5源码、权重文件获取、pycharm配置、pytorch下载、初次运行yolov5代码)

    开发前准备 在这里先梳理一下整个脚本开发用到的东西: python解释器 / 3.7.4版本 pycharm / 版本随意 pytorch / 最好10.2版本 / 11.3版本 yolov5代码文件 ...

最新文章

  1. python 正则表达式方法_Python正则表达式一: 基本使用方法
  2. Java为什么print显示不完全,read-eval-print-loop – 在Java 9上,为什么System.getenv()的输出在jshell中不完整?...
  3. Java的一维数组和二维数组的关系
  4. 初学者应该看看的6个free命令例子
  5. 在浏览器的背后(二) —— HTML语言的语法解析
  6. SQLite Update 语句(http://www.w3cschool.cc/sqlite/sqlite-update.html)
  7. CSS做个Switch开关
  8. 服务器每天产生1t文件,编写自己的服务器 - osc_popfjd1t的个人空间 - OSCHINA - 中文开源技术交流社区...
  9. 封装jquery插件 uoload file
  10. combres java_ASP.NET MVC3 Combres错误:'System.Web.Mvc.UrlHelper'不包含'CombresLink'的定义
  11. Oracle 日期时间函数详解
  12. java 实现宠物领养_基于jsp的宠物领养-JavaEE实现宠物领养 - java项目源码
  13. 网易云音乐显示服务器发生错误,网易云音乐加载失败怎么回事 网易云音乐出现加载失败的有效解决方法...
  14. 无偏估计的数学证明和分析
  15. three.js 入门指南(敷衍一下)
  16. wordpress手动安装插件WooCommerce
  17. 15.CUDA编程手册中文版---附录K CUDA计算能力
  18. 页面布局的几种宽度设置方式—html
  19. 深度残差网络RESNET
  20. 华为手机左侧快捷方式_华为手机的这六个快捷键,让使用更简便!

热门文章

  1. python爬取淘宝数据魔方_读《淘宝数据魔方技术架构解析》有感
  2. 支持全球科研抗疫,艾柏森成功研发Omicron变异株重组蛋白
  3. 【xinfanqie】熟知针式与喷墨打印机之间的区别
  4. html模糊遮罩层磨砂玻璃,常见的PPT背景:如何设计PPT背景?
  5. spring使用之旅(二) ---- AOP的使用
  6. 基于 vue-element-admin 基础模板实现侧边栏菜单动态渲染
  7. python死循环_python中死循环
  8. java获取法定节假日
  9. Font Awesome 是一套绝佳的图标字体库和CSS框架
  10. Latex输入大小写罗马数字