引言

混淆矩阵是分类任务常用的一种评估方法。对角线元素表示预测标签等于真实标签的点数,而非对角线元素则是分类器未正确标记的点的数量。 混淆矩阵的对角线值越高越好,表明有许多正确的预测。1

尤其是在类别数量不平衡的情况下,相比accuracy,混淆矩阵(confusion matrix)对哪个类被错误分类具有更直观的解释

在平时做简单的数据实验时,可以仅用from sklearn.metrics import plot_confusion_matrix或者seaborn对混淆矩阵进行可视化。但是在深度学习训练模型的过程中,在tensorboard中可视化混淆矩阵会更方便结果记录和对照。

混淆矩阵

在tensorboard中的可视化效果:

代码实现

代码参考facebook的SlowFast工程2

引用库

import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import confusion_matrix

计算混淆矩阵

从pytorch模型输出的预测结果preds、真值labels,计算混淆矩阵。

def get_confusion_matrix(preds, labels, num_classes, normalize="true"):"""Calculate confusion matrix on the provided preds and labels.Args:preds (tensor or lists of tensors): predictions. Each tensor is inin the shape of (n_batch, num_classes). Tensor(s) must be on CPU.labels (tensor or lists of tensors): corresponding labels. Each tensor isin the shape of either (n_batch,) or (n_batch, num_classes).num_classes (int): number of classes. Tensor(s) must be on CPU.normalize (Optional[str]) : {‘true’, ‘pred’, ‘all’}, default="true"Normalizes confusion matrix over the true (rows), predicted (columns)conditions or all the population. If None, confusion matrixwill not be normalized.Returns:cmtx (ndarray): confusion matrix of size (num_classes x num_classes)"""if isinstance(preds, list):preds = torch.cat(preds, dim=0)if isinstance(labels, list):labels = torch.cat(labels, dim=0)# If labels are one-hot encoded, get their indices.if labels.ndim == preds.ndim:labels = torch.argmax(labels, dim=-1)# Get the predicted class indices for examples.preds = torch.flatten(torch.argmax(preds, dim=-1))labels = torch.flatten(labels)cmtx = confusion_matrix(labels, preds, labels=list(range(num_classes)))#, normalize=normalize) 部分版本无该参数return cmtx

绘制混淆矩阵

输入get_confusion_matrix获取的混淆矩阵cmtx,类别数量和类别名称,进行混淆矩阵绘制。

def plot_confusion_matrix(cmtx, num_classes, class_names=None, figsize=None):"""A function to create a colored and labeled confusion matrix matplotlib figuregiven true labels and preds.Args:cmtx (ndarray): confusion matrix.num_classes (int): total number of classes.class_names (Optional[list of strs]): a list of class names.figsize (Optional[float, float]): the figure size of the confusion matrix.If None, default to [6.4, 4.8].Returns:img (figure): matplotlib figure."""if class_names is None or type(class_names) != list:class_names = [str(i) for i in range(num_classes)]figure = plt.figure(figsize=figsize)plt.imshow(cmtx, interpolation="nearest", cmap=plt.cm.Blues)plt.title("Confusion matrix")plt.colorbar()tick_marks = np.arange(len(class_names))plt.xticks(tick_marks, class_names, rotation=45)plt.yticks(tick_marks, class_names)# Use white text if squares are dark; otherwise black.threshold = cmtx.max() / 2.0for i, j in itertools.product(range(cmtx.shape[0]), range(cmtx.shape[1])):color = "white" if cmtx[i, j] > threshold else "black"plt.text(j,i,format(cmtx[i, j], ".2f") if cmtx[i, j] != 0 else ".",horizontalalignment="center",color=color,)plt.tight_layout()plt.ylabel("True label")plt.xlabel("Predicted label")return figure

在tensorboard中添加混淆矩阵

将plot_confusion_matrix返回的绘制图像显示在tensorboard中。

from torch.utils.tensorboard import SummaryWriterdef add_confusion_matrix(writer,cmtx,num_classes,global_step=None,subset_ids=None,class_names=None,tag="Confusion Matrix",figsize=None,
):"""Calculate and plot confusion matrix to a SummaryWriter.Args:writer (SummaryWriter): the SummaryWriter to write the matrix to.cmtx (ndarray): confusion matrix.num_classes (int): total number of classes.global_step (Optional[int]): current step.subset_ids (list of ints): a list of label indices to keep.class_names (list of strs, optional): a list of all class names.tag (str or list of strs): name(s) of the confusion matrix image.figsize (Optional[float, float]): the figure size of the confusion matrix.If None, default to [6.4, 4.8]."""if subset_ids is None or len(subset_ids) != 0:# If class names are not provided, use class indices as class names.if class_names is None:class_names = [str(i) for i in range(num_classes)]# If subset is not provided, take every classes.if subset_ids is None:subset_ids = list(range(num_classes))sub_cmtx = cmtx[subset_ids, :][:, subset_ids]sub_names = [class_names[j] for j in subset_ids]sub_cmtx = plot_confusion_matrix(sub_cmtx,num_classes=len(subset_ids),class_names=sub_names,figsize=figsize,)# Add the confusion matrix image to writer.writer.add_figure(tag=tag, figure=sub_cmtx, global_step=global_step)

在训练过程中绘制混淆矩阵

    model.train()# 预测值和标注值,用于绘制混淆矩阵preds=[]labels=[]for i, (inputs, targets) in enumerate(data_loader):targets = targets.to(device, non_blocking=True)#shape: (n_batch,)try:outputs = model(inputs)#shape: (n_batch,n_classes)loss = criterion(outputs, targets)            # 需将tensor从gpu转到cpu上preds.append(outputs.cpu())labels.append(targets.cpu())acc,recall = calculate_precision_and_recall(outputs, targets,pos_label=0)losses.update(float(loss.item()), inputs.size(0))accuracies.update(float(acc), inputs.size(0))recalls.update(float(recall), inputs.size(0))#total_loss+=float(loss.item())optimizer.zero_grad()loss.backward()optimizer.step()"""混淆矩阵可视化"""preds = torch.cat(preds,dim=0)labels = torch.cat(labels,dim=0)cmtx = get_confusion_matrix(preds,labels,len(class_names))add_confusion_matrix(tb_writer,cmtx,num_classes=len(class_names),class_names=class_names,tag="Train Confusion Matrix",figsize=[10,8])

  1. https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html ↩︎

  2. https://github.com/facebookresearch/SlowFast/tree/master/slowfast/visualization ↩︎

Pytorch+Tensorboard混淆矩阵可视化相关推荐

  1. R语言自定义多分类混淆矩阵可视化函数(mutlti class confusion matrix)、R语言多分类混淆矩阵可视化

    R语言自定义多分类混淆矩阵可视化函数(mutlti class confusion matrix).R语言多分类混淆矩阵可视化 目录

  2. Python混淆矩阵可视化:plt.colorbar函数自定义颜色条的数值标签、配置不同情况下颜色条的数值范围以及数据类型(整型、浮点型)

    Python混淆矩阵可视化:plt.colorbar函数自定义颜色条的数值标签.配置不同情况下颜色条的数值范围以及数据类型(整型.浮点型) 目录

  3. matplotlib绘制混淆矩阵_混淆矩阵及其可视化

    混淆矩阵(Confusion Matrix)是机器学习中用来总结分类模型预测结果的一个分析表,是模式识别领域中的一种常用的表达形式.它以矩阵的形式描绘样本数据的真实属性和分类预测结果类型之间的关系,是 ...

  4. python matplotlib绘制混淆矩阵并配色

    文章目录 步骤1:网络测试结果保存 步骤2:矩阵绘制 混淆矩阵绘制结果 步骤1:网络测试结果保存 以pytorch为例,在测试阶段保存结果的参考代码为: resultTxtName = "r ...

  5. 混淆矩阵评价指标_机器学习:模型训练和评估——分类效果的评价

    图 | 源网络文 | 5号程序员 分类模型建立好后,这个模型到底符不符合要求要怎么评判呢? 事实上是有评价标准的. 要评价模型在测试集上预测结果的好坏,可以使用Sklearn库中的metrics模块方 ...

  6. 模型评价ROC\AUC\查准率\查全率\F-score\混淆矩阵\KS曲线\PR曲线等

    文章目录 一.ROC_AUC 1.1 ROC_AUC 概念 1.2 常见评价指标 1.3 sklearn.metrics.roc_curve()参数解释 1.4 ROC_AUC 曲线 二.混淆矩阵 2 ...

  7. R语言使用rpart包构建决策树模型、使用prune函数进行树的剪枝、交叉验证预防过拟合、plotcp可视化复杂度、rpart.plot包可视化决策树、使用table函数计算混淆矩阵评估分类模型性能

    R语言使用rpart包构建决策树模型.使用prune函数进行树的剪枝.使用10折交叉验证选择预测误差最低的树来预防过拟合.plotcp可视化决策树复杂度.rpart.plot包可视化最终决策树.使用t ...

  8. sklearn使用投票器VotingClassifier算法构建多模型融合的软投票器分类器(soft voting)并自定义子分类器的权重(weights)、计算融合模型的混淆矩阵、可视化混淆矩阵

    sklearn使用投票器VotingClassifier算法构建多模型融合的软投票器分类器(soft voting)并自定义子分类器的权重(weights).计算融合模型的混淆矩阵.可视化混淆矩阵 目 ...

  9. R语言使用caret包的confusionMatrix函数计算混淆矩阵、使用编写的自定义函数可视化混淆矩阵(confusion matrix)

    R语言使用caret包的confusionMatrix函数计算混淆矩阵.使用编写的自定义函数可视化混淆矩阵(confusion matrix) 目录

最新文章

  1. 清华大学《高级机器学习》课件和Fellow专家特邀报告(附pdf下载)
  2. Python解析json字符串,json字符串用法
  3. 分布式事务,EventBus 解决方案:CAP【中文文档】
  4. android默认开机动画,修改安卓开机动画(除了部分系统 如MIUI等)
  5. ASP.NET五大核心对象解析
  6. oracle存储过程如何传递一个bean对象_对象传输,序列化和反序列化
  7. 如何启用或禁用错过的呼叫skype for business通知
  8. 兼容多种浏览器“复制到剪贴板”的解决方案
  9. 浅谈找到***点后的处理(清理***)
  10. AI头发笔刷_5G大量PS笔刷AI笔刷打包下载(超过1000款笔刷)
  11. 用 Python j进行一次短视频音频创作
  12. 浅论如何优化搜索引擎排名机制
  13. 局域网内查询嵌入式设备IP的几种方式
  14. [附源码]计算机毕业设计Python+uniapp作业批改系统APP4238p(程序+lw+APP+远程部署)
  15. 【聚类算法】MiniBatchKMeans算法
  16. 贵州建筑施工劳务资质备案流程
  17. python根据图片网址下载图片
  18. mysql不包含模糊查询
  19. 有限差分法、一阶向前差分、一阶向后差分
  20. 钱多多第二阶段冲刺03

热门文章

  1. Trie树(c++实现)
  2. 硬件反垃圾邮件网关|反垃圾邮件软件产品|反垃圾邮箱邮件系统
  3. MVC3.0 将网站设为首页和加为收藏的实现(IE/Firefox)
  4. 2020前端面试(一面面试题)
  5. Nginx配置示例文件
  6. Oracle RMAN备份与还原 - 联机备份讲解
  7. 分布式事务 - 如何解决分布式事务问题?
  8. Intellij IDEA 社区版集成 Database Navigator 数据库管理工具
  9. Fedora安装Mariadb数据库
  10. BGP——权重选路(讲解+配置命令)