1. 什么是混淆矩阵

混淆矩阵是评判模型结果的一种指标,属于模型评估的一部分,常用于评判分类器的优劣。即,混淆矩阵是评判模型结果的指标,属于模型评估的一部分。

此外,混淆矩阵多用于判断分类器(Classifier)的优劣,适用于分类型的数据模型,如

  • 分类树(Classification Tree)
  • 逻辑回归(Logistic Regression)
  • 线性判别分析(Linear Discriminant Analysis)等方法。

一句话解释版本:混淆矩阵就是分别统计分类模型归错类归对类的观测值个数,然后把结果放在一个表里展示出来。这个表就是混淆矩阵。


在分类型模型评判的指标中,常见的方法有如下三种:

  • 混淆矩阵(也称误差矩阵,Confusion Matrix)
  • ROC曲线
  • AUC面积


对于这个混淆矩阵,横坐标是真实标签(Ground Truth),纵坐标是模型预测的类别。对角线是我们最关注的信息,对角线代表预测正确的样本的个数。

Precision (精确率) Recall (召回率) Specificity (特异度)
类别1 0.825 0.963 0.965
类别2 0.989 0.898 0.996
类别3 0.906 0.892 0.980

注意:准确率(accuracy)和精确率(precision)不是一回事,准确率一般用于分类网络,而精确率用于目标检测。

2. 混淆矩阵前置知识

2.1 混淆矩阵的定义

混淆矩阵(Confusion Matrix),它的本质远没有它的名字听上去那么拉风。矩阵,可以理解为就是一张表格,混淆矩阵其实就是一张表格而已。

以分类模型中最简单的二分类为例,对于这种问题,我们的模型最终需要判断样本的结果是0还是1,或者说是positive还是negative。我们通过样本的采集,能够直接知道真实情况下,哪些数据结果是positive,哪些结果是negative。同时,我们通过用样本数据跑出分类模型的结果,也可以知道模型认为这些数据哪些是positive,哪些是negative

对于一个二分类网络,模型本质上只有1个类别,即模型的预测结果只有是这个类别(正样本)不是这个类别(负样本)这两种结果。

因此,我们就能得到这样四个基础指标,我称他们是一级指标(最底层的):

  • 真实值是positive,模型认为是positive的数量(True Positive=TP) -> 真阳性
  • 真实值是positive,模型认为是negative的数量(False Negative=FN):这就是统计学上的第二类错误(Type II Error) -> 假阴性
  • 真实值是negative,模型认为是positive的数量(False Positive=FP):这就是统计学上的第一类错误(Type I Error) -> 假阳性
  • 真实值是negative,模型认为是negative的数量(True Negative=TN) -> 真阴性

对于二分类网络,1 代表的就是Positive, 0 代表的就是Negative。


注意: Positive和Negative是针对网络的预测结果得到的,和真实值无关,真实值和True/False有关。

  • 模型预测的是1(Positive),与GT相符 -> TP -> 真阳性
  • 模型预测的是1(Positive),与GT不符 -> FP -> 假阳性
  • 模型预测的是0(Negative),与GT相符 -> TN -> 真阴性
  • 模型预测的是0(Negative),与GT不符 -> FN -> 假阴性

将这四个指标一起呈现在表格中,就能得到如下这样一个矩阵,我们称它为混淆矩阵(Confusion Matrix):

对于左上角的混淆矩阵来说,同样的,每一行代表真实值的标签,每一列代表预测值的标签。

  • Positive: 正样本(真实值)
  • Negative:负样本(预测值)

2.2 混淆矩阵的指标

预测性分类模型,肯定是希望越准越好。那么,对应到混淆矩阵中,那肯定是希望TP与TN的数量大(预测值和GT一致的情况),而FP与FN的数量小(预测值与GT不符的情况)。所以当我们得到了模型的混淆矩阵后,就需要去看有多少观测值在第二、四象限对应的位置,这里的数值越多越好;反之,在第一、三象限对应位置出现的观测值肯定是越少越好。

  • TP和TN越高越好
  • FP和FN越少越好

2.3 二级指标——准确率、精确率、灵敏度/召回率、特异度

但是,混淆矩阵里面统计的是个数,有时候面对大量的数据,光凭算个数,很难衡量模型的优劣。混淆矩阵是直接把所有的数据都摆了上来,实际上并没有什么解读,所以需要一些指标来衡量混淆矩阵的好坏

因此混淆矩阵在基本的统计结果上又延伸了如下4个指标,我们称它们为二级指标(通过最底层指标加减乘除得到的):

  • 准确率(Accuracy)—— 针对整个模型
  • 精确率(Precision)
  • 灵敏度(Sensitivity):就是召回率(Recall)
  • 特异度(Specificity)
二级指标 公式 描述 通俗解释
Accuracy (准确率) Accuracy=TP+TNTP+FP+TN+FN\large \mathrm{Accuracy = \frac{TP + TN}{TP + FP + TN + FN}}Accuracy=TP+FP+TN+FNTP+TN​ 模型分类正确样本个数(正样本+负样本)占总样本个数的比例 所有正负样本中模型预测对的比例
Precision (精确率) Precision=TPTP+FP\large \mathrm{Precision = \frac{TP}{TP + FP}}Precision=TP+FPTP​ 模型认为是正样本中,预测对的比例 模型认为是正样本中(不一定真的是正样本),预测对的比例
Recall (召回率/查全率) Recall=TPTP+FN\large \mathrm{Recall = \frac{TP}{TP + FN}}Recall=TP+FNTP​ 所有真实的正样本中,模型预测对的比例 真实的正样本中预测了对了多少(模型本应该预测出来的正样本中预测了对了多少)
Specificity (特异度) Specificity=TNTN+FP\large \mathrm{Specificity = \frac{TN}{TN + FP}}Specificity=TN+FPTN​ 所有真实的负样本中,模型预测对的比例 真实的负样本中预测了对了多少(模型本应该预测出来的负样本中预测了对了多少)

通过上面的四个二级指标,可以将混淆矩阵中数量的结果转化为 [0,1][0, 1][0,1] 之间的比率,便于进行标准化的衡量

在实际使用中,使用较多的是前三个指标(Accuracy, Precision, Recall)。

简单记忆

  • Accuracy: 模型预测的所有正负样本中,预测对了多少 —— 模型判断正确的数据(TP+TN)占总数据的比例
  • Precision: 模型预测的所有正样本中,预测对了多少 —— Precision高表示模型检测出的正样本中大部分确实是正样本,只有少量不是正样本被当成正样本
  • Recall: 模型本应该预测出来的正样本中预测了对了多少 —— 召回率也叫查全率,以目标检测为例,我们往往把图片中的物体作为正例,此时召回率高代表着模型可以找出图片中更多的物体!
  • Specificity: 模型本应该预测出来的负样本中预测了对了多少 —— 特异度高代表着模型可以找出图片中更多的背景(负样本)!

2.4 三级指标

在这四个指标的基础上在进行拓展,会产令另外一个三级指标。这个指标叫做F1 Score。它的计算公式是:
F1Score=2PRP+R∈[0,1]\mathrm{ F1 \ Score = \frac{2PR}{P + R} \in [0, 1] } F1 Score=P+R2PR​∈[0,1]

其中,P代表Precision,R代表Recall。

F1-Score指标综合了Precision与Recall的产出的结果。F1-Score的取值范围为[0, 1]:

  • 1代表模型的输出最好
  • 0代表模型的输出结果最差

3. 例子

3.1 准确率(Accuracy)

准确率简单来讲,就是对角线占所有的比例,即:
Accuracy=TP+TNTP+FP+TN+FN=10+15+2010+15+20+3+5+1+6+2+4=10+15+2066≈0.68\begin{aligned} \text{Accuracy} & = \mathrm{\frac{TP + TN}{TP + FP + TN + FN}} \\ & = \frac{10 + 15 + 20}{10+15+20+3+5+1+6+2+4} \\ & = \frac{10 + 15 +20}{66} \\ & \approx 0.68 \end{aligned} Accuracy​=TP+FP+TN+FNTP+TN​=10+15+20+3+5+1+6+2+410+15+20​=6610+15+20​≈0.68​

所有正负样本中,预测对了多少

3.2 精确率(Precision)

对于精确率来说,我们以“猫”为例,3分类可以变为2分类——“猫”和“不为猫”。


Precision=TPTP+FP=1010+3≈0.77\begin{aligned} \text{Precision} & = \mathrm{\frac{TP}{TP + FP}} \\ & = \frac{10}{10+3} \\ & \approx 0.77 \end{aligned} Precision​=TP+FPTP​=10+310​≈0.77​

模型预测的所有正样本中,预测对了多少

3.3 召回率(Recall)


Recall=TPTP+FN=1010+8≈0.56\begin{aligned} \text{Recall} & = \mathrm{\frac{TP}{TP + FN}} \\ & = \frac{10}{10+8} \\ & \approx 0.56 \end{aligned} Recall​=TP+FNTP​=10+810​≈0.56​

模型本应该预测出来的正样本中预测了对了多少

3.4 特异度(Specificity)

Specificity=TPTP+FN=4545+3≈0.94\begin{aligned} \text{Specificity} & = \mathrm{\frac{TP}{TP + FN}} \\ & = \frac{45}{45+3} \\ & \approx 0.94 \end{aligned} Specificity​=TP+FNTP​=45+345​≈0.94​

模型本应该预测出来的负样本中预测了对了多少

3.5 总结

对于二级指标来说:

  • accuracy是可以根据所有类别来进行计算的(就是所有类别中,模型预测对的比例)
  • 剩下的3个二级指标precision, recall, specificity就需要针对每一个类别进行计算(按照上面的例子那样做)。

4. 代码

代码来源于霹雳吧啦WZ老师。

import os
import jsonimport torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTablefrom model import MobileNetV2class ConfusionMatrix(object):"""注意,如果显示的图像不全,是matplotlib版本问题本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常需要额外安装prettytable库  将输出打印为列表"""def __init__(self, num_classes: int, labels: list):self.matrix = np.zeros((num_classes, num_classes))self.num_classes = num_classesself.labels = labelsdef update(self, preds, labels):for p, t in zip(preds, labels):  # p: predict, t: GTself.matrix[p, t] += 1def summary(self):# calculate accuracysum_TP = 0for i in range(self.num_classes):sum_TP += self.matrix[i, i]acc = sum_TP / np.sum(self.matrix)print("the model accuracy is ", acc)# precision, recall, specificitytable = PrettyTable()  # init a table for printtable.field_names = ["", "Precision", "Recall", "Specificity"]for i in range(self.num_classes):  # for each classTP = self.matrix[i, i]FP = np.sum(self.matrix[i, :]) - TPFN = np.sum(self.matrix[:, i]) - TPTN = np.sum(self.matrix) - TP - FP - FNPrecision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.table.add_row([self.labels[i], Precision, Recall, Specificity])print(table)def plot(self):  # plot confusion matrixmatrix = self.matrixprint(matrix)plt.imshow(matrix, cmap=plt.cm.Blues)  # color from white to blueplt.xticks(range(self.num_classes), self.labels, rotation=45)plt.yticks(range(self.num_classes), self.labels)# show colorbarplt.colorbar()plt.xlabel('True Labels')plt.ylabel('Predicted Labels')plt.title('Confusion matrix')# 在图中标注数量/概率信息thresh = matrix.max() / 2# Note:#       x: left -> right; y: top -> bottomfor x in range(self.num_classes):for y in range(self.num_classes):# 注意这里的matrix[y, x]不是matrix[x, y]info = int(matrix[y, x])plt.text(x, y, info,verticalalignment='center',horizontalalignment='center',color="white" if info > thresh else "black")plt.tight_layout()plt.show()if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "data path {} does not exist.".format(image_path)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform)batch_size = 16validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=2)net = MobileNetV2(num_classes=5)# load pretrain weightsmodel_weight_path = "./MobileNetV2.pth"assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)net.load_state_dict(torch.load(model_weight_path, map_location=device))net.to(device)# read class_indictjson_label_path = './class_indices.json'assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)json_file = open(json_label_path, 'r')class_indict = json.load(json_file)labels = [label for _, label in class_indict.items()]confusion = ConfusionMatrix(num_classes=5, labels=labels)net.eval()with torch.no_grad():for val_data in tqdm(validate_loader):val_images, val_labels = val_dataoutputs = net(val_images.to(device))outputs = torch.softmax(outputs, dim=1)outputs = torch.argmax(outputs, dim=1)confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())confusion.plot()confusion.summary()

参考

参考:

  1. 使用pytorch和tensorflow计算分类模型的混淆矩阵_哔哩哔哩_bilibili
  2. https://blog.csdn.net/Orange_Spotty_Cat/article/details/80520839

分类模型计算混淆矩阵相关推荐

  1. 人工智能学习07--pytorch11--分类网络:使用pytorch和tensorflow计算分类模型的混淆矩阵

    师兄说学目标检测之前先学分类 坏了,内容好多!学学学 感谢up主,好人一生平安 混淆矩阵 什么是混淆矩阵: 横坐标:每一列属于该类的所有验证样本.每一列所有元素对应真实类别. 纵坐标:网络的预测类别. ...

  2. 分类模型之混淆矩阵(Confusion Matrix)

    目录 混淆矩阵简介 混淆矩阵及其参数 混淆矩阵的其他指标 混淆矩阵简介 混淆矩阵是ROC曲线绘制的基础,同时它也是衡量分类型模型准确度中最基本,最直观,计算最简单的方法. 可以简单理解为:将一个分类模 ...

  3. R语言使用randomForest包构建随机森林模型(Random forests)、使用importance函数查看特征重要度、使用table函数计算混淆矩阵评估分类模型性能、包外错误估计OOB

    R语言使用randomForest包中的randomForest函数构建随机森林模型(Random forests).使用importance函数查看特征重要度.使用table函数计算混淆矩阵评估分类 ...

  4. R语言使用rpart包构建决策树模型、使用prune函数进行树的剪枝、交叉验证预防过拟合、plotcp可视化复杂度、rpart.plot包可视化决策树、使用table函数计算混淆矩阵评估分类模型性能

    R语言使用rpart包构建决策树模型.使用prune函数进行树的剪枝.使用10折交叉验证选择预测误差最低的树来预防过拟合.plotcp可视化决策树复杂度.rpart.plot包可视化最终决策树.使用t ...

  5. R语言使用R基础安装中的glm函数构建乳腺癌二分类预测逻辑回归模型、分类预测器(分类变量)被自动替换为一组虚拟编码变量、summary函数查看检查模型、使用table函数计算混淆矩阵评估分类模型性能

    R语言使用R基础安装中的glm函数构建乳腺癌二分类预测逻辑回归模型(Logistic regression).分类预测器(分类变量)被自动替换为一组虚拟编码变量.summary函数查看检查模型.使用t ...

  6. R语言使用yardstick包的conf_mat函数计算多分类(Multiclass)模型的混淆矩阵、并使用summary函数基于混淆矩阵输出分类模型评估的其它详细指标(kappa、npv等13个)

    R语言使用yardstick包的conf_mat函数计算多分类(Multiclass)模型的混淆矩阵(confusion matrix).并使用summary函数基于混淆矩阵输出分类模型评估的其它详细 ...

  7. java计算混淆矩阵(分类指标:查准率P,查全率R,P和R的调和均值F1,正确率A)

    [0]README 本文使用 java 计算混淆矩阵,并利用 混淆矩阵值计算 分类指标:通用分类指标有: 查准率,查全率,查准率和查全率的调和均值F1值,正确率, AOC, AUC等:本文计算前4个指 ...

  8. sklearn使用投票器VotingClassifier算法构建多模型融合的软投票器分类器(soft voting)并自定义子分类器的权重(weights)、计算融合模型的混淆矩阵、可视化混淆矩阵

    sklearn使用投票器VotingClassifier算法构建多模型融合的软投票器分类器(soft voting)并自定义子分类器的权重(weights).计算融合模型的混淆矩阵.可视化混淆矩阵 目 ...

  9. sklearn使用投票器VotingClassifier算法构建多模型融合的软投票器分类器(soft voting)并计算融合模型的混淆矩阵、可视化混淆矩阵(confusion matrix)

    sklearn使用投票器VotingClassifier算法构建多模型融合的软投票器分类器(soft voting)并计算融合模型的混淆矩阵.可视化混淆矩阵(confusion matrix) 目录

最新文章

  1. 嫉恨别人不如审视自己
  2. matlab gui优化,matlabgui优化程序
  3. Yii框架中使用SRBAC作为权限管理模块时遇到的问题
  4. git rebase -i 修改提交
  5. js jquery 函数回调
  6. 《大道至简》第六章读后感及本次课后习题11.9
  7. 【codevs1068】乌龟棋noip10年TG----第二个A掉的钻石题
  8. 将python程序打包为exe及一些问题
  9. OpenShift 之 Quarkus(4)用集成的Prometheus监控Quarkus应用
  10. mysql 主从 不一致_揭秘MySQL主从数据不一致
  11. 转:如何编程实现启用禁用网卡
  12. ReentrantLock 源码解析
  13. c++实现八大排序算法
  14. python 多元线性回归的系数检验 t p值,多元线性回归检验t检验(P值),F检验,R方等参数的含义...
  15. 大数据培训(第三季)——Linux基础-徐培成-专题视频课程
  16. 2020年中国半导体划片机行业现状分析,国产替代+需求扩张,行业前景广阔「图」
  17. 【数据结构】选择排序
  18. 一个应届生是怎样搞定google、微软、百度、腾讯、搜狗的
  19. linux运行java程序内存过大_排查java应用linux环境内存占用过高的问题
  20. linux之getopt 函数

热门文章

  1. #10016 灯泡(数学)
  2. 【原创推荐】 计算机毕业设计之Python+Spark+LSTM电商爬虫 商品推荐系统 商品评论情感分析 电商大数据 电商推荐系统 大数据毕业设计
  3. 一行python代码,带你重温经典小游戏
  4. 百度李彦宏:人工智能时代将接替移动互联网时代
  5. 区块链 layer0、layer1和layer2是什么 有什么作用
  6. 大数据技术之_08_Hive学习_03_查询+函数
  7. 刻意练习-如何从新手到大师
  8. canvas仿知乎头像裁剪
  9. android查看轨迹,一种提取并分析Android手机行踪轨迹的方法
  10. 400+陶企云集,20+重磅活动,2021佛山潭洲陶瓷展开幕首日人气火爆!