1.前言

本文就使用tensorboard进行高维向量可视化过程中出现的一些bug和问题,进行一次总结,帮助那些和我一样的小白快速上手Tensorboard高维向量可视化

这里我先展示下最近我做的高维向量可视化的成果,蓝色是非事故图片,红色是事故图片。从结果上看出来,模型对事故和非事故的区分能力还算可以。

事故和非事故二分类结果

那么如何快速上手一个空间向量可视化过程,查看模型分类结果的空间分布呢?

这里我们使用到了tensorflow的tensorboard工具包,以事故分类为例,一步步克服tensorboard中的一些BUG。

2.特征向量提取与记录

一般而言,类似图像分类这种任务,在CNN最后几层中,一般会添加1x1卷积或者是全连接层,将backbone(特征提取网络)输出的特征图进行降维,便于接下来的分类。

在搭建模型的时候,也要刻意引出最后几层全连接层(或1x1卷积层)的输出,因为我们要可视化的高维向量就是这些层的输出(具体层数的输出维度,根据自己任务而定,这里我们在某一神经网络的最后,加入了个维度为8的全连接层(用作可视化的高维向量),然后再接上一个维度为2的全连接层(为了分类))。

这里我们基于pytorch搭建了个模型,模型forward函数代码如下:

    def forward(self,x):batch_size,  channels, height, width = x.size()feature_map=self.cnn(x)feature_map1=self.cnn_layer(feature_map)feature_map=feature_map1.view(batch_size,-1)output1=self.dense1(feature_map) #512 -> 64output2=self.dense2(output1) #64 -> 8output3=self.dense3(output2)#8 -> 2return output1,output2,output3

可见,除了最后的输入output3,我们还额外输出了维度为64的output1和维度为8的output2。

等网络训练完毕后,我们就用训练完的权重跑一遍测试集,将

  • 每张测试集图片的地址,
  • 每张图片经过网络输出的output1,
  • 每张图片经过网络输出的output2,
  • 每张图片经过网络输出的output3

保存到csv文件中,代码如下

    record=[]for original_image in tqdm(image_path):with torch.no_grad():# print('现在处理:', original_image)prep_img = image_process(original_image)prep_img=prep_img.cuda()# print(prep_img.shape)output1, output2, output3 = model(prep_img)# print(output1.shape)record.append([original_image,output1[0],output2[0],output3[0]])record_pd=pd.DataFrame(data=record,index=None,columns=['img_path','output64','output8',"output2"])record_pd.to_csv('vector_record.csv')

保存的CSV的格式如下:

CSV格式

可见,output可都是tensor类型的呀!

3. 标签数据(meta.tsv)和图片数据(sprite.jpg)生成

从上面可视化的结果来看,每个空间展示图片的图片内容标签(蓝或红)信息都被包含了,所以这里我们需要生成需要的标签数据和图片数据。我们从上面的csv文件中导入图片的地址数据,因为知道了地址数据,我们就知道了图片的内容以及标签(因为非事故和事故数据集的地址不同,所以可以用地址的关键字判断类别)

我们先定义一些变量,用作保存时的名称:

# PROJECTOR需要的日志文件名和地址相关参数
LOG_DIR = 'log'
SPRITE_FILE = 'sprite.jpg'
META_FIEL = "meta.tsv"
os.mkdir(LOG_DIR)

接着加载数据:

# 数据加载
sample_data = pd.read_csv(r'E:Deep_learning_PytorchVideo_recognitionAccident_RecognitionCNN_LSTMAccident_Clustingsample_data.csv'

3.1 载入图片并生成一张大图(sprite image)

这里我们先导入地址,然后将地址中的所有图片导入进来并保存到一张list中(名为img_list)

img_path= sample_data['img_path'].values
img_list=[]
for path in img_path:img=cv2.imread(path,0)# print(img.shape)img=cv2.resize(img,(128,128))# print(img.shape)img_list.append(img.reshape(1,128,128))

接着我们定义一个大图像生成函数create_sprite_image,如下:

def create_sprite_image(images):"""Returns a sprite image consisting of images passed as argument. Images should be count x width x height"""if isinstance(images, list):images = np.array(images)img_h = images.shape[1] #112img_w = images.shape[2] #112# sprite图像可以理解成是小图片平成的大正方形矩阵,大正方形矩阵中的每一个元素就是原来的小图片。于是这个正方形的边长就是sqrt(n),其中n为小图片的数量。n_plots = int(np.ceil(np.sqrt(images.shape[0]))) #根号下2000# 使用全1来初始化最终的大图片。spriteimage = np.ones((img_h*n_plots, img_w*n_plots))for i in range(n_plots):for j in range(n_plots):# 计算当前图片的编号this_filter = i*n_plots + jif this_filter < images.shape[0]:# 将当前小图片的内容复制到最终的sprite图像this_img = images[this_filter]spriteimage[i*img_h:(i + 1)*img_h,j*img_w:(j + 1)*img_w] = this_imgreturn spriteimage

然后使用img_list生成这个大图像并保存:

img=np.concatenate(img_list,0) #(2000, 112, 112)
# 生成sprite图像# to_visualise = 1 - np.reshape(img, (-1, 28, 28))
sprite_image = create_sprite_image(img)# 将生成的sprite图片放到相应的日志目录下
path_for_sprites = os.path.join(LOG_DIR, SPRITE_FILE)
# plt.imsave(path_for_mnist_sprites,sprite_image, cmap='gray')
cv2.imwrite(path_for_sprites,sprite_image,[int(cv2.IMWRITE_JPEG_QUALITY), 100])

这里,sprite_image就生成好了,保存在log/sprite.jpg中。

3.2 生成标签数据meta.tsv

我们的事故和非事故图片集放在文件夹:

非事故:Z:Fast_datasetAccident_model_pretrainClassficationFalse_for_classificaition
事故:
Z:Fast_datasetAccident_model_pretrainClassficationTrue_for_classification

这里我们可以看出,可以用图片地址区分事故标签(False为非事故,True为事故)

那么我们代码这么写:

# # 生成每张图片对应的标签文件并写道相应的日志目录下
path_for_metadata = os.path.join(LOG_DIR,META_FIEL)
with open(path_for_metadata, 'w') as f:f.write("IndextLabeln")for index, path in enumerate(img_path):label=0 if 'False' in path else 1f.write("%dt%dn"%(index, label))

即完成了对meta.csv的记录,meta.csv里面是长这样的。

到这里,标签数据(meta.tsv)和图片数据(sprite.jpg)就生成了,接下来到最后的数据的匹配和关联+总日志文件生成。

4. 数据的匹配和关联+总日志文件生成

上述,我们在CSV文件中保存了每个图片在训练好的模型的输出向量,这个输出向量可以看作图片在浅层空间的表示。那我们就从该CSV文件导出需要的向量,这里我们以8维的向量为例子,导入的程序代码如下:

    sample_data = pd.read_csv(r'E:Deep_learning_PytorchVideo_recognitionAccident_RecognitionCNN_LSTMAccident_Clustingsample_data.csv')output= sample_data['output8'].valuesoutput=[eval(i).cpu().numpy()  for i in output]# print(type(output2[0]))final_result=np.asarray(output)

这段代码需要注意的是,我们保存时候,向量以torch.tensor形式保存,从CSV导入的时候,输出的类型又是字符串str类型,所以我们需要做两件事

  1. 字符串转torch.tensor,这里使用eval()函数,且需要from torch import tensor
  2. torch.tensor转数组,因为后面要用tensorflow的tensorboard,所以需要转成numpy

然后我们定义我们本文最重要的函数 visualisation(),代码如下:

LOG_DIR = 'log'
SPRITE_FILE = 'sprite.jpg'
META_FIEL = 'meta.tsv'
TENSOR_NAME = "zw"
TRAINING_STEPS = 10def visualisation(final_result):# 使用一个新的变量来保存最终输出层向量的结果,因为embedding是通过Tensorflow中变量完成的,所以PROJECTOR可视化的都是TensorFlow中的变哇。# 所以这里需要新定义一个变量来保存输出层向量的取值y = tf.Variable(final_result, name=TENSOR_NAME)summary_writer = tf.summary.FileWriter(LOG_DIR)# 通过project.ProjectorConfig类来帮助生成日志文件config = projector.ProjectorConfig()# 增加一个需要可视化的bedding结果embedding = config.embeddings.add()# 指定这个embedding结果所对应的Tensorflow变量名称embedding.tensor_name = y.name# Specify where you find the metadata# 指定embedding结果所对应的原始数据信息。比如这里指定的就是每一张MNIST测试图片对应的真实类别。在单词向量中可以是单词ID对应的单词。# 这个文件是可选的,如果没有指定那么向量就没有标签。# embedding.metadata_path = META_FIELembedding.metadata_path = 'meta.tsv'# Specify where you find the sprite (we will create this later)# 指定sprite 图像。这个也是可选的,如果没有提供sprite 图像,那么可视化的结果# 每一个点就是一个小困点,而不是具体的图片。# embedding.sprite.image_path = SPRITE_FILEembedding.sprite.image_path = 'sprite.jpg'# 在提供sprite图像时,通过single_image_dim可以指定单张图片的大小。# 这将用于从sprite图像中截取正确的原始图片。embedding.sprite.single_image_dim.extend([128, 128])# Say that you want to visualise the embeddings# 将PROJECTOR所需要的内容写入日志文件。projector.visualize_embeddings(summary_writer, config)# 生成会话,初始化新声明的变量并将需要的日志信息写入文件。sess = tf.InteractiveSession()sess.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.save(sess, os.path.join(LOG_DIR, "model"), TRAINING_STEPS)sess.close()summary_writer.close()

需要注意的是图片的大小和上述第三部分定义的图片大小应当一致!

然后接着上面,将final_result输入到函数visualisation中,即可完成log日志的生成!

visualisation(final_result)

log文件下的内容如下:

其中projector_config.pbtxt文件中的内容为

包含了各成员对应的关系

到最后,我们打开命令行CMD:

输入

tensorboard --logdir=E:Deep_learning_PytorchVideo_recognitionAccident_RecognitionCNN_LSTMlog --host=127.0.0.1

这里,一定注意--host!一定注意--host!一定注意--host!(重要的说三遍)

因为如果不输入--host,选择默认,那去网页上就是打不开的!就算打开了,也没有标签文件和图片数据的,这是花了我一个晚上血淋淋的教训呀!

5.总结

关于如何快速上手Tensorboard的高维向量可视化,本篇文章算是比较详细了!算是我一晚上的工作经验吧!终于可以安心做其他事了!

tensorflow生成图片标签_Tensorboard高维向量可视化 + 解决标签和图片不显示BUG相关推荐

  1. python高维向量的可视化_Tensorboard教程:高维向量可视化

    Tensorflow高维向量可视化 觉得有用的话,欢迎一起讨论相互学习~ 参考文献 强烈推荐Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 python3 ...

  2. Tensorboard高维向量可视化

    Tensorboard高维向量可视化 觉得有用的话,欢迎一起讨论相互学习~ 参考文献 强烈推荐Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 python ...

  3. 解决django关于图片无法显示的问题

    解决django关于图片无法显示的问题 参考文章: (1)解决django关于图片无法显示的问题 (2)https://www.cnblogs.com/zhuifeng-mayi/articles/8 ...

  4. 信创办公--基于WPS的Word最佳实践系列(解决WPS插入图片后显示不全问题)

    信创办公–基于WPS的Word最佳实践系列(解决WPS插入图片后显示不全问题) 项目背景 本篇文档是解决WPS插入"嵌入式"图片显示不全的问题.一般造成这个问题的情况是因为图片插入 ...

  5. 【Web】解决简书图片不显示问题“系统维护中,图片暂时无法加载”

    个人博客: http://www.milovetingting.cn 简书不显示图片的解决方法 首次编辑于2019-6-6 最近几天在浏览简书上的文章时,发现图片显示不出来,提示"系统维护中 ...

  6. yy 服务器维护中 图片无法显示,解决简书图片不显示问题“系统维护中,图片暂时无法加载”...

    天突然发现之前的文章图片全部都这样了,我还以为图片丢了! 9949918-5bfb96c2b65b9c7e.png 但是其实图片还是可以访问的 配合 Chrome,Safari 扩展程序 Tamper ...

  7. vue移动端中使用echart折线面积图(设置渐变色)解决ios6/11渐变色不显示bug

    前言: 1.折线本身渐变色 2.折线阴影面积渐变色 效果如图所示: 1.全局引入echart main.js // 如果全局引入就在此加上这两行代码 // 如果就一个页面直接页面引入完事儿 impor ...

  8. 目标检测--将xml文件中标签(矩形框)在其原图片上显示并另存

    """ 目的:将原图片(img)与其xml(xml),合成为打标记的图片(labelled),矩形框标记用红色即可 已有:(1)原图片文件夹(imgs_path),(2) ...

  9. visdom TensorboardX进行可视化-包括对高维特征可视化(T-SNE PCA等)

    文章目录 一.Visdom 安装与使用 小案例 二.TensorBoardX 案例一 案例二 使用PROJECTOR对高维向量可视化 绘制网络结构 一.Visdom Visdom是Facebook在2 ...

最新文章

  1. 泰安服务器维护公司,神云 泰安服务器
  2. 模式识别之数字识别---扑克牌识别
  3. bzoj千题计划143:bzoj1935: [Shoi2007]Tree 园丁的烦恼
  4. opencv7-ml之svm
  5. centos MySQL 双机_CentOS利用Keepalived构建双主MySQL+双机热备
  6. 推荐一个比较好用的Chrome扩展应用,提供了桌面便签功能
  7. Factorial Trailing Zeroes 172
  8. 中国工程院院士,受聘一流大学院长
  9. tesseract 使用说明
  10. 超全!最新互联网大厂的薪资和职级一览
  11. 计算机病毒扩散最快的是什么,根据统计,当前计算机病毒扩散最快的途径是( )...
  12. 前端开发:Mac电脑修改hosts文件的方法
  13. 前端基础面试题附答案
  14. PT6303加充电电路的一套原理图
  15. 解除文件占用,解决文件被占用不能删除
  16. 钉钉ppt放映显示备注_PPT的备注怎么用,放映PPT时如何显示备注 来看看吧
  17. (转)Windows 7 系统下载安装一贴导航
  18. modbustcp测试工具怎么用_年轻人不讲武德不仅白piao接口测试知识还白piao接口测试工具会员...
  19. 视觉导航小车开源项目(1)—小车底盘
  20. 【Java】Java中文分词器Ansj的使用

热门文章

  1. jMeter Thread group 对应的 constant timer
  2. Open SAP 上 SAP Fiori Elements 公开课第一单元学习笔记
  3. 巧用 TypeScript Literal Types 模拟枚举类型
  4. Angular应用bootstrap时的version检测机制
  5. HTML table标签和其子标签如td,td等不同区域focus然后回车的行为差异
  6. SAP Spartacus delivery mode continue button enable与否的逻辑
  7. sublime text里添加对Gradle配置文件的支持
  8. Create new SAP DDL view and click finish in wizard
  9. Method 'GET_ENTITYSET' not implemented in data provider class - correct case
  10. SAP UI5 patternYYY : detailZZZ/{contextPath} - navigation test