写在开头:现在网上的资料对测试数据集的混淆矩阵可视化大都基于MNIST、CAIFR等标准数据集来绘制,如果想把自己制作的数据集的预测结果可视化,如何做?本文为你提供一种解决方案。
上篇博文我们写了自制数据集的制作、训练、可视化流程。本篇博文在此基础上将测试集的预测结果可视化,让我们直观的观察哪些数据预测正确,哪些数据预测错误。数据集的制作流程及代码在上篇博文中已有详细介绍,本次不作为解释重点,重点解释绘制混淆矩阵部分的代码及实现。

  1. 首先导入一些必须的包
from __future__ import print_function #这行代码必须放在第一行,不然会报错,原因我也没有细察,请各位读者注意一下
import os,glob,random, csv,itertools
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import optimizers,losses,layers
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
  1. 设置全局种子、tensorflow版本检查、GPU内存分配等
#设置全局种子
tf.random.set_seed(22)
np.random.seed(22)#检查tensorflow版本是否为2.0版本
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')# 设置GPU显存按需分配(如果没有装GPU版本,这部分就不用写,但是如果不用GPU加速计算,使用CPU将会严重限制模型的训练速度)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:try:for gpu in gpus:tf.config.experimental.set_memory_growth(gpu, True)logical_gpus = tf.config.experimental.list_logical_devices('GPU')print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")except RuntimeError as e:print(e)
  1. 写入训练数据集的csv文件
def load_csv(root, filename, name2label):if not os.path.exists(os.path.join(root, filename)):images = []for name in name2label.keys()images += glob.glob(os.path.join(root, name, '*.png'))images += glob.glob(os.path.join(root, name, '*.jpg'))images += glob.glob(os.path.join(root, name, '*.jpeg'))print(len(images), images)random.shuffle(images)with open(os.path.join(root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images:  # name = img.split(os.sep)[-2]label = name2label[name]writer.writerow([img, label])print('written into csv file:', filename)  images, labels = [], []with open(os.path.join(root, filename)) as f:reader = csv.reader(f)for row in reader:img, label = rowlabel = int(label)images.append(img)labels.append(label)assert len(images) == len(labels)return images, labels
  1. 创建数据集的编码表
def load_huangtu(root, mode='train'):name2label = {}  for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continuename2label[name] = len(name2label.keys())images, labels = load_csv(root, 'images.csv', name2label)if mode == 'train':  # 60%images = images[:int(0.6 * len(images))]labels = labels[:int(0.6 * len(labels))]elif mode == 'val':  # 20% = 60%->80%images = images[int(0.6 * len(images)):int(0.8 * len(images))]labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]else:  # 20% = 80%->100%images = images[int(0.8 * len(images)):]labels = labels[int(0.8 * len(labels)):]return images, labels, name2label
  1. 数据的处理
#下面两个矩阵常量是谷歌经过大量的图片计算得到的值,后续可以直接使用
img_mean = tf.constant([0.485, 0.456, 0.406])
img_std = tf.constant([0.229, 0.224, 0.225])
def normalize(x, mean=img_mean, std=img_std):x = (x - mean)/stdreturn x
def preprocess(x,y):# x: 图片的路径,y:图片的数字编码x = tf.io.read_file(x)x = tf.image.decode_jpeg(x, channels=3) # RGBAx = tf.image.resize(x, [244, 244])x = tf.image.random_flip_up_down(x)#数据增强(上下翻转)x= tf.image.random_flip_left_right(x) #数据增强(左右翻转)x = tf.image.random_crop(x, [224, 224, 3])  #数据增强(随机裁剪)x = tf.cast(x, dtype=tf.float32) / 255.x = normalize(x)   #255-0y = tf.convert_to_tensor(y)y = tf.one_hot(y, depth=7)return x, y
  1. 定义混淆矩阵
#参数 y_true为测试数据集的真实标签,y_pred为网络对测试数据集的预测结果
def plot_confusion_matrix(y_true, y_pred, title = "Confusion matrix",cmap = plt.cm.Blues, save_flg = False):classes = [str(i) for i in range(7)]#参数i的取值范围根据你自己数据集的划分类别来修改,我这儿为7代表数据集共有7类labels = range(7)#数据集的标签类别,跟上面I对应cm = confusion_matrix(y_true, y_pred, labels=labels)plt.figure(figsize=(14, 12))plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title, fontsize=40)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, fontsize=20)plt.yticks(tick_marks, classes, fontsize=20)print('Confusion matrix, without normalization')thresh = cm.max() / 2.for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, cm[i, j],horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.ylabel('True label', fontsize=30)plt.xlabel('Predicted label', fontsize=30)if save_flg:plt.savefig("./confusion_matrix.png")plt.show()
  1. 写入主函数,进行数据集的调用、模型的创建、编译、训练等
def main():images, labels, table = load_huangtu('huangtu', 'train')print('images', len(images), images)print('labels', len(labels), labels)print(table)db = tf.data.Dataset.from_tensor_slices((images, labels))db = db.shuffle(1000).map(preprocess).batch(32)#被注释掉的代码可用于tensorboard可视化'''writter = tf.summary.create_file_writer('此处为文件的创建/保存地质')for step, (x,y) in enumerate(db):with writter.as_default():x = denormalize(x)tf.summary.image('img',x,step=step,max_outputs=9)time.sleep(5)'''batchsz = 128# 创建训练集Datset对象images, labels, table = load_huangtu('huangtu',mode='train')db_train = tf.data.Dataset.from_tensor_slices((images, labels))db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)# 创建验证集Datset对象images2, labels2, table = load_huangtu('huangtu',mode='val')db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))db_val = db_val.map(preprocess).batch(batchsz)# 创建测试集Datset对象images3, labels3, table = load_huangtu('huangtu',mode='test')db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))db_test = db_test.map(preprocess).batch(batchsz)#迁移学习net = keras.applications.DenseNet121(weights='imagenet', include_top=False,pooling='max')net.trainable = Falsenewnet = keras.Sequential([net,layers.Dense(1024,activation='relu'),layers.BatchNormalization(),layers.Dropout(rate=0.5),layers.Dense(7)])newnet.build(input_shape=(4,224,224,3))newnet.summary()newnet.compile(optimizer=optimizers.Adam(lr=0.01),loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])'''##  调用tensorboardlog_dir='调用上面的路径地址'tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)  '''newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=50)score=newnet.evaluate(db_test)print('Test loss:', score[0])print('Test accuracy:', score[1])#下面三行代码为绘制混淆矩阵的传参predict_classes = newnet.predict(db_test)#对测试数据集进行预测true_classes = np.argmax(predict_classes,1)#汲取预测结果plot_confusion_matrix(true_classes, labels3, save_flg = True)#调用混淆矩阵#调用主函数if __name__ == '__main__':main()   

将上面所有的代码按顺序组合起来,就可以做为一个完整的数据集制作、模型训练、混淆矩阵可视化的代码
下图为自制数据集的混淆矩阵可视化结果

对角线上表示预测正确结果。
写在最后:相比对标准数据集的预测,重点注意和思考我们自制数据集的传参。
本文到此结束

tensorflow2.0自制神经网络数据集及预测结果(混淆矩阵)可视化相关推荐

  1. tensorflow2.0 循环神经网络--情感分类实战

    tensorflow2.0 循环神经网络–情感分类实战代码 本文主要是情感分类单层实战RNN Cell代码 import os import numpy as np import tensorflow ...

  2. 【深度学习】利用tensorflow2.0卷积神经网络进行卫星图片分类实例操作详解

    本文的应用场景是对于卫星图片数据的分类,图片总共1400张,分为airplane和lake两类,也就是一个二分类的问题,所有的图片已经分别放置在2_class文件夹下的两个子文件夹中.下面将从这个实例 ...

  3. Python项目实战-Tensorflow2.0实现泰坦尼克生存预测

    目录 一.数据集下载地址 二.探索性因子分析(EDA) 三.特征工程 四.构建Dataset与Model fit和自定义estimator使用 预定义estimator的使用 一.数据集下载地址 # ...

  4. Tensorflow2.0泰坦尼克数据集的python分析以及离散化数据处理(含数据集下载地址)

    泰坦尼克数据集下载 训练集 测试集 导入需要的库 import matplotlib.pyplot as plt %matplotlib inline import numpy as np impor ...

  5. Tensorflow2.0之用循环神经网络生成周杰伦歌词

    文章目录 1.导入需要的库 2.加载数据集 3.相邻采样 4.定义模型 4.1 定义循环神经网络层 4.2 定义循环神经网络 5.定义预测函数 6.裁剪梯度 7.定义模型训练函数 7.1 困惑度 7. ...

  6. TensorFlow2.0学习笔记2-tf2.0两种方式搭建神经网络

    目录 一,TensorFlow2.0搭建神经网络八股 1)import  [引入相关模块] 2)train,test  [告知喂入网络的训练集测试集以及相应的标签] 3)model=tf.keras. ...

  7. tensorflow2.0实现IMDB文本数据集学习词嵌入

    1. IMDB数据集示例如下所示 [{"rating": 5, "title": "The dark is rising!", " ...

  8. 第七章:Tensorflow2.0 RNN循环神经网络实现IMDB数据集训练(理论+实践)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/LQ_qing/article/deta ...

  9. tensorflow2.0莺尾花iris数据集分类|超详细

    tensorflow2.0莺尾花iris数据集分类 超详细 直接上代码 #导入模块 import tensorflow as tf #导入tensorflow模块from sklearn import ...

  10. TensorFlow2.0 学习笔记(三):卷积神经网络(CNN)

    欢迎关注WX公众号:[程序员管小亮] 专栏--TensorFlow学习笔记 文章目录 欢迎关注WX公众号:[程序员管小亮] 专栏--TensorFlow学习笔记 一.神经网络的基本单位:神经元 二.卷 ...

最新文章

  1. MPB:南农成艳芬组-​瘤胃厌氧真菌代谢产物的检测方法
  2. 打开共享文件闪退怎么解决_文件共享解决方案-随时随地共享同步访问文件
  3. springboot 优雅停机_新姿势,Spring Boot 2.3.0 如何优雅停机?
  4. 软件工程--需求分析
  5. linux驱动私有数据,linux驱动开发之字符设备--私有数据和container_of
  6. 数据结构之Dijkstra算法
  7. 软件工程 - 团队重组
  8. 【新功能】媒体处理MPS全新支持自适应多码率、多语言音轨
  9. mac下安装elasticsearch
  10. Java虚拟机(七)——本地方法接口与本地方法栈
  11. JSP的注释、表达式、注意事项
  12. 更改mysql默认连接数_修改mysql最大连接数
  13. vue项目开发实战案例
  14. acr122 java,ACR122U中文开发文档
  15. 小米手机下载二维码APP
  16. 主板温度过高的原因是什么?主板温度高的原因和处理办法
  17. 目标检测中常见指标 AP MAP coco Pascal voc 评价指标说明
  18. 堪称神器的图片无损放大缩小工具
  19. Dell PowerEdge R750 Intel DAOS 顺利通过“HighPerf Ready 1.0”测试
  20. Codeforces Round 862 (Div. 2) 题解

热门文章

  1. 信息学奥赛一本通2011:【20CSPS提高组】贪吃蛇
  2. 文具行业APS解决方案
  3. 图像和视频的主要格式与编码格式。
  4. 普通高等学校毕业生登记表 计算机水平,普通高等学校全日制毕业生登记表自我鉴定怎么写...
  5. python按什么键停止运行_python如何停止运行
  6. markdown基础语法
  7. idea文件名颜色的区别
  8. xcode 中生成和打包 ipa文件的方法和步骤
  9. android别踩白块设计,别踩白块儿实例——按键精灵手机助手
  10. [笔记] 当当音乐人:免费将Midi转化为WAV