支持向量机(SVM)

简介

支持向量机(Support Vector Machine,SVM),是常见的一种判别方法。在机器学习领域,是一个监督学习模型,通常用来进行模式识别、分类及回归分析。与其他算法相比,支持向量机在学习复杂的非线性方程时提供了一种更为清晰、更加强大的方式。
支持向量机是20世纪90年代中期发展起来的基于统计学习理论的一种机器学习方法,通过寻求结构化风险最小来提高学习机泛化能力,实现经验风险和置信范围的最小化,从而达到在统计样本量较少的情况下,也能获得良好统计规律的目的。
通俗来讲,它是一种二类分类模型,其基本模型定义为特征空间上的间隔最大的线性分类器,即支持向量机的学习策略便是间隔最大化,最终可转化为一个凸二次规划问题的求解。

原理

几个概念

线性可分

如图,数据之间分隔得足够开,很容易在图中画出一条直线将这一组数据分开,这组数据就被称为线性可分数据。

分隔超平面
上述将数据集分隔开来的直线称为分隔超平面,由于上面给出的数据点都在二维平面上,所以此时的分隔超平面是一条直线。如果给出的数据集点是三维的,那么用来分隔数据的就是一个平面。因此,更高维的情况可以以此类推,如果数据是100维的,那么就需要一个99维的对象来对数据进行分隔,这些统称为超平面。
间隔
如图片所示,这条分隔线的好坏如何呢?我们希望找到离分隔超平面最近的点,确保它们离分割面的距离尽可能远。在这里点到分隔面的距离称为间隔。间隔尽可能大是因为如果犯错或在有限数据上训练分类器,我们希望分类器尽可能健壮。
支持向量
离分隔超平面最近的那些点是支持向量。

求解

接下来就是要求解最大支持向量到分隔面的距离,需要找到此类问题的求解方法。如图,分隔超平面可以写成wTx+bw^Tx+bwTx+b。要计算距离,值为∣wTA+b∣∣∣w∣∣\frac{|w^TA+b|}{||w||}∣∣w∣∣∣wTA+b∣​。这里的向量w和常数b一起描述了所给数据的超平面。
对wTx+bw^Tx+bwTx+b使用单位阶跃函数得到f(wTx+b)f(w^Tx+b)f(wTx+b),其中当u<0u<0u<0时,f(u)f(u)f(u)输出-1,反之则输出+1。这里使用-1和+1是为了教学上的方便处理,可以通过一个统一公式来表示间隔或数据点到分隔超平面的距离。间隔通过label×∣wTx+b∣∣∣w∣∣label\times{|w^Tx+b| \over ||w||}label×∣∣w∣∣∣wTx+b∣​来计算。如果数据处在正方向(+1)类里面且离分隔超平面很远的位置时,wTx+bw^Tx+bwTx+b是一个很大的正数。同时label×∣wTx+b∣∣∣w∣∣label\times{|w^Tx+b| \over ||w||}label×∣∣w∣∣∣wTx+b∣​也会是一个很大的正数。如果数据点处在负方向(-1)类且离分隔超平面很远的位置时,由于此时类别标签为-1,label×∣wTx+b∣∣∣w∣∣label\times{\frac{|w^Tx+b|}{||w||}}label×∣∣w∣∣∣wTx+b∣​仍然是一个很大的正数。
为了找到具有最小间隔的数据点,就要找到分类器中定义的www和bbb,最小间隔的数据点也就是支持向量。一旦找到支持向量,就需要对该间隔最大化,可以写作max⁡w,b(min⁡n(label×∣wTx+b∣∣∣w∣∣))\max_{w,b}\left(\min_n(label\times{|w^Tx+b| \over ||w||})\right) w,bmax​(nmin​(label×∣∣w∣∣∣wTx+b∣​))
但是直接求解上述问题是非常困难的,所以这里引入拉格朗日乘子法,可以把表达式写成下面的式子:
max⁡_α[∑_i=1mα12∑_i,j=1mlabel(i)×α_i×α_j(x(i),x(j))]{\max\_{\alpha}}\left[\sum\_{i=1}^m\alpha \frac{1}{2}\sum\_{i,j=1}^mlabel^{(i)}\times \alpha\_i\times\alpha\_j{(x^{(i)},x^{(j)})} \right] max_α[∑_i=1mα21​∑_i,j=1mlabel(i)×α_i×α_j(x(i),x(j))]
其中×\times×表示,两个向量的内积。约束条件为C≥α≥0C\geq\alpha\geq0 C≥α≥0 ∑i−1mα_i×label(i)=0\sum_{i-1}^m \alpha\_i\times label^{(i)}=0 i−1∑m​α_i×label(i)=0
这里的常数C用于控制最大化间隔和保证大部分点的函数间隔小于1.0这两个目标的权重。因为所有数据都可能有干扰数据,所以通过引入所谓的松弛变量,允许有些数据点可以处于分隔面错误的一侧。
根据上式可知,只要求出所有的 α\alphaα,那么分隔面就可以通过α\alphaα来表达,SVM的主要工作就是求α\alphaα。这样一步步解出分隔面,那么分类问题游刃而解。

算法优势

  • 泛化错误率低,具有良好的学习能力。
  • 几乎所有分类问题都可以使用SVM解决。
  • 节省内存开销。

实战

实现手写识别系统

为了简单起见,这里的手写识别只针对0到9的数字,为了方便,图像转为了文本。目录trainingDigits中含有2000个例子,用于训练,目录testDigits含有900个例子,用于测试。

虽然手写识别可以用KNN实现而且效果不错,但是KNN毕竟太占内存了,而且要保证性能不变的同时使用较少的内存。而对于SVM,只需要保留很少的支持向量就可以实现目标效果。

  • 流程

    • 准备数据
    • 分析数据
    • 使用SMO算法求出α\alphaα和bbb
    • 训练算法
    • 测试算法
    • 使用算法

代码实现

# -*- coding=UTF-8 -*-
from numpy import *def clipAlpha(aj, H, L):'''辅助函数,调整a的范围:param aj::param H::param L::return:'''if aj > H:aj = Hif L > aj:aj = Lreturn ajdef kernelTrans(X, A, kTup):'''修改kernel:param X::param A::param kTup::return:'''m, n = shape(X)K = mat(zeros((m, 1)))if kTup[0] == 'lin':K = X * A.Telif kTup[0] == 'rbf':for j in range(m):deltaRow = X[j, :] - AK[j] = deltaRow*deltaRow.T# numpy中除法意味着对矩阵展开计算而不是matlab的求矩阵的逆K = exp(K/(-1*kTup[1]**2))# 如果遇到无法识别的元组,程序抛出异常else:raise NameError('Houston We Have a Problem -- That Kernel is not recognized')return Kclass optStruct:'''保存所有重要值,实现对成员变量的填充'''def __init__(self,dataMatIn, classLabels, C, toler, kTup):  # Initialize the structure with the parametersself.X = dataMatInself.labelMat = classLabelsself.C = Cself.tol = tolerself.m = shape(dataMatIn)[0]self.alphas = mat(zeros((self.m,1)))self.b = 0self.eCache = mat(zeros((self.m,2))) #误差缓存self.K = mat(zeros((self.m,self.m)))for i in range(self.m):self.K[:,i] = kernelTrans(self.X, self.X[i,:], kTup)def calcEk(oS, k):'''计算E值 计算误差:param oS::param k::return:'''fXk = float(multiply(oS.alphas,oS.labelMat).T*oS.K[:,k] + oS.b)Ek = fXk - float(oS.labelMat[k])return Ekdef selectJrand(i, m):''':param i: a的下标:param m: a的总数:return:'''j = iwhile (j == i):# 简化版SMO,alpha随机选择j = int(random.uniform(0, m))return jdef selectJ(i, oS, Ei):'''选择第二个a的值以保证每次优化的最大步长(内循环):param i::param oS::param Ei::return:'''maxK = -1maxDeltaE = 0Ej = 0oS.eCache[i] =[1, Ei]validEcacheList = nonzero(oS.eCache[:, 0].A)[0]if(len(validEcacheList)) > 1:for k in validEcacheList:if k == i:continueEk = calcEk(oS, k)deltaE = abs(Ei-Ek)if(deltaE > maxDeltaE):maxK = kmaxDeltaE = deltaEEj = Ekreturn maxK, Ejelse:j = selectJrand(i, oS.m)Ej = calcEk(oS, j)return j, Ejdef updateEk(oS, k):'''计算误差值并存入缓存中:param oS::param k::return:'''Ek = calcEk(oS, k)oS.eCache[k] = [1,Ek]def innerL(i, oS):'''选择第二个a:param i::param oS::return:'''Ei = calcEk(oS, i)if ((oS.labelMat[i]*Ei < -oS.tol) and (oS.alphas[i] < oS.C)) or ((oS.labelMat[i]*Ei > oS.tol) and (oS.alphas[i] > 0)):j, Ej = selectJ(i, oS, Ei)alphaIold = oS.alphas[i].copy()alphaJold = oS.alphas[j].copy()if (oS.labelMat[i] != oS.labelMat[j]):L = max(0, oS.alphas[j] - oS.alphas[i])H = min(oS.C, oS.C + oS.alphas[j] - oS.alphas[i])else:L = max(0, oS.alphas[j] + oS.alphas[i] - oS.C)H = min(oS.C, oS.alphas[j] + oS.alphas[i])if L == H:print("L==H")return 0eta = 2.0 * oS.K[i, j] - oS.K[i, i] - oS.K[j, j]if eta >= 0:print("eta>=0")return 0oS.alphas[j] -= oS.labelMat[j]*(Ei - Ej)/etaoS.alphas[j] = clipAlpha(oS.alphas[j], H, L)updateEk(oS, j)if (abs(oS.alphas[j] - alphaJold) < 0.00001):print ("j not moving enough")return 0oS.alphas[i] += oS.labelMat[j]*oS.labelMat[i]*(alphaJold - oS.alphas[j])updateEk(oS, i)b1 = oS.b - Ei - oS.labelMat[i]*(oS.alphas[i]-alphaIold)*oS.K[i, i] - oS.labelMat[j]*(oS.alphas[j]-alphaJold)*oS.K[i, j]b2 = oS.b - Ej - oS.labelMat[i]*(oS.alphas[i]-alphaIold)*oS.K[i, j] - oS.labelMat[j]*(oS.alphas[j]-alphaJold)*oS.K[j, j]if (0 < oS.alphas[i]) and (oS.C > oS.alphas[i]):oS.b = b1elif (0 < oS.alphas[j]) and (oS.C > oS.alphas[j]):oS.b = b2else:oS.b = (b1 + b2)/2.0return 1else:return 0def smoP(dataMatIn, classLabels, C, toler, maxIter,kTup=('lin', 0)):'''实现platt smo算法:param dataMatIn::param classLabels::param C::param toler::param maxIter::param kTup::return:'''oS = optStruct(mat(dataMatIn), mat(classLabels).transpose(), C, toler, kTup)iter = 0entireSet = True; alphaPairsChanged = 0while (iter < maxIter) and ((alphaPairsChanged > 0) or (entireSet)):alphaPairsChanged = 0if entireSet:for i in range(oS.m):alphaPairsChanged += innerL(i, oS)print("fullSet, iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged))iter += 1else:nonBoundIs = nonzero((oS.alphas.A > 0) *(oS.alphas.A < C))[0]for i in nonBoundIs:alphaPairsChanged += innerL(i, oS)print("non-bound, iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged))iter += 1if entireSet:entireSet = Falseelif (alphaPairsChanged == 0):entireSet = Trueprint("迭代次数: %d" % iter)return oS.b, oS.alphasdef img2vector(filename):'''二值化图像转为向量32*32转为1*1024:param filename: 文件名:return: 向量'''returnVect = zeros((1, 1024))fr = open(filename)for i in range(32):lineStr = fr.readline()for j in range(32):returnVect[0, 32*i+j] = int(lineStr[j])return returnVectdef loadImages(dirName):'''导入数据集:param dirName::return:'''from os import listdirhwLabels = []trainingFileList = listdir(dirName)m = len(trainingFileList)trainingMat = zeros((m, 1024))for i in range(m):fileNameStr = trainingFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])# 这里是二分类问题,只分类数字1和9,数字分类结果为9时返回-1if classNumStr == 9:hwLabels.append(-1)else:hwLabels.append(1)trainingMat[i, :] = img2vector('%s/%s' % (dirName, fileNameStr))return trainingMat, hwLabelsdef testDigits(kTup=('rbf', 10)):'''测试算法,使用smop训练:param kTup: 核函数:return:'''dataArr, labelArr = loadImages('data/trainingDigits')b, alphas = smoP(dataArr, labelArr, 200, 0.0001, 10000, kTup)datMat = mat(dataArr)labelMat = mat(labelArr).transpose()svInd = nonzero(alphas.A > 0)[0]sVs = datMat[svInd]labelSV = labelMat[svInd]print("有 %d 支持向量" % shape(sVs)[0])m, n = shape(datMat)errorCount = 0for i in range(m):kernelEval = kernelTrans(sVs, datMat[i, :], kTup)predict = kernelEval.T * multiply(labelSV, alphas[svInd]) + bif sign(predict) != sign(labelArr[i]):errorCount += 1print("训练数据错误率是: %f" % (float(errorCount)/m))dataArr, labelArr = loadImages('data/testDigits')errorCount = 0datMat = mat(dataArr)labelMat = mat(labelArr).transpose()m,n = shape(datMat)for i in range(m):kernelEval = kernelTrans(sVs, datMat[i, :], kTup)predict = kernelEval.T * multiply(labelSV, alphas[svInd]) + bif sign(predict) != sign(labelArr[i]):errorCount += 1print("测试数据错误率是: %f" % (float(errorCount)/m))def loadDataSet(filename):'''加载数据集:param filename: 文件名:return:'''dataMat = []labelMat = []fr = open(filename)for line in fr.readlines():lineArr = line.strip().split('\t')dataMat.append([float(lineArr[0]), float(lineArr[1])])labelMat.append(float(lineArr[2]))return dataMat, labelMatif __name__ == '__main__':testDigits(('rbf', 20))

补充说明

参考书《Python3数据分析与机器学习实战》,具体数据集和代码可以查看我的GitHub,欢迎star或者fork。

机器学习-分类之支持向量机(SVM)原理及实战相关推荐

  1. 机器学习-分类之多层感知机原理及实战

    多层感知机(Multi-Layer Perceptron) 简介 生物神经网络具有相互连接的神经元,神经元带有接受输入信号的树突,然后基于这些输入,它们通过轴突向另一个神经元产生输出信号.使用人工神经 ...

  2. 统计学习方法|支持向量机(SVM)原理剖析及实现

    欢迎直接到我的博客查看最近文章:www.pkudodo.com.更新会比较快,评论回复我也能比较快看见,排版也会更好一点. 原始blog链接: http://www.pkudodo.com/2018/ ...

  3. 机器学习算法の03 支持向量机SVM

    机器学习算法の03 支持向量机SVM SVM的基本概念 线性可分支持向量机 非线性支持向量机和核函数 线性支持向量机与松弛变量 LR与SVM的区别与联系 SVM的基本概念 基本概念: 支持向量机(su ...

  4. 机器学习算法 08 —— 支持向量机SVM算法(核函数、手写数字识别案例)

    文章目录 系列文章 支持向量机SVM算法 1 SVM算法简介 1.1 引入 1.2 算法定义 2 SVM算法原理 2.1 线性可分支持向量机 2.2 SVM计算过程与算法步骤(有点难,我也没理解透,建 ...

  5. 支持向量机SVM原理解析

    支持向量机(SVM) 支持向量机(support vector machine,SVM)使用训练集中的一个子集来表示决策边界,边界用于样本分类,这个子集称作支持向量(support vector). ...

  6. 机器学习:支持向量机SVM原理与理解

    引言 --"举牌子:Support Vector Machines " 一直在犹豫要不要写SVM,因为网上已经有很多详细的SVM原理的解释甚至详细推导,而这东西又庞大复杂,想了解的 ...

  7. svm多分类_人工智能科普|机器学习重点知识——支持向量机SVM

    作为模式识别或者机器学习的爱好者,同学们一定听说过支持向量机这个概念,这可是一个,在机器学习中避不开的重要问题. 其实关于这个知识点,有一则很通俗有趣的传说: 很久以前,一个村庄里住着一位退隐的大侠, ...

  8. 文本分类之支持向量机SVM详解(6)机器学习

    1 支持向量机介绍   对两类样本点进行分类,如下图,有a线.b线.c线三条线都可以将两类样本点很好的分开类,我们可以观察到b线将两类样本点分类最好,原因是我们训练出来的分类模型主要应用到未知样本中, ...

  9. 支持向量机——SVM原理

    SVM--Support Vector Machine 5.11 update:拉格朗日对偶问题的推导 5.15 update:SMO算法推导 5.17 update:sklearn实现 文章目录 S ...

最新文章

  1. Vue.js 技术揭秘学习 (2) Vue 实例挂载的实现
  2. oracle表空间大小规划,关于oracle表空间的规划方法
  3. 中国电信发布转型升级战略:构建一横四纵生态圈
  4. 用SAP BSP应用实现图片灰度效果
  5. lynda ux_UX心态
  6. Oracle 11g系统自动收集统计信息
  7. 大厂机密!30 个提升团队研发效能的锦囊
  8. mysql 基本概念学习(时间,事务)
  9. java 远程监控文件系统_Java 文件系统监控(WatchService)
  10. Android大图片裁剪解决方案
  11. python--时间日期
  12. 23 模块代码编写基础
  13. 微信小程序 #项目笔记# | 从0到1实现婚礼邀请函小程序
  14. Linux 修改环境变量设置的三种方式
  15. 第2章第6节:使用Slider滑杆在指定的范围内选择一个数值 [SwiftUI快速入门到实战]
  16. BGP公网成本节省50%秘笈,共享流量包、共享带宽包,便宜到阿里云快哭了
  17. 【ArchSummit】社交元宇宙的技术挑战与探索
  18. vivo X90、vivo X90 Pro和vivo X90 Pro+的区别 参数对比哪个好
  19. c++ opencv 彩色图rgb 转换hsv 再通道分离
  20. android下面res目录

热门文章

  1. 如何理解Minor/Major/Full GC
  2. Netty 采用NIO 而非AIO 的理由
  3. Redis中的客户端重定向
  4. MybatisPlus入门Lombok的使用
  5. 存储函数和存储过程的区别
  6. 什么是设计模式(Design Patterns)
  7. Oracle 同义词、DBLINK、表空间的使用
  8. php5.6.16,OSX 10.11 中重新编译PHP5.6.16问题
  9. linux执行cd会使用系统调用,深入理解Linux系统调用
  10. 对于linux下指令的进一步扩充与巩固