一、前述

怎么样对训练出来的模型进行评估是有一定指标的,本文就相关指标做一个总结。

二、具体

1、混淆矩阵

混淆矩阵如图:

 第一个参数true,false是指预测的正确性。

 第二个参数true,postitives是指预测的结果。

 相关公式:

检测正列的效果:

检测负列的效果:

公式解释:

fp_rate:

tp_rate:

recall:(召回率)

值越大越好

presssion:(准确率)

TP:本来是正例,通过模型预测出来是正列

TP+FP:通过模型预测出来的所有正列数(其中包括本来是负例,但预测出来是正列)

值越大越好

F1_Score:

准确率和召回率是负相关的。如图所示:

通俗解释:

实际上非常简单,精确率是针对我们预测结果而言的,它表示的是预测为正的样本中有多少是真正的正样本。那么预测为正就有两种可能了,一种就是把正类预测为正类(TP),另一种就是把负类预测为正类(FP),也就是

召回率是针对我们原来的样本而言的,它表示的是样本中的正例有多少被预测正确了。那也有两种可能,一种是把原来的正类预测成正类(TP),另一种就是把原来的正类预测为负类(FN)。

其实就是分母不同,一个分母是预测为正的样本数,另一个是原来样本中所有的正样本数。

2、ROC曲线

过程:对第一个样例,预测对,阈值是0.9,所以曲线向上走,以此类推。

对第三个样例,预测错,阈值是0.7 ,所以曲线向右走,以此类推。

几种情况:

所以得出结论,曲线在对角线以上,则准确率好。

3、AUC面积

M是样本中正例数

N是样本中负例数

其中累加解释是把预测出来的所有概率结果按照分值升序排序,然后取正例所对应的索引号进行累加

通过AUC面积预测出来的可以知道好到底有多好,坏到底有多坏。因为正例的索引比较大,则AUC面积越大。

总结:

 4、交叉验证

为在实际的训练中,训练的结果对于训练集的拟合程度通常还是挺好的(初试条件敏感),但是对于训练集之外的数据的拟合程度通常就不那么令人满意了。因此我们通常并不会把所有的数据集都拿来训练,而是分出一部分来(这一部分不参加训练)对训练集生成的参数进行测试,相对客观的判断这些参数对训练集之外的数据的符合程度。这种思想就称为交叉验证。

 一般3折或者5折交叉验证就足够了。

 三、代码

#!/usr/bin/python
# -*- coding: UTF-8 -*-
# 文件名: mnist_k_cross_validate.pyfrom sklearn.datasets import fetch_mldata
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
from sklearn.model_selection import cross_val_score
from sklearn.base import BaseEstimator #评估指标
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import RandomForestClassifier# Alternative method to load MNIST, if mldata.org is down
from scipy.io import loadmat #利用Matlib加载本地数据
mnist_raw = loadmat("mnist-original.mat")
mnist = {"data": mnist_raw["data"].T,"target": mnist_raw["label"][0],"COL_NAMES": ["label", "data"],"DESCR": "mldata.org dataset: mnist_k_cross_validate-original",
}
print("Success!")
# mnist_k_cross_validate = fetch_mldata('MNIST_original', data_home='test_data_home')
print(mnist)X, y = mnist['data'], mnist['target'] # X 是70000行 784个特征 y是70000行 784个像素点
print(X.shape, y.shape)
#
some_digit = X[36000]
print(some_digit)
some_digit_image = some_digit.reshape(28, 28)#调整矩阵 28*28=784 784个像素点调整成28*28的矩阵 图片是一个28*28像素的图片 每一个像素点是一个rgb的值
print(some_digit_image)
#
plt.imshow(some_digit_image, cmap=matplotlib.cm.binary,interpolation='nearest')
plt.axis('off')
plt.show()
#
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[:60000]#6/7作为训练,1/7作为测试
shuffle_index = np.random.permutation(60000)#返回一组随机的数据 shuffle 打乱60000中每行的值 即每个编号的值不是原先的对应的值
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index] # Shuffle之后的取值
# #
y_train_5 = (y_train == 5)# 是5就标记为True,不是5就标记为false
y_test_5 = (y_test == 5)
print(y_test_5)
#这里可以直接写成LogGression
sgd_clf = SGDClassifier(loss='log', random_state=42)# log 代表逻辑回归 random_state或者random_seed 随机种子 写死以后生成的随机数就是一样的
sgd_clf.fit(X_train, y_train_5)#构建模型
print(sgd_clf.predict([some_digit]))# 测试模型 最终为5
# #
### K折交叉验证
##总共会运行3次
skfolds = StratifiedKFold(n_splits=3, random_state=42)# 交叉验证 3折 跑三次 在训练集中的开始1/3 中测试,中间1/3 ,最后1/3做验证
for train_index, test_index in skfolds.split(X_train, y_train_5):#可以把sgd_clf = SGDClassifier(loss='log', random_state=42)这一行放入进来,传不同的超参数 这里就不用克隆了clone_clf = clone(sgd_clf)# clone一个上一个一样的模型 让它不变了 每次初始随机参数w0,w1,w2都一样,所以设定随机种子是一样X_train_folds = X_train[train_index]#对应的是训练集中训练的X 没有阴影的y_train_folds = y_train_5[train_index]# 对应的是训练集中的训练y 没有阴影的X_test_folds = X_train[test_index]#对应的是训练集中的测试的X 阴影部分的y_test_folds = y_train_5[test_index]#对应的是训练集中的测试的Y 阴影部分的clone_clf.fit(X_train_folds, y_train_folds)#构建模型y_pred = clone_clf.predict(X_test_folds)#验证print(y_pred)n_correct = sum(y_pred == y_test_folds)# 如若预测对了加和 因为true=1 false=0print(n_correct / len(y_pred))#得到预测对的精度 #用判断正确的数/总共预测的 得到一个精度
# #PS:这里可以把上面的模型生成直接放在交叉验证里面传一些超参数比如阿尔法,看最后的准确率则知道什么超参数最好。#这是Sk_learn里面的实现的函数cv是几折,score评估什么指标这里是准确率,结果类似上面一大推代码
print(cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring='accuracy')) #这是Sk_learn里面的实现的函数cv是几折,score评估什么指标这里是准确率class Never5Classifier(BaseEstimator):#给定一个分类器,永远不会分成5这个类别 因为正负列样本不均匀,所以得出的结果是90%,所以只拿精度是不准确的。def fit(self, X, y=None):passdef predict(self, X):return np.zeros((len(X), 1), dtype=bool)never_5_clf = Never5Classifier()
print(cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring='accuracy'))#给每一个结果一个结果
# #
# #
##混淆矩阵 可以准确地知道哪一个类别判断的不准
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)#给每一个结果预测一个概率
print(confusion_matrix(y_train_5, y_train_pred))
# #
y_train_perfect_prediction = y_train_5
print(confusion_matrix(y_train_5, y_train_5))
#准确率,召回率,F1Score
print(precision_score(y_train_5, y_train_pred))
print(recall_score(y_train_5, y_train_pred))
print(sum(y_train_pred))
print(f1_score(y_train_5, y_train_pred))sgd_clf.fit(X_train, y_train_5)
y_scores = sgd_clf.decision_function([some_digit])
print(y_scores)threshold = 0 # Z的大小 wT*x的结果
y_some_digit_pred = (y_scores > threshold)
print(y_some_digit_pred)threshold = 200000
y_some_digit_pred = (y_scores > threshold)
print(y_some_digit_pred)y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method='decision_function')
print(y_scores)#直接得出Scoreprecisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
print(precisions, recalls, thresholds)def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):plt.plot(thresholds, precisions[:-1], 'b--', label='Precision')plt.plot(thresholds, recalls[:-1], 'r--', label='Recall')plt.xlabel("Threshold")plt.legend(loc='upper left')plt.ylim([0, 1])# plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
# plt.savefig('./temp_precision_recall')y_train_pred_90 = (y_scores > 70000)
print(precision_score(y_train_5, y_train_pred_90))
print(recall_score(y_train_5, y_train_pred_90))fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)def plot_roc_curve(fpr, tpr, label=None):plt.plot(fpr, tpr, linewidth=2, label=label)plt.plot([0, 1], [0, 1], 'k--')plt.axis([0, 1, 0, 1])plt.xlabel('False Positive Rate')plt.ylabel('True positive Rate')plot_roc_curve(fpr, tpr)
plt.show()
# plt.savefig('img_roc_sgd')print(roc_auc_score(y_train_5, y_scores))forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method='predict_proba')
y_scores_forest = y_probas_forest[:, 1]fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
plt.plot(fpr, tpr, 'b:', label='SGD')
plt.plot(fpr_forest, tpr_forest, label='Random Forest')
plt.legend(loc='lower right')
plt.show()
# plt.savefig('./img_roc_forest')print(roc_auc_score(y_train_5, y_scores_forest))#
#

acc 看中整体

auc看中正例

【机器学习】--模型评估指标之混淆矩阵,ROC曲线和AUC面积相关推荐

  1. 分类模型的评估指标(2)---ROC曲线与AUC简介

    首先,我们需要了解一下,什么是ROC曲线? ROC曲线,即受试者工作特征曲线(Receiver Operating Characteristic curve,简称ROC曲线,是根据一系列不同的二分类方 ...

  2. 决策树分类评估指标之混淆矩阵

    问题的提出 如果决策树的目标是尽量捕获少数类,则准确率模型评估的意义不大,需要新的模型评估指标.简单来看,只需要查看模型在少数类上的准确率就好,只要能够将少数类尽量捕捉出来,就能够达到目的. 但是,新 ...

  3. 机器学习模型评估指标总结!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:太子长琴,Datawhale优秀学习者 本文对机器学习模型评估指标 ...

  4. 【机器学习基础】非常详细!机器学习模型评估指标总结!

    作者:太子长琴,Datawhale优秀学习者 本文对机器学习模型评估指标进行了完整总结.机器学习的数据集一般被划分为训练集和测试集,训练集用于训练模型,测试集则用于评估模型.针对不同的机器学习问题(分 ...

  5. 交叉验证分析每一折(fold of Kfold)验证数据的评估指标并绘制综合ROC曲线

    交叉验证分析每一折(fold of Kfold)验证数据的评估指标并绘制综合ROC曲线 Receiver Operating Characteristic (ROC) with cross valid ...

  6. ROC曲线,AUC面积

    AUC(Area under Curve):Roc曲线下的面积,介于0.1和1之间.Auc作为数值可以直观的评价分类器的好坏,值越大越好. 首先AUC值是一个概率值,当你随机挑选一个正样本以及负样本, ...

  7. 机器学习 模型评估指标 - ROC曲线和AUC值

    机器学习算法-随机森林初探(1) 随机森林拖了这么久,终于到实战了.先分享很多套用于机器学习的多种癌症表达数据集 https://file.biolab.si/biolab/supp/bi-cance ...

  8. 机器学习模型评估指标ROC、AUC详解

    我是小z ROC/AUC作为机器学习的评估指标非常重要,也是面试中经常出现的问题(80%都会问到).其实,理解它并不是非常难,但是好多朋友都遇到了一个相同的问题,那就是:每次看书的时候都很明白,但回过 ...

  9. 评估指标:混淆矩阵、PR、mAP、ROC、AUC

    文章目录 TP.TN.FP.FN 准确率 Accuracy 和 错误率 Error rate 混淆矩阵 confusion matrix 查准率 Precision 和 召回率 Recall PR 曲 ...

最新文章

  1. 昨天又写到个结构体排序,用多种cmp
  2. Android联系人Contacts详解
  3. weblogic启动服务报错
  4. 刘强东卸任!“二号位”徐雷:从摇滚青年到掌舵京东
  5. 在python中、如果异常并未被处理或捕捉_7、Python-异常
  6. ubuntu 配置 tftp 服务器
  7. CodeForces 650A Watchmen
  8. 四种Java跨域配置
  9. python中for循环的用法a+aa+aaa-Python练习题 013:求解a+aa+aaa……
  10. MS SQL2000 数据库置疑解决方法
  11. 纯干货,PSI 原理解析与应用
  12. 基本图像变换:线性变换,仿射变换,投影变换
  13. scratch课程设计
  14. Oracle中的commit与rollback
  15. C语言例题:用星号输出棱形图案。
  16. D2FQ(2021 FAST)
  17. Lenovo笔记本电脑触摸板无反应-解决方法
  18. 基于回调的观察者模式
  19. 小程序上传视频的php接口处理,微信小程序[第十二篇] -- 上传视频
  20. 泛微OA调用SAP接口

热门文章

  1. python的web抓取_python实现从web抓取文档的方法
  2. android 之Fragment(轻量级的Activity)详解
  3. 查看oracle系统信息,查看 ORACLE 系统级信息
  4. win10远程计算机管理,Win10系统下实现批量远程桌面管理的具体方法
  5. virtual box挂载 共享文件夹
  6. This tutorial code needs the xfeatures2d contrib module to be run.
  7. ssh 到另一台机器执行命令
  8. pytorch 入门(二) cnn 手写数字识别
  9. echarts实时更新数据_虎牙为S10拼了8.0年度更新!随时回放实时数据,还能养柴犬...
  10. garch预测 python_【2019年度合辑】手把手教你用Python做股票量化分析