tensorflow2.0自制神经网络数据集及预测结果(混淆矩阵)可视化
写在开头:现在网上的资料对测试数据集的混淆矩阵可视化大都基于MNIST、CAIFR等标准数据集来绘制,如果想把自己制作的数据集的预测结果可视化,如何做?本文为你提供一种解决方案。
上篇博文我们写了自制数据集的制作、训练、可视化流程。本篇博文在此基础上将测试集的预测结果可视化,让我们直观的观察哪些数据预测正确,哪些数据预测错误。数据集的制作流程及代码在上篇博文中已有详细介绍,本次不作为解释重点,重点解释绘制混淆矩阵部分的代码及实现。
- 首先导入一些必须的包
from __future__ import print_function #这行代码必须放在第一行,不然会报错,原因我也没有细察,请各位读者注意一下
import os,glob,random, csv,itertools
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import optimizers,losses,layers
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
- 设置全局种子、tensorflow版本检查、GPU内存分配等
#设置全局种子
tf.random.set_seed(22)
np.random.seed(22)#检查tensorflow版本是否为2.0版本
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')# 设置GPU显存按需分配(如果没有装GPU版本,这部分就不用写,但是如果不用GPU加速计算,使用CPU将会严重限制模型的训练速度)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:try:for gpu in gpus:tf.config.experimental.set_memory_growth(gpu, True)logical_gpus = tf.config.experimental.list_logical_devices('GPU')print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")except RuntimeError as e:print(e)
- 写入训练数据集的csv文件
def load_csv(root, filename, name2label):if not os.path.exists(os.path.join(root, filename)):images = []for name in name2label.keys()images += glob.glob(os.path.join(root, name, '*.png'))images += glob.glob(os.path.join(root, name, '*.jpg'))images += glob.glob(os.path.join(root, name, '*.jpeg'))print(len(images), images)random.shuffle(images)with open(os.path.join(root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images: # name = img.split(os.sep)[-2]label = name2label[name]writer.writerow([img, label])print('written into csv file:', filename) images, labels = [], []with open(os.path.join(root, filename)) as f:reader = csv.reader(f)for row in reader:img, label = rowlabel = int(label)images.append(img)labels.append(label)assert len(images) == len(labels)return images, labels
- 创建数据集的编码表
def load_huangtu(root, mode='train'):name2label = {} for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continuename2label[name] = len(name2label.keys())images, labels = load_csv(root, 'images.csv', name2label)if mode == 'train': # 60%images = images[:int(0.6 * len(images))]labels = labels[:int(0.6 * len(labels))]elif mode == 'val': # 20% = 60%->80%images = images[int(0.6 * len(images)):int(0.8 * len(images))]labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]else: # 20% = 80%->100%images = images[int(0.8 * len(images)):]labels = labels[int(0.8 * len(labels)):]return images, labels, name2label
- 数据的处理
#下面两个矩阵常量是谷歌经过大量的图片计算得到的值,后续可以直接使用
img_mean = tf.constant([0.485, 0.456, 0.406])
img_std = tf.constant([0.229, 0.224, 0.225])
def normalize(x, mean=img_mean, std=img_std):x = (x - mean)/stdreturn x
def preprocess(x,y):# x: 图片的路径,y:图片的数字编码x = tf.io.read_file(x)x = tf.image.decode_jpeg(x, channels=3) # RGBAx = tf.image.resize(x, [244, 244])x = tf.image.random_flip_up_down(x)#数据增强(上下翻转)x= tf.image.random_flip_left_right(x) #数据增强(左右翻转)x = tf.image.random_crop(x, [224, 224, 3]) #数据增强(随机裁剪)x = tf.cast(x, dtype=tf.float32) / 255.x = normalize(x) #255-0y = tf.convert_to_tensor(y)y = tf.one_hot(y, depth=7)return x, y
- 定义混淆矩阵
#参数 y_true为测试数据集的真实标签,y_pred为网络对测试数据集的预测结果
def plot_confusion_matrix(y_true, y_pred, title = "Confusion matrix",cmap = plt.cm.Blues, save_flg = False):classes = [str(i) for i in range(7)]#参数i的取值范围根据你自己数据集的划分类别来修改,我这儿为7代表数据集共有7类labels = range(7)#数据集的标签类别,跟上面I对应cm = confusion_matrix(y_true, y_pred, labels=labels)plt.figure(figsize=(14, 12))plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title, fontsize=40)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, fontsize=20)plt.yticks(tick_marks, classes, fontsize=20)print('Confusion matrix, without normalization')thresh = cm.max() / 2.for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, cm[i, j],horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.ylabel('True label', fontsize=30)plt.xlabel('Predicted label', fontsize=30)if save_flg:plt.savefig("./confusion_matrix.png")plt.show()
- 写入主函数,进行数据集的调用、模型的创建、编译、训练等
def main():images, labels, table = load_huangtu('huangtu', 'train')print('images', len(images), images)print('labels', len(labels), labels)print(table)db = tf.data.Dataset.from_tensor_slices((images, labels))db = db.shuffle(1000).map(preprocess).batch(32)#被注释掉的代码可用于tensorboard可视化'''writter = tf.summary.create_file_writer('此处为文件的创建/保存地质')for step, (x,y) in enumerate(db):with writter.as_default():x = denormalize(x)tf.summary.image('img',x,step=step,max_outputs=9)time.sleep(5)'''batchsz = 128# 创建训练集Datset对象images, labels, table = load_huangtu('huangtu',mode='train')db_train = tf.data.Dataset.from_tensor_slices((images, labels))db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)# 创建验证集Datset对象images2, labels2, table = load_huangtu('huangtu',mode='val')db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))db_val = db_val.map(preprocess).batch(batchsz)# 创建测试集Datset对象images3, labels3, table = load_huangtu('huangtu',mode='test')db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))db_test = db_test.map(preprocess).batch(batchsz)#迁移学习net = keras.applications.DenseNet121(weights='imagenet', include_top=False,pooling='max')net.trainable = Falsenewnet = keras.Sequential([net,layers.Dense(1024,activation='relu'),layers.BatchNormalization(),layers.Dropout(rate=0.5),layers.Dense(7)])newnet.build(input_shape=(4,224,224,3))newnet.summary()newnet.compile(optimizer=optimizers.Adam(lr=0.01),loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])'''## 调用tensorboardlog_dir='调用上面的路径地址'tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) '''newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=50)score=newnet.evaluate(db_test)print('Test loss:', score[0])print('Test accuracy:', score[1])#下面三行代码为绘制混淆矩阵的传参predict_classes = newnet.predict(db_test)#对测试数据集进行预测true_classes = np.argmax(predict_classes,1)#汲取预测结果plot_confusion_matrix(true_classes, labels3, save_flg = True)#调用混淆矩阵#调用主函数if __name__ == '__main__':main()
将上面所有的代码按顺序组合起来,就可以做为一个完整的数据集制作、模型训练、混淆矩阵可视化的代码
下图为自制数据集的混淆矩阵可视化结果
对角线上表示预测正确结果。
写在最后:相比对标准数据集的预测,重点注意和思考我们自制数据集的传参。
本文到此结束
tensorflow2.0自制神经网络数据集及预测结果(混淆矩阵)可视化相关推荐
- tensorflow2.0 循环神经网络--情感分类实战
tensorflow2.0 循环神经网络–情感分类实战代码 本文主要是情感分类单层实战RNN Cell代码 import os import numpy as np import tensorflow ...
- 【深度学习】利用tensorflow2.0卷积神经网络进行卫星图片分类实例操作详解
本文的应用场景是对于卫星图片数据的分类,图片总共1400张,分为airplane和lake两类,也就是一个二分类的问题,所有的图片已经分别放置在2_class文件夹下的两个子文件夹中.下面将从这个实例 ...
- Python项目实战-Tensorflow2.0实现泰坦尼克生存预测
目录 一.数据集下载地址 二.探索性因子分析(EDA) 三.特征工程 四.构建Dataset与Model fit和自定义estimator使用 预定义estimator的使用 一.数据集下载地址 # ...
- Tensorflow2.0泰坦尼克数据集的python分析以及离散化数据处理(含数据集下载地址)
泰坦尼克数据集下载 训练集 测试集 导入需要的库 import matplotlib.pyplot as plt %matplotlib inline import numpy as np impor ...
- Tensorflow2.0之用循环神经网络生成周杰伦歌词
文章目录 1.导入需要的库 2.加载数据集 3.相邻采样 4.定义模型 4.1 定义循环神经网络层 4.2 定义循环神经网络 5.定义预测函数 6.裁剪梯度 7.定义模型训练函数 7.1 困惑度 7. ...
- TensorFlow2.0学习笔记2-tf2.0两种方式搭建神经网络
目录 一,TensorFlow2.0搭建神经网络八股 1)import [引入相关模块] 2)train,test [告知喂入网络的训练集测试集以及相应的标签] 3)model=tf.keras. ...
- tensorflow2.0实现IMDB文本数据集学习词嵌入
1. IMDB数据集示例如下所示 [{"rating": 5, "title": "The dark is rising!", " ...
- 第七章:Tensorflow2.0 RNN循环神经网络实现IMDB数据集训练(理论+实践)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/LQ_qing/article/deta ...
- tensorflow2.0莺尾花iris数据集分类|超详细
tensorflow2.0莺尾花iris数据集分类 超详细 直接上代码 #导入模块 import tensorflow as tf #导入tensorflow模块from sklearn import ...
- TensorFlow2.0 学习笔记(三):卷积神经网络(CNN)
欢迎关注WX公众号:[程序员管小亮] 专栏--TensorFlow学习笔记 文章目录 欢迎关注WX公众号:[程序员管小亮] 专栏--TensorFlow学习笔记 一.神经网络的基本单位:神经元 二.卷 ...
最新文章
- MPB:南农成艳芬组-​瘤胃厌氧真菌代谢产物的检测方法
- 打开共享文件闪退怎么解决_文件共享解决方案-随时随地共享同步访问文件
- springboot 优雅停机_新姿势,Spring Boot 2.3.0 如何优雅停机?
- 软件工程--需求分析
- linux驱动私有数据,linux驱动开发之字符设备--私有数据和container_of
- 数据结构之Dijkstra算法
- 软件工程 - 团队重组
- 【新功能】媒体处理MPS全新支持自适应多码率、多语言音轨
- mac下安装elasticsearch
- Java虚拟机(七)——本地方法接口与本地方法栈
- JSP的注释、表达式、注意事项
- 更改mysql默认连接数_修改mysql最大连接数
- vue项目开发实战案例
- acr122 java,ACR122U中文开发文档
- 小米手机下载二维码APP
- 主板温度过高的原因是什么?主板温度高的原因和处理办法
- 目标检测中常见指标 AP MAP coco Pascal voc 评价指标说明
- 堪称神器的图片无损放大缩小工具
- Dell PowerEdge R750 Intel DAOS 顺利通过“HighPerf Ready 1.0”测试
- Codeforces Round 862 (Div. 2) 题解