文章目录

  • 特征图可视化方法
    • 1. tensor->numpy->plt.save
    • 2. register_forward_pre_hook函数实现特征图获取
    • 3. 反卷积可视化
  • 特征图可视化的意义
    • 1. 改进训练网络结构
    • 2. 删除冗余节点实现模型压缩

特征图可视化方法

1. tensor->numpy->plt.save

以VGG网络可视化为例,参考代码见链接。

  • 不同层的特征图比较
modulelist = list(vgg.features.modules())def to_grayscale(image):# mean valueimage = torch.sum(image, dim=0)image = torch.div(image, image.shape[0])return imagedef layer_outputs(image):outputs = []names = []for layer in modulelist[1:]:outputs.append(layer(image))names.append(str(layer))output_im = []for i in outputs:temp = to_grayscale(i.squeeze(0))output_im.append(temp.data.cpu().numpy())fig = plt.figure()plt.rcParams["figure.figsize"] = (30, 50)for i in range(len(output_im)):a = fig.add_subplot(8, 4, i+1)imgplot = plt.imshow(output_im[i])plt.axis('off')a.set_title(names[i].partition('(')[0], fontsize=30)plt.savefig('layer_outputs.jpg', bbox_inches='tight')
  • 指定层的不同通道特征图比较
def filter_outputs(image, layer_to_visualize):if layer_to_visualize < 0:layer_to_visualize += 31output = Nonename = Nonefor count, layer in enumerate(modulelist[1:]):image = layer(image)if count == layer_to_visualize: output = imagename = str(layer)filters = []output = output.data.squeeze()for i in range(output.shape[0]):filters.append(output[i, :, :])fig = plt.figure()plt.rcParams["figure.figsize"] = (10, 10)for i in range(int(np.sqrt(len(filters))) * int(np.sqrt(len(filters)))):fig.add_subplot(np.sqrt(len(filters)), np.sqrt(len(filters)), i+1)imgplot = plt.imshow(filters[i].cpu())plt.axis('off')

2. register_forward_pre_hook函数实现特征图获取

采用register_forward_pre_hook(hook_func: Callable[..., None])函数获取特征图,括号中的参数是一个需要自行实现的函数名,其参数 module, input, output 固定,分别代表模块名称、一个tensor组成的tuple输入和tensor输出;随后采用torchvision.utils.make_gridtorchvision.utils.save_image将特征图转化为 PIL.Image 类型,存储为png格式图片并保存。保存图片的尺寸与特征图张量尺寸一致。关于上述函数的详细解释可参考博文。
其中由于hook_func参数固定,故定义get_image_name_for_hook函数为不同特征图命名,并定义全局变量COUNT表示特征图在网络结构中的顺序。具体实现如下。

COUNT = 0  # global_para for featuremap naming
IMAGE_FOLDER = './save_image'
INSTANCE_FOLDER = Nonedef hook_func(module, input, output):image_name = get_image_name_for_hook(module)data = output.clone().detach().permute(1, 0, 2, 3)# torchvision.utils.save_image(data, image_name, pad_value=0.5)from PIL import Imagefrom torchvision.utils import make_gridgrid = make_grid(data, nrow=8, padding=2, pad_value=0.5, normalize=False, range=None, scale_each=False)ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()im = Image.fromarray(ndarr)# wandb save from jpg/png filewandb.log({f"{image_name}": wandb.Image(im)})# save locally# im.save(image_path)def get_image_name_for_hook(module):os.makedirs(INSTANCE_FOLDER, exist_ok=True)base_name = str(module).split('(')[0]image_name = '.'  # '.' is surely exist, to make first loop condition Trueglobal COUNTwhile os.path.exists(image_name):COUNT += 1image_name = '%d_%s' % (COUNT, base_name)return image_nameif __name__ == '__main__':# clear output folderif os.path.exists(IMAGE_FOLDER):shutil.rmtree(IMAGE_FOLDER)# TODO: wandb & model initializationmodel.eval()# layers to logmodules_for_plot = (torch.nn.LeakyReLU, torch.nn.BatchNorm2d, torch.nn.Conv2d)for name, module in model.named_modules():if isinstance(module, modules_for_plot):module.register_forward_hook(hook_func)index = 1for idx, batch in enumerate(val_loader):# global COUNTCOUNT = 1INSTANCE_FOLDER = os.path.join(IMAGE_FOLDER, f'{index}_pic')# forwardimages_val = Variable(torch.from_numpy(batch[0]).type(torch.FloatTensor)).cuda()outputs = model(images_val)

3. 反卷积可视化

参考文献:Visualizing and Understanding Convolutional Networks
对特征图 tensor 张量进行反池化-反激活-反卷积得到与原始输入图片尺寸一致的特征图。

  • 反卷积为卷积核转置后进行卷积操作(实为转置卷积);
  • 反激活与激活操作相同,直接调用ReLU函数(保证输出值非负即可);
  • 反池化操作为利用池化过程中记录的激活值位置信息(Switches)复原特征图尺寸,其余位置赋零值。


特征图可视化的意义

1. 改进训练网络结构

图(b)包含过多低频、高频信息,很少有中频信息;图(d)中存在较多混叠伪影。因此对神经网络进行如下改进:

  1. 将卷积核尺寸从11×11缩小为7×7
  2. 将卷积层步长从4缩减为2

改进后对应特征层输出如图(c)和图(e)所示,特征提取结果更为鲜明,无效特征(dead feature map)减少,且特征图更加清晰,混影减少。

2. 删除冗余节点实现模型压缩

可视化结果里有一些纯黑的特征图(下图红色方框标出),即所谓的 dead feature map,且不同的输入数据下固定卷积层的 dead feature map 位置相同。这些 dead feature map 没有办法提供有效信息,又因它们位置固定,因此可以将对应的卷积核从网络中剔除,起到模型压缩的作用。


卷积神经网络特征图可视化及其意义相关推荐

  1. 卷积神经网络特征图可视化热图可视化

    文章目录 前言 一.可视化特征图 二.热力图可视化(图像分类) 总结 前言 使用pytorch中的钩子将特征图和梯度勾出来,从而达到可视化特征图(featuremap)和可视化热图(heatmap)的 ...

  2. 卷积神经网络特征图可视化(自定义网络和VGG网络)

    借助Keras和Opencv实现的神经网络中间层特征图的可视化功能,方便我们研究CNN这个黑盒子里到发生了什么. 自定义网络特征可视化 代码: # coding: utf-8from keras.mo ...

  3. Grad-CAM 神经网络特征图可视化

    参见:https://zhuanlan.zhihu.com/p/269702192 神经网络的可解释性离不开特征图(feature map)的可视化. 如何分析CNN feature map上哪些区域 ...

  4. 神经网络特征图可视化

    一.原理 pytorch 中的hook可以不必改变网络输入输出的结构,方便的获取.改变网络中间层变量的值和梯度.这个功能广泛用于可视化神经网络中间层的feature.gradient.从而诊断神经网络 ...

  5. 卷积神经网络特征图大小计算公式

    基本公式

  6. 卷积神经网络及其特征图可视化

    参考链接:https://www.jianshu.com/p/362b637e2242 参考链接:https://blog.csdn.net/dcrmg/article/details/8125549 ...

  7. 可视化卷积神经网络的过滤器_万字长文:深度卷积神经网络特征可视化技术(CAM)最新综述...

    ↑ 点击蓝字 关注极市平台作者丨皮特潘@知乎来源丨https://zhuanlan.zhihu.com/p/269702192编辑丨极市平台 极市导读 本文通过引用七篇论文来论述CAM技术,对CAM的 ...

  8. CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化

    AI:CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化 基于前文 https://zhangphil.blog.csdn.net/article/details/103581736 ...

  9. 三行代码可视化神经网络特征图

    三行代码可视化神经网络特征图 正文 正文 在科研论文,方案讲解,模型分析中,合理解释特征图是对最终结果的一个加分项.但是之前的一些可视化特征图的方法往往会有一些tedious,于是我在这里给大家推荐一 ...

最新文章

  1. 03-背景音乐及广播
  2. FastReport.net 使用记录
  3. Java 实现图片合成
  4. java day04【 Idea、方法】
  5. Java 开发必备,EasyExcel 操作详解!
  6. CentOS系统配置 iptables防火墙
  7. 复现monodepth2之KITTI数据集准备
  8. eclipse汉化 eclipse汉化版退回英文版详细介绍
  9. 高等数学(第七版)同济大学 习题3-4 个人解答(前8题)
  10. HTML5 界面元素 Canvas 参考手册
  11. 利用Veeam BackupReplication工具实现vsphere虚拟机备份
  12. 北京/苏州内推 | 微软亚洲互联网工程院招聘NLP算法工程师(可实习)
  13. Mac苹果电脑总是自动重启?怎么解决自动重启问题
  14. 双系统正确卸载Ubuntu系统
  15. 组建计算机网络通常采用3种模式,对等网的组建_计算机中的543原则_计算机网络工作模式(2)...
  16. 用于单眼3D物体检测的可学习深度引导卷积
  17. C++ #define用法详解
  18. KNN算法(二) sklearn KNN实践
  19. Kitty用HTML和css咋做,使用 CSS3 绘制 Hello Kitty
  20. 开关电源波纹的产生、测量及抑制,一篇全搞定!

热门文章

  1. 时​钟​周​期​及​秒​(​s​)​ ​毫​秒​(​m​s​)​ ​微​秒​(​μ​s​)​ ​纳​秒​(​n​s​)​ ​皮​秒​(​p​s​)​之​间​转​换
  2. Linux(4):文件属性
  3. jq获取id变量值(Ajax)
  4. mixly for Mac以及Arduino uno开发板的使用
  5. Android textView文字渐变色设置
  6. C++随机马赛克图程序
  7. h5 动画效果常见制作手法
  8. 63-Linux如何解决僵死进程
  9. vnc与windows之间的复制粘贴
  10. SQL Sever数据库损坏修复