SVM C++ 实现
去年的时候,使用过 C++ 版本的SVM 实现过基于无人机的道路检测,但是当时,对博大精深的 SVM 只是了解皮毛。最近,对 SVM 的基本公式及相关变体的公式,重新推导了一遍,并且分别用 Python 和 C++ 实现了一遍。此文,是用 C++ 实现的。
SVM 公式的推导是需要掌握的,其实,如果一步一步地推导,基本公式是不难推导的,比如目标函数啊,拉格朗日乘子法, 以及涉及的对偶问题、KKT 条件、SMO 算法等。当进一步往下推导,比如 核函数、软间隔以及正则化等,需要点耐心去理解。
数据集: http://download.csdn.net/download/wz2671/10172405
main.cpp
#include <Windows.h>
#include "SVM.h"
#include "matrix.h"
#include "mat.h"
#include <iostream>
#pragma comment(lib, "libmat.lib")
#pragma comment(lib, "libmx.lib")
using namespace std; const int fn = 13;
const int sn1 = 59;
const int sn2 = 71;
const int sn3 = 48;
const int sn = 178; int readData(double* &data, double* &label)
{ MATFile *pmatFile = NULL; mxArray *pdata = NULL; mxArray *plabel = NULL; int ndir;//矩阵数目 //读取数据文件 pmatFile = matOpen("wine_data.mat", "r"); if (pmatFile == NULL) return -1; /*获取.mat文件中矩阵的名称 char **c = matGetDir(pmatFile, &ndir); if (c == NULL) return -2; */ pdata = matGetVariable(pmatFile, "wine_data"); data = (double *)mxGetData(pdata); matClose(pmatFile); //读取类标 pmatFile = matOpen("wine_label.mat", "r"); if (pmatFile == NULL) return -1; plabel = matGetVariable(pmatFile, "wine_label"); label = (double *)mxGetData(plabel); matClose(pmatFile); } int main()
{ doubl *data ; double *label; readData(data, label); //需要注意从.mat文件中读取出的数据按列存储 double *d; double *l; SVM svm; //第一组数据集与第二组数据集 预处理 l = new double[sn1 + sn2]; for(int i=0; i<sn1+sn2; i++) { if (fabs(label[i] - 2)<1e-3) l[i] = -1; else l[i] = 1; } d = new double[(sn1 + sn2)*fn]; for (int i = 0; i < fn; i++) { for (int j = 0; j < sn1+sn2; j++) { d[j*fn + i] = data[i*sn + j]; } } /* for (int i = 0; i < sn1 + sn2; i++) { for (int j = 0; j < fn; j++) { cout << d[i*fn + j] << ' '; } cout << endl; } */ svm.initialize(d, l, sn1+sn2, fn); svm.SMO(); cout << "数据集1和数据集2"; svm.show(); delete l; delete d; //第二组数据集与第三组数据集 l = new double[sn2 + sn3]; for (int i = sn1; i < sn1 + sn2 + sn3; i++) { if (fabs(label[i] - 2) < 1e-3) l[i-sn1] = 1; else if (fabs(label[i] - 3) < 1e-3) l[i-sn1] = -1; } d = new double[(sn2 + sn3)*fn]; for (int i = 0; i < fn; i++) { for (int j = sn1; j < sn; j++) { d[(j - sn1)*fn + i] = data[i*sn + j]; } } svm.initialize(d, l , sn2+sn3, fn); svm.SMO(); cout << "\n数据集2和数据集3"; svm.show(); delete l; delete d; //第一组数据集和第三组数据集 l = new double[sn1 + sn3]; for (int i = 0; i < sn1 + sn2 + sn3; i++) { if (fabs(label[i] - 1) < 1e-3) l[i] = 1; else if (fabs(label[i] - 3) < 1e-3) l[i - sn2] = -1; } d = new double[(sn1 + sn3)*fn]; for (int i = 0; i < fn; i++) { for (int j = 0; j < sn1; j++) { d[j*fn + i] = data[i*sn + j]; } for (int j = sn1 + sn2; j < sn; j++) { d[(j - sn2)*fn + i] = data[i*sn + j]; } } svm.initialize(d, l, sn1 + sn3, fn); svm.SMO(); cout << "\n数据集1和数据集3"; svm.show(); delete l; delete d; getchar(); return 0;
}
SVM.h
/*
用支持向量机求解二分类问题
分离超平面为:w'·x+b=0
分类决策函数:f(x)=sign(w'·x+b)
*/
#include <iostream>
using namespace std; class SVM
{
private: int sampleNum; //样本数 int featureNum; //特征数 double **data; //存放样本 行:样本, 列:特征 double *label; //存放类标 double *alpha; //double *w; 对于非线性问题,涉及kernel,不方便算 double b; double *gx; double s_max(double, double); double s_min(double, double); int secondAlpha(int); void computeGx(); double kernel(int, int); void update(int , int ,double, double); bool isConvergence(); bool takeStep(int, int); public: ~SVM(); //初始化数据 void initialize(double *, double *, int, int); //序列最小最优算法 void SMO(); double objFun(int); void show();
};
SVM.cpp
#include "SVM.h"
#include <math.h>
using namespace std; #define eps 1e-2 //误差精度
const int C = 100; //惩罚参数 SVM::~SVM()
{ if (data) delete[]data; if (label) delete label; if (alpha) delete alpha; if (gx) delete gx;
} //d中为样本,每个样本按行存储; l标签(1或-1); sn样本个数; fn特征个数
void SVM::initialize(double *d, double *l, int sn, int fn)
{ this->sampleNum = sn; this->featureNum = fn; this->label = new double[sampleNum]; this->data = new double*[sampleNum]; for (int i = 0; i < sampleNum; i++) { this->label[i] = l[i]; } for (int i = 0; i < sampleNum; i++) { this->data[i] = new double[featureNum]; for (int j = 0; j < featureNum; j++) { data[i][j] = d[i*featureNum + j]; } } alpha = new double[sampleNum] {0}; gx = new double[sampleNum] {0}; } double SVM::s_max(double a, double b)
{ return a > b ? a : b;
} double SVM::s_min(double a, double b)
{ return a < b ? a : b;
} double SVM::objFun(int x)
{ int j = 0; //选择一个0 < alpha[j] < C for (int i = 0; i < sampleNum; i++) { if (alpha[i]>0 && alpha[i] < C) { j = i; break; } } //计算b double b = label[j]; for (int i = 0; i < sampleNum; i++) { b -= alpha[i] * label[i] * kernel(i, j); } //构造决策函数 double objf = b; for (int i = 0; i < sampleNum; i++) { objf += alpha[i] * label[i] * kernel(x, i); } return objf;
} //判断有无收敛
bool SVM::isConvergence()
{ //alpah[i] * y[i]求和等于0 //0 <= alpha[i] <= C //y[i] * gx[i]满足一定条件 double sum = 0; for (int i = 0; i < sampleNum; i++) { if (alpha[i] < -eps || alpha[i] > C + eps) return false; else { // alpha[i] = 0 if (fabs(alpha[i]) < eps && label[i] * gx[i] < 1 - eps) return false; // 0 < alpha[i] < C if (alpha[i] > -eps && alpha[i] < C + eps && fabs(label[i] * gx[i] - 1)>eps) return false; // alpha[i] = C if (fabs(alpha[i] - C) < eps && label[i] * gx[i] > 1 + eps) return false; } sum += alpha[i] * label[i]; } if (fabs(sum) > eps) return false; return true;
} //假装是个核函数
//两个向量做内积
double SVM::kernel(int i, int j)
{ double res = 0; for (int k = 0; k < featureNum; k++) { res += data[i][k] * data[j][k]; } return res;
} //计算g(xi),也就是对样本i的预测值
void SVM::computeGx()
{ for (int i = 0; i < sampleNum; i++) { gx[i] = 0; for(int j=0; j < sampleNum; j++) { gx[i] += alpha[j] * label[j] * kernel(i, j); } gx[i] += b; }
} //更新很多东西
void SVM::update(int a1, int a2, double x1, double x2)
{ //更新阈值b double b1_new = -(gx[a1] - label[a1]) - label[a1] * kernel(a1, a1)*(alpha[a1] - x1) - label[a2] * kernel(a2, a1)*(alpha[a2] - x2) + b; double b2_new = -(gx[a2] - label[a2]) - label[a1] * kernel(a1, a2)*(alpha[a1] - x1) - label[a2] * kernel(a2, a2)*(alpha[a2] - x2) + b; if (fabs(alpha[a1]) < eps || fabs(alpha[a1] - C) < eps || fabs(alpha[a2]) < eps || fabs(alpha[a2] - C) < eps) b = (b1_new + b2_new) / 2; else b = b1_new; /* int j = 0; //选择一个0 < alpha[j] < C for (int i = 0; i < sampleNum; i++) { if (alpha[i]>0 && alpha[i] < C) { j = i; break; } } //计算b double b = label[j]; for (int i = 0; i < sampleNum; i++) { b -= alpha[i] * label[i] * kernel(i, j); } */ //更新gx computeGx();
} //选取第二个变量
/*
先选择是对应E1-E2最大的
若没有,用启发式规则,选目标函数有足够下降的alpha2
还没有,选择新的alpha1
*/
int SVM::secondAlpha(int a1)
{ //先计算出所有的E,也就是样本xi的预测值与真实输出之差Ei=g(xi)-yi //若E1为正,选最小的Ei作为E2,反正选最大 bool pos = (gx[a1] - label[a1] > 0); double tmp = pos ? 100000000 : -100000000; double ei = 0; int a2 = -1; for (int i = 0; i < sampleNum; i++) { ei = gx[i] - label[i]; if (pos && ei < tmp || !pos && ei > tmp) { tmp = ei; a2 = i; } } //对于特殊情况,直接遍历间隔边界上的支持向量点,选择具有最大下降的值 return a2;
} //选定a1和a2,进行更新
bool SVM::takeStep(int a1, int a2)
{ if (a1 < -eps) return false; double x1, x2; //old alpha x2 = alpha[a2]; x1 = alpha[a1]; //计算剪辑的边界 double L, H; double s = label[a1] * label[a2];//a1 与 a2同号或异号 L = s < 0 ? s_max(0, alpha[a2] - alpha[a1]) : s_max(0, alpha[a2] + alpha[a1] - C); H = s < 0 ? s_min(C, C + alpha[a2] - alpha[a1]) : s_min(C, alpha[a2] + alpha[a1]); if (L >= H) return false; double eta = kernel(a1, a1) + kernel(a2, a2) - 2 * kernel(a1, a2); //更新alpah[a2] if (eta > 0) { alpha[a2] = x2 + label[a2] * (gx[a1] - label[a1] - gx[a2] + label[a2]) / eta; if (alpha[a2] < L) alpha[a2] = L; else if (alpha[a2] > H) alpha[a2] = H; } else//我也不知道为什么这么算,我抄的论文里的,意思是选到超平面距离近的边界 { alpha[a2] = L; double Lobj = objFun(a2); alpha[a2] = H; double Hobj = objFun(a2); if (Lobj < Hobj - eps) alpha[a2] = L; else if (Lobj > Hobj + eps) alpha[a2] = H; else alpha[a2] = x2; } //下降太少,忽略不计 if (fabs(alpha[a2] - x2) < eps*(alpha[a2] + x2 + eps)) { alpha[a2] = x2; return false; } //更新alpha[a1] alpha[a1] = x1 + s*(x2 - alpha[a2]); update(a1, a2, x1, x2); /* for (int ii = 0; ii < sampleNum; ii++) { cout << gx[ii] << endl; } */ return true; } //由SVM分类决策的对偶最优化问题求解alpha
/*
用序列最小最优化算法(SMO)求解alpha
step1:选取一对需要更新的变量alpha[i]和alpha[j]
step2:固定alpha[i]和alpha[j]以外的参数,求解对偶问题的最优化解获得更新后的alpha[i]和alpha[j]
参考:李航《统计学习方法》 JC Platt《Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines》
*/
void SVM::SMO()
{ //bool convergence = false; //判断有没有收敛 int a1, a2; bool Changed = true; //有没有更新 int numChanged = 0; //更新了多少次 int *eligSample = new int[sampleNum]; // 记录访问过的样本 int cnt = 0; //样本个数 computeGx(); do { numChanged = 0; cnt = 0; //选择第一个变量(最不满足KKT条件的样本点) //优先选 0 < alpha < C 的样本 , alpha会随着后面的迭代发生变化 for (int i = 0; i < sampleNum; i++) { //记录下不满足KKT条件的样本,做个缓存 if (Changed) { cnt = 0; for (int j = 0; j < sampleNum; j++) { if (alpha[j] > eps && alpha[j] < C - eps) { eligSample[cnt++] = j; } } Changed = false; } if (alpha[i] > eps && alpha[i] < C-eps) { a1 = i; //不满足KKT条件 if (fabs(label[i] * gx[i] - 1) > eps) { //选择第二个变量,优先选下降最多的 a2 = secondAlpha(i); Changed = takeStep(a1, a2); numChanged += Changed; if(Changed) continue; else //目标函数没有下降 { //先依次遍历间隔边界上的 for(int j=0; j<cnt;j++) { if (eligSample[j] == i) continue; a2 = eligSample[j]; Changed = takeStep(a1, a2); numChanged += Changed; if (Changed) break; } if (Changed) continue; //再遍历整个数据集 int k = 0; for (int j = 0; j < sampleNum; j++) { //这是上面已经试过的间隔上的点 if (eligSample[k] == j) { k++; continue; } a2 = j; Changed = takeStep(a1, a2); numChanged += Changed; if (Changed) break; } //找不到合适的alpha2, 换一个alpha1 } } } } if(numChanged)//已经有改变了 { Changed = false; continue; } //选其他不满足KKT条件的样本 for (int i = 0; i < sampleNum; i++) { a1 = i; if (fabs(alpha[i]) < eps && label[i] * gx[i] < 1 || fabs(alpha[i] - C) < eps && label[i] * gx[i] > 1) { //选择第二个变量,步骤同上 a2 = secondAlpha(i); Changed = takeStep(a1, a2); numChanged += Changed; if (Changed) continue; else //目标函数没有下降 { //先依次遍历间隔边界上的 //间隔边界上的点已经记录在eligSample中了 for(int j=0; j<cnt; j++) { if (eligSample[j] == i) continue; a2 = eligSample[j]; Changed = takeStep(a1, a2); numChanged += Changed; if (Changed) break; } if (Changed) continue; //再遍历整个数据集 int k = 0; for (int j = 0; j < sampleNum; j++) { if (j == eligSample[k]) { k++; continue; } a2 = j; Changed = takeStep(a1, a2); numChanged += Changed; if (Changed) break; } //找不到合适的alpha2, 换一个alpha1 } } } /*
// if (!Changed) { cout<<"num"<<numChanged<<endl; show(); } //《统计学习方法》里说的收敛条件是这个,但不管用 //所以改用JC Platt论文伪代码所提方法(也不是完全一样) convergence = isConvergence(); //show(); cnt++; if (cnt == 10000) { cout << "num" << numChanged << endl; show(); } */ }while (numChanged); delete eligSample;
} void SVM::show()
{ cout << "支持向量为:" << endl; for (int i = 0; i < sampleNum; i++) { if(alpha[i]>eps) cout <<i<<" 对应的alpha为:"<<alpha[i]<< endl; } cout << endl;
}
结果显示:
SVM C++ 实现相关推荐
- 支持向量机SVM序列最小优化算法SMO
支持向量机(Support Vector Machine)由V.N. Vapnik,A.Y. Chervonenkis,C. Cortes 等在1964年提出.序列最小优化算法(Sequential ...
- 线性回归、逻辑回归及SVM
1,回归(Linear Regression) 回归其实就是对已知公式的未知参数进行估计.可以简单的理解为:在给定训练样本点和已知的公式后,对于一个或多个未知参数,机器会自动枚举参数的所有可能取值(对 ...
- svm rbf人脸识别 yale_实操课——机器学习之人脸识别
SVM(Support Vector Machine)指的是支持向量机,是常见的一种判别方法.在机器学习领域,是一个有监督的学习模型,通常用来进行模式识别.分类以及回归分析.在n维空间中找到一个分类超 ...
- Python,OpenCV基于支持向量机SVM的手写数字OCR
Python,OpenCV基于支持向量机SVM的手写数字OCR 1. 效果图 2. SVM及原理 2. 源码 2.1 SVM的手写数字OCR 2.2 非线性SVM 参考 上一节介绍了基于KNN的手写数 ...
- 机器学习中的数学基础(4.1):支持向量机Support Vector Machine(SVM)
SVM可以说是一个很经典的二分类问题,属于有监督学习算法的一种.看过那么多的博客知乎解释SVM我一定要自己总结一篇,加深一下自己的理解. 带着问题去读文章会发现,柳暗花明又一村,瞬间李敏浩出现在眼前的 ...
- SVM算法实现光学字符识别
目录 1.数据来源 2.数据预处理 3.模型训练 4.模型性能评估 5.模型性能提升 5.1.核函数的选取 5.2.惩罚参数C的选取 OCR (Optical Character Recognitio ...
- 使用OpenCV进行SVM分类demo
代码来源 https://github.com/mbeyeler/opencv-machine-learning/blob/master/notebooks/06.01-Implementing-Yo ...
- SVM进行手写数字识别
使用了TensorFlow中的mnist数据集 from sklearn import svm import numpy as np from sklearn.metrics import class ...
- 机器学习(18)-- SVM支持向量机(根据身高体重分类性别)
目录 一.基础理论 二.身高体重预测性别 1.获取数据(男女生身高体重) 2.数据处理(合并数据) 3.设置标签 4.创建分类器(支持向量机) 4-1.创建svm分类器 4-2.设置分类器属性(线性核 ...
- Udacity机器人软件工程师课程笔记(二十二) - 物体识别 - 色彩直方图,支持向量机SVM
物体识别 1.HSV色彩空间 如果要进行颜色检测,HSV颜色空间是当前最常用的. HSV(Hue, Saturation, Value)是根据颜色的直观特性由A. R. Smith在1978年创建的一 ...
最新文章
- [知识库分享系列] 三、Web(高性能Web站点建设)
- 数据结构——二叉树的层次遍历进阶
- marquee滚动起始位置_巧用喵影关键帧制作滚动水印,让视频小偷无可盗
- TF-卷积函数 tf.nn.conv2d 介绍
- Redis2.6安装报错
- mysql表连接_SELECT中的多表连接
- 开发机至少要有16G内存
- java并发编程实战源码_java并发编程实战(附源码)
- (附源码)springboot电子阅览室app 毕业设计 016514
- Dos窗口文字背景颜色设置
- 安卓客户端使用矢量图
- 计算机博士、加班到凌晨也要化妆、段子手……IT 女神驾到!
- mysql latch和缓存关系_latch:cachebufferschains等待事件导致的latch争用的原理原因与...
- 使用ESP8266-01S 作为Station PC作为Server通讯出现 ERROR CLOSED问题的解决办法
- FreeSWITCH的传真发送
- 使用appium进行app自动化测试时遇到AppActivity设置正确但报Connect Appium Server Fail.A new session could not be created
- 从1234中选出3个组成不重复的三位数
- Scrapy爬取中国地震台网1年内地震数据
- java arraylist 无序_关于Java:按字母顺序排序arraylist(不区分大小写)
- 苹果手机可以微信分身吗_为什么手机自带的微信分身被腾讯微信限制登录呢?...