论文:ICCV 2017《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》

代码:https://github.com/yizt/Grad-CAM.pytorch/blob/master/main.py
           https://github.com/jacobgil/pytorch-grad-cam/blob/master/grad-cam.py

1、首先定义并训练好CNN网络,原网络结构不用调整。假设网路训练好,得到一个best_net。

class GradCAM(object):"""1: gradients update when input2: backpropatation by the high scores of class"""def __init__(self, net, layer_name):self.net = netself.layer_name = layer_nameself.feature = Noneself.gradient = Noneself.net.eval()self.handlers = []self._register_hook()def _get_features_hook(self, module, input, output):self.feature = output#print("feature shape:{}".format(output.size()))def _get_grads_hook(self, module, input_grad, output_grad):""":param input_grad: tuple, input_grad[0]: Noneinput_grad[1]: weightinput_grad[2]: bias:param output_grad:tuple,length = 1:return:"""self.gradient = output_grad[0]def _register_hook(self):for (name, module) in self.net.named_modules():if name == self.layer_name:self.handlers.append(module.register_forward_hook(self._get_features_hook))self.handlers.append(module.register_backward_hook(self._get_grads_hook))def remove_handlers(self):for handle in self.handlers:handle.remove()def __call__(self, inputs, index=None):""":param inputs: [1,3,H,W]:param index: class id:return:"""self.net.zero_grad()output = self.net(inputs)  # [1,num_classes]if index is None:index = np.argmax(output.cpu().data.numpy())target = output[0][index]target.backward()gradient = self.gradient[0].cpu().data.numpy()  # [C,H,W]weight = np.mean(gradient, axis=(1, 2))  # [C]feature = self.feature[0].cpu().data.numpy()  # [C,H,W]cam = feature * weight[:, np.newaxis, np.newaxis]  # [C,H,W]cam = np.sum(cam, axis=0)  # [H,W]cam = np.maximum(cam, 0)  # ReLU# nomalizationcam -= np.min(cam)cam /= np.max(cam)# resize to 256*256cam = cv2.resize(cam, (256, 256))return camclass GradCamPlusPlus(GradCAM):def __init__(self, net, layer_name):super(GradCamPlusPlus, self).__init__(net, layer_name)def __call__(self, inputs, index=None):""":param inputs: [1,3,H,W]:param index: class id:return:"""self.net.zero_grad()output = self.net(inputs)  # [1,num_classes]if index is None:index = np.argmax(output.cpu().data.numpy())target = output[0][index]target.backward()gradient = self.gradient[0].cpu().data.numpy()  # [C,H,W]gradient = np.maximum(gradient, 0.)  # ReLUindicate = np.where(gradient > 0, 1., 0.)  # 示性函数norm_factor = np.sum(gradient, axis=(1, 2))  # [C]归一化for i in range(len(norm_factor)):norm_factor[i] = 1. / norm_factor[i] if norm_factor[i] > 0. else 0.  # 避免除零alpha = indicate * norm_factor[:, np.newaxis, np.newaxis]  # [C,H,W]weight = np.sum(gradient * alpha, axis=(1, 2))  # [C]  alpha*ReLU(gradient)feature = self.feature[0].cpu().data.numpy()  # [C,H,W]cam = feature * weight[:, np.newaxis, np.newaxis]  # [C,H,W]cam = np.sum(cam, axis=0)  # [H,W]# cam = np.maximum(cam, 0)  # ReLU# nomalizationcam -= np.min(cam)cam /= np.max(cam)# resize cam = cv2.resize(cam, (256, 256))return camclass GuidedBackPropagation(object):def __init__(self, net):self.net = netfor (name, module) in self.net.named_modules():if isinstance(module, nn.ReLU):module.register_backward_hook(self.backward_hook)self.net.eval()@classmethoddef backward_hook(cls, module, grad_in, grad_out):""":param module::param grad_in: tuple,length=1:param grad_out: tuple,length=1:return: tuple(new_grad_in,)"""return torch.clamp(grad_in[0], min=0.0),def __call__(self, inputs, index=None):""":param inputs: [1,3,H,W]:param index: class_id:return:"""self.net.zero_grad()output = self.net(inputs)  # [1,num_classes]if index is None:index = np.argmax(output.cpu().data.numpy())target = output[0][index]target.backward()return inputs.grad[0]  # [3,H,W]
def show_cam_on_image(img, mask):heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)heatmap = np.float32(heatmap) / 255cam = heatmap + np.float32(img)cam = cam / np.max(cam)cv2.imwrite("cam.jpg", np.uint8(255 * cam))root='/data/fjsdata/qtsys/img/sz.002509-20200325.png'
img_list = []
img_list.append( cv2.resize(cv2.imread(root).astype(np.float32), (256, 256)))#(256, 256) is the model input size
inputs = torch.from_numpy(np.array(img_list)).type(torch.FloatTensor).cuda()
# Grad-CAM
#grad_cam = GradCAM(net=best_net, layer_name='conv3')
#mask = grad_cam(inputs.permute(0, 3, 1, 2))  # cam mask
#show_cam_on_image(img_list[0], mask)
#grad_cam.remove_handlers()# Grad-CAM++
#grad_cam_plus_plus = GradCamPlusPlus(net=best_net, layer_name='conv3')
#mask_plus_plus = grad_cam_plus_plus(inputs.permute(0, 3, 1, 2))  # cam mask
#show_cam_on_image(img_list[0], mask)
#grad_cam_plus_plus.remove_handlers()# GuidedBackPropagation
gbp = GuidedBackPropagation(best_net)
inputs = inputs.requires_grad_(True)
inputs.grad.zero_()
grad = gbp(inputs.permute(0, 3, 1, 2))
print(grad)

最后GuidedBackPropagation没完全调通,详细阅读论文后再处理。前面Grad-CAM 和Grad-CAM++可以。

Grad-CAM (CNN可视化) Python示例相关推荐

  1. Class Activation Mapping (CNN可视化) Python示例

    Class Activation Mapping 论文:CVPR2016<Learning Deep Features for Discriminative Localization> 代 ...

  2. CNN可视化!从CVPR 2022出发,聊聊CAM是如何激活我们文章的热度!

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 转载自:极市平台  | 作者:matrix明仔 导读 本文从CVPR2022中三篇不同领域的文章中CAM的表 ...

  3. CNN可视化最新研究方法进展(附结构、算法)

    译者 | reason_W 责编 | 明 明 出品 | AI科技大本营(公众号ID:rgznai100) [AI科技大本营导读]深度学习一直被看做是一个难以解释的"黑匣子".一方面 ...

  4. caffe预测、特征可视化python接口调用

    转载自: 深度学习(九)caffe预测.特征可视化python接口调用 - hjimce的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/hjimce/articl ...

  5. 深度学习(九)caffe预测、特征可视化python接口调用

    caffe预测.特征可视化python接口调用 原文地址:http://blog.csdn.net/hjimce/article/details/48972877 作者:hjimce 网上有很多caf ...

  6. 如何使用Elasticsearch,Logstash和Kibana实时可视化Python中的日志

    by Ritvik Khanna Ritvik Khanna着 如何使用Elasticsearch,Logstash和Kibana实时可视化Python中的日志 (How to use Elastic ...

  7. Excel 数据的统计分析及绘图自动处理的python示例(精益办公实战2)

    Excel 数据统计分析及绘图的自动处理python示例(精益办公实战2) 1.背景描述: "看数不如看表,看表不如看图" 2.数据准备和任务要求: 数据准备 一份已经经过数据清洗 ...

  8. [数据分析与可视化] Python绘制数据地图2-GeoPandas地图可视化

    本文主要介绍GeoPandas结合matplotlib实现地图的基础可视化.GeoPandas是一个Python开源项目,旨在提供丰富而简单的地理空间数据处理接口.GeoPandas扩展了Pandas ...

  9. CNN可视化技术总结(一)--特征图可视化

    导言: 在CV很多方向所谓改进模型,改进网络,都是在按照人的主观思想在改进,常常在说CNN的本质是提取特征,但并不知道它提取了什么特征,哪些区域对于识别真正起作用,也不知道网络是根据什么得出了分类结果 ...

最新文章

  1. Performance Metrics for Binary Classification
  2. 动物为什么会预知地震,地震后为什么会下雨?
  3. matlab axb c,matlab调用C源代码(续)
  4. 民间75个不传之密 ,医院都不知道的秘密
  5. bootstrap 页面分成三列_20分钟成功编写bootstrap响应式页面 就这么简单
  6. unity小技巧总结
  7. PHP 循环时间控制缓冲方法
  8. Nginx笔记总结十六:nginx优化指南
  9. JavaScript 函数基础
  10. FREESPACE 发布 logo v1.1
  11. 最新支持android的手机型号,Andorid10支持手机型号有哪些 安卓10适配机型介绍
  12. 多按键多界面二维数组表驱动设计
  13. uniapp 微信浏览器H5页面自定义分享链接
  14. 永洪bi logo更换
  15. ps6人脸识别液化工具在哪_Photoshop教学:人脸识别液化功能介绍
  16. 【pycharm】复制粘贴快捷键失效
  17. webdriver中的截图截图方法
  18. 机器学习服务文本翻译能力升级,中文直译模型让译文表达更地道!
  19. R语言:scatterplot3d(绘制三维散…
  20. 通过redis-cli批量删除多个指定模式的key

热门文章

  1. 计算机右键管理中没有用户管理,我的电脑右键菜单中没有管理选项如何解决? 我的电脑右键菜单中没有管理选项解决的方法有哪些?...
  2. xshell进行ssh链接报错“所选的用户密钥未在远程主机上注册”处理
  3. 一周一论文(翻译)——[SIGMOD 19] Elasticutor:Rapid Elasticity for Realtime Stateful Stream Processing
  4. 泛珠三角计算机作品大赛2018,2018年泛珠三角大学生计算机作品赛广西赛区选拔赛圆满结束...
  5. 20190703 关于如何驱动
  6. HTML5API(5)
  7. Atitit 关于处理环保行动联盟和动物解放阵线游击队的任命书 委任状
  8. 参考滴滴左右对齐自适应
  9. react native windows 搭建(完整版)
  10. JavaScript 3D图表