本次实验基于自己搭建的CNN网络实现眼睛状态的分类,本来是打算迁移学习利用VGG16网络进行分类的,但是实验效果特别差,而且速度很慢,应该是博主自己的问题。而自己搭建的CNN网络的模型准确率也很高,运行速度很快。本文的重点在于混淆矩阵的绘制,这是之前没有接触过的东西。

1.导入库

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os,pathlib,PIL
from tensorflow import keras
from tensorflow.keras import layers,models,Sequential

2.数据加载

数据所在文件路径

data_dir = "E:/tmp/.keras/datasets/Eye_photos"
data_dir = pathlib.Path(data_dir)
img_count = len(list(data_dir.glob('*/*.jpg')))#图片总数

超参数的设置

height = 224
width = 224
epochs = 10
batch_size = 64

构建一个ImageDataGenerator,在之前的实验中,我通常在这一步会进行数据加强,包括左右翻转、图片翻转某个角度,水平翻转等。但是在本次实验中,并没有进行这一操作。因为本次识别的眼睛状态包括左看、右看、前看、闭眼四种状态,如果进行数据增强的话,左看变为右看,这样数据没有达到增强的效果,反而引入噪声数据,得不偿失。

train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,validation_split=0.2)#以8:2的比例划分为训练集和测试集

分为训练集和测试集

train_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='training'
)
test_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='validation'
)
Found 3448 images belonging to 4 classes.
Found 859 images belonging to 4 classes.

查看标签

all_images_paths = list(data_dir.glob('*'))##”*”匹配0个或多个字符
all_images_paths = [str(path) for path in all_images_paths]
all_label_names = [path.split("\\")[5].split(".")[0] for path in all_images_paths]
['close_look', 'forward_look', 'left_look', 'right_look']

3.CNN网络搭建

model = tf.keras.Sequential([tf.keras.layers.Conv2D(16,3,padding="same",activation="relu",input_shape=(height,width,3)),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(32,3,padding="same",activation="relu"),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(64,3,padding="same",activation="relu"),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Flatten(),tf.keras.layers.Dense(1024,activation="relu"),tf.keras.layers.Dense(512,activation="relu"),tf.keras.layers.Dense(4,activation="softmax")
])

优化器的设置,具体的原理可以参考车牌识别那篇博客。

initial_learning_rate = 1e-4
lr_sch = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=initial_learning_rate,decay_rate=0.96,decay_steps=20,staircase=True
)

计算loss值的方式我在上篇博客中讲述了。

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_sch),loss=tf.keras.losses.CategoricalCrossentropy(),metrics=['accuracy']
)history = model.fit(train_ds,validation_data=test_ds,epochs=epochs
)

结果如下所示:

在epochs=20的情况下,模型的准确率在93%左右,比较可观。

保存模型:

model.save("E:/tmp/.keras/datasets/model.h5")

加载模型

new_model = tf.keras.models.load_model("E:/tmp/.keras/datasets/model.h5")

利用模型对图片进行预测:

plt.figure(figsize=(10,5))
plt.suptitle("预测结果展示")
for images,labels in test_ds:for i in range(8):ax = plt.subplot(2,4,i+1)plt.imshow(images[i])img_array = tf.expand_dims(images[i],0)#增加一维pre = new_model.predict(img_array)plt.title(all_label_names[np.argmax(pre)])plt.axis("off")break
plt.show()

4.混淆矩阵

混淆矩阵也称误差矩阵,是表示精度评价的一种标准格式,用n行n列的矩阵形式来表示。在图像精度评价中,主要用于比较分类结果和实际测得值,可以把分类结果的精度显示在一个混淆矩阵里面。混淆矩阵是通过将每个实测像元的位置和分类与分类图像中的相应位置和分类相比较计算的。
我们最熟悉的混淆矩阵就是二分类的混淆矩阵:

TP = True Postive = 真阳性; FP = False Positive = 假阳性

FN = False Negative = 假阴性; TN = True Negative = 真阴性

至于多分类的混淆矩阵,与二分类的混淆矩阵相差不多。我们来绘制眼睛状态识别的混淆矩阵。

sns.heatmap是用来绘制混淆矩阵的主要工具,这是seaborn包下面的一个方法,具体如下:

seaborn.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt='.2g', annot_kws=None, linewidths=0, linecolor='white', cbar=True, cbar_kws=None, cbar_ax=None, square=False, xticklabels='auto', yticklabels='auto', mask=None, ax=None, **kwargs)

其实除了第一个参数data外,其余的参数都是缺省参数,可以不用管。这里的data,如果接收的是干干净净的numpy二维数组的话,可以看到行标就是0,1,2,如果是DataFrame,就可以用列名来标记了。

所需要的库

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd

定义一个绘制混淆矩阵的函数

#绘制混淆矩阵
def plot_cm(labels,pre):conf_numpy = confusion_matrix(labels,pre)#根据实际值和预测值绘制混淆矩阵conf_df = pd.DataFrame(conf_numpy,index=all_label_names,columns=all_label_names)#将data和all_label_names制成DataFrameplt.figure(figsize=(8,7))sns.heatmap(conf_df,annot=True,fmt="d",cmap="BuPu")#将data绘制为混淆矩阵plt.title('混淆矩阵',fontsize = 15)plt.ylabel('真实值',fontsize = 14)plt.xlabel('预测值',fontsize = 14)plt.show()

得到预测值与实际值

test_pre = []
test_label = []
for images,labels in test_ds:for image,label in zip(images,labels):img_array = tf.expand_dims(image,0)#增加一共维度pre = new_model.predict(img_array)#预测结果test_pre.append(all_label_names[np.argmax(pre)])#将预测结果传入列表test_label.append(all_label_names[np.argmax(label)])#将真实结果传入列表break#由于硬件问题。这里我只用了一个batch,一共64张图片。
plot_cm(test_label,test_pre)#绘制混淆矩阵


努力加油a啊

深度学习之眼睛状态识别混淆矩阵的绘制相关推荐

  1. 【深度学习】OCR文本识别

    OCR文字识别定义 OCR(optical character recognition)文字识别是指电子设备(例如扫描仪或数码相机)检查纸上打印的字符,然后用字符识别方法将形状翻译成计算机文字的过程: ...

  2. 【实战】深度学习构建人脸面部表情识别系统

    实战:深度学习构建人脸面部表情识别系统 一.表情数据集 数据集采用了kaggle面部表情识竞赛的人脸表情识别数据集. https://www.kaggle.com/c/challenges-in-re ...

  3. 基于深度学习的命名实体识别研究综述——论文研读

    基于深度学习的命名实体识别研究综述 摘要: 0引言 1基于深度学习的命名实体识别方法 1.1基于卷积神经网络的命名实体识别方法 1.2基于循环神经网络的命名实体识别方法 1.3基于Transforme ...

  4. 利用深度学习进行交通灯识别_通过深度学习识别交通信号灯

    利用深度学习进行交通灯识别 by David Brailovsky 戴维·布雷洛夫斯基(David Brailovsky) 通过深度学习识别交通信号灯 (Recognizing Traffic Lig ...

  5. 深度学习之视频人脸识别系列二:人脸检测与对齐

    作者 | 东田应子 [磐创AI导读]本文是深度学习之视频人脸识别系列的第二篇文章,介绍人脸检测与对齐的相关算法.欢迎大家关注我们的公众号:磐创AI. 一.人脸检测与关键点检测 问题描述: 人脸检测解决 ...

  6. DeepEye:一个基于深度学习的程序化交易识别与分类方法

    DeepEye:一个基于深度学习的程序化交易识别与分类方法 徐广斌,张伟 上海证券交易所资本市场研究所,上海 200120  上海证券交易所产品创新中心,上海 200120    摘要:基于沪市A股交 ...

  7. 一种基于深度学习的增值税发票影像识别系统

    一种基于深度学习的增值税发票影像识别系统-专利技术交底书 缩略语和关键术语定义 1.卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构 ...

  8. 基于深度学习的人脸性别识别系统(含UI界面,Python代码)

    摘要:人脸性别识别是人脸识别领域的一个热门方向,本文详细介绍基于深度学习的人脸性别识别系统,在介绍算法原理的同时,给出Python的实现代码以及PyQt的UI界面.在界面中可以选择人脸图片.视频进行检 ...

  9. 基于深度学习的轴承故障识别-构建基础的CNN模型

    上回书说到,处理序列的基本深度学习算法分别是循环神经网络(recurrent neural network)和一维卷积神经网络(1D convnet).上篇构建了基础的LSTM模型,这一篇自然轮到CN ...

最新文章

  1. linux @webserviceclient 访问超时_Linux系统调优
  2. How to change windows applicatioin's position via Win32 API
  3. caffe 加入 cudnn编译
  4. Anaconda 安装 Python 库(MySQLdb)的方法
  5. Windows Communication Foundation环境安装篇
  6. linux6.2 网络yum,配置RHEL6.2的YUM源
  7. Nginx学习总结(2)——Nginx手机版和PC电脑版网站配置
  8. function adapter bind(C++11)
  9. DEVC6.0使用教程
  10. Go 微服务开发框架 DMicro 的设计思路
  11. 基本概念学习(7002)---网络流量控制
  12. 大数据基础知识思维导图
  13. CDA数据分析LEVEL1--数据结构
  14. 不用科学梯子下载mokee-mkq-mr1分支安卓10.0源码
  15. border-radius简介
  16. CVE-2018-4990 漏洞详情分析
  17. 全力以赴地完成书稿中
  18. 第15章_存储过程与函数(创建存储过程、调用存储过程、存储函数的使用、存储过程和函数的查看、修改、删除)
  19. Unity基础学习六,网络同步
  20. 程序员算事业单位吗_程序员嘚瑟,嘲笑事业单位员工,一年收入才4万,众人:你还年轻...

热门文章

  1. 在C#中实现Socket端口复用
  2. IOS基础之打砖块项目演练
  3. php phpanalysis2.0,使用phpAnalysis打造PHP应用非侵入式性能分析器
  4. XSS跨站脚本(web应用)——XSS相关工具及使用(四)
  5. 如何保证进程间同步工作_冬季建房如何保证混凝土浇筑效果好,做好养护工作...
  6. 显示unc路径服务器根目录,错误:“您必须输入带有盘符的完整路径,例如:C:\ APPor形式的UNC路径:\\服务器\共享”...
  7. python坐标轴刻度设置_学习python中matplotlib绘图设置坐标轴刻度、文本
  8. c++11仔细地将参数传递给线程std::thread
  9. php tp框架做选中删除,关于thinkphp框架实现删除和批量删除的分析
  10. python报数组越界_python数组越界