凌云时刻 · 技术

导读:机器学习算法中有一个重要环节就是评判算法的好坏,我们在之间的笔记中讲过多种评价回归算法的评测标准,比如均方误差(MSE)、均方根误差(RMSE)、平均绝对误差(MAE)、   (R Squared)。但是在分类问题中我们一直使用分类准确度这一个指标,也就是预测对分类的样本数量除以总预测样本数量。但是这个方法存在很大的一个缺陷,所以这篇笔记主要介绍评价分类问题的方式方法。

作者 | 计缘

来源 | 凌云时刻(微信号:linuxpk)

精准率和召回率之间的平衡

在上一篇笔记中,我们了解了逻辑回归的决策边界,比如在二分类问题中,决策边界公式为:

当  大于0时,我们认为分类是1,当小于0时,我们认为分类为0。

如上图所示,黑色直线表‍‍示  ,橘黄色直线所在位置表示区分类别为1还是0的分界点,既大于0是蓝色点类型,小于0是红色点类型。那如果我们让‍‍  不‍‍等于0,而等于一个‍‍阀值threshold呢‍‍?

那上面的图就会是下面这样:‍‍

从上面的图看‍‍,threshold是大‍‍于0的,这样就相当于调整了区别分类的分界点位置。那么会影响到什么呢?

从上图可以看到‍‍,当threshold为0时‍‍,示例中的精准率是0.86,召回率是0.75。

当调‍‍整threshold大于‍‍0后,示例中的精准率是1,召回率是0.38。

当调‍‍整threshold小‍‍于0后,示例中的精准率是0.7,召回率是0.88。

从这三种情况可以看出,精‍‍准率和召回率是互相牵制的,精准率高了,召回率就低。召回率高,精准率就低。所以threshold就又是一个超参数,用来调节使精准率和召回率达到平衡。‍‍

 通过程序验证精准率和召回率的平衡关系

我们还是使用手写数据的样本数据来验证:

import numpy as np
from sklearn import datasets
# 使用手写数据作为样本数据
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()# 将多分类问题转换为二分类问题,同时让样本数据产生极度偏斜,
# 也就是我们关注的数据占总数据的1/9
y[digits.target == 9] = 1
y[digits.target != 9] = 0from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
log_reg.fit(X_train, y_train)
y_predict = log_reg.predict(X_test)from sklearn.metrics import f1_score
f1_score(y_test, y_predict)
# 结果
0.86746987951807231from sklearn.metrics import confusion_matrix
confusion_matrix(y_test, y_predict)
# 结果
array([[403,   2],[  9,  36]])from sklearn.metrics import precision_score
precision_score(y_test, y_predict)
# 结果
0.94736842105263153from sklearn.metrics import recall_score
recall_score(y_test, y_predict)
# 结果
0.80000000000000004

我们如何设‍‍置threshold呢,其实Scikit Learn中的逻辑回归提供了一个获取评判分数的函数,也就是上图中黑色直线的Score值:

decision_score = log_reg.decision_function(X_test)
decision_score.shape
# 结果
(450,)

Scikit Learn中的confusion_matrixprecision_scorerecall_score函数都是基于threshold‍‍为0计算的,也就是判断decision_score中的所有值,如果大于0就分类为1,如果小于0就分类为0。那我们现在‍‍将threshold调‍‍大一点,比如将5作为区分1和0的分界点,那么我们的预测值就可以这样求:

y_predict2 = np.array(decision_score >= 5, dtype='int')

然后我们再来看看精准率和召回率:

precision_score(y_test, y_predict2)
# 结果
0.95999999999999996recall_score(y_test, y_predict2)
# 结果
0.53333333333333333

‍‍再将threshold调小看看‍‍:

y_predict3 = np.array(decision_score >= -5, dtype='int')
precision_score(y_test, y_predict3)
# 结果
0.72727272727272729recall_score(y_test, y_predict3)
# 结果
0.88888888888888884

通过代码我们可以很明显的看到调‍‍节threshold后,精准率和召回率的变化。

PR曲线

‍‍通过上一小节我们知道精准率和召回率是相互牵制的,我也认识了一个新的超参数threshold,通过它能调节精准率和召回率。那么我们如何找到一个平衡点,使得精准率和召回率都在一个比较好的水平,换句话说也就是如何找到好的超参‍‍数threshold。‍‍‍‍

这一小节就介绍一个工具,帮助我们更好的找到这个超参数,这就是PR曲线(Precision-Recall曲线)。我们直接来看看Scikit Learn中提供的函数:

from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_test, decision_score)import matplotlib.pyplot as plt
plt.plot(thresholds, precisions[:-1])
plt.plot(thresholds, recalls[:-1])
plt.show()

上图的横‍‍轴是threshold值,蓝色曲线是精准率,黄色曲线是召回率,他们相交点的‍‍threshold值,‍‍就是PR达到平衡的点。‍‍

plt.plot(precisions, recalls)
plt.show()

上图中,横轴是精准率,纵轴是召回率。这个图反应了PR的总体趋势。通过这个PR曲线我们除了可以判断选择最优的threshold值,还可以判断不同模型的好坏程度。‍‍

比如上图中的模型A和模型B可以是通过不同的算法训练的出的模型,也可以是同一个算法,通过不同超参数组合训练出的模型。显然模型B要比模型A好,因为模型B无论是精准率还是召回率都要比模型A的高。

ROC曲线

这一小节我们来看一个新的指标,ROC曲线,既接收者操作特征曲线,是Receiver Operation Characteristic Curve缩写,最早出现在信号检测理论中,后来被广泛应用在不同领域。在机器学习中,ROC用来描述分类模型的TPR和FPR之间的关心,从而确定分类模型的好坏。

 FPR和TPR

FPR和TPR同样是基于混淆矩阵而来的,FPR的公式为:

TPR的公式为:

可以看到TPR其实就是Recall指标,而FPR是和TPR相反的指标。下面我们使用Scikit Learn中封装的方法来看看手写数据的TPR、FPR和ROC曲线:

from sklearn.metrics import roc_curve
fprs, tprs, thresholds = roc_curve(y_test, decision_score)plt.plot(fprs, tprs)
plt.show()

从ROC曲线图可以看出,随着FPR的增大,TPR也是随之增大的。我们通过观察这根曲线下的面积大小来判断分类模型的好坏程度,面积越大,说明分类模型越好。Scikit Learn中也提供了计算这个面积的函数:

from sklearn.metrics import roc_auc_score
roc_auc_score(y_test, decision_score)# 结果
0.98304526748971188

ROC曲线和PR曲线有一个不同之处是,ROC曲线对极度有偏的数据是不敏感的。所以如果样本数据有极度有偏的情况时,通常还是主要使用PR曲线来判断模型的好坏,ROC曲线辅助判断。

多分类问题中的混淆矩阵

我们之前讲的混淆矩阵和精准、召回率都是在二分类问题的前提下。这篇笔记的最后来看看多分类问题中的混淆矩阵。我们同样使用手写数字数据,但这次不再对数据做极度有偏处理了:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasetsdigits = datasets.load_digits()
X = digits.data
y = digits.targetfrom sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.8)from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
log_reg.fit(X_train, y_train)
log_reg.score(X_test, y_test)# 结果
0.93949930458970787

Scikit Learn 的precision_score方法有一个average参数,默认值为binary,既默认计算二分类问题。如果要计算多分类问题,需要将average参数设置为micro

from sklearn.metrics import precision_scoreprecision_score(y_test, y_predict, average='micro')# 结果
0.93949930458970787

下面来看看这个手写数字十分类问题的混淆矩阵:

from sklearn.metrics import confusion_matrix
confusion_matrix(y_test, y_predict)# 结果
array([[141,   0,   0,   0,   0,   0,   0,   0,   1,   0],[  0, 132,   0,   0,   0,   0,   2,   0,   4,   2],[  0,   2, 141,   0,   0,   0,   0,   0,   0,   0],[  0,   0,   1, 131,   0,   5,   1,   0,  10,   0],[  0,   1,   0,   0, 136,   0,   0,   1,   1,   4],[  0,   0,   1,   0,   0, 141,   0,   1,   0,   0],[  0,   2,   0,   0,   1,   0, 146,   0,   1,   0],[  0,   1,   0,   0,   0,   0,   0, 137,   2,   2],[  0,   9,   3,   0,   0,   4,   4,   1, 120,   3],[  0,   1,   0,   6,   0,   1,   0,   0,   9, 126]])

看多分类问题的混淆矩阵和二分类问题的混淆矩阵方法一样,同样行表示真值,列表示预测值。从上面的结果可看到,混淆矩阵的对角线数值最大,这个对角线就是真值和预测值相同的TP值。我们将这个多分类混淆矩阵通过Matplotlib的matshow方法绘制出来,直观的看一下:

cfm = confusion_matrix(y_test, y_predict)
# cmap是colormap,既将绘制的矩阵的每个点映射成什么颜色,这里映射成灰度值
plt.matshow(cfm, cmap=plt.cm.gray)
plt.show()

上面这个图可以很清晰的看到TP值,但是我们希望能从图上直观的分析问题,既这个模型预测错误的数据。下面我们将混淆矩阵做一下转换,求出错误矩阵,既FP值矩阵:

# 首先求出一个向量,这个向量的每个元素表示每个手写数字有多少个样本,也就是将混淆矩阵在列方向,将每行的数加起来。
row_sums = np.sum(cfm, axis=1)
# 然后让混淆矩阵中的每个元素和它所在那一行的求和相除,既得到了每个数字的预测召回率
err_matrix = cfm / row_sums
# 通过Numpy的fill_diagonal方法,将错误矩阵的对角线的值都替换成0,因为我们主要看FP,所以要消除掉最高的精准率
np.fill_diagonal(err_matrix, 0)
# 最后同样用灰度值将错误矩阵绘制出来
plt.matshow(err_matrix, cmap=plt.cm.gray)
plt.show()

上图中,颜色约亮的格子表示预测错误的数量越多,比如左上角那个白色的格子就表示真值为3,但是有不少样本数据被预测成了8。左下角的白色格子表示真值为8,但是有不少样本数据被预测成了1。所以从这个错误矩阵上可以很好的分析出具体的预测错误点,从而根据这些信息调整分类模型或者样本数据。

END

往期精彩文章回顾

机器学习笔记(二十三):算法精准率、召回率

机器学习笔记(二十二):逻辑回归中使用模型正则化

机器学习笔记(二十一):决策边界

机器学习笔记(二十):逻辑回归(2)

机器学习笔记(十九):逻辑回归

机器学习笔记(十八):模型正则化

机器学习笔记(十七):交叉验证

机器学习笔记(十六):多项式回归、拟合程度、模型泛化

机器学习笔记(十五):人脸识别

机器学习笔记(十四):主成分分析法(PCA)(2)

长按扫描二维码关注凌云时刻

每日收获前沿技术与科技洞见

机器学习笔记(二十四):召回率、混淆矩阵相关推荐

  1. 机器学习笔记二十四 中文分词资料整理

    一.常见的中文分词方案 1. 基于字符串匹配(词典) 基于规则的常见的就是最大正/反向匹配,以及双向匹配. 规则里糅合一定的统计规则,会采用动态规划计算最大的概率路径的分词. 以上说起来很简单,其中还 ...

  2. 嵌入式Linux驱动笔记(二十四)------framebuffer之使用spi-tft屏幕(上)

    你好!这里是风筝的博客, 欢迎和我一起交流. 最近入手了一块spi接口的tft彩屏,想着在我的h3板子上使用framebuffer驱动起来. 我们知道: Linux抽象出FrameBuffer这个设备 ...

  3. 【机器学习】二分类问题中的混淆矩阵、准确率、召回率等 (Python代码实现)

    文章目录 混淆矩阵 召回率与准确率 准确度Accuracy sklearn代码示例 混淆矩阵 混淆矩阵(Confusion Matrix):将分类问题按照真实情况与判别情况两个维度进行归类的一个矩阵, ...

  4. 机器学习笔记(十四):异常检测

    目录 1)Problem motivation 2)Gaussian distribution 3)Algorithm 4)Developing and evaluating an anomaly d ...

  5. 【Visual C++】游戏开发笔记二十四 由DirectX的几个版本说开去

    分享一下我老师大神的人工智能教程!零基础,通俗易懂!http://blog.csdn.net/jiangjunshow 也欢迎大家转载本篇文章.分享知识,造福人民,实现我们中华民族伟大复兴! 本系列文 ...

  6. 机器学习笔记(十四)——HMM估计问题和前向后向算法

    一.隐马尔科夫链的第一个基本问题 估计问题:给定一个观察序列O=O1O2-OTO=O_1O_2\dots O_T和模型u=(A,B,π)u = (\boldsymbol{A,B,\pi}),如何快速地 ...

  7. Java8的其它 新特性(笔记二十四)

    标题 Java 8新特性简介 一.Lambda表达式 1.为什么使用Lambda表达式 2.使用举例 二.函数式(Functional)接口 1.什么是函数式(Functional)接口 2.如何理解 ...

  8. 机器学习(二十四)——数据不平衡问题, 强化学习

    https://antkillerfarm.github.io/ 数据不平衡问题 https://mp.weixin.qq.com/s/e0jXXCIhbaZz7xaCZl-YmA 如何处理不均衡数据 ...

  9. 机器学习知识点(二十四)隐马尔可夫模型HMM维特比Viterbi算法Java实现

    1.隐马尔可夫模型HMM    学习算法,看中文不如看英文,中文喜欢描述的很高深.    http://www.comp.leeds.ac.uk/roger/HiddenMarkovModels/ht ...

  10. [傅里叶变换及其应用学习笔记] 二十四. 级联,脉冲响应

    我们上节课学习了 在离散有限维空间中,任何线性系统都是通过矩阵间的相乘得到的 在连续无限维空间中,任何线性系统都是通过对核函数的积分得到的 脉冲响应(impulse response) 级联线性系统( ...

最新文章

  1. 图像处理之基础---极坐标系及其与直角坐标系的关系
  2. react native 组件之switch组件的用法
  3. python给函数设置超时时间_在 Linux/Mac 下为Python函数添加超时时间的方法
  4. ZjDroid工具介绍及脱壳详细示例
  5. OpenCV检测面部特征点的实例(附完整代码)
  6. java beans_java beans的概念及应用?
  7. 航行金税盘_通过陌生事物的情感进行统计好奇心航行
  8. 个人练习-jq 鼠标移上移出查看图片(放大)提示
  9. LabVIEW入门教程
  10. 移动硬盘提示格式化的处理
  11. 如何在桌面电脑上使用 SAS 硬盘
  12. 联系人管理系统 Python GUI版
  13. Python中Print()函数的用法___实例详解(全,例多)
  14. 报错解决:Lammps中lmp_mpi编译出错
  15. h5 移动开发 html页面跳转,iosh5混合开发项目仿app页面跳转优化
  16. C++中的同名二义性和路径二义性
  17. 六大接口管理平台,总有一款适合你的!
  18. 并行测试和变异测试三篇文献总结(二)
  19. 国产男装「升级潮」下,九牧王、劲霸、海澜之家们顺利「上分」了吗?
  20. 2021-01-12 DataGrip2020.3 离线安装驱动

热门文章

  1. 【库】JavaScript——滚动条( 不是很完善 )
  2. 收到“【有奖话题】虚拟空间“筑梦师”,谈谈微软虚拟化 ”礼物一个
  3. 【缅怀妈妈系列诗歌】之一:去医院的路,好长
  4. poj3264 Balanced Lineup(树状数组)
  5. java android 读写西门子PLC数据,包含S7协议和Fetch/Write协议,s7支持200smart,300PLC,1200PLC,1500PLC...
  6. Python学习之not,and,or篇
  7. android loginDemo +WebService用户登录验证
  8. [Python]将Excel文件中的数据导入MySQL
  9. 处理数据类型转换,数制转换、编码转换相关的类
  10. 经典200例-003 为项目添加已有类