深度学习自学记录(3)——两种多分类混淆矩阵的Python实现(含代码),矩阵,样本,模型,类别,真实

深度学习自学记录(3)——两种多分类混淆矩阵的Python实现(含代码)

深度学习自学记录(3)——两种多分类混淆矩阵的Python实现(含代码)1、什么是混淆矩阵2、分类模型评价指标3、两种多分类混淆矩阵3.1直接打印出每一个类别的分类准确率。3.2打印具体的分类结果的数值4、总结

1、什么是混淆矩阵

深度学习中,混淆矩阵是ROC曲线绘制的基础,同时它也是衡量分类型模型准确度中最基本,最直观,计算最简单的方法。它可以直观地了解分类模型在每一类样本里面表现,常作为模型评估的一部分。它可以非常容易的表明多个类别是否有混淆(也就是一个class被预测成另一个class)。

首先要明确几个概念:

T或者F:该样本 是否被正确分类。

P或者N:该样本 原本是正样本还是负样本。

真正例(True Positive,TP):预测正确;模型预测也是正例,样本的真实类别是正例,

真负例(True Negative,TN):预测正确:模型预测为负例,样本的真实类别是负例,

伪正例(False Positive,FP):预测错误:模型预测为正例,样本的真实类别是负例,

伪负例(False Negative,FN):预测错误;模型预测为负例,样本的真实类别是正例,

将这四个指标一起呈现在表格中,就能得到如下这样一个矩阵,我们称它为混淆矩阵(Confusion Matrix),这里从其他博客偷了张图:

在混线矩阵中,以对角线为分界线。以上图为例子:对角线的位置表示预测正确,对角线以外的位置表示把样本错误的预测为其他样本。

2、分类模型评价指标

从混淆矩阵可以直观地看出各个参数的数值大小。

查准率是在模型预测为正的所有样本中,模型预测对的比重,即:“分类器认为是正类并且确实是正类的部分占所有分类器认为是正类的比例”。计算公式如下式所示:

Precision=TP/(TP+FP)

Precision=TP/(TP+FP)

Precision=TP/(TP+FP)

召回率是在真实值是正的所有样本中,模型预测对的比重,即:“分类器认为是正类并且确实是正类的部分占所有确实是正类的比例”。计算公式如下式所示:

Recall=TP/(TP+FN)

Recall=TP/(TP+FN)

Recall=TP/(TP+FN)

F1-Score指标综合了Precision与Recall的产出的结果。F1-Score的取值范围从0到1的,1代表模型的输出最好,0代表模型的输出结果最差,计算公式如下式所示:

除了F1分数之外,F2分数和F0.5分数在统计学中也得到大量的应用。其中,F2分数中,召回率的权重高于精准率,而F0.5分数中,精准率的权重高于召回率。

3、两种多分类混淆矩阵

多分类混淆矩阵根据不同需求可以绘制不同的矩阵:

1、直接打印出每一个类别的分类准确率。

2、打印具体的分类结果的数值,方便数据的分析和各类指标的计算

在介绍具体代码之前,首先来介绍confusion_matrix()函数,它是Python中的sklearn库提供的输出矩阵数据的方法:

def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None):

参数意义:

y_true: 是样本真实分类结果,y_pred: 是样本预测分类结果

y_pred:预测结果

labels:是所给出的类别,通过这个可对类别进行选择

sample_weight : 样本权重

3.1直接打印出每一个类别的分类准确率。# 显示混淆矩阵

def plot_confuse(model, x_val, y_val):

# 获得预测结果

predictions = predict(model,x_val)

#获得真实标签

truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label

cm = confusion_matrix(y_true=truelabel, y_pred=predictions)

plt.figure()

# 指定分类类别

classes = range(np.max(truelabel)+1)

title='Confusion matrix'

#混淆矩阵颜色风格

cmap=plt.cm.jet

cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] plt.imshow(cm, interpolation='nearest', cmap=cmap)

plt.title(title)

plt.colorbar()

tick_marks = np.arange(len(classes))

plt.xticks(tick_marks, classes, rotation=45)

plt.yticks(tick_marks, classes)

thresh = cm.max() / 2.

# 按照行和列填写百分比数据

for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):

plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",

color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()

plt.ylabel('True label')

plt.xlabel('Predicted label')

plt.show()

3.2打印具体的分类结果的数值# 显示混淆矩阵

def plot_confuse_data(model, x_val, y_val):

classes = range(0,6)

predictions = predict(model,x_val)

truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label

confusion = confusion_matrix(y_true=truelabel, y_pred=predictions)

#颜色风格为绿。。。。

plt.imshow(confusion, cmap=plt.cm.Greens)

# ticks 坐标轴的坐标点

# label 坐标轴标签说明

indices = range(len(confusion))

# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表

plt.xticks(indices, classes)

plt.yticks(indices, classes)

plt.colorbar()

plt.xlabel('Predicted label')

plt.ylabel('True label')

plt.title('Confusion matrix')

# plt.rcParams两行是用于解决标签不能显示汉字的问题

plt.rcParams['font.sans-serif']=['SimHei'] plt.rcParams['axes.unicode_minus'] = False

# 显示数据

for first_index in range(len(confusion)): #第几行

for second_index in range(len(confusion[first_index])): #第几列

plt.text(first_index, second_index, confusion[first_index][second_index])

# 显示

plt.show()

4、总结

1、混淆矩阵是深度学习中分类模型最常用的评估指标。网上大部分都是显示各类的分类正确率,不够灵活。显示具体数值灵活性大,可以计算自己想要的指标。

2、多分类的混淆矩阵中 查准率为主对角线上的值除以该值所在列的和;召回率等于主对角线上的值除以该值所在行的和。

作者:胖大海pyh以上就是关于对深度学习自学记录(3)——两种多分类混淆矩阵的Python实现(含代码)的详细介绍。欢迎大家对深度学习自学记录(3)——两种多分类混淆矩阵的Python实现(含代码)内容提出宝贵意见

python多分类混淆矩阵代码_深度学习自学记录(3)——两种多分类混淆矩阵的Python实现(含代码)...相关推荐

  1. bp 神经网络 优点 不足_深度学习之BP神经网络--Stata和R同步实现(附Stata数据和代码)

    说明:本文原发于"计量经济圈"公众号,在此仅展示Stata的部分.R部分请移步至本人主页的"R语言与机器学习--经济学视角"专栏,或点击下方链接卡跳转. 盲区行 ...

  2. python 归一化feed-dict程序代码_深度学习-中国大学mooc-题库零氪

    第一讲 人工智能导论 2.29日考勤 1.以下关于深度学习和机器学习的关系,描述正确的是: A.机器学习的范畴包含深度学习: B.深度学习的范畴包含机器学习: C.二者没有关系: D.二者等同. 第三 ...

  3. python实现目标检测源代码包_深度学习目标检测系列:faster RCNN实现|附python源码...

    摘要: 本文在讲述RCNN系列算法基本原理基础上,使用keras实现faster RCNN算法,在细胞检测任务上表现优异,可动手操作一下. 目标检测一直是计算机视觉中比较热门的研究领域,有一些常用且成 ...

  4. 深度置信网络预测算法matlab代码_深度学习双色球彩票中的应用研究资料

    点击蓝字关注我们 AI研习图书馆,发现不一样的世界 深度学习在双色球彩票中的应用研究 前言 人工神经网络在双色球彩票中的应用研究网上已经有比较多的研究论文和资料,之前比较火的AlphaGo中用到的深度 ...

  5. pytorch 矩阵相乘_深度学习 — — PyTorch入门(三)

    点击关注我哦 autograd和动态计算图可以说是pytorch中非常核心的部分,我们在之前的文章中提到:autograd其实就是反向求偏导的过程,而在求偏导的过程中,链式求导法则和雅克比矩阵是其实现 ...

  6. matlab数据归一化代码_深度学习amp;Matlab-LeNet实现图像分类

    明天准备用卷积神经网络处理分类问题,数据集大概有几万张图片,打算取其中的两类做一个简单分类.在这里先回顾一下以前在Matlab上利用LeNet对Mnist数据集做分类的代码. 本来是准备用Pytorc ...

  7. python在无人驾驶中的应用_深度学习在无人驾驶汽车中的应用

    人工智能及识别技术 本栏目责任编辑 : 唐一东 Computer Knowledge and Technology 电脑知识 与技术 第 11 卷第 24 期 (2015 年 8 月 ) 深度学习在无 ...

  8. 基于深度学习的人脸检测和关键点检测推理实践(OpenCV实现,含代码)

    目录 一.任务概述 二.环境准备 三.实现步骤 3.1 Python推理 3.2 C++推理 3.2.1 环境准备 3.2.2 推理 3.3 Java推理 一.任务概述 最近项目中大量场景需要用到人脸 ...

  9. adam算法效果差原因_深度学习优化器-Adam两宗罪

    在上篇文章中,我们用一个框架来回顾了主流的深度学习优化算法.可以看到,一代又一代的研究者们为了我们能炼(xun)好(hao)金(mo)丹(xing)可谓是煞费苦心.从理论上看,一代更比一代完善,Ada ...

最新文章

  1. 600页!分享珍藏很久的《推荐系统学习手册》(附链接)
  2. 关于java中的各种流
  3. Java多线程,实现卖电影票的业务
  4. 《深度学习的数学》二刷总结
  5. 深度 | 宽客玩转华尔街:谁才是新的“华尔街之王”?
  6. 精彩回顾|2021 中国 .NET 开发者峰会
  7. mock java_JAVA的mock工具mockito简介
  8. Android Service、IntentService,Service和组件间通信
  9. 百度地图加载空白颜色_详细解析百度收录和百度排名关系
  10. Hyperledger Fabric chaincode 开发(疑难解答)
  11. 《R语言预测实战》PDF,数据及代码
  12. 可以放游戏网站云服务器,游戏网站选择哪个云服务器好?游戏服务器配置方案?...
  13. 笔记本计算机的连接无线网络连接,笔记本电脑连接wifi的方法步骤
  14. 小米笔记本开机提示:no bootable device -- insert boot disk and press any key
  15. 他人的建议和意见对自已做决定的影响
  16. 一般业务系统的数据字典表结构
  17. PHP错题本功能实现,今天教你如何制作错题本!
  18. 智能家居为我们主要提供什么服务,主要实现了什么功能
  19. 国际项目投标那些事(三)海外项目招标文件的组成
  20. 微信、支付宝二码合一扫码支付实现思路

热门文章

  1. Kafka消息格式中的变长字段(Varints)
  2. 曹大带我学 Go(10)—— 如何给 Go 提性能优化的 pr
  3. Linux线程(五)
  4. C++中类的6个默认成员函数
  5. LiveVideoStack线上分享第四季(二):基于内容的自适应视频传输算法及其应用...
  6. LiveVideoStack线上交流分享 ( 七) —— 舞台现场直播技术实践
  7. 使用级联SFU改善媒体质量和规模
  8. 展望2018:人工智能为媒体服务赋能
  9. 《Go语言圣经》学习笔记 第八章 Groroutines和Channels
  10. TCP/IP详解--TIME_WAIT状态存在的原因