借助Keras和Opencv实现的神经网络中间层特征图的可视化功能,方便我们研究CNN这个黑盒子里到发生了什么。

自定义网络特征可视化

代码:

# coding: utf-8from keras.models import Model
import cv2
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers import Activation
from pylab import *
import kerasdef get_row_col(num_pic):squr = num_pic ** 0.5row = round(squr)col = row + 1 if squr - row > 0 else rowreturn row, coldef visualize_feature_map(img_batch):feature_map = np.squeeze(img_batch, axis=0)print(feature_map.shape)feature_map_combination = []plt.figure()num_pic = feature_map.shape[2]row, col = get_row_col(num_pic)for i in range(0, num_pic):feature_map_split = feature_map[:, :, i]feature_map_combination.append(feature_map_split)plt.subplot(row, col, i + 1)plt.imshow(feature_map_split)axis('off')title('feature_map_{}'.format(i))plt.savefig('feature_map.png')plt.show()# 各个特征图按1:1 叠加feature_map_sum = sum(ele for ele in feature_map_combination)plt.imshow(feature_map_sum)plt.savefig("feature_map_sum.png")def create_model():model = Sequential()# 第一层CNN# 第一个参数是卷积核的数量,第二三个参数是卷积核的大小model.add(Convolution2D(9, 5, 5, input_shape=img.shape))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(4, 4)))# 第二层CNNmodel.add(Convolution2D(9, 5, 5, input_shape=img.shape))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(3, 3)))# 第三层CNNmodel.add(Convolution2D(9, 5, 5, input_shape=img.shape))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2, 2)))# 第四层CNNmodel.add(Convolution2D(9, 3, 3, input_shape=img.shape))model.add(Activation('relu'))# model.add(MaxPooling2D(pool_size=(2, 2)))return modelif __name__ == "__main__":img = cv2.imread('001.jpg')model = create_model()img_batch = np.expand_dims(img, axis=0)conv_img = model.predict(img_batch)  # conv_img 卷积结果visualize_feature_map(conv_img)

这里定义了一个4层的卷积,每个卷积层分别包含9个卷积、Relu激活函数和尺度不等的池化操作,系数全部是随机初始化。
输入的原图如下:

第一层卷积后可视化的特征图:

所有第一层特征图1:1融合后整体的特征图:

第二层卷积后可视化的特征图:

所有第二层特征图1:1融合后整体的特征图:

第三层卷积后可视化的特征图:

所有第三层特征图1:1融合后整体的特征图:

第四层卷积后可视化的特征图:

所有第四层特征图1:1融合后整体的特征图:

从不同层可视化出来的特征图大概可以总结出一点规律:

  • 1. 浅层网络提取的是纹理、细节特征
  • 2. 深层网络提取的是轮廓、形状、最强特征(如猫的眼睛区域)
  • 3. 浅层网络包含更多的特征,也具备提取关键特征(如第一组特征图里的第4张特征图,提取出的是猫眼睛特征)的能力
  • 4. 相对而言,层数越深,提取的特征越具有代表性
  • 5. 图像的分辨率是越来越小的

VGG19网络特征可视化

代码:

# coding: utf-8
from keras.applications.vgg19 import VGG19
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
from pylab import *def get_row_col(num_pic):squr = num_pic ** 0.5row = round(squr)col = row + 1 if squr - row > 0 else rowreturn row, coldef visualize_feature_map(img_batch):feature_map = img_batchprint(feature_map.shape)feature_map_combination = []plt.figure()num_pic = feature_map.shape[2]row, col = get_row_col(num_pic)for i in range(0, num_pic):feature_map_split = feature_map[:, :, i]feature_map_combination.append(feature_map_split)plt.subplot(row, col, i + 1)plt.imshow(feature_map_split)axis('off')plt.savefig('feature_map.png')plt.show()# 各个特征图按1:1 叠加feature_map_sum = sum(ele for ele in feature_map_combination)plt.imshow(feature_map_sum)plt.savefig("feature_map_sum.png")if __name__ == "__main__":base_model = VGG19(weights='imagenet', include_top=False)# model = Model(inputs=base_model.input, outputs=base_model.get_layer('block1_pool').output)# model = Model(inputs=base_model.input, outputs=base_model.get_layer('block2_pool').output)# model = Model(inputs=base_model.input, outputs=base_model.get_layer('block3_pool').output)# model = Model(inputs=base_model.input, outputs=base_model.get_layer('block4_pool').output)model = Model(inputs=base_model.input, outputs=base_model.get_layer('block5_pool').output)img_path = '001.jpg'img = image.load_img(img_path)x = image.img_to_array(img)x = np.expand_dims(x, axis=0)x = preprocess_input(x)block_pool_features = model.predict(x)print(block_pool_features.shape)feature = block_pool_features.reshape(block_pool_features.shape[1:])visualize_feature_map(feature)

从第一到第五层的特征图分别如下:

从第一层到第五层各特征图按1:1比例融合后特征依次为:

卷积神经网络特征图可视化(自定义网络和VGG网络)相关推荐

  1. 卷积神经网络特征图可视化及其意义

    文章目录 特征图可视化方法 1. tensor->numpy->plt.save 2. register_forward_pre_hook函数实现特征图获取 3. 反卷积可视化 特征图可视 ...

  2. 卷积神经网络特征图可视化热图可视化

    文章目录 前言 一.可视化特征图 二.热力图可视化(图像分类) 总结 前言 使用pytorch中的钩子将特征图和梯度勾出来,从而达到可视化特征图(featuremap)和可视化热图(heatmap)的 ...

  3. Grad-CAM 神经网络特征图可视化

    参见:https://zhuanlan.zhihu.com/p/269702192 神经网络的可解释性离不开特征图(feature map)的可视化. 如何分析CNN feature map上哪些区域 ...

  4. 神经网络特征图可视化

    一.原理 pytorch 中的hook可以不必改变网络输入输出的结构,方便的获取.改变网络中间层变量的值和梯度.这个功能广泛用于可视化神经网络中间层的feature.gradient.从而诊断神经网络 ...

  5. 卷积神经网络特征图大小计算公式

    基本公式

  6. 卷积神经网络及其特征图可视化

    参考链接:https://www.jianshu.com/p/362b637e2242 参考链接:https://blog.csdn.net/dcrmg/article/details/8125549 ...

  7. 可视化卷积神经网络的过滤器_万字长文:深度卷积神经网络特征可视化技术(CAM)最新综述...

    ↑ 点击蓝字 关注极市平台作者丨皮特潘@知乎来源丨https://zhuanlan.zhihu.com/p/269702192编辑丨极市平台 极市导读 本文通过引用七篇论文来论述CAM技术,对CAM的 ...

  8. 三行代码可视化神经网络特征图

    三行代码可视化神经网络特征图 正文 正文 在科研论文,方案讲解,模型分析中,合理解释特征图是对最终结果的一个加分项.但是之前的一些可视化特征图的方法往往会有一些tedious,于是我在这里给大家推荐一 ...

  9. CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化

    AI:CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化 基于前文 https://zhangphil.blog.csdn.net/article/details/103581736 ...

最新文章

  1. 写了 5 年 Java,这些坑还是没躲过……
  2. python 的回调函数
  3. Spring Cloud Alibaba基础教程:几种服务消费方式(RestTemplate、WebClient、Feign)
  4. ThinkServer TD340服务器安装操作系统[转]
  5. 数据数组赋值_嵌入式-数组赋值
  6. 201521123007《Java程序设计》第13周学习总结
  7. 【专栏】国内外物联网平台初探(篇二:阿里云物联网套件)
  8. Linux三剑客之grep
  9. java学习之路 之 Java集合练习题
  10. EXCEL图表 横坐标日期格式无法修改问题
  11. 概率学A和C公式,Java计算阶乘,不重复三位数
  12. 2021-08-18
  13. w乐ndows update更新失败,黑鲨教你解决Windows系统update更新失败问题
  14. 电脑Svchost.exe 进程占CPU100% 的解决办法
  15. ELS3120代替品MPCS-341 3A 光电耦合器 用于IGBT/MOSFET隔离栅极驱动芯片
  16. linux kill一个进程杀不掉怎么解决?
  17. 接口练习(台灯案例)
  18. FreeSwitch呼入处理流程
  19. 如何避免陷入流量旋涡
  20. 对于lpad与level的理解

热门文章

  1. SNL编译器之词法分析器
  2. iPhone、iPad 即将过气 Apple TV 才是苹果的未来
  3. JavaScript将时间戳转为日期
  4. python开发框架——Django基础知识(十一)
  5. Game - Deppin绿色安装红色警戒(OpenRA-Red Alert)
  6. Rrd 文档 总结(一)
  7. NASA World Wind开源项目配置
  8. 总结java高级面试题
  9. serverlet java_IDEA2019 JavaWeb Serverlet 基础
  10. python 百度ocr安装_Python利用百度文字识别(OCR)服务实现图片文字提取,准确率超高...