Global Average Pooling(GAP)

参考:
深度学习基础系列(十)| Global Average Pooling是否可以替代全连接层?和深度学习|Global Average Pooling
Network In Network中对GAP的描述:
In this paper, we propose another strategy called global average pooling to replace the traditional fully connected layers in CNN. The idea is to generate one feature map for each corresponding category of the classification task in the last mlpconv layer. Instead of adding fully connected layers on top of the feature maps, we take the average of each feature map, and the resulting vector is fed directly into the softmax layer. One advantage of global average pooling over the fully connected layers is that it is more native to the convolution structure by enforcing correspondences between feature maps and categories. Thus the feature maps can be easily interpreted as categories confidence maps. Another advantage is that there is no parameter to optimize in the global average pooling thus overfitting is avoided at this layer. Futhermore, global average pooling sums out the spatial information, thus it is more robust to spatial translations of the input.
用图来表示:

从图中可以直观看出GAP就是对每张特征图取其均值,用这个均值来表示该特征图送入softmax计算。

Grad-CAM

参考文献:
深度学习论文笔记(可解释性)——CAM与Grad-CAM
在讲之前先明确一点:CNN最后一层特征图富含有最为丰富类别语意信息。

import torch
import torch.nn.functional as Fdef find_vgg_layer(arch, target_layer_name):"""Find vgg layer to calculate GradCAM and GradCAM++Args:arch: default torchvision densenet modelstarget_layer_name (str): the name of layer with its hierarchical information. please refer to usages below.target_layer_name = 'features'target_layer_name = 'features_42'target_layer_name = 'classifier'target_layer_name = 'classifier_0'Return:target_layer: found layer. this layer will be hooked to get forward/backward pass information."""hierarchy = target_layer_name.split('_')if len(hierarchy) >= 1:target_layer = arch.featuresif len(hierarchy) == 2:target_layer = target_layer[int(hierarchy[1])]return target_layerclass GradCAM(object):"""Calculate GradCAM salinecy map.A simple example:# initialize a model, model_dict and gradcamresnet = torchvision.models.resnet101(pretrained=True)resnet.eval()model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224))gradcam = GradCAM(model_dict)# get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)img = load_img()normed_img = normalizer(img)# get a GradCAM saliency map on the class index 10.mask, logit = gradcam(normed_img, class_idx=10)# make heatmap from mask and synthesize saliency map using heatmap and imgheatmap, cam_result = visualize_cam(mask, img)Args:model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys.verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict."""def __init__(self, model_dict, verbose=False):model_type = model_dict['type']layer_name = model_dict['layer_name']self.model_arch = model_dict['arch']self.gradients = dict()self.activations = dict()def backward_hook(module, grad_input, grad_output):self.gradients['value'] = grad_output[0]return Nonedef forward_hook(module, input, output):self.activations['value'] = outputreturn Noneif 'vgg' in model_type.lower():target_layer = find_vgg_layer(self.model_arch, layer_name)target_layer.register_forward_hook(forward_hook)target_layer.register_backward_hook(backward_hook)if verbose:try:input_size = model_dict['input_size']except KeyError:print("please specify size of input image in model_dict. e.g. {'input_size':(224, 224)}")passelse:device = 'cuda' if next(self.model_arch.parameters()).is_cuda else 'cpu'self.model_arch(torch.zeros(1, 3, *(input_size), device=device))print('saliency_map size :', self.activations['value'].shape[2:])def forward(self, input, class_idx=None, retain_graph=False):"""Args:input: input image with shape of (1, 3, H, W)class_idx (int): class index for calculating GradCAM.If not specified, the class index that makes the highest model prediction score will be used.Return:mask: saliency map of the same spatial dimension with inputlogit: model output"""b, c, h, w = input.size()logit = self.model_arch(input)print(logit.shape)if class_idx is None:score = logit[:, logit.max(1)[-1]].squeeze()  # get the max socreprint(score)else:score = logit[:, class_idx].squeeze()self.model_arch.zero_grad()score.backward(retain_graph=retain_graph)gradients = self.gradients['value']activations = self.activations['value']# print(gradients.shape, activations.shape)  # torch.Size([1, 512, 14, 14]) torch.Size([1, 512, 14, 14])b, k, u, v = gradients.size()alpha = gradients.view(b, k, -1).mean(2)  # torch.Size([1, 512])# alpha = F.relu(gradients.view(b, k, -1)).mean(2)weights = alpha.view(b, k, 1, 1)  # torch.Size([1, 512, 1, 1])saliency_map = (weights * activations).sum(1, keepdim=True)saliency_map = F.relu(saliency_map)print('saliency_map', saliency_map.shape)saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False)saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).datareturn saliency_map, logitdef __call__(self, input, class_idx=None, retain_graph=False):return self.forward(input, class_idx, retain_graph)

代码解读(注意这里的讲解用的字母都是根据原论文的):
输入预训练模型,要提取的层(这里用vgg16最后一个MaxPool2d()前的relu(),即features_29),使用hook提取features_29层的激活值和梯度,这层的特征图、激活值和梯度大小为 1 × 512 × 14 × 14 1 \times 512 \times 14 \times 14 1×512×14×14。将每一个特征图用一个GAP获得神经元重要性权重 α k c \alpha_k^c αkc​,对应代码和公式:
α k c = 1 Z ∑ i ∑ j ∂ y c ∂ A i j k \alpha_k^c=\frac{1}{Z}\sum_i \sum_j{\frac{\partial y^c}{\partial A^k_{ij}}} αkc​=Z1​i∑​j∑​∂Aijk​∂yc​

alpha = gradients.view(b, k, -1).mean(2)  # torch.Size([1, 512])
weights = alpha.view(b, k, 1, 1)  # torch.Size([1, 512, 1, 1])

We perform a weighted combination of forward activation maps, and follow it by a ReLU to obtain:
L G r a d − C A M c = R e L U ( ∑ k α k c A k ) L_{Grad-CAM}^c=ReLU(\sum_k{\alpha_k^cA^k}) LGrad−CAMc​=ReLU(k∑​αkc​Ak)

saliency_map = (weights * activations).sum(1, keepdim=True)
saliency_map = F.relu(saliency_map)

关于上采样的可以看pytorch torch.nn 实现上采样——nn.Upsample

实现结果:从左到右依次为原图,Grad-CAM的heatmap,Grad-CAM叠加后的效果

XAI系列基础知识之Grad-CAM相关推荐

  1. 51单片机系列--基础知识

    51单片机系列--基础知识 主要参数及功能 引脚及功能 工作时序 主要参数及功能 (1)8位CPU (2)4KB程序存储器(ROM) (3)128字节的数据存储器(RAM) (4)32条 I/O 口线 ...

  2. JUC系列——基础知识 day1-1

    JUC系列--基础知识 day1-1 JUC基础知识 进程 线程 进程和线程区别 并行与并发 同步 使用场景 异步 使用情景 QuickStart(new Thread方式创建新线程) 匿名内部类方式 ...

  3. mysql全套基础知识_mysql系列--基础知识

    注:本文为mysql基础知识的总结,基础点很多若是有些不足,还请自行搜索.持续更新 一.mysql简介 数据库简介 数据库是计算机应用系统中的一种专门管理数据资源的系统 数据库是一组经过计算机处理后的 ...

  4. 深度学习——keras教程系列基础知识

    大家好,本期我们将开始一个新的专题的写作,因为有一些小伙伴想了解一下深度学习框架Keras的知识,恰好本人也会一点这个知识,因此就开始尝试着写一写吧.本着和大家一起学习的态度,有什么写的不是很好的地方 ...

  5. ssas连接mysql_SSAS系列基础知识

    1.什么是Cube? 简单 Cube 对象由基本信息.维度和度量值组组成. 基本信息包括多维数据集的名称.多维数据集的默认度量值.数据源和存储模式等.维度是多维数据集中使用的实际维度组.所有维度都必须 ...

  6. 基金投资从入门到精通之一:基础知识篇

    第一篇 基础知识篇 第一节      认识基金 基金投资入门系列--基础知识 1.什么是证券投资基金? 通俗地说,证券投资基金是通过汇集众多投资者的资金,交给银行保管,由专业的基金管理公司负责投资于股 ...

  7. GNN 系列:Graph 基础知识介绍

    点击上方"Datawhale",选择"星标"公众号 第一时间获取价值内容 [导读]图卷积神经网络(Graph Convolutional Network)作为最 ...

  8. [C#基础知识系列]专题十七:深入理解动态类型

    本专题概要: 动态类型介绍 为什么需要动态类型 动态类型的使用 动态类型背后的故事 动态类型的约束 实现动态行为 总结 引言: 终于迎来了我们C# 4中特性了,C# 4主要有两方面的改善--Com 互 ...

  9. [基础知识]Linux新手系列之三

    2019独角兽企业重金招聘Python工程师标准>>> [基础知识]Linux新手系列之三 给Linux新手 [系列之三] Linux相关资料由兄弟连分享 OK,从哪里得到Linux ...

最新文章

  1. python网络编程-异常处理-异常捕获-抛出异常-断言-自定义异常-UDP通信-socketserver模块应用-03
  2. linux远程登录ssh免密码配置方法
  3. 批量删除数据库中有特定开始字符的表、视图和存储过程
  4. JavaXml教程(十)XML作为属性文件使用
  5. jmeterhttp代理服务器_Jmeter使用HTTP代理服务器录制
  6. Win7——Win10系统如何安装Win7系统
  7. 接口自动化测试系列之PHPUnit-POST请求接口测试方法
  8. 学习笔记-数据结构与算法之线性表
  9. html页面的ajax请求,【提问】ajax请求返回整个html页面
  10. UT斯达康互动电视UI界面设计大赛作品 求拍砖
  11. 一页纸商业计划书 (Business Plan) 模板
  12. Qt ui 到底是什么?
  13. MacPro 迁移至 Mac Mini-M1 与 踩坑 For 后端开发
  14. js库笔记(一):swr ahooks
  15. vue接入下载文件接口
  16. java jodd框架介绍及使用示例
  17. Golang 对接宝付、通联、富友金账户...填坑记
  18. 汇总|CVPR 2021 自动驾驶相关论文
  19. macos iTerm2 优化
  20. denoiser插件_红巨人降噪磨皮调色插件套装 Red Giant Magic Bullet Suite v13.0.4 Win/Mac

热门文章

  1. 2015最新最全 Android 谷歌消息推送GCM 详细使用教程
  2. JavaScript字符串去重
  3. JavaScript 字符串:字符串相加
  4. 从零开始,直到···
  5. 人脸识别时,一定要穿衣服!要不然……
  6. 简单聊聊 Perlin 噪声(下篇)
  7. VR科普研学体验馆VR交通模拟体验设备厂家
  8. Springboot解决业务并发问题
  9. 夏普迎来108周年庆,全方位发力8K+5G
  10. ELECTRON-VUE相关报错