机器学习实战之路 —— 5 SVM支持向量机

  • 1. 支持向量机概述
    • 1.1 线性分类
    • 1.2 非线性分类
  • 2. 支持向量机分类中的问题
    • 2.1 核函数的选择
    • 2.2 多类分类
    • 2.3 不平衡数据的处理
    • 2.4 主要算法实现步骤
  • 3. 实战
    • 3.1 多类分类
    • 3.2 不平衡数据的处理
    • 3.3 手写数字识别
      • 3.3.1 数据集描述
      • 3.3.2 识别分类实现
      • 3.3.3 不同核函数下的识别率
  • 4.参考学习的书目及论文

1. 支持向量机概述

统计学习是当欠缺合适的理论模型时,对大量的观测数据采用的分析推理方法。在传统的统计模式下,分类问题的研究往往都是在数据集的数据量非常庞大的前提下进行的,而实际应用过程中,数据集中的数据量都是有限的,尤其在基于高维特征空间的分类问题中,想要得到趋于无穷大数目的数据更加困难。统计学习理论从机理上探讨并研究了小样本数据的期望风险与数据经验风险间的联系,并研究如何应用统计学习理论开发新的机器学习算法等一系列问题,其为模式识别等机器学习问题提供了坚实的理论基础。其中,支持向量机(Support Vector Machines,SVM)就是基于统计学习理论发展起来的一种数据机器学习方法,是一种基于结构风险最小化原则,实现有序风险最小化的方法。
SVM已经在许多实际问题中得到了广泛成功的应用,如图像处理中的图像过滤、视频字幕提取、图像分类和检索,语音识别以及手写体识别;在网络流量的特征选择和提取、流量的识别及分类等领域也得以广泛应用;在汽车领域的应用中也取得了非常好的效果,比如对驾驶员的超车意图和汽车行驶的并线意图识别,车辆辅助驾驶系统以及智能交通等。

1.1 线性分类

支持向量机在解决数据的分类问题时一般包括数据线性可分和数据线性不可分的情况,支持向量机是首先由线性可分发展而来,下面先讨论线性可分的二分类情况。
下图中二分类情况下,两类样本分别是蓝色星形点和红色圆形点样本,在二维平面中,两类样本为线性可分的,H为分类线,H1、H2为平行于H,且经过两类样本离分类线H最近的直线,H1、H2之间的直线距离称为分类间隔(Margin)。SVM的最优分类线需要将两类数据样本正确的分开,且需要使分类间隔最大化。这是因为正确的分离两类样本可以使经验风险最小,而要使泛化能力边界中的置信范围最小,则需要将分类间隔最大化,从而保证实际风险最小。将此思想推广到高维空间,最优分类线就转化为最优分类平面。

2∣∣ω∣∣\frac{2}{||ω||}∣∣ω∣∣2​ 为分类间隔,支持向量机的核心思想是使间隔最大化,保证正确分类的同时,使离决策超平面最近的数据也最大几何间隔的分开,以使超平面具备更好的预测和泛化能力。求∣∣ω∣∣2||ω||^2∣∣ω∣∣2最小即是求分类间隔最大。则最优分类线 H 就是最大化分类间隔,即使12∣∣ω∣∣2\frac{1}{2}||ω||^221​∣∣ω∣∣2 最小的分类线。推广到分类平面,支持向量(Support Vector,SV)就是平行于分类平面,且经过两类样本离分类平面最近的超平面H1、H2上的样本,由它们构成了最优分类平面。因此,可以将最优分类平面问题转化为具有线性约束的凸二次规划寻优问题。如上图的点A、C、D就是支持向量,而B,E不是。

1.2 非线性分类

在实际的应用中的分类样本通常是非线性可分的,当分类样本为线性不可分的情况时,在原始空间中无法构建一个最优决策超平面以将两类样本最大间隔的分开,因此需要引入核函数将分类样本数据从一个空间映射变换到另一个空间,从而使它转化为线性可分,变换转化的基本思想大致如下图所示:首先选择一个非线性映射ϕ(x)\phi(x)ϕ(x),然后将x映射到高维特征空间F,最后在F中构造最优超平面。

在上示例图中,可将原始数据的二维空间映射到一个合适的三维空间 ,就能找到一个合适的分离超平面。如果原始数据空间是有限维,即属性数有限,那么一定存在一个高维特征空间使样本可分。通过这种非线性变换在高维特征空间中实现线性分类,且计算复杂度保持不变。

2. 支持向量机分类中的问题

2.1 核函数的选择

在进行参数优化之前需要选择适当的核函数,核函数在不增加计算复杂度的条件下,能够有效的将特征样本映射到高维空间中进行线性问题的分类。核函数的选择确定了不同的映射和特征空间,同时核函数超参数的改变也隐含的改变了特征子空间分布的复杂度,即维度。复杂度高的超平面可能得不到泛化性能好的分类器,因此选择一个合适的核函数将数据投影到一个合适的特征空间,对于SVM来说有重要的意义,因为泛化能力是SVM算法实用化关键。核函数的形式决定了SVM的类型和复杂度,然而关于核函数的构造和选择没有统一的标准和规则,需要根据具体的分类问题具体分析。 目前SVM 中常用的核函数主要有以下几种:

此外,还可通过函数组合得到,例如:

  • 若K1K_1K1​和K2K_2K2​为核函数,则对于任意正数γ1γ_1γ1​、γ2γ_2γ2​,其线性组合

γ1K1+γ2K2γ1K_1+γ2K_2 γ1K1​+γ2K2​

也是核函数;

  • 若K1K_1K1​和K2K_2K2​为核函数,则而核函数的直积

K1⊙K2(x,z)=K1(x,z)K2(x,z)K_1⊙K_2(x,z)=K_1(x,z)K_2(x,z) K1​⊙K2​(x,z)=K1​(x,z)K2​(x,z)

也是核函数;

  • 若K1K_1K1​为核函数,则对于任意函数g(x)g(x)g(x)

K(x,z)=g(x)K1(x,z)g(z)K(x,z)=g(x)K_1(x,z)g(z) K(x,z)=g(x)K1​(x,z)g(z)

也是核函数。

2.2 多类分类

从设计的原理看,支持向量机只能用于二分类问题,不能将其直接用来解决多分类问题,然而现实中涉及到的分类问题绝大部分是多分类问题的。国内外现有处理多类分类问题的思路主要有两大类:一是直接修改支持向量机的二次规划形式,在所有样本的基础上求解一个大的二次规划问题从一次性解决多类分类问题,称之为直接多分类方法;二是按照某种规则构造一系列二类分类问题,最终多类分类问题转化为多个二类分类问题来解决,称为基于二分类的多分类方法。

直接多分类的目标函数比较复杂,计算复杂度也非常高,实现困难。综合考虑分类精度与时间复杂度,在样本数量较多的情况下,直接多分类方法不适于实际问题的应用。基于二分类的多分类方法,主要将多分类问题分解为多个二分类问题,思想简单易于操作,且每次处理的支持向量要少的多,因此花费的训练时间比较短。基于这些优点,基于二分类的多分类成为实践应用的主流方法。
基于二分类的多分类有两种常用方法:一对一(One vs One,OVO)和一对其余(One vs Rest,OVR)方法。其中 OVO 算法解决问题的思路是:假设有 n 个类别的分类任务,将这 N个类别两两组合,就会产生 N(N-1)/2 个二分类任务。在测试过程中,把新样本同时提交给所有分类器,这样就会得到N(N-1)/2 个结果,而最终的测试结果会综合所有的结果,把预测的最多的类别作为最终分类的结果。OVR 算法则每次训练中选取一个类的样例作为正例,而所有其他类的样例作为负例来训练 N 个分类器。在测试时,若仅有一个分类器预测为正类,则对应的类别标记为最终分类结果。两种方法的算法示意图如下所示:

2.3 不平衡数据的处理

不平衡数据广泛存在于医疗诊断、雷达图像监测、诈骗检测、文本分类、金融贷款管理、企业破产预测、电子设备故障预测等多个领域,具有极高的应用前景和现实意义。这里的不平衡数据是指在训练样本集中某些类别样本的数量远大于其它类别样本的数量,其中样本数量少的类为少数类(正类),样本数量多的类为多数类(负类)。研究表明在样本数量不平衡的情况下,一般的机器学习方法对少类样本分类和预测的准确率远低于对多类样本分类和预测的准确率,即算法偏向于多类。但是在很多实际应用中,例如故障诊断、欺诈检测等,通常更关注少类样本分类的准确率。
例如下图给出的一个不平衡数据示例,从中可以看出绿点样本数量远大于红点样本数量,但对于这些红点也就是正类的分类准确率才是我们更关心的。

目前,多数分类方法都是建立在样本集平衡的假设之下的,即使用的分类错误率来评价其分类性能。然而在样本集呈现不平衡分布时,由于我们更关心少类样本的分类准确率,因此不再适合简单的使用分类错误率,这就需要一些新的评价方法和指标来衡量各种不平衡数据分类方法的分类性能。
为表述方便,下面首先给出数据集混合矩阵的相关定义。对于不平衡数据集的每一个待分类样本,两分类分类方法有四种可能的判决结果。通常以关注的类为正类(少数类),其他类为负类,以下我们记为:
TP(True Positive):属于正类且被判别为正类的样本个数;
FP(False Positive):属于负类且被判别为正类的样本个数;
FN(False Negative):属于正类且被判别为负类的样本个数;
TN(True Negative):属于负类且被判别为负类的样本个数。
则数据集的混合矩阵定义和其评价方法和指标如下所示:


目前,不平衡数据处理方法主要分为两大类,一是数据预处理的方法,包括对少类样本过采样技术和多类样本的欠采样技术,其中过采样技术通过复制或内插的方法来增加少类样本的数量,但是这种方法遵循一种假设:多类样本之间仍是多类样本,少类样本附近仍然是少类样本,这个假设与数据特性相关,有时是不成立的,因此该方法对数据特性依赖较强;而欠采样技术使用采样和修剪的方法减少多类样本数量,该方法会造成数据集分类信息的部分丢失。另一大类是基于SVM算法改进的方法,主要包括在高维特征空间上移动分类超平面;为不同类别的样本分配不同的惩罚系数;以及根据样本特性修正核函数形式等,算法改进的方法一方面充分利用现有的信息,另一方面又基本上不增加算法的计算复杂度,它同数据预处理的方法相互补充,共同解决不平衡数据分类问题。

2.4 主要算法实现步骤

输入:训练集T={(x1,y1),(x2,y2),...,(xN,yN)}T= \big\{(x_1,y_1), (x_2,y_2),...,(x_N,y_N)\big\}T={(x1​,y1​),(x2​,y2​),...,(xN​,yN​)},其中xi∈X=Rn,yi∈Y={−1,+1},i=1,2,...,N;x_i∈X=R^n,y_i∈Y=\{-1,+1\},i=1,2,...,N;xi​∈X=Rn,yi​∈Y={−1,+1},i=1,2,...,N;

输出:分类决策树

(1)选择适当的核函数K(x,z)K(x,z)K(x,z)和适当的惩罚参数C>0,构造并求解最优化问题
min⁡α12∑i=1N∑j=1NαiαjyiyjK(xi⋅xj)−∑i=1Nαis.t.∑i=1Nαiyi=00≤αi≤C,i=1,2,...,N\begin{aligned} \min_\alpha \quad &\frac{1}{2}\sum_{i=1}^{N}\sum_{j=1}^{N}\alpha_i\alpha_jy_iy_jK\big(x_i·x_j\big)-\sum_{i=1}^{N}\alpha_i \\ s.t.\quad&\sum_{i=1}^{N}\alpha_iy_i=0\\ &0≤\alpha_i≤C,i=1,2,...,N \end{aligned}αmin​s.t.​21​i=1∑N​j=1∑N​αi​αj​yi​yj​K(xi​⋅xj​)−i=1∑N​αi​i=1∑N​αi​yi​=00≤αi​≤C,i=1,2,...,N​

求得最优解α∗=(α1∗,α2∗,...,αN∗)T\alpha^*=\big(\alpha_1^*,\alpha_2^*,...,\alpha_N^*\big)^Tα∗=(α1∗​,α2∗​,...,αN∗​)T

(2)选择α∗\alpha^*α∗的一个正分量C>α∗>0C>\alpha^*>0C>α∗>0,计算
b∗=yi−∑i=1Nαi∗yiK(xi⋅xj)b^*=y_i-\sum_{i=1}^{N}\alpha_i^*y_iK(x_i·x_j) b∗=yi​−i=1∑N​αi∗​yi​K(xi​⋅xj​)
(3)构造决策函数:
f(x)=sign(∑i=1Nαi∗yiK(x⋅xi)+b∗)f(x)=sign\big(\sum_{i=1}^{N}\alpha_i^*y_iK(x·x_i)+b^*\big) f(x)=sign(i=1∑N​αi∗​yi​K(x⋅xi​)+b∗)

3. 实战

sklearn中的SVM分类算法对应的是svm.SVC。

class sklearn.svm.SVC(*, C=1.0, kernel='rbf', degree=3, gamma='scale', coef0=0.0, shrinking=True, probability=False, tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=- 1, decision_function_shape='ovr', break_ties=False, random_state=None)

关于参数的详细介绍也可在官网的论述中查看。本节只给出一些关键参数的简要说明如下:
C:惩罚参数,默认值是1.0,一般参数值的范围可为0.0001~10000,惩罚参数值越大,则对错误例惩罚程度越大,但可能会导致过拟合。
kernel :核函数类型,可选类型为’linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’,默认是 ‘rbf’。
degree :多项式’poly’核函数的维度,默认值为3,若为其他核函数类型则该参数将被自动忽略。
gamma :‘rbf’,‘poly’和’sigmoid’核函数的系数。当前默认值为’auto’,gamma值为1 / n_features; 如果gamma=‘scale’,则gamma值为1 / (n_features * X.var())。
coef0 :核函数的常数项,默认值为3,只对’poly’和’sigmoid’核函数类型生效。

3.1 多类分类

本例将通过scipy中的multivariate_normal创建四分类随机多元正态分布样本集,使用svm.SVC对样本分类,实现代码如下:

import numpy as np
from sklearn import svm
from scipy import stats
from sklearn.metrics import accuracy_score
import matplotlib as mpl
import matplotlib.pyplot as pltdef extend(a, b, r=0.01):return a * (1 + r) - b * r, -a * r + b * (1 + r)if __name__ == "__main__":# 创建四分类样本集np.random.seed(100)         # 撒固定的种子,保证每次样本集数据相同N = 200                     # 每个分类200个样本点x = np.empty((4*N, 2))      # 随机返回800*2的元组means = [(-1, 1), (1, 1), (1, -1), (-1, -1)]                # 各个类的分布均值sigmas = [np.eye(2), 2*np.eye(2), np.diag((1,2)), np.array(((3, 2), (2, 3)))]   # 各个类的分布协方差矩阵for i in range(4):mn = stats.multivariate_normal(means[i], sigmas[i]*0.1) # 根据分布均值和协方差创建多元正态分布x[i*N:(i+1)*N, :] = mn.rvs(N)                           # 随机抽取创建多元正态分布,并对样本特征X赋值a = np.array((0,1,2,3)).reshape((-1, 1))y = np.tile(a, N).flatten() # 创建标签值# 支持向量分类机模型,高斯核,错误项的惩罚参数C=1,核函数系数gamma=1,一对一分类决策函数clf = svm.SVC(C=1, kernel='rbf', gamma=1, decision_function_shape='ovo')clf.fit(x, y)                           # 训练SVCy_hat = clf.predict(x)                  # 预测acc = accuracy_score(y, y_hat)          # 训练集精度np.set_printoptions(suppress=True)print('预测正确的样本个数:%d,正确率:%.2f%%' % (round(acc*4*N), 100*acc))print('支撑向量数目:', clf.n_support_)print(clf.decision_function(x))         # 样本点距超平面距离# 画图x1_min, x2_min = np.min(x, axis=0)x1_max, x2_max = np.max(x, axis=0)x1_min, x1_max = extend(x1_min, x1_max)x2_min, x2_max = extend(x2_min, x2_max)x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]x_test = np.stack((x1.flat, x2.flat), axis=1)y_test = clf.predict(x_test)y_test = y_test.reshape(x1.shape)cm_light = mpl.colors.ListedColormap(['#FF8080', '#80FF80', '#8080FF', '#F0F080'])cm_dark = mpl.colors.ListedColormap(['r', 'g', 'b', 'y'])mpl.rcParams['font.sans-serif'] = ['SimHei']mpl.rcParams['axes.unicode_minus'] = Falseplt.figure(facecolor='w')plt.pcolormesh(x1, x2, y_test, cmap=cm_light)plt.contour(x1, x2, y_test, levels=(0,1,2), colors='k', linestyles='--')plt.scatter(x[:, 0], x[:, 1], s=20, c=y, cmap=cm_dark, edgecolors='k', alpha=0.7)plt.xlabel('$X_1$', fontsize=11)plt.ylabel('$X_2$', fontsize=11)plt.xlim((x1_min, x1_max))plt.ylim((x2_min, x2_max))plt.grid(b=True)plt.tight_layout(pad=2.5)plt.title('SVM多分类方法:One/One or One/Rest', fontsize=14)plt.show()

输出结果如下:

预测正确的样本个数:792,正确率:99.00%
支撑向量数目: [14 29 22 23]
[[ 1.50432191  1.2820055   1.71149331  0.38199111 -0.55217345 -0.41357172][ 1.6535127   1.3253407   1.75495585  0.57391031 -0.08517941 -0.32364579][ 1.23434954  1.14042349  1.36757566  0.23062588 -0.19802485 -0.16228541]...[-0.03906226 -0.04228719 -1.47537791  0.13425707 -1.1147555  -1.37442757][-0.21845916 -0.12649473 -1.1387974   0.16443125 -1.17337707 -1.19144718][ 0.02383455 -0.63753932 -1.56095266 -0.49305719 -1.30025723 -1.00020888]]

SVM多分类结果输出图形如下:

3.2 不平衡数据的处理

本例将创建正数类样本数量为10、负数类样本数量为990的不平衡样本数据集,分别使用线性核与高斯核,对正负样本赋予不同权重值进行样本分类,实现代码如下:

import numpy as np
from sklearn import svm
import matplotlib.colors
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.exceptions import UndefinedMetricWarning
import warningsif __name__ == "__main__":np.random.seed(0)   # 撒固定的种子,保证每次样本集数据相同c1 = 990        # 反例样本数c2 = 10         # 正例样本数N = c1 + c2     # 样本总数x_c1 = 3*np.random.randn(c1, 2)             # 随机取反例坐标值(二维)x_c2 = 0.5*np.random.randn(c2, 2) + (4, 4)  # 随机取正例坐标值(二维)x = np.vstack((x_c1, x_c2))                 # 堆叠正例反例坐标值y = np.ones(N)                              # 对正例赋标签值y[:c1] = -1                                 # 对反例赋标签值# 样本点显示大小s = np.ones(N) * 30s[:c1] = 10# 权重值weight = [1,30,1,30]# 对正例使用权重值,使用线型核、高斯核的SVC分类器clfs = [svm.SVC(C=1, kernel='linear', class_weight={-1:1, 1:weight[0]}),svm.SVC(C=1, kernel='linear', class_weight={-1:1, 1:weight[1]}),svm.SVC(C=0.8, kernel='rbf', gamma=0.5, class_weight={-1:1, 1:weight[2]}),svm.SVC(C=0.8, kernel='rbf', gamma=0.5, class_weight={-1:1, 1:weight[3]})]titles = [('Linear, Weight=%d' % weight[0]), ('Linear, Weight=%d' % weight[1]), ('RBF, Weight=%d' % weight[2]),('RBF, Weight=%d' % weight[3])]x1_min, x2_min = np.min(x, axis=0)x1_max, x2_max = np.max(x, axis=0)x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]grid_test = np.stack((x1.flat, x2.flat), axis=1)  # 测试点cm_light = matplotlib.colors.ListedColormap(['#77E0A0', '#FF8080'])cm_dark = matplotlib.colors.ListedColormap(['g', 'r'])matplotlib.rcParams['font.sans-serif'] = ['SimHei']matplotlib.rcParams['axes.unicode_minus'] = Falseplt.figure(figsize=(10, 8), facecolor='w')# 遍历分类器for i, clf in enumerate(clfs):clf.fit(x, y)               # 训练SVCy_hat = clf.predict(x)      # 预测训练集# 输出性能度量值print(i+1, '次:')print('accuracy(精确度):\t', accuracy_score(y, y_hat))print('precision(准确率):\t', precision_score(y, y_hat, pos_label=1))print('recall(召回率):\t\t', recall_score(y, y_hat, pos_label=1))print('F1-score(F1度量):\t', f1_score(y, y_hat, pos_label=1))print()# 画图plt.subplot(2, 2, i+1)grid_hat = clf.predict(grid_test)           # 预测分类值grid_hat = grid_hat.reshape(x1.shape)       # 使之与输入的形状相同plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light, alpha=0.8)plt.scatter(x[:, 0], x[:, 1], c=y, edgecolors='k', s=s, cmap=cm_dark)      # 样本的显示plt.xlim(x1_min, x1_max)plt.ylim(x2_min, x2_max)plt.title(titles[i])plt.grid(b=True, ls=':')plt.suptitle('SVC对不平衡数据的处理', fontsize=18)plt.tight_layout(1.5)plt.subplots_adjust(top=0.92)plt.show()

输出结果如下:

1 次:
accuracy(精确度):  0.99
precision(准确率):     0.0
recall(召回率):    0.0
F1-score(F1度量):     0.02 次:
accuracy(精确度):  0.941
precision(准确率):     0.14492753623188406
recall(召回率):    1.0
F1-score(F1度量):     0.253164556962025333 次:
accuracy(精确度):  0.994
precision(准确率):     0.7
recall(召回率):    0.7
F1-score(F1度量):     0.74 次:
accuracy(精确度):  0.994
precision(准确率):     0.625
recall(召回率):    1.0
F1-score(F1度量):     0.7692307692307693

不平衡数据样本集的分类结果输出图形如下:

3.3 手写数字识别

阿拉伯数字作为唯一被世界各国通用的符号,其在各行各业的地位是无可取代的。随着科技的发展,越来越多的数据信息需要被录入于计算机之中随后进行处理,通过人工识别纸张上数字的方法由于效率低下已不适用于海量数据的识别。目前手写体数字识别技术对各行各业发展的影响也越来越大。银行票据、财务报表、邮政编码等都可以通过手写体数字识别进行数字识别,不仅使人们摆脱重复且易出错的操作,而且极大地提升了工作效率。
数字只有 0-9 共 10 个数字,看似笔画简单且数量很少,但由于不同国家及个人的书写习惯的各不不同,以及部分数字之间的差异性较小,所以手写数字在识别方面的通用性并不太好;而且手写数字还应用在银行及财务等重要领域,因此对其识别的准确性具有很高的要求。手写体数字识别在追求准确、高效和低拒识率方面的研究也一直是热点方向。本节也将对SVM在手写数字识别方面的应用展开学习和探讨。

3.3.1 数据集描述

本小节实战将以经典的手写体数字图片数据集为例,原始数据可以从UCI官方数据地址获取下载:
https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits
该数据集也已经集成作为Scikit-learn的内部的手写体数字图片数据集。训练数据文件optdigits.tra中包含的图像数据共有3823条,每张手写体数字图片是8∗88*88∗8的像素。测试数据文件optdigits.tes中共有1797条数据,图片同样是8∗88*88∗8像素。以下显示了部分0-9数字的训练数据图像示例:

3.3.2 识别分类实现

使用sklearn中的GridSearchCV(超参数自动搜索模块),核函数类型选取为高斯核函数(RBF),寻找svm.SVC对样本分类的最优的C和gama参数组合,实现代码如下:

import numpy as np
from sklearn import svm
import matplotlib.colors
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import accuracy_score
import os
from sklearn.model_selection import GridSearchCV
from time import timeif __name__ == "__main__":# 读取数据-训练数据print('Load Training File Start...')data = np.loadtxt('optdigits.tra', dtype=np.float, delimiter=',')x, y = np.split(data, (-1, ), axis=1)print("训练样本个数:%d,特征数:%d \n" % x.shape)# 数据处理images = x.reshape(-1, 8, 8)            # 每个样本转换为8*8数组y = y.ravel().astype(np.int)            # 转置为行,并将值转换为整型# 读取数据-测试数据print('Load Test Data Start...')data = np.loadtxt('optdigits.tes', dtype=np.float, delimiter=',')x_test, y_test = np.split(data, (-1, ), axis=1)print("测试样本个数:%d,特征数:%d \n" % x_test.shape)# 数据处理images_test = x_test.reshape(-1, 8, 8)  # 每个样本转换为8*8数组y_test = y_test.ravel().astype(np.int)  # 转置为行,并将值转换为整型print('Load Data OK...')# SVC的C和gamma值选取范围params = {'C':np.logspace(0, 3, 7), 'gamma':np.logspace(-5, 0, 11)}# 超参数自动搜索模块GridSearchCV,系统地遍历多种参数组合,通过交叉验证确定最佳效果参数 3折交叉验证model = GridSearchCV(svm.SVC(kernel='rbf'), param_grid=params, cv=3)print('Start Learning...')t0 = time()model.fit(x, y)     # 训练数据t1 = time()t = t1 - t0         # 训练耗时print('训练+CV耗时:%d分钟%.3f秒' % (int(t/60), t - 60*int(t/60)))print('最优参数:\t', model.best_params_)print('Learning is OK...')print('训练集准确率:', accuracy_score(y, model.predict(x)))y_hat = model.predict(x_test)print('测试集准确率:', accuracy_score(y_test, y_hat))# 选取错分样本err_images = images_test[y_test != y_hat]err_y_hat = y_hat[y_test != y_hat]err_y = y_test[y_test != y_hat]print('错分样本的预测值:', err_y_hat)        # 错分样本的预测值print('错分样本的实际值:', err_y)            # 错分样本的实际值# 输出错分样本图片matplotlib.rcParams['font.sans-serif'] = ['SimHei']matplotlib.rcParams['axes.unicode_minus'] = Falseplt.figure(figsize=(10, 8), facecolor='w')for index, image in enumerate(err_images):if index >= 12:breakplt.subplot(3, 4, index + 1)plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')plt.title('错分为:%i,真实值:%i' % (err_y_hat[index], err_y[index]))plt.tight_layout()plt.show()

输出结果如下:

Load Training File Start...
训练样本个数:3823,特征数:64 Load Test Data Start...
测试样本个数:1797,特征数:64 Load Data OK...
Start Learning...
训练+CV耗时:5分钟18.484秒
最优参数:     {'C': 10.0, 'gamma': 0.001}
Learning is OK...
训练集准确率: 1.0
测试集准确率: 0.9827490261547023
[0 1 2 ... 8 9 8]
[0 1 2 ... 8 9 8]
错分样本的预测值: [9 1 1 1 1 9 5 9 9 9 9 9 9 8 1 0 1 3 8 9 9 3 5 9 1 7 3 5 8 5 1]
错分样本的实际值: [5 2 2 2 8 7 7 5 7 7 7 7 7 1 8 6 8 9 9 3 8 8 8 7 8 3 9 9 3 3 8]

输出图形如下:

3.3.3 不同核函数下的识别率

基于上节的高斯核函数(RBF)的计算结果,我们可再分别计算在线性核函数(linear)、多项式核函数(poly)、sigmoid 核函数类型下的识别情况,进一步综合考察不同核函数类型下模型的训练优化时间和准确率等对比结果,进而可得到一些有意义的结论。
下图为不同核函数类型的识别对比情况:


从以上识别对比结果我们可以看出,在本例中:
(1)当采用RBF核函数时,手写数字在测试集及训练集的分类准确率都是最高的,但其计算时间花费最多;当采用linear核函数时,虽然计算时间最少和测试集分类准确率很高,但在测试集的准确率上却是最低的。
(2)四种核函数模型的测试集准确率:
RBF 核函数 >poly 核函数>sigmoid 核函数 >linear 核函数
(3)四种核函数模型的训练优化时间:
linear 核函数 <poly 核函数<sigmoid 核函数 <rbf 核函数
在实际运用中,通常需要兼顾计算资源和结果准确率等因素综合考量,根据具体的分类问题具体分析,选择相应合适的核函数类型。

4.参考学习的书目及论文

  1. 机器学习算法视频 - 邹博
  2. 《统计学习方法》第7章 支持向量机
  3. 《机器学习实战》第6章 支持向量机
  4. 《机器学习 - 周志华》第6章 支持向量机
  5. 《集体编程智慧》第9章 高阶分类:核方法与SVM
  6. 《基于支持向量机SVM的车牌识别》 余承波 2018 硕士论文 第3章 基于 SVM(支持向量机)的车牌字符识别
  7. 《基于驾驶意图识别的纯电动汽车驱动控制策略研究》李晓东 2018 硕士论文 第3章 纯电动汽车驾驶意图识别研究
  8. 《基于智能网联汽车的CAN总线攻击与防御检测技术研究》杨宏 2017 硕士论文 第6章 基于支持向量机的 CAN 总线异常检测
  9. 《基于支持向量机的行人检测技术研究》杨萌 2018 硕士论文 第3章 基于支持向量机的分类器设计

=文档信息=
本学习笔记由博主整理编辑,仅供非商用学习交流使用
由于水平有限,错误和纰漏之处在所难免,欢迎大家交流指正
如本文涉及侵权,请随时留言博主,必妥善处置
版权声明:非商用自由转载-保持署名-注明出处
署名(BY) :zhudj
文章出处:https://zhudj.blog.csdn.net/

机器学习实战之路 —— 5 SVM支持向量机相关推荐

  1. 刻意学习:机器学习实战--Task03分类问题:支持向量机

    刻意学习:机器学习实战–Task03分类问题:支持向量机 1 什么是SVM? 首先,支持向量机不是一种机器,而是一种机器学习算法. 1.SVM - Support Vector Machine ,俗称 ...

  2. 《机器学习实战》:通俗理解支持向量机

    代码.数据集.文章我都是放到了https://github.com/AAAZC/SVM_blog 上面了,文章在issues里面,建议上这个网站看 <机器学习实战>:通俗理解支持向量机 关 ...

  3. 机器学习实战之路 —— 4 Boosting算法

    机器学习实战之路 -- 4 Boosting算法 1. Boosting算法概述 2. 主要算法实现 2.1 AdaBoost 2.2 GBDT 2.3 XGBoost 3. 实战 - 鸢尾花数据集分 ...

  4. 机器学习(MATLAB实现)——SVM支持向量机(一)

    SVM支持向量机 支持向量机理论概述 二分类支持向量机 多分类支持向量机 libsvm工具箱使用简介 训练函数 预测函数 libsvm参数实例 一点拓展 参考文献 支持向量机理论概述 核函数用于将支持 ...

  5. 【机器学习算法-python实现】svm支持向量机(1)—理论知识介绍

    (转载请注明出处:http://blog.csdn.net/buptgshengod) 1.背景      强烈推荐阅读(http://www.cnblogs.com/jerrylead/archiv ...

  6. 【手把手机器学习入门到放弃】SVM支持向量机

    支持向量机 打仗的时候只有站最前面的人在打而已 支持向量机也是完成分类问题的一个工具,不同于逻辑回归,在支持向量机解决的分类问题中,只有最靠近对方阵营的样本对分界线的确定起到作用,而远离分界线的那些样 ...

  7. 【机器学习算法-python实现】svm支持向量机(3)—核函数

    (转载请注明出处:http://blog.csdn.net/buptgshengod) 1.背景知识    前面我们提到的数据集都是线性可分的,这样我们可以用SMO等方法找到支持向量的集合.然而当我们 ...

  8. 【机器学习算法-python实现】svm支持向量机(2)—简化版SMO算法

    (转载请注明出处:http://blog.csdn.net/buptgshengod) 1.背景知识      通过上一节我们通过引入拉格朗日乗子得到支持向量机变形公式.详细变法可以参考这位大神的博客 ...

  9. Python3《机器学习实战》学习笔记(八):支持向量机原理篇之手撕线性SVM

    原 Python3<机器学习实战>学习笔记(八):支持向量机原理篇之手撕线性SVM 置顶 2017年09月23日 17:50:18 阅读数:12644 转载请注明作者和出处: https: ...

最新文章

  1. 一个理想主义者关于爱情和美女、事业与金钱的疯人痴语
  2. 安装php openssl扩展
  3. gis属性表怎么导成excel_第022篇:ArcGIS中将属性表直接导出为Excel的方法
  4. 软件著作权 开源框架_开源软件分享-基于.net core 3.1的快速开发框架
  5. 电脑换ip_代理ip地址怎么换
  6. AI实战:从入门到精通系列——用感知器实现情感分类(一)
  7. android webview js 失效,Android WebView注入JQuery、JS脚本及执行无效的问题解决
  8. POJ-3590 The shuffle Problem 置换+DP | DFS
  9. 烂泥:perl中CPAN的安装
  10. gVim取消自动备份
  11. mac 安装mysql 找不到_mac安装mysql遇到的坑
  12. 小牛电动股权曝光:李彦持股4.4% 李一男持股降至28.1%
  13. 计算机知识太多了记不住,内容太多记不住?教你提高记忆力
  14. 001_linux基础命令
  15. 40多年祖传中医的临床经验总结(收藏)
  16. PCB设计之阻抗不连续性,如何解决?
  17. Java学习笔记之 Lambda表达式
  18. 关于内外网数据同步解决方案
  19. 记一次UDP接入服务的性能测试
  20. DBeaver:开源、跨平台、强大的数据库管理工具

热门文章

  1. 【嵌入式百科】001——字长、比特、字节、字、双字
  2. JavaScript相等与全等区别
  3. GMGC昆山数娱峰会:VR爆发,差的不只是一层窗户纸
  4. magicwatch2可以鸿蒙吗,荣耀手表2可以发微信吗?荣耀MagicWatch2支持微信消息回复吗...
  5. unicode转中文的代码html,unicode的html页面编码转换成中文
  6. 2022好用不亏的数码产品推荐、趁着618还没结束赶紧入
  7. android华为和小米,都是用安卓,为什么用小米和华为体验完全不一样?网友的评论扎心了!...
  8. 个人注册CSDN后第一篇分享关于测试工程师工作心得的文章
  9. Nginx-域名跳转到另外一个域名
  10. 【jQuery学习】淘宝精品栏案例