本文给出完整代码实现CNN特征的可视化输入图像,也就是简单的deep dream图,有助于更好的理解CNN工作原理,并掌握用梯度上升法生成满足要求输入图像的技术。更清晰美观的deep dream图需要加入一些其他技巧,可以参考我另一篇文章。

1,原理

  深度网络的常规工作过程是,给定输入样本及标签,前向处理后的输出结果和标签计算出某种损失函数,再由损失函数反向传播求网络各参数的梯度,根据此梯度更新参数,以使损失函数减小,逐渐训练得到一个逼近标签的网络。
  有趣的是,我们也可以反向思维,固定网络的参数不变,而是优化输入图像。根据某项指标求得输入图像的梯度,再根据此梯度优化输入,就可以得到满足要求的输入图像。由于往往需要最大化某种指标,与通常最小化损失函数的梯度下降法不同,所以这种方法也被称为梯度上升法。在程序实现上,为了利用SGD,Adam等成熟的梯度下降优化算法,我们只需要给指标加个负号,也就变成和梯度下降一样了。
  如果我们把特征图的某个部分的均值作为最大化的指标,此时就可以得到使指定特征图部分最大响应的输入图,也就能够直观的看出指定部分的特征图到底处理的是什么类型的特征。这个指定的特征图部分可以是某个层,也可以是某个层的某个通道,甚至也可以是某个通道上某个元素。我们先来看单个元素的情况:

2,指定特征图单个元素的最大响应输入图像可视化

import torch
import torchvision.models as models
import cv2
import time
t0 = time.time()model = models.resnet18(pretrained=True).cuda()
batch_size = 1for params in model.parameters():params.requires_grad = False
model.eval()def hook(module,inp,out):global featuresfeatures = outdata = torch.rand(batch_size,3,224,224).cuda()
data.requires_grad=Truemu = torch.Tensor([0.485, 0.456, 0.406]).unsqueeze(-1).unsqueeze(-1).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).unsqueeze(-1).unsqueeze(-1).cuda()
unnormalize = lambda x: x*std + mu
normalize = lambda x: (x-mu)/std#optimizer = torch.optim.SGD([data], lr=1, momentum=0.99) #参数需根据实际情况再调
optimizer = torch.optim.Adam([data], lr=0.1, weight_decay=1e-6)
myhook = model.layer2.register_forward_hook(hook)
n,h,w = 0,3,8
for i in range(4001):x = (data - data.min()) / (data.max() - data.min())x = normalize(x)_ = model(x)loss =  - features[:,n,h,w].mean() #指定元素#loss =  - features[:,n,:,:].mean() #指定通道#loss =  - features.mean() #指定层optimizer.zero_grad()loss.backward()optimizer.step()if i%100==0:print('data.abs().mean():',data.abs().mean().item())print('loss:',loss.item())print('time: %.2f'%(time.time()-t0))
myhook.remove()
data_i = data[0]
data_i = (data_i - data_i.min()) / (data_i.max() - data_i.min())
data_i = data_i.permute(1,2,0).data.cpu().numpy()*255
data_i = data_i[...,::-1].astype('uint8')  #注意cv2使用BGR顺序
cv2.imwrite('./feature_visual/layer2_filter%d.png'%n,data_i)

完整代码如上,我们以torchvision中的预训练resnet18为例,按照大的模块划分,resnet18中有4个layer,画出每个layer中某个元素对应的输入特征图如下:

图1.resnet18不同层指定元素的最大响应输入图像

从图中还可以看出,每个特征元素在输入图像上的感受野有多大,显然越靠后层元素的感受野越大。

3,指定通道

图2.resnet18不同层指定通道最大响应输入图像

可以看出,前序层的输入响应图都是一些均匀的纹理图,显然是图像的更基础的组成要素;随着网络越深,输入响应图则逐渐呈现出更加全局性的一些图像概念,以及更加高级和接近实物的一些图像概念,有些似乎是某种实物的形状特征。
我们还可以看出,单个通道的响应中可以看到多个类别的图像的“影子”,所以通道并不是类别依赖的。

4,指定层

图3.resnet18不同层最大响应输入图像

5,其他一些网络

图4.其他网络最末卷积层不同通道的最大响应输入图像

可以看出,不同网络的风格还是很不相同的,其中densenet不像别的那么恶心,还挺好看。

pytorch简单代码实现deep dream图(即CNN特征可视化 features visualization)相关推荐

  1. CV之IE之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成不同尺寸和质量的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)—五个架构设计思维导图

    CV之IE之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成不同尺寸和质量的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)-五个架构设计思维导图 ...

  2. CV之IG之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成带背景的不同尺寸高质量的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)案例

    CV之IG之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成带背景的不同尺寸高质量的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)案例 目录 基于 ...

  3. CV之IG之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成更高质量的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)案例应用

    CV之IG之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成更高质量的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)案例应用 目录 基于TF框架利 ...

  4. CV之IG之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成更大尺寸的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)案例应用

    CV之IG之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成更大尺寸的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)案例应用 目录 基于TF框架利 ...

  5. CV之IG之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成原始的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)案例应用

    CV之IG之Inception:基于TF框架利用Inception模型+GD算法的某层网络图像生成原始的Deep Dream幻觉梦境图片(特征可视化实现图像可解释性)案例应用 目录 基于TF框架利用I ...

  6. python 热度图_keras CNN卷积核可视化,热度图教程

    卷积核可视化 import matplotlib.pyplot as plt import numpy as np from keras import backend as K from keras. ...

  7. 好玩的deep dream(清晰版,pytorch完整代码)

      本文给出pytorch完整代码实现deep dream,加入了图像金字塔处理和高斯平滑处理,使生成图更加清晰美观.文中还讨论了各种因素对生成图的影响. 1, 完整代码   Deep dream图是 ...

  8. 窥探神经网络:Deep Dream

    Deep Dream : 窥探神经网络模型的内部 通常我们通过使用大量的标记数据训练神经网络模型,以图像识别模型为例,模型通常由多个卷积层堆叠而成,中间还有一些池化和激活的操作,每一个图像从输入层到输 ...

  9. python导入txt文件并绘图-Python实现读取txt文件并画三维图简单代码示例

    记忆力差的孩子得勤做笔记! 刚接触python,最近又需要画一个三维图,然后就找了一大堆资料,看的人头昏脑胀的,今天终于解决了!好了,废话不多说,直接上代码! #由三个一维坐标画三维散点 #codin ...

最新文章

  1. 在linux批量删除多级目录下同一格式的文件
  2. 用java写个简单的直播强求_全网最简单易懂的Netty入门示例,再不会用Netty我直播吃翔...
  3. 【计算机系统设计】实践笔记(1)数据通路构建:取指部件分析
  4. VSTO 得到Office文档的选中内容(Word、Excel、PPT、Outlook)
  5. C++ STL一些注意事项
  6. C++中dynamic_cast的简介
  7. 复合选择器-focus选择器(HTML、CSS)
  8. (3)《Head First HTML与CSS》学习笔记---CSS入门
  9. 计算机专业实习心得,计算机毕业实习心得体会范本5篇
  10. linux内存使用率如何查看,linux内存使用率 linux查看内存
  11. dk 图解计算机科学pdf,DK英语:7套DK经典图解词典,再也不用死记硬背了!
  12. NIOS II --- UART
  13. 【深入理解TcaplusDB技术】入门Tcaplus SQL Driver
  14. 面经_OPPO研究院_数据科学研究员实习岗
  15. sql server delete语句删除行
  16. 原生与H5混合式开发详解
  17. python模拟比赛测试胜率
  18. 计算机用鼠标画图,实现鼠标在电脑上画画
  19. OPENSTACK-1-管理企业OSP部署-验证云上服务的功能性
  20. Android开发——流量统计

热门文章

  1. (笔试题)最大覆盖点
  2. 练习1-17 编写一个程序,打印长度大于80个字符的所有输入行.
  3. java连续输入_java – 要求用户进行多次输入
  4. selenium webdriver(python)_selenium、webdriver及浏览器的关系及对应版本安装
  5. 小鼠皮肤组织细胞悬液制备流程
  6. 第八天学习Java的笔记(方法有参无参,有返回值和无返回值)
  7. Keil forc51安装教程
  8. Android studio | From Zero To One ——安装教程及前期学习总结
  9. 数据库mysql驱动在8.0以上解决时区问题
  10. 下一代 Web 应用模型 —— Progressive Web App (PWA)