tensorflow生成图片标签_Tensorboard高维向量可视化 + 解决标签和图片不显示BUG
1.前言
本文就使用tensorboard进行高维向量可视化过程中出现的一些bug和问题,进行一次总结,帮助那些和我一样的小白快速上手Tensorboard高维向量可视化。
这里我先展示下最近我做的高维向量可视化的成果,蓝色是非事故图片,红色是事故图片。从结果上看出来,模型对事故和非事故的区分能力还算可以。
那么如何快速上手一个空间向量可视化过程,查看模型分类结果的空间分布呢?
这里我们使用到了tensorflow的tensorboard工具包,以事故分类为例,一步步克服tensorboard中的一些BUG。
2.特征向量提取与记录
一般而言,类似图像分类这种任务,在CNN最后几层中,一般会添加1x1卷积或者是全连接层,将backbone(特征提取网络)输出的特征图进行降维,便于接下来的分类。
在搭建模型的时候,也要刻意引出最后几层全连接层(或1x1卷积层)的输出,因为我们要可视化的高维向量就是这些层的输出(具体层数的输出维度,根据自己任务而定,这里我们在某一神经网络的最后,加入了个维度为8的全连接层(用作可视化的高维向量),然后再接上一个维度为2的全连接层(为了分类))。
这里我们基于pytorch搭建了个模型,模型forward函数代码如下:
def forward(self,x):batch_size, channels, height, width = x.size()feature_map=self.cnn(x)feature_map1=self.cnn_layer(feature_map)feature_map=feature_map1.view(batch_size,-1)output1=self.dense1(feature_map) #512 -> 64output2=self.dense2(output1) #64 -> 8output3=self.dense3(output2)#8 -> 2return output1,output2,output3
可见,除了最后的输入output3,我们还额外输出了维度为64的output1和维度为8的output2。
等网络训练完毕后,我们就用训练完的权重跑一遍测试集,将
- 每张测试集图片的地址,
- 每张图片经过网络输出的output1,
- 每张图片经过网络输出的output2,
- 每张图片经过网络输出的output3
保存到csv文件中,代码如下
record=[]for original_image in tqdm(image_path):with torch.no_grad():# print('现在处理:', original_image)prep_img = image_process(original_image)prep_img=prep_img.cuda()# print(prep_img.shape)output1, output2, output3 = model(prep_img)# print(output1.shape)record.append([original_image,output1[0],output2[0],output3[0]])record_pd=pd.DataFrame(data=record,index=None,columns=['img_path','output64','output8',"output2"])record_pd.to_csv('vector_record.csv')
保存的CSV的格式如下:
可见,output可都是tensor类型的呀!
3. 标签数据(meta.tsv)和图片数据(sprite.jpg)生成
从上面可视化的结果来看,每个空间展示图片的图片内容和标签(蓝或红)信息都被包含了,所以这里我们需要生成需要的标签数据和图片数据。我们从上面的csv文件中导入图片的地址数据,因为知道了地址数据,我们就知道了图片的内容以及标签(因为非事故和事故数据集的地址不同,所以可以用地址的关键字判断类别)
我们先定义一些变量,用作保存时的名称:
# PROJECTOR需要的日志文件名和地址相关参数
LOG_DIR = 'log'
SPRITE_FILE = 'sprite.jpg'
META_FIEL = "meta.tsv"
os.mkdir(LOG_DIR)
接着加载数据:
# 数据加载
sample_data = pd.read_csv(r'E:Deep_learning_PytorchVideo_recognitionAccident_RecognitionCNN_LSTMAccident_Clustingsample_data.csv'
3.1 载入图片并生成一张大图(sprite image)
这里我们先导入地址,然后将地址中的所有图片导入进来并保存到一张list中(名为img_list)
img_path= sample_data['img_path'].values
img_list=[]
for path in img_path:img=cv2.imread(path,0)# print(img.shape)img=cv2.resize(img,(128,128))# print(img.shape)img_list.append(img.reshape(1,128,128))
接着我们定义一个大图像生成函数create_sprite_image,如下:
def create_sprite_image(images):"""Returns a sprite image consisting of images passed as argument. Images should be count x width x height"""if isinstance(images, list):images = np.array(images)img_h = images.shape[1] #112img_w = images.shape[2] #112# sprite图像可以理解成是小图片平成的大正方形矩阵,大正方形矩阵中的每一个元素就是原来的小图片。于是这个正方形的边长就是sqrt(n),其中n为小图片的数量。n_plots = int(np.ceil(np.sqrt(images.shape[0]))) #根号下2000# 使用全1来初始化最终的大图片。spriteimage = np.ones((img_h*n_plots, img_w*n_plots))for i in range(n_plots):for j in range(n_plots):# 计算当前图片的编号this_filter = i*n_plots + jif this_filter < images.shape[0]:# 将当前小图片的内容复制到最终的sprite图像this_img = images[this_filter]spriteimage[i*img_h:(i + 1)*img_h,j*img_w:(j + 1)*img_w] = this_imgreturn spriteimage
然后使用img_list生成这个大图像并保存:
img=np.concatenate(img_list,0) #(2000, 112, 112)
# 生成sprite图像# to_visualise = 1 - np.reshape(img, (-1, 28, 28))
sprite_image = create_sprite_image(img)# 将生成的sprite图片放到相应的日志目录下
path_for_sprites = os.path.join(LOG_DIR, SPRITE_FILE)
# plt.imsave(path_for_mnist_sprites,sprite_image, cmap='gray')
cv2.imwrite(path_for_sprites,sprite_image,[int(cv2.IMWRITE_JPEG_QUALITY), 100])
这里,sprite_image就生成好了,保存在log/sprite.jpg中。
3.2 生成标签数据meta.tsv
我们的事故和非事故图片集放在文件夹:
非事故:Z:Fast_datasetAccident_model_pretrainClassficationFalse_for_classificaition
事故:
Z:Fast_datasetAccident_model_pretrainClassficationTrue_for_classification
这里我们可以看出,可以用图片地址区分事故标签(False为非事故,True为事故)
那么我们代码这么写:
# # 生成每张图片对应的标签文件并写道相应的日志目录下
path_for_metadata = os.path.join(LOG_DIR,META_FIEL)
with open(path_for_metadata, 'w') as f:f.write("IndextLabeln")for index, path in enumerate(img_path):label=0 if 'False' in path else 1f.write("%dt%dn"%(index, label))
即完成了对meta.csv的记录,meta.csv里面是长这样的。
到这里,标签数据(meta.tsv)和图片数据(sprite.jpg)就生成了,接下来到最后的数据的匹配和关联+总日志文件生成。
4. 数据的匹配和关联+总日志文件生成
上述,我们在CSV文件中保存了每个图片在训练好的模型的输出向量,这个输出向量可以看作图片在浅层空间的表示。那我们就从该CSV文件导出需要的向量,这里我们以8维的向量为例子,导入的程序代码如下:
sample_data = pd.read_csv(r'E:Deep_learning_PytorchVideo_recognitionAccident_RecognitionCNN_LSTMAccident_Clustingsample_data.csv')output= sample_data['output8'].valuesoutput=[eval(i).cpu().numpy() for i in output]# print(type(output2[0]))final_result=np.asarray(output)
这段代码需要注意的是,我们保存时候,向量以torch.tensor形式保存,从CSV导入的时候,输出的类型又是字符串str类型,所以我们需要做两件事
- 字符串转torch.tensor,这里使用eval()函数,且需要from torch import tensor
- torch.tensor转数组,因为后面要用tensorflow的tensorboard,所以需要转成numpy。
然后我们定义我们本文最重要的函数 visualisation(),代码如下:
LOG_DIR = 'log'
SPRITE_FILE = 'sprite.jpg'
META_FIEL = 'meta.tsv'
TENSOR_NAME = "zw"
TRAINING_STEPS = 10def visualisation(final_result):# 使用一个新的变量来保存最终输出层向量的结果,因为embedding是通过Tensorflow中变量完成的,所以PROJECTOR可视化的都是TensorFlow中的变哇。# 所以这里需要新定义一个变量来保存输出层向量的取值y = tf.Variable(final_result, name=TENSOR_NAME)summary_writer = tf.summary.FileWriter(LOG_DIR)# 通过project.ProjectorConfig类来帮助生成日志文件config = projector.ProjectorConfig()# 增加一个需要可视化的bedding结果embedding = config.embeddings.add()# 指定这个embedding结果所对应的Tensorflow变量名称embedding.tensor_name = y.name# Specify where you find the metadata# 指定embedding结果所对应的原始数据信息。比如这里指定的就是每一张MNIST测试图片对应的真实类别。在单词向量中可以是单词ID对应的单词。# 这个文件是可选的,如果没有指定那么向量就没有标签。# embedding.metadata_path = META_FIELembedding.metadata_path = 'meta.tsv'# Specify where you find the sprite (we will create this later)# 指定sprite 图像。这个也是可选的,如果没有提供sprite 图像,那么可视化的结果# 每一个点就是一个小困点,而不是具体的图片。# embedding.sprite.image_path = SPRITE_FILEembedding.sprite.image_path = 'sprite.jpg'# 在提供sprite图像时,通过single_image_dim可以指定单张图片的大小。# 这将用于从sprite图像中截取正确的原始图片。embedding.sprite.single_image_dim.extend([128, 128])# Say that you want to visualise the embeddings# 将PROJECTOR所需要的内容写入日志文件。projector.visualize_embeddings(summary_writer, config)# 生成会话,初始化新声明的变量并将需要的日志信息写入文件。sess = tf.InteractiveSession()sess.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.save(sess, os.path.join(LOG_DIR, "model"), TRAINING_STEPS)sess.close()summary_writer.close()
需要注意的是图片的大小和上述第三部分定义的图片大小应当一致!
然后接着上面,将final_result输入到函数visualisation中,即可完成log日志的生成!
visualisation(final_result)
log文件下的内容如下:
其中projector_config.pbtxt文件中的内容为
包含了各成员对应的关系
到最后,我们打开命令行CMD:
输入
tensorboard --logdir=E:Deep_learning_PytorchVideo_recognitionAccident_RecognitionCNN_LSTMlog --host=127.0.0.1
这里,一定注意--host!一定注意--host!一定注意--host!(重要的说三遍)
因为如果不输入--host,选择默认,那去网页上就是打不开的!就算打开了,也没有标签文件和图片数据的,这是花了我一个晚上血淋淋的教训呀!
5.总结
关于如何快速上手Tensorboard的高维向量可视化,本篇文章算是比较详细了!算是我一晚上的工作经验吧!终于可以安心做其他事了!
tensorflow生成图片标签_Tensorboard高维向量可视化 + 解决标签和图片不显示BUG相关推荐
- python高维向量的可视化_Tensorboard教程:高维向量可视化
Tensorflow高维向量可视化 觉得有用的话,欢迎一起讨论相互学习~ 参考文献 强烈推荐Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 python3 ...
- Tensorboard高维向量可视化
Tensorboard高维向量可视化 觉得有用的话,欢迎一起讨论相互学习~ 参考文献 强烈推荐Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 python ...
- 解决django关于图片无法显示的问题
解决django关于图片无法显示的问题 参考文章: (1)解决django关于图片无法显示的问题 (2)https://www.cnblogs.com/zhuifeng-mayi/articles/8 ...
- 信创办公--基于WPS的Word最佳实践系列(解决WPS插入图片后显示不全问题)
信创办公–基于WPS的Word最佳实践系列(解决WPS插入图片后显示不全问题) 项目背景 本篇文档是解决WPS插入"嵌入式"图片显示不全的问题.一般造成这个问题的情况是因为图片插入 ...
- 【Web】解决简书图片不显示问题“系统维护中,图片暂时无法加载”
个人博客: http://www.milovetingting.cn 简书不显示图片的解决方法 首次编辑于2019-6-6 最近几天在浏览简书上的文章时,发现图片显示不出来,提示"系统维护中 ...
- yy 服务器维护中 图片无法显示,解决简书图片不显示问题“系统维护中,图片暂时无法加载”...
天突然发现之前的文章图片全部都这样了,我还以为图片丢了! 9949918-5bfb96c2b65b9c7e.png 但是其实图片还是可以访问的 配合 Chrome,Safari 扩展程序 Tamper ...
- vue移动端中使用echart折线面积图(设置渐变色)解决ios6/11渐变色不显示bug
前言: 1.折线本身渐变色 2.折线阴影面积渐变色 效果如图所示: 1.全局引入echart main.js // 如果全局引入就在此加上这两行代码 // 如果就一个页面直接页面引入完事儿 impor ...
- 目标检测--将xml文件中标签(矩形框)在其原图片上显示并另存
""" 目的:将原图片(img)与其xml(xml),合成为打标记的图片(labelled),矩形框标记用红色即可 已有:(1)原图片文件夹(imgs_path),(2) ...
- visdom TensorboardX进行可视化-包括对高维特征可视化(T-SNE PCA等)
文章目录 一.Visdom 安装与使用 小案例 二.TensorBoardX 案例一 案例二 使用PROJECTOR对高维向量可视化 绘制网络结构 一.Visdom Visdom是Facebook在2 ...
最新文章
- 泰安服务器维护公司,神云 泰安服务器
- 模式识别之数字识别---扑克牌识别
- bzoj千题计划143:bzoj1935: [Shoi2007]Tree 园丁的烦恼
- opencv7-ml之svm
- centos MySQL 双机_CentOS利用Keepalived构建双主MySQL+双机热备
- 推荐一个比较好用的Chrome扩展应用,提供了桌面便签功能
- Factorial Trailing Zeroes 172
- 中国工程院院士,受聘一流大学院长
- tesseract 使用说明
- 超全!最新互联网大厂的薪资和职级一览
- 计算机病毒扩散最快的是什么,根据统计,当前计算机病毒扩散最快的途径是( )...
- 前端开发:Mac电脑修改hosts文件的方法
- 前端基础面试题附答案
- PT6303加充电电路的一套原理图
- 解除文件占用,解决文件被占用不能删除
- 钉钉ppt放映显示备注_PPT的备注怎么用,放映PPT时如何显示备注 来看看吧
- (转)Windows 7 系统下载安装一贴导航
- modbustcp测试工具怎么用_年轻人不讲武德不仅白piao接口测试知识还白piao接口测试工具会员...
- 视觉导航小车开源项目(1)—小车底盘
- 【Java】Java中文分词器Ansj的使用
热门文章
- jMeter Thread group 对应的 constant timer
- Open SAP 上 SAP Fiori Elements 公开课第一单元学习笔记
- 巧用 TypeScript Literal Types 模拟枚举类型
- Angular应用bootstrap时的version检测机制
- HTML table标签和其子标签如td,td等不同区域focus然后回车的行为差异
- SAP Spartacus delivery mode continue button enable与否的逻辑
- sublime text里添加对Gradle配置文件的支持
- Create new SAP DDL view and click finish in wizard
- Method 'GET_ENTITYSET' not implemented in data provider class - correct case
- SAP UI5 patternYYY : detailZZZ/{contextPath} - navigation test