CAM激活图可视化系列(代码直接可用)——GradCAM(Pytorch官方版plus+可自定义修改版)

  • 原理
  • 官方模块安装
  • 注意
  • GradCAM代码(本人修改:直接可用——官方版PLUS)
  • 结果对比
  • GradCAM代码(简书+本人修改:直接可用——可自定义修改版)
  • 另一组实验结果:
  • 后记:GradCAM本身其实只是一个工具,最后呈现的效果还是由所用模型+(类别编码)本身决定
  • PLUS相关参考

原理

原作者论文CAM原理图
特征图关于这个类别分数的梯度(维度为[C, H, W])。最后对特征图梯度的空间维度计算平均值,得到与类别信息有关且与特征图通道数一致的权重,再根据权重将原图与热力图(激活图)叠加即可。其中叠加的热力图要与原图大小一致(所以其中可能有加入一些如插值等的图像增强操作)

官方模块安装

pip install grad-cam

注意

GradCAM代码(本人修改:直接可用——官方版PLUS)

import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget#####torch 模块项目(官方版+进阶使用)来源:https://github.com/jacobgil/pytorch-grad-cam####
####原理参考:https://blog.csdn.net/qq_37541097/article/details/123089851
######CAM 模块化代码(黑盒)def Grad_CAM_perBox(model, target_layers, img_path, save_path, target_category):image = img_path# 此部分可抽取出来做成模块,然后列表使用data_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg_path = imageassert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path).convert('RGB')img = np.array(img, dtype=np.uint8)# [N, C, H, W]img_tensor = data_transform(img)# expand batch dimensioninput_tensor = torch.unsqueeze(img_tensor, dim=0)cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)#####torch官方用法target_category = target_category####target_category = None  ####默认获取最大概率类别获得CAM图grayscale_cam = cam(input_tensor=input_tensor, targets=target_category)# #####按自己输入的类别编码获取CAM图# targets = [ClassifierOutputTarget(target_category)]## grayscale_cam = cam(input_tensor=input_tensor, targets=targets)grayscale_cam = grayscale_cam[0, :]visualization = show_cam_on_image(img.astype(dtype=np.float32) / 255.,grayscale_cam,use_rgb=True)plt.imshow(visualization)plt.savefig(save_path)plt.show()if __name__ == '__main__':####模型可换成自己的或其他的modelsmodel = models.mobilenet_v3_large(pretrained=True)target_layers = [model.features[-1]]# model = models.vgg16(pretrained=True)# target_layers = [model.features]# model = models.resnet34(pretrained=True)# target_layers = [model.layer4]# model = models.regnet_y_800mf(pretrained=True)# target_layers = [model.trunk_output]# model = models.efficientnet_b0(pretrained=True)# target_layers = [model.features]img_path = "./cat.png"save_path = "./CAM.png"# target_category = 281  # tabby, tabby cat# target_category = 254  # pug, pug-dogtarget_category = NoneGrad_CAM_perBox(model, target_layers, img_path, save_path, target_category)

结果对比


GradCAM代码(简书+本人修改:直接可用——可自定义修改版)

######CAM 详细实现原理(可更改代码):https://www.jianshu.com/p/fd2f09dc3cc9
######Grad_CAM 详细实现原理(可更改代码):https://www.jianshu.com/p/fd2f09dc3cc9
import math
import numpy as np
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from typing import Optional, List
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
from torch import Tensor
from matplotlib import cm
from torchvision.transforms.functional import to_pil_imagedef Grad_CAM_perModify(img_path,save_path,net,target_layers):preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])feature_map = []  # 建立列表容器,用于盛放输出特征图def forward_hook(module, inp, outp):  # 定义hookfeature_map.append(outp)  # 把输出装入字典feature_maptarget_layers.register_forward_hook(forward_hook)  # 对net.layer4这一层注册前向传播feature_map = []  # 建立列表容器,用于盛放输出特征图def forward_hook(module, inp, outp):  # 定义hookfeature_map.append(outp)  # 把输出装入字典feature_maptarget_layers.register_forward_hook(forward_hook)  # 对net.layer4这一层注册前向传播grad = []  # 建立列表容器,用于盛放特征图的梯度def backward_hook(module, inp, outp):  # 定义hookgrad.append(outp)  # 把输出装入列表gradtarget_layers.register_full_backward_hook(backward_hook)  # 对net.features这一层注册反向传播orign_img = Image.open(img_path).convert('RGB')  # 打开图片并转换为RGB模型img = preprocess(orign_img)  # 图片预处理img = torch.unsqueeze(img, 0)  # 增加batch维度 [1, 3, 224, 224]# out = net(img.cuda())  # 前向传播out = net(img)  # 前向传播###自动获取预测类别编码cls_idx = torch.argmax(out).item()  # 获取预测类别编码###或者自行指定类别编码# cls_idx = 281score = out[:, cls_idx].sum()  # 获取预测类别分数net.zero_grad()score.backward(retain_graph=True)  # 由预测类别分数反向传播weights = grad[0][0].squeeze(0).mean(dim=(1, 2))  # 获得权重grad_cam = (weights.view(*weights.shape, 1, 1) * feature_map[0].squeeze(0)).sum(0)def _normalize(cams: Tensor) -> Tensor:"""CAM normalization"""cams.sub_(cams.flatten(start_dim=-2).min(-1).values.unsqueeze(-1).unsqueeze(-1))cams.div_(cams.flatten(start_dim=-2).max(-1).values.unsqueeze(-1).unsqueeze(-1))return camsgrad_cam = _normalize(F.relu(grad_cam, inplace=True)).cpu()mask = to_pil_image(grad_cam.detach().numpy(), mode='F')def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = 'jet', alpha: float = 0.6) -> Image.Image:"""Overlay a colormapped mask on a background imageArgs:img: background imagemask: mask to be overlayed in grayscalecolormap: colormap to be applied on the maskalpha: transparency of the background imageReturns:overlayed image"""if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image):raise TypeError('img and mask arguments need to be PIL.Image')if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:raise ValueError('alpha argument is expected to be of type float between 0 and 1')cmap = cm.get_cmap(colormap)# Resize mask and apply colormapoverlay = mask.resize(img.size, resample=Image.BICUBIC)overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 1:]).astype(np.uint8)# Overlay the image with the maskoverlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))return overlayed_imgresult = overlay_mask(orign_img, mask)result.show()result.save(save_path)if __name__ == '__main__':img_path = "./cat2.png"save_path = "./CAM2.png"# net = models.mobilenet_v3_large(pretrained=True)# net = models.vgg11_bn(pretrained=True).cuda()  # 导入模型net = models.vgg11_bn(pretrained=True)  # 导入模型# print(net)### 指定激活(可视化)哪一层target_layers = net.featuresGrad_CAM_perModify(img_path,save_path,net,target_layers)

另一组实验结果:

后记:GradCAM本身其实只是一个工具,最后呈现的效果还是由所用模型+(类别编码)本身决定

PLUS相关参考

原理参考:https://blog.csdn.net/qq_37541097/article/details/123089851
torch 模块项目(官方版+进阶使用)来源:https://github.com/jacobgil/pytorch-grad-cam
CAM 详细实现原理(可更改代码):https://www.jianshu.com/p/fd2f09dc3cc9
Grad_CAM 详细实现原理(可更改代码):https://www.jianshu.com/p/fd2f09dc3cc9

CAMs激活图可视化系列——GradCAM相关推荐

  1. CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化

    AI:CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化 基于前文 https://zhangphil.blog.csdn.net/article/details/103581736 ...

  2. 【计算机视觉】Class Activation Mapping(CAM、GradCAM) 特征定位、激活图

    转载自:https://zhuanlan.zhihu.com/p/51631163 目录 论文来源 GAP(全局平均池化层) CAM(类激活映射) CAM的缺陷 CAM的应用 Grad-CAM 两者区 ...

  3. 简明代码介绍类激活图CAM, GradCAM, GradCAM++

      类激活图(class activation map, CAM)能够显示输入图像各区域对于分类神经网络指定类别提供信息的多少,可以帮助我们更好的理解神经网络的工作过程.关于CAM网上讲的有不少,我这 ...

  4. CNN可视化技术总结(一)--特征图可视化

    导言: 在CV很多方向所谓改进模型,改进网络,都是在按照人的主观思想在改进,常常在说CNN的本质是提取特征,但并不知道它提取了什么特征,哪些区域对于识别真正起作用,也不知道网络是根据什么得出了分类结果 ...

  5. 【花雕动手做】有趣好玩的音乐可视化系列项目(31)--LCD1602液晶屏

    偶然心血来潮,想要做一个音乐可视化的系列专题.这个专题的难度有点高,涉及面也比较广泛,相关的FFT和FHT等算法也相当复杂,不过还是打算从最简单的开始,实际动手做做试验,耐心尝试一下各种方案,逐步积累 ...

  6. 【花雕动手做】有趣好玩的音乐可视化系列项目(30)--P6 LED单元板

    偶然心血来潮,想要做一个音乐可视化的系列专题.这个专题的难度有点高,涉及面也比较广泛,相关的FFT和FHT等算法也相当复杂,不过还是打算从最简单的开始,实际动手做做试验,耐心尝试一下各种方案,逐步积累 ...

  7. 数据可视化系列(三):布局格式定方圆

    前言 期待了好久的datawhale可视化教程终于出来了,这次标题狠有文艺范儿,哈哈哈 这次我主要目的是最近要写篇论文,也正好为以后建模画图打劳基础~ 大家可以多看看官方教程: 中文官方网站:http ...

  8. 数据可视化系列(二):艺术画笔见乾坤

    前言 期待了好久的datawhale可视化教程终于出来了,这次标题狠有文艺范儿,哈哈哈 这次我主要目的是最近要写篇论文,也正好为以后建模画图打劳基础~ 大家可以多看看官方教程: 中文官方网站:http ...

  9. 数据可视化系列(一):Matplotlib初相识

    前言 期待了好久的datawhale可视化教程终于出来了,这次标题狠有文艺范儿,哈哈哈 这次我主要目的是最近要写篇论文,也正好为以后建模画图打捞基础~ 大家可以多看看官方教程: 中文官方网站:http ...

最新文章

  1. JavaScript学习总结(十六)——Javascript闭包(Closure)
  2. 实现商城商品秒杀分析
  3. 以新ICT构建全联接的电力物联网,迈入能源智能时代
  4. 七步确定一个优化项目的难易度
  5. qt 判断ctrl键被按下_惊雷!证监会公告,又一家千亿白马股被按下“暂停键”...
  6. Python实训day02pm【元组、字典、lambda】
  7. 语言语法糖_【c#】几种常用语法糖
  8. 排序之选择排序:简单选择+堆排序
  9. UITableView 重用cell方法edequeueReusableCellWithIdentifier,出现错误
  10. AnnotationConfigApplicationContext ad has not been refreshed yet 错误
  11. Oracle 20c 新特性:持久化内存数据库 - Persistent Memory Database
  12. 数据科学 IPython 笔记本 8.1 matplotlib
  13. 学习总结-《父与子的编程之旅》chapter 9
  14. H5+APP安卓原生插件开发+离线打包
  15. zabbix—监控mysql数据
  16. chrome拓展 --截屏文字识别
  17. nginx配置Strict Transport Security
  18. Synchronized和Reentrantlock的区别
  19. Linux误删数据恢复实验
  20. oracle 修改用户信息表,Oracle批量修改用户表table的表空间 | 学步园

热门文章

  1. 华三交换机配置多个镜像口_【转】交换机端口镜像,如何配置多个观察口
  2. python新建文件夹代码_Python文件夹与文件的操作实现代码
  3. mysql数据库学习之sql调优思路
  4. 教你如何实现一个完美的移动端瀑布流组件(附源码)
  5. 电子科大互加数据库课程作业——ER图设计
  6. Java 面向对象与对象的创建过程及变量
  7. 简易公交车查询系统c语言,公交线路免费api接口代码
  8. 快速将微信文章导成word
  9. 分析:人名搜索Spock会成下个谷歌吗
  10. html5会员管理,微信会员管理系统支持客户微信一键注册成为会员?