前言:

SVM(支持向量机)一种训练分类器的学习方法

mnist 是一个手写字体图像数据库,训练样本有60000个,测试样本有10000个

LibSVM 一个常用的SVM框架

OpenCV3.0 中的ml包含了很多的ML框架接口,就试试了。

详细的OpenCV文档:http://docs.opencv.org/3.0-beta/doc/tutorials/ml/introduction_to_svm/introduction_to_svm.html

mnist数据下载:http://yann.lecun.com/exdb/mnist/

LibSVM下载:http://www.csie.ntu.edu.tw/~cjlin/libsvm/

========================我是分割线=============================

训练的过程大致如下:

1. 读取mnist训练集数据

2. 训练

3. 读取mnist测试数据,对比预测结果,得到错误率

具体实现:

1. mnist给出的数据文件是二进制文件

四个文件,解压后如下

"train-images.idx3-ubyte" 二进制文件,存储了头文件信息以及60000张28*28图像pixel信息(用于训练)

"train-labels.idx1-ubyte" 二进制文件,存储了头文件信息以及60000张图像label信息

"t10k-images.idx3-ubyte"二进制文件,存储了头文件信息以及10000张28*28图像pixel信息(用于测试)

"t10k-labels.idx1-ubyte"二进制文件,存储了头文件信息以及10000张图像label信息

因为OpenCV中没有直接导入MINST数据的文件,所以需要自己写函数来读取

首先要知道,MNIST数据的数据格式

IMAGE FILE包含四个int型的头部数据(magic number,number_of_images, number_of_rows, number_of_columns)

余下的每一个byte表示一个pixel的数据,范围是0-255(可以在读入的时候scale到0~1的区间)

LABEL FILE包含两个int型的头部数据(magic number, number of items)

余下的每一个byte表示一个label数据,范围是0-9

注意(第一个坑):MNIST是大端存储,然而大部分的Intel处理器都是小端存储,所以对于int、long、float这些多字节的数据类型,就要一个一个byte地翻转过来,才能正确显示。

1 //翻转

2 int reverseInt(inti) {3 unsigned charc1, c2, c3, c4;4

5 c1 = i & 255;6 c2 = (i >> 8) & 255;7 c3 = (i >> 16) & 255;8 c4 = (i >> 24) & 255;9

10 return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) +c4;11 }

View Code

然后读取MNIST文件,但是它是二进制文件,打开方式

所以不能用

ifstream file(fileName);

而要改成

ifstream file(fileName, ios::binary);

注意(第二个坑):如果用第一条指令来打开文件,不会报错,但是数据会出现错误,头部数据仍然正确,但是后面的pixel数据大部分都是0,我刚开始没注意,开始training的时候发现等了很久...真的是很久...(7+ hours)...估计是达到迭代终止的最大次数了,才停下来的

嗯,stack overflow上也有类似的提问:

注意(第三个坑):

training时,IMAGE和LABEL的数据分别都放进一个MAT中存储,但是只能是CV32_F或者CV32_S的格式,不然会assertion报错

OPENCV给出的文档中,例子是这样的:(但是predict的时候又会要求label的格式是unsigned int)所以...可以设置data的Mat格式为CV_32FC1,label的Mat格式为CV_32SC1

顺便地,图像训练数据的转换存储格式(http://stackoverflow.com/questions/14694810/using-opencv-and-svm-with-images?rq=1)

最后,为了验证读取数据的正确性,一个有效的办法就是输出第一个和最后一个数据(可以输出答应第一个/最后一个image以及label)

2. 训练

(此处我是直接对原图像训练,并没有提取任何的特征)

也有人建议这里应该对图像做HOG特征提取,再配合label训练(我还没试过...不知道效果如何...)

opencv3.0和2.4的SVM接口有不同,基本可以按照以下的格式来执行:

ml::SVM::Params params;params.svmType =ml::SVM::C_SVC;params.kernelType =ml::SVM::POLY;params.gamma = 3;

Ptr<:svm> svm = ml::SVM::create(params);

Mat trainData;//每行为一个样本

Mat labels;

svm->train( trainData , ml::ROW_SAMPLE , labels );//...

svm->save("....");//文件形式为xml,可以保存在txt或者xml文件中

Ptr svm=statModel::load("....");

Mat query;//输入, 1个通道

Mat res; //输出

svm->predict(query, res);

但是要注意,如果报错的话最好去看opencv3.0的文档,里面有函数原型和解释,我在实际操作的过程中,也做了一些改动

1)设置参数

SVM的参数有很多,但是与C_SVC和RBF有关的就只有gamma和C,所以设置这两个就好,终止条件设置和默认一样,由经验可得(其实是查阅了很多的资料,把gamma设置成0.01,这样训练收敛速度会快很多)

Ptr svm =SVM::create();

svm->setType(SVM::C_SVC);

svm->setKernel(SVM::RBF);

svm->setGamma(0.01);

svm->setC(10.0);

svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000,FLT_EPSILON));

svm_type –指定SVM的类型,下面是可能的取值:

CvSVM::C_SVC C类支持向量分类机。 n类分组 (n \geq 2),允许用异常值惩罚因子C进行不完全分类。

CvSVM::NU_SVC \nu类支持向量分类机。n类似然不完全分类的分类器。参数为 \nu 取代C(其值在区间【0,1】中,nu越大,决策边界越平滑)。

CvSVM::ONE_CLASS 单分类器,所有的训练数据提取自同一个类里,然后SVM建立了一个分界线以分割该类在特征空间中所占区域和其它类在特征空间中所占区域。

CvSVM::EPS_SVR \epsilon类支持向量回归机。训练集中的特征向量和拟合出来的超平面的距离需要小于p。异常值惩罚因子C被采用。

CvSVM::NU_SVR \nu类支持向量回归机。 \nu 代替了 p。

kernel_type –SVM的内核类型,下面是可能的取值:

CvSVM::LINEAR 线性内核。没有任何向映射至高维空间,线性区分(或回归)在原始特征空间中被完成,这是最快的选择。K(x_i, x_j) = x_i^T x_j.

CvSVM::POLY 多项式内核: K(x_i, x_j) = (\gamma x_i^T x_j + coef0)^{degree}, \gamma > 0.

CvSVM::RBF 基于径向的函数,对于大多数情况都是一个较好的选择: K(x_i, x_j) = e^{-\gamma ||x_i - x_j||^2}, \gamma > 0.

CvSVM::SIGMOID Sigmoid函数内核:K(x_i, x_j) = \tanh(\gamma x_i^T x_j + coef0).

degree – 内核函数(POLY)的参数degree。

gamma – 内核函数(POLY/ RBF/ SIGMOID)的参数\gamma。

coef0 – 内核函数(POLY/ SIGMOID)的参数coef0。

Cvalue – SVM类型(C_SVC/ EPS_SVR/ NU_SVR)的参数C。

nu – SVM类型(NU_SVC/ ONE_CLASS/ NU_SVR)的参数 \nu。

p – SVM类型(EPS_SVR)的参数 \epsilon。

class_weights – C_SVC中的可选权重,赋给指定的类,乘以C以后变成 class\_weights_i * C。所以这些权重影响不同类别的错误分类惩罚项。权重越大,某一类别的误分类数据的惩罚项就越大。

term_crit – SVM的迭代训练过程的中止条件,解决部分受约束二次最优问题。您可以指定的公差和/或最大迭代次数。

2)训练

Mat trainData;

Mat labels;

trainData=read_mnist_image(trainImage);

labels=read_mnist_label(trainLabel);

svm->train(trainData, ROW_SAMPLE, labels);

3)保存

svm->save("mnist_dataset/mnist_svm.xml");

3. 测试,比对结果

(此处的FLT_EPSILON是一个极小的数,1.0 - FLT_EPSILON != 1.0)

Mat testData;

Mat tLabel;

testData=read_mnist_image(testImage);

tLabel=read_mnist_label(testLabel);float count = 0;for (int i = 0; i < testData.rows; i++) {

Mat sample=testData.row(i);float res = svm1->predict(sample);

res= std::abs(res - tLabel.at(i, 0)) <= FLT_EPSILON ? 1.f : 0.f;

count+=res;

}

cout<< "正确的识别个数 count =" << count <

cout<< "错误率为..." << (10000 - count + 0.0) / 10000 * 100.0 << "%....\n";

这里没有使用svm->predict(query, res);

然后就查看了opencv的文档,当传入数据是Mat 而不是cvMat时,可以利用predict的返回值(float)来判断预测是否正确。

运行结果:

1)1000个训练数据/1000个测试数据

2)2000个训练数据/2000个测试数据

3)5000个训练数据/5000个测试数据

4)10000个训练数据/10000个测试数据

5)60000个训练数据/10000个测试数据

最后,关于运行时间(在程序正确的前提下,训练时长和初始的参数设置有关),给出我最的运行结果(1000张图是11s左右,6000张是1300s ~ 2000s左右)

代码:

1 #ifndef MNIST_H2 #define MNIST_H

3

4 #include

5 #include

6 #include

7 #include

8 #include

9

10 using namespacecv;11 using namespacestd;12

13 //小端存储转换

14 int reverseInt(inti);15

16 //读取image数据集信息

17 Mat read_mnist_image(const stringfileName);18

19 //读取label数据集信息

20 Mat read_mnist_label(const stringfileName);21

22 #endif

mnist.h

1 #include "mnist.h"

2

3 //计时器

4 doublecost_time;5 clock_t start_time;6 clock_t end_time;7

8 //测试item个数

9 int testNum = 10000;10

11 int reverseInt(inti) {12 unsigned charc1, c2, c3, c4;13

14 c1 = i & 255;15 c2 = (i >> 8) & 255;16 c3 = (i >> 16) & 255;17 c4 = (i >> 24) & 255;18

19 return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) +c4;20 }21

22 Mat read_mnist_image(const stringfileName) {23 int magic_number = 0;24 int number_of_images = 0;25 int n_rows = 0;26 int n_cols = 0;27

28 Mat DataMat;29

30 ifstream file(fileName, ios::binary);31 if(file.is_open())32 {33 cout << "成功打开图像集 ... \n";34

35 file.read((char*)&magic_number, sizeof(magic_number));36 file.read((char*)&number_of_images, sizeof(number_of_images));37 file.read((char*)&n_rows, sizeof(n_rows));38 file.read((char*)&n_cols, sizeof(n_cols));39 //cout << magic_number << " " << number_of_images << " " << n_rows << " " << n_cols << endl;

40

41 magic_number =reverseInt(magic_number);42 number_of_images =reverseInt(number_of_images);43 n_rows =reverseInt(n_rows);44 n_cols =reverseInt(n_cols);45 cout << "MAGIC NUMBER =" <(i, j) =pixel_value;65

66 //打印第一张和最后一张图像数据

67 if (i == 0) {68 s.at(j / n_cols, j % n_cols) =pixel_value;69 }70 else if (i == number_of_images - 1) {71 e.at(j / n_cols, j % n_cols) =pixel_value;72 }73 }74 }75 end_time =clock();76 cost_time = (end_time - start_time) /CLOCKS_PER_SEC;77 cout << "读取Image数据完毕......" << cost_time << "s\n";78

79 imshow("first image", s);80 imshow("last image", e);81 waitKey(0);82 }83 file.close();84 returnDataMat;85 }86

87 Mat read_mnist_label(const stringfileName) {88 intmagic_number;89 intnumber_of_items;90

91 Mat LabelMat;92

93 ifstream file(fileName, ios::binary);94 if(file.is_open())95 {96 cout << "成功打开Label集 ... \n";97

98 file.read((char*)&magic_number, sizeof(magic_number));99 file.read((char*)&number_of_items, sizeof(number_of_items));100 magic_number =reverseInt(magic_number);101 number_of_items =reverseInt(number_of_items);102

103 cout << "MAGIC NUMBER =" << magic_number << "; NUMBER OF ITEMS =" << number_of_items <

105 //-test-106 //number_of_items = testNum;107 //记录第一个label和最后一个label

108 unsigned int s = 0, e = 0;109

110 cout << "开始读取Label数据......\n";111 start_time =clock();112 LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);113 for (int i = 0; i < number_of_items; i++) {114 unsigned char temp = 0;115 file.read((char*)&temp, sizeof(temp));116 LabelMat.at(i, 0) = (unsigned int)temp;117

118 //打印第一个和最后一个label

119 if (i == 0) s = (unsigned int)temp;120 else if (i == number_of_items - 1) e = (unsigned int)temp;121 }122 end_time =clock();123 cost_time = (end_time - start_time) /CLOCKS_PER_SEC;124 cout << "读取Label数据完毕......" << cost_time << "s\n";125

126 cout << "first label =" << s <

1 /*

2 svm_type –3 指定SVM的类型,下面是可能的取值:4 CvSVM::C_SVC C类支持向量分类机。 n类分组 (n \geq 2),允许用异常值惩罚因子C进行不完全分类。5 CvSVM::NU_SVC \nu类支持向量分类机。n类似然不完全分类的分类器。参数为 \nu 取代C(其值在区间【0,1】中,nu越大,决策边界越平滑)。6 CvSVM::ONE_CLASS 单分类器,所有的训练数据提取自同一个类里,然后SVM建立了一个分界线以分割该类在特征空间中所占区域和其它类在特征空间中所占区域。7 CvSVM::EPS_SVR \epsilon类支持向量回归机。训练集中的特征向量和拟合出来的超平面的距离需要小于p。异常值惩罚因子C被采用。8 CvSVM::NU_SVR \nu类支持向量回归机。 \nu 代替了 p。9

10 可从 [LibSVM] 获取更多细节。11

12 kernel_type –13 SVM的内核类型,下面是可能的取值:14 CvSVM::LINEAR 线性内核。没有任何向映射至高维空间,线性区分(或回归)在原始特征空间中被完成,这是最快的选择。K(x_i, x_j) = x_i^T x_j.15 CvSVM::POLY 多项式内核: K(x_i, x_j) = (\gamma x_i^T x_j + coef0)^{degree}, \gamma > 0.16 CvSVM::RBF 基于径向的函数,对于大多数情况都是一个较好的选择: K(x_i, x_j) = e^{-\gamma ||x_i - x_j||^2}, \gamma > 0.17 CvSVM::SIGMOID Sigmoid函数内核:K(x_i, x_j) = \tanh(\gamma x_i^T x_j + coef0).18

19 degree – 内核函数(POLY)的参数degree。20

21 gamma – 内核函数(POLY/ RBF/ SIGMOID)的参数\gamma。22

23 coef0 – 内核函数(POLY/ SIGMOID)的参数coef0。24

25 Cvalue – SVM类型(C_SVC/ EPS_SVR/ NU_SVR)的参数C。26

27 nu – SVM类型(NU_SVC/ ONE_CLASS/ NU_SVR)的参数 \nu。28

29 p – SVM类型(EPS_SVR)的参数 \epsilon。30

31 class_weights – C_SVC中的可选权重,赋给指定的类,乘以C以后变成 class\_weights_i * C。所以这些权重影响不同类别的错误分类惩罚项。权重越大,某一类别的误分类数据的惩罚项就越大。32

33 term_crit – SVM的迭代训练过程的中止条件,解决部分受约束二次最优问题。您可以指定的公差和/或最大迭代次数。34

35 */

36

37

38 #include "mnist.h"

39

40 #include

41 #include

42 #include "opencv2/imgcodecs.hpp"

43 #include

44 #include

45

46 #include

47 #include

48

49 using namespacestd;50 using namespacecv;51 using namespacecv::ml;52

53 string trainImage = "mnist_dataset/train-images.idx3-ubyte";54 string trainLabel = "mnist_dataset/train-labels.idx1-ubyte";55 string testImage = "mnist_dataset/t10k-images.idx3-ubyte";56 string testLabel = "mnist_dataset/t10k-labels.idx1-ubyte";57 //string testImage = "mnist_dataset/train-images.idx3-ubyte";58 //string testLabel = "mnist_dataset/train-labels.idx1-ubyte";59

60 //计时器

61 doublecost_time_;62 clock_t start_time_;63 clock_t end_time_;64

65 intmain()66 {67

68 //--------------------- 1. Set up training data ---------------------------------------

69 Mat trainData;70 Mat labels;71 trainData =read_mnist_image(trainImage);72 labels =read_mnist_label(trainLabel);73

74 cout << trainData.rows << " " << trainData.cols <

77 //------------------------ 2. Set up the support vector machines parameters --------------------

78 Ptr svm =SVM::create();79 svm->setType(SVM::C_SVC);80 svm->setKernel(SVM::RBF);81 //svm->setDegree(10.0);

82 svm->setGamma(0.01);83 //svm->setCoef0(1.0);

84 svm->setC(10.0);85 //svm->setNu(0.5);86 //svm->setP(0.1);

87 svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));88

89 //------------------------ 3. Train the svm ----------------------------------------------------

90 cout << "Starting training process" <train(trainData, ROW_SAMPLE, labels);93 end_time_ =clock();94 cost_time_ = (end_time_ - start_time_) /CLOCKS_PER_SEC;95 cout << "Finished training process...cost" << cost_time_ << "seconds..." <

97 //------------------------ 4. save the svm ----------------------------------------------------

98 svm->save("mnist_dataset/mnist_svm.xml");99 cout << "save as /mnist_dataset/mnist_svm.xml" <

101

102 //------------------------ 5. load the svm ----------------------------------------------------

103 cout << "开始导入SVM文件...\n";104 Ptr svm1 = StatModel::load("mnist_dataset/mnist_svm.xml");105 cout << "成功导入SVM文件...\n";106

107

108 //------------------------ 6. read the test dataset -------------------------------------------

109 cout << "开始导入测试数据...\n";110 Mat testData;111 Mat tLabel;112 testData =read_mnist_image(testImage);113 tLabel =read_mnist_label(testLabel);114 cout << "成功导入测试数据!!!\n";115

116

117 float count = 0;118 for (int i = 0; i < testData.rows; i++) {119 Mat sample =testData.row(i);120 float res = svm1->predict(sample);121 res = std::abs(res - tLabel.at(i, 0)) <= FLT_EPSILON ? 1.f : 0.f;122 count +=res;123 }124 cout << "正确的识别个数 count =" << count <

127 system("pause");128 return 0;129 }

OpenCV的详细介绍:请点这里

OpenCV的下载地址:请点这里

linux手写数字识别,OpenCV 3.0中的SVM训练 mnist 手写字体识别相关推荐

  1. python手写数字识别实验报告_Python代码实现简单的MNIST手写数字识别(适合初学者看)...

    补充:由于很多同学找我要原数据集和代码,所以我上传到了资源里,https://download..net/download/zugexiaodui/10913834 初学机器学习,第一步是做一个简单的 ...

  2. 基于人工智能方法的手写数字图像识别_【工程分析】基于ResNet的手写数字识别...

    ねぇ 呐 私に気付いてよ 快点注意到我吧 もう そんな事 那种事 一定 望んでも 再去奢求 しょうがないだろ 也无可奈何吧 --真野あゆみ<Bipolar emotion>(作詞:Mits ...

  3. 深度学习3—用三层全连接神经网络训练MNIST手写数字字符集

    上一篇文章:深度学习2-任意结点数的三层全连接神经网络 距离上篇文章过去了快四个月了,真是时光飞逝,之前因为要考博所以耽误了更新,谁知道考完博后之前落下的接近半个学期的工作是如此之多,以至于弄到现在才 ...

  4. c# hdf5 写string_聊一聊C#8.0中的 await foreach

    (给DotNet加星标,提升.Net技能) 转自:码农阿宇 cnblogs.com/CoderAyu/p/10680805.html AsyncStreamsInCShaper 8.0 很开心今天能与 ...

  5. cxcy在c语言中表示坐标,c – OpenCV 3.0中的活动轮廓模型

    我正在尝试使用C语言中的Opencv 3.0实现Active Contour Models算法. 这个算法基于我为MatLab编写的脚本,并没有按预期工作. 这两个图像显示了两种算法运行的结果. Ma ...

  6. 正则表达式只能写数字_正则表达式真的很骚,可惜你不会写

    源:https://juejin.im/post/5b96a8e2e51d450e6a2de115 本文旨在用最通俗的语言讲述最枯燥的基本知识 文章提纲: 元字符 重复限定符 分组 转义 条件或 区间 ...

  7. 红旗系统linux忘了开机密码,红旗Linux6.0中忘记了root密码

    很久很久以前,用虚拟机安装了red flag6.0+windows xp 的双系统,很久很久以后,打开虚拟机red flag系统忘记了密码:在网上搜了下解决办法,转了先,不知管不管用!

  8. pytorch 预测手写体数字_深度学习之PyTorch实战(3)——实战手写数字识别

    如果需要小编其他论文翻译,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 上一节,我们已经 ...

  9. 全连接神经网络实现MNIST手写数字识别

    在对全连接神经网络的基本知识(全连接神经网络详解)学习之后,通过MNIST手写数字识别这个小项目来学习如何实现全连接神经网络. MNIST数据集 对于深度学习的任何项目来说,数据集是其中最为关键的部分 ...

最新文章

  1. mysql linux_linux下mysql下载安装
  2. VUE3模板ref引用子组件或者子组件的方法
  3. 阿里小米获运营商牌照;罗永浩吐槽苹果;谷歌曾私下求情欧盟 | 极客头条
  4. 【ASM】udev简介及配置、多路径(multipath)等
  5. JSONObject以及json(转)
  6. 页面可用性之浏览器默认字体与CSS 中文字体
  7. jacob调用word宏
  8. 在网上看到SpiceWorks是一个免费但很强大的HELPDESK系统
  9. 如何使用光盘启动计算机,怎么用光盘PE安装win7系统
  10. edm邮件html模板,EDM模板使用说明
  11. whether 连词或代词词性都不能作为疑问词
  12. java学习笔记第三周(二)
  13. 机器学习中的概率分布
  14. 1秒钟组装发动机,我震惊了
  15. 百度百聘企业简单信息获取
  16. 使用Pageoffice打开Office word报错0x80040154问题或者卸除WPS后Microsoft Office图标无法显示问题
  17. 新库上线 | 税收调查企业专利及引用被引用数据
  18. foxmail国外只能收邮件,不能发邮件
  19. 【图像处理 直方图 OpenCV实现】
  20. 浅谈性能优化有哪些指标

热门文章

  1. HTML5 canvas fillText() 方法
  2. 功能强大的STLINK-V3MINI(和V3MODS),尺寸仅15 x 30mm
  3. 移动软件开发四——高校新闻网
  4. 视频会议的昨天、今天和明天
  5. python学习记录一:关于分布式进程执行报错以及解决方案
  6. 关于HBase及HBase Shell的认识
  7. 如何将物流信息导出保存在EXCEL表格里面,物流查询
  8. 5.1字符串和常用的数据结构(列表、元组、集合、字典)
  9. html酷炫电子时钟效果,Html5时钟特效代码
  10. Trip.io:区块链在旅行住宿预订领域落地