之前写过一篇关于在scikit-learn工具包中,可视化estimator分类模型分类结果的confusion matrix混淆矩阵可视化的方法,具体可以参考看这里,看这里。今天这篇介绍一下如何使用scikit-learn工具中提供的相关方法,可视化其他任意框架(比如深度学习框架)的分类模型预测结果的混淆矩阵。

下面先说一下几个关键步骤:

1、确定类别列表,类别列表和one-hot的编码顺序一致,这里使用cifar-10的类别列表作为演示的例子。

classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"

2、准备好样本的真实label,这里我手动构造一个1000个样本的label,每一类100个。

# 生成数据集的GT标签
gt_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):gt_labels[i] = i
gt_labels = gt_labels.reshape(1, -1).squeeze()
print("gt_labels.shape : {}".format(gt_labels.shape))
print("gt_labels : {}".format(gt_labels[::5]))

3、准备好样本的预测label,这里我也手动构造这1000个样本的预测label,构造时才用了一点规则,构造出来的预测结果保证从第0类到第9类的预测准确率是逐渐降低的。

# 生成数据集的预测标签
pred_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):# 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值# 这样生成的预测准确率从0到9逐渐递减pred_labels[i] = np.random.randint(0, i + 1, 100)
pred_labels = pred_labels.reshape(1, -1).squeeze()
print("pred_labels.shape : {}".format(pred_labels.shape))
print("pred_labels : {}".format(pred_labels[::5]))

4、计算真是label和预测label的混淆矩阵,直接调用scikit-learn中的confusion_matrix方法

# 使用sklearn工具中confusion_matrix方法计算混淆矩阵
confusion_mat = confusion_matrix(gt_labels, pred_labels)
print("confusion_mat.shape : {}".format(confusion_mat.shape))
print("confusion_mat : {}".format(confusion_mat))

5、混淆矩阵可视化,在scikit-learn工具中有一个plot_confusion_matrix方法可以可视化sklearn训练的模型estimator的混淆矩阵,具体参数如下:

但是,现在的问题是我们使用的是别的框架训练的模型,也就没有这个estimator参数可以供sklearn使用,怎么办?

我们看一下plot_confusion_matrix函数的代码可以发现,他其实内部调用了以下方法:

那么,我们也仿照这个调用方式来写一下试试,代码如下:

# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes)
disp.plot(include_values=True,            # 混淆矩阵每个单元格上显示具体数值cmap="viridis",                 # 不清楚啥意思,没研究,使用的sklearn中的默认值ax=None,                        # 同上xticks_rotation="horizontal",   # 同上values_format="d"               # 显示的数值格式
)

6、将以上代码整合一下,输入数据的真实label和预测label,就可以可视化混淆矩阵了,并且不仅局限于评估scikit-learn的estimator,可以适用于所有框架的输出结果,完整代码如下:

import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib import pyplot as pltclasses = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]# 生成数据集的GT标签
gt_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):gt_labels[i] = i
gt_labels = gt_labels.reshape(1, -1).squeeze()
print("gt_labels.shape : {}".format(gt_labels.shape))
print("gt_labels : {}".format(gt_labels[::5]))# 生成数据集的预测标签
pred_labels = np.zeros(1000).reshape(10, -1)
for i in range(10):# 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值# 这样生成的预测准确率从0到9逐渐递减pred_labels[i] = np.random.randint(0, i + 1, 100)
pred_labels = pred_labels.reshape(1, -1).squeeze()
print("pred_labels.shape : {}".format(pred_labels.shape))
print("pred_labels : {}".format(pred_labels[::5]))# 使用sklearn工具中confusion_matrix方法计算混淆矩阵
confusion_mat = confusion_matrix(gt_labels, pred_labels)
print("confusion_mat.shape : {}".format(confusion_mat.shape))
print("confusion_mat : {}".format(confusion_mat))# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes)
disp.plot(include_values=True,            # 混淆矩阵每个单元格上显示具体数值cmap="viridis",                 # 不清楚啥意思,没研究,使用的sklearn中的默认值ax=None,                        # 同上xticks_rotation="horizontal",   # 同上values_format="d"               # 显示的数值格式
)
plt.show()

7、混淆矩阵的可视化结果

上图中的可视化结果符合我们在生成预测label标签时使用的规则,就是对于每个类别 i 的预测结果是0-i之间的随机值,这样的话,每个类别的预测误差只会出现在类别编号比它小的部分,也就是上图中展示的下三角矩阵。

在混淆矩阵中,横轴上的标签标示样本的预测label,纵轴上的标签标示样本的实际label。所以,对角线上的数字表示预测label和真是label一致的数量,也就是预测正确的数量。对于其他位置的数字就表示预测错误的,举个例子,比如第2行、第1列,也就是对应着(airplane, automobile)位置的数字51,表示有51个真实label为automobile的样本被预测为了airplane。

通过可视化的混淆矩阵,模型的误差,以及效果分类不好的类别,以及为什么不好,以及容易和哪个类之间出现误识别就一目了然了。

参考:https://blog.csdn.net/cxx654/article/details/107296343

分类模型confusion matrix混淆矩阵可视化相关推荐

  1. 分类模型评估---从混淆矩阵到ROC,AUC,GINI,KS,Lift,Gain,MSE

    4.4.2分类模型评判指标(一) - 混淆矩阵(Confusion Matrix) https://blog.csdn.net/Orange_Spotty_Cat/article/details/80 ...

  2. 详细讲解分类模型评价指标(混淆矩阵)python示例

    前言 1.回归模型(regression): 对于回归模型的评估方法,通常会采用平均绝对误差(MAE).均方误差(MSE).平均绝对百分比误差(MAPE)等方法. 2.聚类模型(clustering) ...

  3. 决策树及分类模型评价指标(混淆矩阵,LIFT曲线 重要)

    决策树评价指标:ROC lift(提升度):类似提纯:按照decile从高到低排列,之后计算每个decile里响应数与该decile里行数的比值得到一个response rate,另外,单独计算所有行 ...

  4. R语言自定义多分类混淆矩阵可视化函数(mutlti class confusion matrix)、R语言多分类混淆矩阵可视化

    R语言自定义多分类混淆矩阵可视化函数(mutlti class confusion matrix).R语言多分类混淆矩阵可视化 目录

  5. R语言glmnet拟合lasso回归模型实战:lasso回归模型的模型系数及可视化、lasso回归模型分类评估计算(混淆矩阵、accuracy、Deviance)

    R语言glmnet拟合lasso回归模型实战:lasso回归模型的模型系数(lasso regression coefficients)及可视化.lasso回归模型分类评估计算(混淆矩阵.accura ...

  6. R语言glmnet拟合岭回归模型实战:岭回归模型的模型系数(ridge regression coefficients)及可视化、岭回归模型分类评估计算(混淆矩阵、accuracy、Deviance)

    R语言glmnet拟合岭回归模型(ridge regression)实战:岭回归模型的模型系数(ridge regression coefficients)及可视化.岭回归模型分类评估计算(混淆矩阵. ...

  7. Pytorch+Tensorboard混淆矩阵可视化

    引言 混淆矩阵是分类任务常用的一种评估方法.对角线元素表示预测标签等于真实标签的点数,而非对角线元素则是分类器未正确标记的点的数量. 混淆矩阵的对角线值越高越好,表明有许多正确的预测.1 尤其是在类别 ...

  8. 模型效果评价—混淆矩阵(原理及Python实现)

      对于分类模型,在建立好模型后,我们想对模型进行评价,常见的指标有混淆矩阵.KS曲线.ROC曲线.AUC面积等.也可以自己定义函数,把模型结果分割成n(100)份,计算top1的准确率.覆盖率.   ...

  9. 决策树分类评估指标之混淆矩阵

    问题的提出 如果决策树的目标是尽量捕获少数类,则准确率模型评估的意义不大,需要新的模型评估指标.简单来看,只需要查看模型在少数类上的准确率就好,只要能够将少数类尽量捕捉出来,就能够达到目的. 但是,新 ...

  10. Python混淆矩阵可视化:plt.colorbar函数自定义颜色条的数值标签、配置不同情况下颜色条的数值范围以及数据类型(整型、浮点型)

    Python混淆矩阵可视化:plt.colorbar函数自定义颜色条的数值标签.配置不同情况下颜色条的数值范围以及数据类型(整型.浮点型) 目录

最新文章

  1. 正则表达式用单个空格替换多个空格
  2. 笔记本网络计算机和设备不可见,WIN10局域网电脑和设备显示不完整
  3. 系统超时或者服务器会话丢失,第 17 章 配置 Web 服务器(Undertow)
  4. the Differences between abstract class interface in C#接口和抽象类的区别
  5. 同频切换的事件_LTE前台路测切换问题处理大礼包
  6. javascript实现简体与繁体的转换(可下载)
  7. github如何同步fork到自己仓库的代码
  8. 旅行商问题(TSP)、车辆路径问题(VRP,MDVRP,VRPTW)模型介绍
  9. DelayQueue 阻塞队列
  10. 为什么国内抖音没有网页版,原因竟然是这样!
  11. 什么促使计算机图形学发展,计 算 机 图 形 学 的 过 去、 现 在 和 未 来.doc
  12. 美国诚实签经验——医生的预约单和赴美生子的费用明细表
  13. Java工程师成神之路~(2018修订版)
  14. Mysql5.7 的错误日志中最常见的note级别日志解释
  15. easypoi 语法_高中英语 | 必修1选修8必须掌握的语法重难点汇总 (全八册)
  16. 计算机学报是期刊论文吗,《计算机学报》北大核心电子期刊发表技巧
  17. 0085 开头的电话拦截方法(小米手机有效)
  18. 用这9个问题来审视自己
  19. TCHAR和string的转换
  20. 讲讲如何写论文和发论文(通信类)

热门文章

  1. VMware Workstation 英文改中文界面
  2. 移动硬盘RAW格式修复
  3. 计算机组成原理 静态随机存储器实验,计算机组成原理静态随机存储器实验
  4. java基础企业级入门视频教程-周江超-专题视频课程
  5. 驱动开发---cc1: error: code model kernel does not support PIC mode(改文件Unhelp?try it)
  6. 婚姻中,不去表达爱,比不爱更可怕
  7. 钉钉企业微信与飞书模式区别
  8. 拉勾网positionAjax获取的时候(带有账号登陆的时候)频繁获取被拉黑
  9. JAVA一维数组求和
  10. spring data jpa 动态查询Specification(包括各个In、like、Between等等各种工具类,及完整(分页查询)用法步骤(到返回给前端的结果))