二元分类器在两个类中区分,而多类分类器(也称为多项分类器)可以区分两个以上的类。有一些算法(如随机森林分类器或朴素贝叶斯分类器)可以直接处理多个类。也有一些严格的二元分类器(如支持向量机分类器或线性分类器)。但是,有多种策略可以让你用几个二元分类器实现多类分类的目的。

我们这里使用的是mnist数据集。

OVR
要创建一个系统将数字图片分为10类(从0到9),一种方法是训练10个二元分类器,每个数字一个(0检测器、1检测器、2检测器,以此类推)。然后,当你需要对一张图片进行检测分类时,获取每个分类器的决策分数,哪个分类器给分最高,就将其分为哪个类。这称为一对剩余(OvR)策略,也称为一对多(oneversusall)。

OVO
另一种方法是为每一对数字训练一个二元分类器:一个用于区分0和1,一个区分0和2,一个区分1和2,以此类推。这称为一对一(OvO)策略。如果存在N个类别,那么这需要训练N×(N1)/2个分类器。对于MNIST问题,这意味着要训练45个二元分类器!当需要对一张图片进行分类时,你需要运行45个分类器来对图片进行分类,最后看哪个类获胜最多。OvO的主要优点在于,每个分类器只需要用到部分训练集对其必须区分的两个类进行训练。

有些算法(例如支持向量机分类器)在数据规模扩大时表现糟糕。对于这类算法,OvO是一个优先的选择,因为在较小训练集上分别训练多个分类器比在大型数据集上训练少数分类器要快得多。但是对大多数二元分类器来说,OvR策略还是更好的选择。ScikitLearn可以检测到你尝试使用二元分类算法进行多类分类任务,它会根据情况自动运行OvR或者OvO。我们用sklearn.svm.SVC类来试试SVM分类器:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

二分类模型

我们先准备好数据:

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
X,y = mnist['data'], mnist['target']X_train, X_test = X[:6000], X[6000:]
y_train, y_test = y[:6000].astype(np.uint8), y[6000:].astype(np.uint8)

数据准备好了,我们开始训练模型:

from sklearn.svm import SVC
svm_clf = SVC()
svm_clf.fit(X_train, y_train)
SVC()

OK,模型训练好了,我们预测几个数据:

svm_clf.predict([X[0], X[1]])
array([5, 0], dtype=uint8)
print(y[0], y[1])
5 0

完美,预测正确。

这段代码使用原始目标类0到9(y_train)在训练集上对SVC进行训练,而不是以“5”和“剩余”作为目标类(y_train_5),然后做出预测(在本例中预测正确)。而在内部,ScikitLearn实际上训练了45个二元分类器,获得它们对图片的决策分数,然后选择了分数最高的类。

要想知道是不是这样,可以调用decision_function()方法。它会返回10个分数,每个类1个,而不再是每个实例返回1个分数

some_scores = svm_clf.decision_function([X[1]])
print(some_scores)
[[ 9.31124939  0.70612005  6.21459611  4.98107511 -0.29750684  8.277479743.82932284  1.74975607  3.81848706  6.05566815]]

我们可以看出,第0个结果的分数最高,所以分类结果为0。

如果想要强制ScikitLearn使用一对一或者一对剩余策略,可以使用OneVsOneClassifier或OneVsRestClassifier类。只需要创建一个实例,然后将分类器传给其构造函数(它甚至不必是二元分类器)。

from sklearn.multiclass import OneVsRestClassifier
ovr_clf = OneVsRestClassifier(SVC())
ovr_clf.fit(X_train, y_train)
ovr_clf.predict([X[0]])
array([5], dtype=uint8)

多分类模型

上述方式使用的是一个二分类模型来计算多分类问题,我们也可以直接使用SGDClassifier, RandomForestClassifer等可用于多分类的模型:

from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier()
sgd_clf.fit(X_train, y_train)
sgd_clf.predict([X[0]])
array([5], dtype=uint8)

我们看一下分数:

sgd_clf.decision_function([X[0]])
array([[-562212.94461254, -707645.64712029, -293553.98530211,-55603.9170907 , -842804.94696506,   41561.46574389,-642257.63484493, -793582.58056233, -381185.27236082,-359218.48451935]])

模型好像相当自信,除了5以为,其它score都是负数。

我们再看一下准确率:

from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring='accuracy')
array([0.8595, 0.868 , 0.8625])

所有的测试折叠上都超过了85%。如果是一个纯随机分类器,准确率大概是10%,所以这个结果不是太糟,但是依然有提升的空间。例如,将输入进行简单缩放可以将准确率稍为提高:

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaler = scaler.fit_transform(X_train.astype(np.float64))
cross_val_score(sgd_clf, X_train_scaler, y_train, cv=3, scoring='accuracy')
array([0.867, 0.898, 0.889])

多分类问题的误差分析

首先看看混淆矩阵。就像之前做的,使用cross_val_predict()函数进行预测,然后调用confusion_matrix()函数:

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
y_train_pred = cross_val_predict(sgd_clf, X_train_scaler, y_train, cv=3)
confusion_matrix = confusion_matrix(y_train, y_train_pred)
print(confusion_matrix)
[[570   0   3   2   1  11   3   0   2   0][  0 633   8   2   1   6   0   1  18   2][ 10   8 499  10   9   4  11   9  18   3][  3   5  21 512   1  36   1   8  11  10][  4   2   8   0 557   2   9   7  10  24][  5   6   5  25  14 411  10   1  24  13][  4   4   6   0   5   7 578   1   3   0][  3   5   9   3  10   0   1 575   2  43][  3  18  14  17   4  10   8   3 461  13][  7   4   6   8  17   3   0  29   8 519]]

我们可以看到数字主要集中在对角线,说明效果还是比较理想的。

我们使用更直观的图线方式来看一下混淆矩阵:

plt.matshow(confusion_matrix, cmap=plt.cm.gray)
plt.show
<function matplotlib.pyplot.show(close=None, block=None)>

数字5看起来暗一点,这以为着数据集中被判断为5的照片比较少,有可能是5真的少,也有可能我们把5分到了其它类。所以我们看一下错误率:

row_sums = confusion_matrix.sum(axis=1, keepdims=True)
nor_conf_mx = confusion_matrix / row_sums

用0填充对角线,只保留错误:

np.fill_diagonal(nor_conf_mx, 0)
plt.matshow(nor_conf_mx, cmap=plt.cm.gray)
plt.show()

现在可以清晰地看到分类器产生的错误种类了。记住,每行代表实际类,而表示预测类。第5列看起来非常亮,说明有许多图片被错误地分类为数字5了。

分析混淆矩阵通常可以帮助你深入了解如何改进分类器。通过上图来看,你的精力可以花在改进数字5的分类错误上。例如,可以试着收集更多看起来像数字8的训练数据,以便分类器能够学会将它们与真实的数字区分开来。

反正,发现问题后,你可以通过各种方式来解决这个问题,比如检查样本,把样本一个一个打印出来。你可能发现样本中的3和5非常类型。毕竟我们这里使用的是线性分类器,一些写的不规范的3和5很容易被混淆。

sklearn综合示例5:多分类分类器相关推荐

  1. sklearn综合示例9:分类问题的onehot与预测阈值调整

    本文介绍了: 如何将多个标签做onehot,比如说总共有1000个标签,用户带了其中100个标签,那就是一个1000维的feautre,其中100维=1,其余900维=0. 调整分类算法的分类阈值,比 ...

  2. sklearn综合示例7:集成学习与随机森林

    假设你创建了一个包含1000个分类器的集成,每个分类器都只有51%的概率是正确的(几乎不比随机猜测强多少).如果你以大多数投票的类别作为预测结果,可以期待的准确率高达75%.但是,这基于的前提是所有的 ...

  3. sklearn综合示例8:SVM

    1.线性SVM SVM特别适用于中小型复杂数据集的分类. SVM对特征的缩放非常敏感 SVM的基本思想可以用一些图来说明.图51所示的数据集来自第4章末尾引用的鸢尾花数据集的一部分.两个类可以轻松地被 ...

  4. sklearn综合示例3:逻辑回归

    文章目录 API 模型参数 penalty dual tol C fit_intercept class_weight random_state solver max_iter verbose war ...

  5. sklearn综合示例2:决策树

    scikit-learn 是适用于数据处理和机器学习处理非常强大的库.提供数据降维.回归.聚类.分类等功能,是机器学习从业者的必备库之一. 示例一 案例:鸢尾属植物数据集(iris)分类. 鸢尾属植物 ...

  6. python使用sklearn中的make_classification函数生成分类模型(classification)需要的仿真数据、使用pandas查看生成数据的特征数据、目标数据

    python使用sklearn中的make_classification函数生成分类模型(classification)需要的仿真数据.使用pandas查看生成数据的特征数据(features).目标 ...

  7. make--变量与函数的综合示例 自动生成依赖关系

    一.变量与函数的示例 示例的要求 1.自动生成target文件夹存放可执行文件 2.自动生成objs文件夹存放编译生成的目标文件 3.支持调试版本的编译选项 4.考虑代码的扩展性 完成该示例所需的 1 ...

  8. C结构体工具DirectStruct(综合示例二)

    2019独角兽企业重金招聘Python工程师标准>>> C结构体工具DirectStruct(综合示例二) 1.编写定义文件,用工具dsc处理之,自动生成XML转换代码和ESQL代码 ...

  9. spark数据处理示例一:分类

    spark数据处理示例一:分类 @(SPARK)[spark, ML] spark数据处理示例一分类 知识点 1slice 2NaN 3mapValue 4groupBy 5state 6isNaN ...

最新文章

  1. mysql innodb引擎丢失_【MySQL】InnoDB引擎ibdata文件损坏/删除后使用frm和ibd文件恢复数据...
  2. Assembly of long, error-prone reads using repeat graphs 使用重复图组装长且容易出错的读操作
  3. 独家|一文解读合成数据在机器学习技术下的表现
  4. JTable表头也就是标题行给隐藏
  5. 当下网络营销市场中为何企业可通过网络营销提升自我价值?
  6. .Net中url传递中文的解决方案
  7. Centos设置程序开机自启的方法
  8. POJ 1182 食物链,并查集的拓展
  9. Python 学习日记第二篇 -- 列表,元组
  10. AttributeError: ‘str‘ object has no attribute ‘copy
  11. Linux设备树 .dtb文件,内核使用dtb文件的过程
  12. feign传递多个对象_面向对象
  13. MYSQL 字符集问题
  14. 502php,php502是什么问题
  15. 关于十字翻转棋的解法研究
  16. 利用阿里云邮件推送免费发邮件,每天免费200封,速度快,还高大上
  17. MATLAB将.mat矩阵写成.tif图片
  18. Voldemort博客
  19. 《道德经》 老子部分阅读笔记
  20. 如何编辑二维码内容并批量生成

热门文章

  1. 【二分法万能模板】Leecode 74. 搜索二维矩阵——Leecode日常刷题系列
  2. (*长期更新)软考网络工程师学习笔记——数据链路层与网络层的相关计算题
  3. 通俗易懂了解Vue双向绑定原理及实现
  4. python开发测试岗_作为测试开发岗的面试官,我都是怎么选人的?
  5. python3.6字典有序_Python如何按值对字典进行排序?
  6. PHP域名查墙代码,怎么查看域名是否被墙检测(教你一招域名被墙解决办法)
  7. win98 老电脑 文件导出_首次装电脑之前何不先模拟一番,这款练手神器可以帮你...
  8. 研华数据采集卡如何采集压力信号转化为数字信号_涨知识啦!PLC编程中如何使用开关、模拟、脉冲量...
  9. 微x怎么设置主题_红人堂:抖音直播预告文案怎么写?5个小技巧提高你的文案吸引力!...
  10. java传统的项目有哪些内容_请问java全套内容都有什么呢?