我们这里介绍的一种可视化方法,它有助于了解一张图像的哪一部分让卷积神经网络做出了最终的分类决策。这有助于对卷积神经网络的决策过程进行调试,特别是分类错误的情况下。这种方法可以定位图像中的特定目标。

我们使用预训练的VGG网络来演示这种方法。

from keras.applications.vgg16 import VGG16K.clear_session()
model = VGG16(weights='imagenet')


如图所示,这是两只非洲象的图片。我们将这张图片转换为VGG16能够读取的格式:模型大小为224224的图像上进行训练,这些训练图像都根据keras.applications.vgg16.preprocess_input函数中的内置的规则进行预处理。因此,我们需要加载图像,将其大小调整为224224,然后将其转化为float32格式的Numpy张量,并应用这些预处理规则。

from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input, decode_predictions
import numpy as npimg_path = '/Users/fchollet/Downloads/creative_commons_elephant.jpg'img = image.load_img(img_path, target_size=(224, 224))   # 大小为224*224的Python图像库图像x = image.img_to_array(img)  # 形状为(224, 224, 3)的float32格式Numpy数组x = np.expand_dims(x, axis=0)  # 添加一个维度,将数组转化为(1, 224, 224, 3)的形状批量x = preprocess_input(x)   #按批量进行预处理(按通道颜色进行标准化)

可以在图像上运行预训练的VGG16网络,并将预测向量解码为我们可以读的形式。

preds = model.predict(x)
print('Predicted:', decode_predictions(preds, top=3)[0])

Predicted: [(‘n02504458’, ‘African_elephant’, 0.90942144), (‘n01871265’, ‘tusker’, 0.08618243), (‘n02504013’, ‘Indian_elephant’, 0.0043545929)]

对这个图像预测的前三个类别分别是:

  • 非洲象:92.5%的概率
  • 长牙动物:7%的概率
  • 印度象:0.4%的概率

网络认为预测向量中最大激活的元素对应是“非洲象”类别的元素,索引编号386

np.argmax(preds[0])

386

为了展示图像中哪些部分最像非洲象,我们使用Grad-CAM算法:

african_elephant_output = model.output[:, 386]   # 预测向量中的非洲象元素last_conv_layer = model.get_layer('block5_conv3')  # block5_conv3层的输出特征图,它是VGG16的最后一个卷积层grads = K.gradients(african_elephant_output, last_conv_layer.output)[0]   # 非洲象类别相对于block5_conv3输出特征图的梯度pooled_grads = K.mean(grads, axis=(0, 1, 2))   # 形状是(512, )的向量,每个元素是特定特征图通道的梯度平均大小iterate = K.function([model.input], [pooled_grads, last_conv_layer.output[0]])  # 这个函数允许我们获取刚刚定义量的值:对于给定样本图像,pooled_grads和block5_conv3层的输出特征图pooled_grads_value, conv_layer_output_value = iterate([x])  # 给我们两个大象样本图像,这两个量都是Numpy数组for i in range(512):conv_layer_output_value[:, :, i] *= pooled_grads_value[i]  # 将特征图数组的每个通道乘以这个通道对大象类别重要程度heatmap = np.mean(conv_layer_output_value, axis=-1)  # 得到的特征图的逐通道的平均值即为类激活的热力图
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
plt.matshow(heatmap)
plt.show()

最后,我们可以用OpenCV来生成一张图像,将原始图像叠加在刚刚得到的热力图上

import cv2img = cv2.imread(img_path)  # 用cv2加载原始图像heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))  # 将热力图的大小调整为与原始图像相同heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)   # 将热力图应用于原始图像superimposed_img = heatmap * 0.4 + img    # 这里的0.4是热力图强度因子cv2.imwrite('/Users/fchollet/Downloads/elephant_cam.jpg', superimposed_img)   # 将图像保存到硬盘

更多精彩内容,欢迎关注我的微信公众号:数据瞎分析

可视化卷及神经网络热力图相关推荐

  1. keras_猫狗分类案例(三)_卷机神经网络的可视化(可视化类激活的热力图)

    卷机神经网络的可视化(可视化类激活的热力图) 参考:https://www.cnblogs.com/zhhfan/p/9978099.html python深度学习 可视化类激活的热力图 我还要介绍另 ...

  2. SIGIR阿里论文 | 可视化理解深度神经网络CTR预估模型

    小叽导读:尽管业界对于图像处理和自然语言处理领域,在算法可解释性方向上已经取得了一些进展,但对于电商与广告领域,目前还是空白.另一方面,深度学习技术已经开始被大规模应用到广告业务中.广告是很多互联网现 ...

  3. 可视化类激活的热力图

    可视化类激活的热力图 #类激活图:它是指对输入图像生成类激活的热力图. #类激活热力图是与特定输出类别相关的二维分数网格,对任何输入图像的每一个位置都要进行计算,它表示每个位置的重要程度. #这种方法 ...

  4. CVPR 2019 | 基于可解释性以及细粒度的可视化解释卷积神经网络

    作者丨张彪 学校丨北京交通大学硕士生 研究方向丨卷积神经网络的内部可视化(可解释性) 研究目的 卷积神经网络(CNN)已经被证明在许多视觉基准测试上产生了最先进的结果,尽管如此,CNN 的黑盒特性使得 ...

  5. 最全目标检测相关资料整理 (目标检测+数据增强+卷价神经网络+类别不均衡...)

    1 小目标检测: 综述: 综述论文Augmentation for small object detection 深度学习笔记(十)Augmentation for small object dete ...

  6. 基于Python可视化的卷积神经网络的城市感知评估系统

    资源下载地址:https://download.csdn.net/download/sheziqiong/85661101 资源下载地址:https://download.csdn.net/downl ...

  7. Python遥感可视化 — folium模块展示热力图

    欢迎关注博主的微信公众号:"智能遥感". 该公众号将为您奉上Python地学分析.爬虫.数据分析.Web开发.机器学习.深度学习等热门源代码. 本人的GitHub代码资料主页(持续 ...

  8. 在R中对李克特量表(likert)数据进行可视化描述性统计分析,热力图、密度图、柱状图

    在R中对李克特量表带数据进行可视化描述性统计分析 李克特量表是一种常用的社会调查问卷模式.常规论文中对多级的李克特量表数据大多计算均值来进行描述性统计分析,但均值较难表现样本整体分布状况,R中like ...

  9. 卷及神经网络CNN for image retrieval

    背景 对于CBIR的背景以及应用,直接上两张图上来,分别对于下面的图1和图2:图1 图1的背景是自己用matlab在cifar10上生成的背景图,前景图是来自How many public photo ...

最新文章

  1. 社区拼团软件系统开发为什么这么火热?
  2. catia 安装打开闪退_win10catia r20应用程序无法正常启动的解决办法
  3. Linux常用的50个命令
  4. Qt Creator调试Qt Quick示例应用程序
  5. c语言汽车租赁系统实验报告,汽车租赁系统的c语言,数据结构的语言程序
  6. ArduinoUNO实战-第十三章-步进电机驱动实验
  7. linux 硬盘分区,分区,删除分区,格式化,挂载,卸载笔记
  8. windows安全中心(windows defender)对下载内容报毒解决方案
  9. c语言:用二分法求方程在(-10,10)之间的根:2x^3-4x^2+3x-6=0.
  10. PostgreSQL中的索引—7(GIN)
  11. psd2html 阿里,psd2html
  12. 网站访问流程及原理分析
  13. AI处理器-寒武纪NPU芯片简介
  14. 电脑重装系统,微信备份与恢复聊天记录,保存的文件。微信聊天记录迁移
  15. 课题申请的技术指标是什么
  16. matlab求解解析解,Matlab中解析解与数值解的区别
  17. 很难找齐的常识(转收藏)
  18. 山东2021年高考成绩查询状元,2021年山东高考状元多少分,今年山东高考状元资料名单...
  19. 从软件公司的企业文化浅谈什么是管理能力
  20. #离散#ssl 1747 登山机器人问题

热门文章

  1. springboot整合elasticJob实战(纯代码开发三种任务类型用法)以及分片系统,事件追踪详解...
  2. WCF与ASP.NET Core性能比较
  3. cacti监控一览无余
  4. 25个你可能不知道的Linux真相
  5. UCenter实现同步登陆原理
  6. 虚拟化部署之Hyper-V虚拟网络配置
  7. .NET1.0升级至2.0十个问题
  8. const 的学习(转载)
  9. python删除中文停用词_python词云 wordcloud+jieba生成中文词云图
  10. OverFeat: Integrated Recognition, Localization and Detection using Convolutional Networks