文章目录

前言

一、效果图

二、使用步骤

1.使用方法

2.注意事项

总结

参考



前言

最近写论文需要观察中间特征层的特征图,使用的是yolov5的代码仓库,但是苦于找不到很好的轮子,于是参考了很多,只找了这个,但是我觉得作者写的太复杂了(我之前就是这个作者的小粉丝),在参考了github的yolov5作者给出的issue建议后,自己写了个轮子,没有复杂的步骤,借助torchvision中的transforms将tensor转化为PIL,再通过matplotlib保存特图。希望能给大家带来一些帮助。


一、效果图

先上一下效果图,因为深层的特征有高达1024个,这里我只打印了8*8的特征图,用plt.subplot将64张特征图展示在一张图片上。原图为我在百度上随便搜的猫咪:

这是yolov5x.pt进行detect过程中,经过可视化后的第一个C3模块的前64张特征图:

这里也可以设置为灰度图,后续代码中会给出。

可以看到不同特征图所提取到的特征几乎都不相同,有的侧重边缘,有的则是侧重整体,当然这只是第一个C3的特征图,相对于更深层的特征来说,浅层的特征大多是完整的,而更深层的特征则会更小,而且是提取到的细小特征,当然,这些特征图也都是相互联系的,网络结构是个整体。

借助yolov5作者在issue里说到的:

BTW, a single feature map may be in my opinion a shallow set of information, as you are looking at a 2d spatial slice but are not aptly observing relationships across the feature space (as the convolutions do).

I guess an analogy is that you would be viewing the R, G, B layers of a color image by themselves, when it helps to view them together to get the complete picture.

单个特征图可能是一组浅层信息,因为你正在查看 2d 空间切片,但并未恰当地观察特征空间中的关系(如卷积所做的那样)。

这里是我自己的理解,通过特征图的可视化,也进一步的理解了卷积到底干了些什么事情,如果有想进一步交流的小伙伴,私信一起讨论,一起学习呀。

二、使用步骤

1.使用方法

使用方法很简单,只需要在utils中的general.py或者plots.py添加如下函数:

import matplotlib.pyplot as plt
from torchvision import transformsdef feature_visualization(features, model_type, model_id, feature_num=64):"""features: The feature map which you need to visualizationmodel_type: The type of feature mapmodel_id: The id of feature mapfeature_num: The amount of visualization you need"""save_dir = "features/"if not os.path.exists(save_dir):os.makedirs(save_dir)# print(features.shape)# block by channel dimensionblocks = torch.chunk(features, features.shape[1], dim=1)# # size of feature# size = features.shape[2], features.shape[3]plt.figure()for i in range(feature_num):torch.squeeze(blocks[i])feature = transforms.ToPILImage()(blocks[i].squeeze())# print(feature)ax = plt.subplot(int(math.sqrt(feature_num)), int(math.sqrt(feature_num)), i+1)ax.set_xticks([])ax.set_yticks([])plt.imshow(feature)# gray feature# plt.imshow(feature, cmap='gray')# plt.show()plt.savefig(save_dir + '{}_{}_feature_map_{}.png'.format(model_type.split('.')[2], model_id, feature_num), dpi=300)

接着在models中的yolo.py中的这个地方:

def forward_once(self, x, profile=False):y, dt = [], []  # outputsfor m in self.model:if m.f != -1:  # if not from previous layerx = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layersif profile:o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0  # FLOPSt = time_synchronized()for _ in range(10):_ = m(x)dt.append((time_synchronized() - t) * 100)print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))x = m(x)  # runy.append(x if m.i in self.save else None)  # save output# add in hereif profile:print('%.1fms total' % sum(dt))return x

添加如下代码:

            feature_vis = Trueif m.type == 'models.common.C3' and feature_vis:print(m.type, m.i)feature_visualization(x, m.type, m.i)

添加在yolo.py后,无论是在detect.py还是在train.py中都会进行可视化特征图。

然而训练的过程中并不一定需要一直可视化特征图,feature_vis参数是用来控制是否保存可视化特征图的,保存的特征图会存在features文件夹中。如果想看其它层的特征只需要修改m.type或是用m.i来进行判断是否可视化特征图。m.type对应的是yaml文件中的module,即yolov5的基础模块,例如c3,conv,spp等等,而m.i则更好理解,即是模块的id,通常就是顺序,如果你尝试修改过配置文件,那么你肯定知道是什么。

如果不明白,多使用print函数,用list.len()和tensor.size去查看列表长度和张量维度,打印出来你就知道了。

这里有一个点我很迷惑,不知道有没有大佬可以告诉我原因,就是我并没有找到yolo.py和detect.py之间的关联,detect.py中使用的是:

model = attempt_load(weights, map_location=device)

而并没有使用yolo.py中的Model函数,但是运行detect.py同样可以可视化特征图,不是很懂pytorch代码中的这个机制,希望有大佬可以指教一下,代码还是有些菜。

2.注意事项

注意1:在yolo.py的开头import feature_visualization:

from utils.general import feature_visualization

注意2:yolov5无论是在detect还是在train的过程中,都会先对模型进行Summary,即验证你的模型的层数,参数以及是否有梯度,这个过程也会保存特征图,但是不要担心,因为你保存的特征图名字是相同的,会被覆盖,如果你打印的出来log就会看到整个模型跑了两次:

Model Summary: 476 layers, 87730285 parameters, 0 gradients

注意3:建议训练完成的网络使用detect.py来进行验证特征图。

当然在yolo.py里面也可以将'__main__'中的 :

model = Model(opt.cfg).to(device)

替换为:

model = attempt_load(opt.weights, map_location=device)

同样可以跑通(把detect.py中的opt.weights复制过来)。在yolo.py中打开Profile,将随机生成的图片换成自己的图片,就可以正常的进行验证。


总结

周末摸鱼时间写了这个(也不算摸鱼,下周该写论文初稿了orz),希望给大家带来帮助,如果有疑问或者错误,在评论区或者私信联系我,之后我会把这个提交一个pr到yolov5的官方仓库里(之前提交了一个visdrone.yaml的配置文件,幸被采用了,参考的就是这个作者的代码,感谢!),就到这里,最后上一个spp结构的特征图输出,希望和大家一起讨论。

以上。

参考

pytorch特征图可视化

pytorch 提取卷积神经网络的特征图可视化

深度学习笔记~卷积网络中特征图的可视化

自用代码 | YOLOv5 特征图可视化代码

将tensor张量转换成图片格式并保存

Pytorch中Tensor与各种图像格式的相互转化

yolov5特征图可视化相关推荐

  1. 收藏 | PyTorch模型训练特征图可视化(TensorboardX)

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨Pa ...

  2. caffe之特征图可视化及特征提取

    上一篇博客,介绍了怎么对训练好的model的各层权重可视化,这篇博客,我们介绍测试图片输入网络后产生的特征图的可视化 记得上篇中,我们是写了一个新的文件test.cpp,然后编译运行那个文件的,这是因 ...

  3. CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化

    AI:CNN神经网络猫狗分类经典案例,深度学习过程中间层激活特征图可视化 基于前文 https://zhangphil.blog.csdn.net/article/details/103581736 ...

  4. CNN可视化技术总结(一)--特征图可视化

    导言: 在CV很多方向所谓改进模型,改进网络,都是在按照人的主观思想在改进,常常在说CNN的本质是提取特征,但并不知道它提取了什么特征,哪些区域对于识别真正起作用,也不知道网络是根据什么得出了分类结果 ...

  5. caffe for windows的matlab接口(四):权重和特征图可视化的一个例子

    模型读取 参照三,想实现一个自己图像的可视化过程: 首先我发现自己训练出的model没有deploy文件.查阅了下:"如果要把训练好的模型拿来测试新的图片,那必须得要一个deploy.pro ...

  6. 卷积神经网络特征图可视化热图可视化

    文章目录 前言 一.可视化特征图 二.热力图可视化(图像分类) 总结 前言 使用pytorch中的钩子将特征图和梯度勾出来,从而达到可视化特征图(featuremap)和可视化热图(heatmap)的 ...

  7. 卷积神经网络特征图可视化及其意义

    文章目录 特征图可视化方法 1. tensor->numpy->plt.save 2. register_forward_pre_hook函数实现特征图获取 3. 反卷积可视化 特征图可视 ...

  8. 深度学习网络和特征图可视化的工具介绍

    1.深度学习网络结构画图工具: 网络结构画图工具https://cbovar.github.io/ConvNetDraw/ 输入:层信息 输出:网络结构图 网络结构图实例 2.caffe可视化工具 输 ...

  9. 【总结】Keras+VGG16特征图可视化,帮助你深入理解VGG16

    Keras+VGG16特征图可视化 一.VGG16结构理解 1. 可视化结构图 2. VGGNet各级别网络结构图 3. VGG16网络结构图 二.Keras实现VGG16 代码实现 三.VGG16特 ...

最新文章

  1. 创新工场2018年夏令营DeepCamp第一套解答笔记
  2. Java中的occur_time,PLSQL报错: ORA-12170:TNS connect timeout occurred
  3. boost::hana模块在无限可迭代对象上测试 hana::index_if
  4. 乐高创意机器人moc_乐高变形金刚爵士方头仔MOC图纸
  5. CF889E-Mod Mod Mod【dp】
  6. iOS Crash常规跟踪方法及Bugly集成运用
  7. 【转】为什么要用GIT而不是SVN
  8. xp计算机关闭139端口,关闭139端口,小编告诉你如何关闭139端口
  9. C# WebApi 返回详细错误信息
  10. 塞拉菲娜创始人 - 钰儿
  11. 嵌入式学习:裸机开发_L4_官方SDK开发LED实验
  12. 吊打面试官系列之:掌握了这166个Linux常用命令,面试官果然被我征服了。。
  13. mysql删除视图sql语句_怎么样删除视图中的全部数据 用SQL语言编写。
  14. mvn编译“Cannot find matching toolchain definitions for the following toolchain types“报错解决方法
  15. python绘制图形沙漏_论计时沙漏对于学习python的重要性
  16. 设置模式之-------原型模型
  17. Base64与图片之间互相转换
  18. OpenGL 3.0,等得花儿都谢了
  19. 天下宝藏手游 服务器维护好久,2017年5月4日服务器停机维护公告
  20. 【辅助驾驶】图像拼接[3]——车载全景可视系统SurroundView

热门文章

  1. spring boot后台管理系统
  2. iphonex适配游戏_王者荣耀Iphone X出现问题怎么办_王者荣耀iPhoneX适配版本常见问题说明_游戏吧...
  3. 手把手教你开发一款1024程序员节日历提醒服务
  4. Python数据可视化:象限图的应用
  5. 王者服务器维护段位掉了,王者荣耀更新掉段机制是什么 S22赛季段位继承规则介绍...
  6. Java中JPS命令监控
  7. FP5207 升压 5V9V12V24V36V42V大功率芯片
  8. creator小功能----关于帧动画Animation和骨骼动画Skeleton一些有趣的东西
  9. 互联网七字诀:专注、极致、口碑、快(雷总提出)
  10. CSS —— 背景图片填满DIV、鼠标滑过放大图片