文章目录

  • 前言
  • 1. CAM(Class Activation Map)
  • 2. Grad-CAM
  • 3. PyTorch中的hook机制
  • 4. Grad-CAM的PyTorch简洁实现
  • 参考资料

前言

CNN中的特征可视化大体可分为两类:

  • 细节信息:ZFNet中使用的deconvolution,改进的guide backpropagation
  • 信息的重要性区分:类激活图(CAM),改进的Grad-CAM

第一类方法只显示了在深层特征中保留了哪些信息,而没有突出显示这些信息的相对重要性。第二类方法则具有一定的解释性,例如在分类任务中,通过CAM能够解释模型究竟是通过重点学习输入图像中的哪些信息来判断类别的。

1. CAM(Class Activation Map)

Network in Network中提出了用全局平均池化(GAP)替代全连接层以加强特征映射与类别之间的联系,更具可解释性。受该思想启发,CAM可视化技术应运而出。生成CAM的流程如下图所示(论文原图):

可以看出,生成CAM的步骤非常简单,但是对网络结构有要求(网络末端为GAP+FC这样的结构,并且FC只有一层,用于输出类别概率)。假设分类任务采用的是VGG网络,此时生成CAM的步骤为:

  1. 将VGG中的前两个FC替换为GAP,重新训练;
  2. 获取最后一个卷积层输出的特征图[f1,f2,...,fn][f_1, f_2, ..., f_n][f1​,f2​,...,fn​],以及全连接层的权重[w1,w2,...,wn][w_1, w_2, ..., w_n][w1​,w2​,...,wn​];
  3. 计算CAM=∑i=1nwifiCAM=\sum_{i=1}^{n}w_if_iCAM=∑i=1n​wi​fi​

不难发现,若网络结构不符合要求,按照上述方法计算CAM需要修改网络结构和重新训练。针对该问题,后续研究中提出了Gard-CAM。

2. Grad-CAM

由上述CAM的计算方法可知,生成CAM的关键是获取特征图的权重。基于对原始CAM的改进,Grad-CAM通过求网络输出的类别置信度对特征图的偏导来获取权重,适用于任意网络,并且能够可视化任意层的类激活图(通常选择最后一个卷积层,因为其包含了丰富的高级语义和空间信息)。

  • 生成Grad-CAM的步骤如下:
  1. 图片送入网络,前向传播,获取最后一个卷积层的特征图AkA^kAk(可选,任意层均可,kkk为通道index);
  2. 反向传播,获取网络输出的类别 ccc 的概率ycy^cyc关于AkA^kAk的梯度∂yc∂Ak\frac{\partial y^c}{\partial A^k}∂Ak∂yc​;
  3. 计算权重αkc=1Z∑i∑j∂yc∂Ai,jk\alpha^{c}_{k}=\frac{1}{Z}\sum\limits_{i}\sum\limits_{j}\frac{\partial y^c}{\partial A^k_{i,j}}αkc​=Z1​i∑​j∑​∂Ai,jk​∂yc​
  4. 计算Grad-CAM:LGrad−CAMc=ReLU(∑kαkcAk)L_{Grad-CAM}^{c}=ReLU(\sum\limits_{k}\alpha^{c}_{k}A^k)LGrad−CAMc​=ReLU(k∑​αkc​Ak)
  • 求偏导的意义:参考知乎中的文章,偏导表示输出关于输入的变化率,也就是特征图上变化一个单位,得到的输出变化多少单位。可以反映出输出ycy^cyc关于Ai,jkA^k_{i,j}Ai,jk​的敏感程度,如果梯度大,则非常敏感,表示该位置更有可能属于类别 ccc。

3. PyTorch中的hook机制

  • PyTorch中设计hook的目的:在不改变网络代码、不在forward中返回某一层的输出的情况下,获取网络中某一层在前向传播或反向传播过程的输入和输出,并对其进行相关操作(例如:特征图可视化,梯度裁剪)。

4. Grad-CAM的PyTorch简洁实现

import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import torchvision.models as models
from torchvision.transforms import Compose, Normalize, ToTensorclass GradCAM():'''Grad-cam: Visual explanations from deep networks via gradient-based localizationSelvaraju R R, Cogswell M, Das A, et al. https://openaccess.thecvf.com/content_iccv_2017/html/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.html'''def __init__(self, model, target_layers, use_cuda=True):super(GradCAM).__init__()self.use_cuda = use_cudaself.model = modelself.target_layers = target_layersself.target_layers.register_forward_hook(self.forward_hook)self.target_layers.register_full_backward_hook(self.backward_hook)self.activations = []self.grads = []def forward_hook(self, module, input, output):self.activations.append(output[0])def backward_hook(self, module, grad_input, grad_output):self.grads.append(grad_output[0].detach())def calculate_cam(self, model_input):if self.use_cuda:device = torch.device('cuda')self.model.to(device)                 # Module.to() is in-place method model_input = model_input.to(device)  # Tensor.to() is not a in-place methodself.model.eval()# forwardy_hat = self.model(model_input)max_class = np.argmax(y_hat.cpu().data.numpy(), axis=1)# backwardmodel.zero_grad()y_c = y_hat[0, max_class]y_c.backward()# get activations and gradientsactivations = self.activations[0].cpu().data.numpy().squeeze()grads = self.grads[0].cpu().data.numpy().squeeze()# calculate weightsweights = np.mean(grads.reshape(grads.shape[0], -1), axis=1)weights = weights.reshape(-1, 1, 1)cam = (weights * activations).sum(axis=0)cam = np.maximum(cam, 0) # ReLUcam = cam / cam.max()return cam@staticmethoddef show_cam_on_image(image, cam):# image: [H,W,C]h, w = image.shape[:2]cam = cv2.resize(cam, (h,w))cam = cam / cam.max()heatmap = cv2.applyColorMap((255*cam).astype(np.uint8), cv2.COLORMAP_JET) # [H,W,C]heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)image = image / image.max()heatmap = heatmap / heatmap.max()result = 0.4*heatmap + 0.6*imageresult = result / result.max()plt.figure()plt.imshow((result*255).astype(np.uint8))plt.colorbar(shrink=0.8)plt.tight_layout()plt.show()@staticmethoddef preprocess_image(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):preprocessing = Compose([ToTensor(),Normalize(mean=mean, std=std)])return preprocessing(img.copy()).unsqueeze(0) if __name__ == '__main__':image = cv2.imread('both.png') # (224,224,3)input_tensor = GradCAM.preprocess_image(image)model = models.resnet18(pretrained=True)grad_cam = GradCAM(model, model.layer4[-1], 224)cam = grad_cam.calculate_cam(input_tensor)GradCAM.show_cam_on_image(image, cam)
  • 测试结果

    (https://github.com/jacobgil/pytorch-grad-cam/blob/master/examples/both.png)

参考资料

  • CAM论文
  • Grad-CAM论文
  • 如何使用 PyTorch Hook
  • Grad-cam:原理及pytorch实现
  • Grad-CAM 原理和实现

CNN可视化技术 -- CAM Grad-CAM详解及pytorch简洁实现相关推荐

  1. CNN可视化技术总结(一)--特征图可视化

    导言: 在CV很多方向所谓改进模型,改进网络,都是在按照人的主观思想在改进,常常在说CNN的本质是提取特征,但并不知道它提取了什么特征,哪些区域对于识别真正起作用,也不知道网络是根据什么得出了分类结果 ...

  2. DeepLearning tutorial(4)CNN卷积神经网络原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43225445 DeepLearning tutorial(4)CNN卷积神经网络原理简介 ...

  3. qml学习笔记(二):可视化元素基类Item详解(上半场anchors等等)

    原博主博客地址:http://blog.csdn.net/qq21497936 本文章博客地址:http://blog.csdn.net/qq21497936/article/details/7851 ...

  4. R语言使用survminer包生存分析及可视化(ggsurvplot)实战详解:从数据集导入、生存对象生成、ggsurvplot可视化参数配置、设置、可视化对比

    R语言使用survminer包生存分析及可视化(ggsurvplot)实战详解:从数据集导入.生存对象生成.ggsurvplot可视化参数配置.设置.可视化对比 目录 R语言使用survminer包生 ...

  5. 可视化数据库管理工具DataGrip使用详解

    参考链接:https://www.hangge.com/blog/cache/detail_2829.html 日常开发中少不了各种可视化数据库管理工具.如果需要同时能连接多种数据库,大家肯定都会想到 ...

  6. 高可用集群技术之corosync应用详解(一)

    Corosync概述: Corosync是集群管理套件的一部分,它在传递信息的时候可以通过一个简单的配置文件来定义信息传递的方式和协议等.它是一个新兴的软件,2008年推出,但其实它并不是一个真正意义 ...

  7. 大数据技术Hbase 和 Hive 详解

    目录 两者的特点 各自的限制 应用场景 大数据技术Hbase 和 Hive 详解, 今天给大家介绍一下关于零基础学习大数据视频教程之HBASE 和 HIVE 是多么重要的技术,那么两者有什么区别呢 ? ...

  8. 【Dash搭建可视化网站】项目13:销售数据可视化大屏制作步骤详解

    销售数据可视化大屏制作步骤详解 1 项目效果图 2 项目架构 3 文件介绍和功能完善 3.1 assets文件夹介绍 3.2 app.py和index.py文件完善 3.3 header.py文件完善 ...

  9. 【转载】城域网IPv6过渡技术—NAT444与DS-lite详解

    城域网IPv6过渡技术-NAT444与DS-lite详解 转自 https://network.51cto.com/art/201311/419211.htm### 文章目录 城域网IPv6过渡技术- ...

最新文章

  1. ylbtech-Unitity-CS:Hello world
  2. U盘无法拷贝超过4G的大文件
  3. 爬虫必须得会的预备知识
  4. php如何加网址链接,怎么给一个PHP密码访问页面加超链接
  5. 计算机应用基础word教程,计算机应用基础-文字处理word教程PPT课件.ppt
  6. 鱼骨图分析法实际案例_8D根本原因分析——5WHY与鱼骨图培训课件(PPT64完整详细)...
  7. MySQL之视图、触发器、事务、存储过程、函数
  8. CES现场直击 AI让你现场获得虚拟双胞胎
  9. jQuery右键菜单ContextMenu使用笔记
  10. ModelAttribue注解的使用
  11. spring mvc 页面跳转 携带数据的两种方式
  12. Windows98 win98.bif 文件
  13. drozer 找不到java_自己安装drozer时出现各种问题的解决
  14. 请问下面这段代码哪里有错? private static final String s=
  15. 淘宝评论爬取 python pandas
  16. 2021年N1叉车司机模拟考试题库软件及全国真题汇总
  17. 【7gyy】高手分享辨别电脑病毒技巧
  18. mysql-5.7 基础篇
  19. 机器学习笔记:随机深度网络 stochastic depth
  20. 如何使用在线客服转接功能

热门文章

  1. 无创血糖检测技术研究进展
  2. 应用层(计网_06)
  3. prefixTreeEspan 频繁子树模式挖掘 A pattern growth 算法实现 mining embedded subtrees.
  4. MATLAB中的共轭转置与转置
  5. commvault 配置mysql_Commvault_Oracle DG恢复到单机操作手册
  6. 直流电动机输出功率与转速的关系问题
  7. css音阶波浪动画图,线性渐变色
  8. us域名在哪里注册_us域名,什么是us域名,注册us域名有什么优势
  9. 2022年登高架设考试题模拟考试题库及模拟考试
  10. 尚未领取号牌和行驶证的机动车需要临时上道路行驶的,怎么办?