文章目录

  • 引言
  • 卷积核可视化
  • 参数直方图可视化
  • 激活可视化
  • 小结

引言

一直以来,深度神经网络作为一种功能强大的“黑盒”,被认为可解释性较弱。目前,常用的一种典型可解释性分析方法是就是可视化方法。

本文整理了深度神经网络训练过程中常用的可视化技巧,便于对训练过程进行分析和检查。

卷积核可视化

以resnet18为例,提取第一层的卷积核(7x7)进行可视化,可以看出大多提取的是边缘、角点之类的底层视觉特征。

在全连接层前的卷积层采用的是3x3卷积核,表达高层语义信息,更加抽象:

这里采用torchvision.utils.make_grid对卷积核进行网格化显示,图像网格的列数由nrow参数确定。

卷积核的可视化代码参考1进行修改:

def plot_conv(writer,model):for name,param in model.named_parameters():if 'conv' in name and 'weight' in name:in_channels = param.size()[1] # 输入通道out_channels = param.size()[0]   # 输出通道k_w, k_h = param.size()[3], param.size()[2]   # 卷积核的尺寸kernel_all = param.view(-1, 1, k_w, k_h)  # 每个通道的卷积核kernel_grid = torchvision.utils.make_grid(kernel_all, normalize=True, scale_each=True, nrow=in_channels)writer.add_image(f'{name}_all', kernel_grid, global_step=0)

参数直方图可视化

利用直方图可以对每一层参数的分布进行直观展示,便于分析模型参数的学习情况。


全连接层的参数分布如下图所示:

代码示例如下:

def plot_param_hist(writer,model):for name, param in model.named_parameters():writer.add_histogram(f"{name}", param, 0)

激活可视化

输入图像经过第一个卷积层的激活映射:


经过layer2和layer3的激活:

从pytorch模型中获取指定层的权重和激活的代码如下,参考facebook的工程2

class GetWeightAndActivation:"""A class used to get weights and activations from specified layers from a Pytorch model."""def __init__(self, model, layers):"""Args:model (nn.Module): the model containing layers to obtain weights and activations from.layers (list of strings): a list of layer names to obtain weights and activations from.Names are hierarchical, separated by /. For example, If a layer follow a path"s1" ---> "pathway0_stem" ---> "conv", the layer path is "s1/pathway0_stem/conv"."""self.model = modelself.hooks = {}self.layers_names = layers# eval modeself.model.eval()self._register_hooks()def _get_layer(self, layer_name):"""Return a layer (nn.Module Object) given a hierarchical layer name, separated by /.Args:layer_name (str): the name of the layer."""layer_ls = layer_name.split("/")prev_module = self.modelfor layer in layer_ls:prev_module = prev_module._modules[layer]return prev_moduledef _register_single_hook(self, layer_name):"""Register hook to a layer, given layer_name, to obtain activations.Args:layer_name (str): name of the layer."""def hook_fn(module, input, output):self.hooks[layer_name] = output.clone().detach()layer = get_layer(self.model, layer_name)layer.register_forward_hook(hook_fn)def _register_hooks(self):"""Register hooks to layers in `self.layers_names`."""for layer_name in self.layers_names:self._register_single_hook(layer_name)def get_activations(self, input, bboxes=None):"""Obtain all activations from layers that we register hooks for.Args:input (tensors, list of tensors): the model input.bboxes (Optional): Bouding boxes data that might be requiredby the model.Returns:activation_dict (Python dictionary): a dictionary of the pair{layer_name: list of activations}, where activations are outputs returnedby the layer."""input_clone = [inp.clone() for inp in input]if bboxes is not None:preds = self.model(input_clone, bboxes)else:preds = self.model(input_clone)activation_dict = {}for layer_name, hook in self.hooks.items():# list of activations for each instance.activation_dict[layer_name] = hookreturn activation_dict, predsdef get_weights(self):"""Returns weights from registered layers.Returns:weights (Python dictionary): a dictionary of the pair{layer_name: weight}, where weight is the weight tensor."""weights = {}for layer in self.layers_names:cur_layer = get_layer(self.model, layer)if hasattr(cur_layer, "weight"):weights[layer] = cur_layer.weight.clone().detach()else:logger.error("Layer {} does not have weight attribute.".format(layer))return weights

对给定输入进行测试,输出指定层的激活映射,并绘制在tensorboard中:

# 模型测试,避免改变权重
model.eval()# Set up writer for logging to Tensorboard format.
writer = tb.TensorboardWriter(cfg)# 注册指定层的激活hook
layer_ls=["conv1","layer1/1/conv2","layer2/1/conv2","layer3/1/conv2","layer4/1/conv2"]
model_vis = GetWeightAndActivation(model, layer_ls)# 给定一个输入,获取指定层的激活映射
activations, preds = model_vis.get_activations(inputs)# 绘制激活映射(如画在tensorboard中)
plot_weights_and_activations(writer,activations,tag="Input {}/Activations: ".format(0))

小结

本文整理了深度神经网络常用的局部可视化代码,对卷积核、权重和激活映射进行可视化,便于对训练过程进行分析和检查。有需要的朋友可以马住收藏。


  1. https://zhuanlan.zhihu.com/p/54947519 ↩︎

  2. https://github.com/facebookresearch/SlowFast ↩︎

深度神经网络可解释性:卷积核、权重和激活可视化(pytorch+tensorboard)相关推荐

  1. 深度神经网络是否模拟了人类大脑皮层结构

    深度神经网络(DNN)是否模拟了人类大脑皮层结构? 来源:AI科技大本营 微信号 概要:人工智能交融了诸多学科,而目前对人工智能的探索还处于浅层面,我们需要从不同角度和层次来思考,比如人工智能和大脑的 ...

  2. 【深度学习】Pytorch的深度神经网络剪枝应用

    [深度学习]Pytorch的深度神经网络剪枝应用 文章目录 1 概述 2 pytorch基于卷积层通道剪枝的方法 3 模型剪枝:Learning Efficient Convolutional Net ...

  3. pytorch自带网络_使用PyTorch Lightning自动训练你的深度神经网络

    作者:Erfandi Maula Yusnu, Lalu 编译:ronghuaiyang 原文链接 使用PyTorch Lightning自动训练你的深度神经网络​mp.weixin.qq.com 导 ...

  4. 打开深度神经网络黑箱:竟是模块化的?图聚类算法解密权重结构 | ICML 2020

    十三 发自 凹非寺 量子位 报道 | 公众号 QbitAI 深度神经网络这个黑箱子,似乎有了更清晰的轮廓. 我们都知道深度神经网络性能十分强大,但具体效果为什么这么好,权重为什么要这么分配,可能连&q ...

  5. NNs(Neural Networks,神经网络)和Polynomial Regression(多项式回归)等价性之思考,以及深度模型可解释性原理研究与案例...

    1. Main Point 0x1:行文框架 第二章:我们会分别介绍NNs神经网络和PR多项式回归各自的定义和应用场景. 第三章:讨论NNs和PR在数学公式上的等价性,NNs和PR是两个等价的理论方法 ...

  6. 神经网络可解释性、深度学习新方法,2020 年 AI 有哪些势不可挡的研究趋势?...

    来演:雷锋网 2019 年最后一场学术顶会告诉我们 2020 年该研究什么! 文 | MrBear 作为 2019 年最后一场重量级的人工智能国际学术顶会,NeurIPS 2019 所反映出的一些人工 ...

  7. 打开深度学习的黑盒,详解神经网络可解释性

    深度学习的可解释性研究在近年来顶会的录取文献词云上频频上榜,越来越多的研究工作表明,打开深度学习的黑盒并不是那么遥不可及.这些工作令人们更加信赖深度学习算法生成的结果,也通过分析模型工作的机理,让新的 ...

  8. 【深度学习】基于深度神经网络进行权重剪枝的算法(二)

    [深度学习]基于深度神经网络进行权重剪枝的算法(二) 文章目录 1 摘要 2 介绍 3 OBD 4 一个例子 1 摘要 通过从网络中删除不重要的权重,可以有更好的泛化能力.需求更少的训练样本.更少的学 ...

  9. 【深度学习】基于深度神经网络进行权重剪枝的算法(一)

    [深度学习]基于深度神经网络进行权重剪枝的算法(一) 1 pruning 2 代码例子 3 tensorflow2 keras 权重剪裁(tensorflow-model-optimization)3 ...

最新文章

  1. 完全平方数(打表+二分)
  2. 分享10个实用的超绚CSS3按钮设计
  3. 软件项目开发应写的13类文档
  4. 南邮 Android 课程设计,南邮大四课程设计.doc
  5. qpython3可视图形界面_python GUI库图形界面开发之PyQt5窗口控件QWidget详细使用方法...
  6. 哈夫曼编码压缩率计算_考研经验分享(哈工大计算机)
  7. 媒体转码切片_cVideo云转码系统
  8. (转)Spring Boot(七):Mybatis 多数据源最简解决方案
  9. 《Android音视频开发》— Android 书籍
  10. 12.docker inspect
  11. 关于jenkins打包部署
  12. 管家婆软件使用在线支付教程
  13. 软件定义汽车-AUTOSAR解决方案
  14. 安装Kylin Linux Advanced Server V10操作系统
  15. 完全用计算机制作的三维动画,通过四个步骤告诉你三维动画怎么制作
  16. python类的实例化和继承
  17. 工程项目提成标准方案_工程绩效提成奖金方案
  18. 基于深度强化学习的机器人运动控制研究进展 | 无模型强化学习 | 元学习
  19. PBRT-v2在windows下的配置与使用
  20. 求n的阶乘和求n的阶乘和——两种方法

热门文章

  1. 数据仓库入门(实验3)添加主键和关系
  2. [转]Gson的基本使用
  3. NYOJ-491 幸运三角形
  4. xcode 4.2下怎么添加framework?
  5. NextCloud Installation on CentOS 7 server
  6. Tomcat整合APR
  7. 产品经理之市场需求分析详解(非原创)
  8. Android8.1 MTK平台 截屏功能分析
  9. 函数式编程 -- 纯函数、柯里化函数
  10. C#LeetCode刷题之#594-最长和谐子序列​​​​​​​​​​​​​​(Longest Harmonious Subsequence)