转自http://blog.csdn.net/firefight/article/details/6452188

是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list

其他方法:http://blog.csdn.net/onezeros/article/details/5672192

使用OPENCV训练手写数字识别分类器

1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数


4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

[cpp] view plaincopyprint?
  1. #include "stdafx.h"
  2. #include <fstream>
  3. #include "opencv2/opencv.hpp"
  4. #include <vector>
  5. using namespace std;
  6. using namespace cv;
  7. #define SHOW_PROCESS 0
  8. #define ON_STUDY 0
  9. class NumTrainData
  10. {
  11. public:
  12. NumTrainData()
  13. {
  14. memset(data, 0, sizeof(data));
  15. result = -1;
  16. }
  17. public:
  18. float data[64];
  19. int result;
  20. };
  21. vector<NumTrainData> buffer;
  22. int featureLen = 64;
  23. void swapBuffer(char* buf)
  24. {
  25. char temp;
  26. temp = *(buf);
  27. *buf = *(buf+3);
  28. *(buf+3) = temp;
  29. temp = *(buf+1);
  30. *(buf+1) = *(buf+2);
  31. *(buf+2) = temp;
  32. }
  33. void GetROI(Mat& src, Mat& dst)
  34. {
  35. int left, right, top, bottom;
  36. left = src.cols;
  37. right = 0;
  38. top = src.rows;
  39. bottom = 0;
  40. //Get valid area
  41. for(int i=0; i<src.rows; i++)
  42. {
  43. for(int j=0; j<src.cols; j++)
  44. {
  45. if(src.at<uchar>(i, j) > 0)
  46. {
  47. if(j<left) left = j;
  48. if(j>right) right = j;
  49. if(i<top) top = i;
  50. if(i>bottom) bottom = i;
  51. }
  52. }
  53. }
  54. //Point center;
  55. //center.x = (left + right) / 2;
  56. //center.y = (top + bottom) / 2;
  57. int width = right - left;
  58. int height = bottom - top;
  59. int len = (width < height) ? height : width;
  60. //Create a squre
  61. dst = Mat::zeros(len, len, CV_8UC1);
  62. //Copy valid data to squre center
  63. Rect dstRect((len - width)/2, (len - height)/2, width, height);
  64. Rect srcRect(left, top, width, height);
  65. Mat dstROI = dst(dstRect);
  66. Mat srcROI = src(srcRect);
  67. srcROI.copyTo(dstROI);
  68. }
  69. int ReadTrainData(int maxCount)
  70. {
  71. //Open image and label file
  72. const char fileName[] = "../res/train-images.idx3-ubyte";
  73. const char labelFileName[] = "../res/train-labels.idx1-ubyte";
  74. ifstream lab_ifs(labelFileName, ios_base::binary);
  75. ifstream ifs(fileName, ios_base::binary);
  76. if( ifs.fail() == true )
  77. return -1;
  78. if( lab_ifs.fail() == true )
  79. return -1;
  80. //Read train data number and image rows / cols
  81. char magicNum[4], ccount[4], crows[4], ccols[4];
  82. ifs.read(magicNum, sizeof(magicNum));
  83. ifs.read(ccount, sizeof(ccount));
  84. ifs.read(crows, sizeof(crows));
  85. ifs.read(ccols, sizeof(ccols));
  86. int count, rows, cols;
  87. swapBuffer(ccount);
  88. swapBuffer(crows);
  89. swapBuffer(ccols);
  90. memcpy(&count, ccount, sizeof(count));
  91. memcpy(&rows, crows, sizeof(rows));
  92. memcpy(&cols, ccols, sizeof(cols));
  93. //Just skip label header
  94. lab_ifs.read(magicNum, sizeof(magicNum));
  95. lab_ifs.read(ccount, sizeof(ccount));
  96. //Create source and show image matrix
  97. Mat src = Mat::zeros(rows, cols, CV_8UC1);
  98. Mat temp = Mat::zeros(8, 8, CV_8UC1);
  99. Mat img, dst;
  100. char label = 0;
  101. Scalar templateColor(255, 0, 255 );
  102. NumTrainData rtd;
  103. //int loop = 1000;
  104. int total = 0;
  105. while(!ifs.eof())
  106. {
  107. if(total >= count)
  108. break;
  109. total++;
  110. cout << total << endl;
  111. //Read label
  112. lab_ifs.read(&label, 1);
  113. label = label + '0';
  114. //Read source data
  115. ifs.read((char*)src.data, rows * cols);
  116. GetROI(src, dst);
  117. #if(SHOW_PROCESS)
  118. //Too small to watch
  119. img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);
  120. resize(dst, img, img.size());
  121. stringstream ss;
  122. ss << "Number " << label;
  123. string text = ss.str();
  124. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
  125. //imshow("img", img);
  126. #endif
  127. rtd.result = label;
  128. resize(dst, temp, temp.size());
  129. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
  130. for(int i = 0; i<8; i++)
  131. {
  132. for(int j = 0; j<8; j++)
  133. {
  134. rtd.data[ i*8 + j] = temp.at<uchar>(i, j);
  135. }
  136. }
  137. buffer.push_back(rtd);
  138. //if(waitKey(0)==27) //ESC to quit
  139. //  break;
  140. maxCount--;
  141. if(maxCount == 0)
  142. break;
  143. }
  144. ifs.close();
  145. lab_ifs.close();
  146. return 0;
  147. }
  148. void newRtStudy(vector<NumTrainData>& trainData)
  149. {
  150. int testCount = trainData.size();
  151. Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
  152. Mat res = Mat::zeros(testCount, 1, CV_32SC1);
  153. for (int i= 0; i< testCount; i++)
  154. {
  155. NumTrainData td = trainData.at(i);
  156. memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));
  157. res.at<unsigned int>(i, 0) = td.result;
  158. }
  159. /START RT TRAINNING//
  160. CvRTrees forest;
  161. CvMat* var_importance = 0;
  162. forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),
  163. CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
  164. forest.save( "new_rtrees.xml" );
  165. }
  166. int newRtPredict()
  167. {
  168. CvRTrees forest;
  169. forest.load( "new_rtrees.xml" );
  170. const char fileName[] = "../res/t10k-images.idx3-ubyte";
  171. const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
  172. ifstream lab_ifs(labelFileName, ios_base::binary);
  173. ifstream ifs(fileName, ios_base::binary);
  174. if( ifs.fail() == true )
  175. return -1;
  176. if( lab_ifs.fail() == true )
  177. return -1;
  178. char magicNum[4], ccount[4], crows[4], ccols[4];
  179. ifs.read(magicNum, sizeof(magicNum));
  180. ifs.read(ccount, sizeof(ccount));
  181. ifs.read(crows, sizeof(crows));
  182. ifs.read(ccols, sizeof(ccols));
  183. int count, rows, cols;
  184. swapBuffer(ccount);
  185. swapBuffer(crows);
  186. swapBuffer(ccols);
  187. memcpy(&count, ccount, sizeof(count));
  188. memcpy(&rows, crows, sizeof(rows));
  189. memcpy(&cols, ccols, sizeof(cols));
  190. Mat src = Mat::zeros(rows, cols, CV_8UC1);
  191. Mat temp = Mat::zeros(8, 8, CV_8UC1);
  192. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
  193. Mat img, dst;
  194. //Just skip label header
  195. lab_ifs.read(magicNum, sizeof(magicNum));
  196. lab_ifs.read(ccount, sizeof(ccount));
  197. char label = 0;
  198. Scalar templateColor(255, 0, 0);
  199. NumTrainData rtd;
  200. int right = 0, error = 0, total = 0;
  201. int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
  202. while(ifs.good())
  203. {
  204. //Read label
  205. lab_ifs.read(&label, 1);
  206. label = label + '0';
  207. //Read data
  208. ifs.read((char*)src.data, rows * cols);
  209. GetROI(src, dst);
  210. //Too small to watch
  211. img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
  212. resize(dst, img, img.size());
  213. rtd.result = label;
  214. resize(dst, temp, temp.size());
  215. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
  216. for(int i = 0; i<8; i++)
  217. {
  218. for(int j = 0; j<8; j++)
  219. {
  220. m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
  221. }
  222. }
  223. if(total >= count)
  224. break;
  225. char ret = (char)forest.predict(m);
  226. if(ret == label)
  227. {
  228. right++;
  229. if(total <= 5000)
  230. right_1++;
  231. else
  232. right_2++;
  233. }
  234. else
  235. {
  236. error++;
  237. if(total <= 5000)
  238. error_1++;
  239. else
  240. error_2++;
  241. }
  242. total++;
  243. #if(SHOW_PROCESS)
  244. stringstream ss;
  245. ss << "Number " << label << ", predict " << ret;
  246. string text = ss.str();
  247. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
  248. imshow("img", img);
  249. if(waitKey(0)==27) //ESC to quit
  250. break;
  251. #endif
  252. }
  253. ifs.close();
  254. lab_ifs.close();
  255. stringstream ss;
  256. ss << "Total " << total << ", right " << right <<", error " << error;
  257. string text = ss.str();
  258. putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
  259. imshow("img", img);
  260. waitKey(0);
  261. return 0;
  262. }
  263. void newSvmStudy(vector<NumTrainData>& trainData)
  264. {
  265. int testCount = trainData.size();
  266. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
  267. Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
  268. Mat res = Mat::zeros(testCount, 1, CV_32SC1);
  269. for (int i= 0; i< testCount; i++)
  270. {
  271. NumTrainData td = trainData.at(i);
  272. memcpy(m.data, td.data, featureLen*sizeof(float));
  273. normalize(m, m);
  274. memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));
  275. res.at<unsigned int>(i, 0) = td.result;
  276. }
  277. /START SVM TRAINNING//
  278. CvSVM svm = CvSVM();
  279. CvSVMParams param;
  280. CvTermCriteria criteria;
  281. criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
  282. param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);
  283. svm.train(data, res, Mat(), Mat(), param);
  284. svm.save( "SVM_DATA.xml" );
  285. }
  286. int newSvmPredict()
  287. {
  288. CvSVM svm = CvSVM();
  289. svm.load( "SVM_DATA.xml" );
  290. const char fileName[] = "../res/t10k-images.idx3-ubyte";
  291. const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
  292. ifstream lab_ifs(labelFileName, ios_base::binary);
  293. ifstream ifs(fileName, ios_base::binary);
  294. if( ifs.fail() == true )
  295. return -1;
  296. if( lab_ifs.fail() == true )
  297. return -1;
  298. char magicNum[4], ccount[4], crows[4], ccols[4];
  299. ifs.read(magicNum, sizeof(magicNum));
  300. ifs.read(ccount, sizeof(ccount));
  301. ifs.read(crows, sizeof(crows));
  302. ifs.read(ccols, sizeof(ccols));
  303. int count, rows, cols;
  304. swapBuffer(ccount);
  305. swapBuffer(crows);
  306. swapBuffer(ccols);
  307. memcpy(&count, ccount, sizeof(count));
  308. memcpy(&rows, crows, sizeof(rows));
  309. memcpy(&cols, ccols, sizeof(cols));
  310. Mat src = Mat::zeros(rows, cols, CV_8UC1);
  311. Mat temp = Mat::zeros(8, 8, CV_8UC1);
  312. Mat m = Mat::zeros(1, featureLen, CV_32FC1);
  313. Mat img, dst;
  314. //Just skip label header
  315. lab_ifs.read(magicNum, sizeof(magicNum));
  316. lab_ifs.read(ccount, sizeof(ccount));
  317. char label = 0;
  318. Scalar templateColor(255, 0, 0);
  319. NumTrainData rtd;
  320. int right = 0, error = 0, total = 0;
  321. int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
  322. while(ifs.good())
  323. {
  324. //Read label
  325. lab_ifs.read(&label, 1);
  326. label = label + '0';
  327. //Read data
  328. ifs.read((char*)src.data, rows * cols);
  329. GetROI(src, dst);
  330. //Too small to watch
  331. img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
  332. resize(dst, img, img.size());
  333. rtd.result = label;
  334. resize(dst, temp, temp.size());
  335. //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
  336. for(int i = 0; i<8; i++)
  337. {
  338. for(int j = 0; j<8; j++)
  339. {
  340. m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
  341. }
  342. }
  343. if(total >= count)
  344. break;
  345. normalize(m, m);
  346. char ret = (char)svm.predict(m);
  347. if(ret == label)
  348. {
  349. right++;
  350. if(total <= 5000)
  351. right_1++;
  352. else
  353. right_2++;
  354. }
  355. else
  356. {
  357. error++;
  358. if(total <= 5000)
  359. error_1++;
  360. else
  361. error_2++;
  362. }
  363. total++;
  364. #if(SHOW_PROCESS)
  365. stringstream ss;
  366. ss << "Number " << label << ", predict " << ret;
  367. string text = ss.str();
  368. putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
  369. imshow("img", img);
  370. if(waitKey(0)==27) //ESC to quit
  371. break;
  372. #endif
  373. }
  374. ifs.close();
  375. lab_ifs.close();
  376. stringstream ss;
  377. ss << "Total " << total << ", right " << right <<", error " << error;
  378. string text = ss.str();
  379. putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
  380. imshow("img", img);
  381. waitKey(0);
  382. return 0;
  383. }
  384. int main( int argc, char *argv[] )
  385. {
  386. #if(ON_STUDY)
  387. int maxCount = 60000;
  388. ReadTrainData(maxCount);
  389. //newRtStudy(buffer);
  390. newSvmStudy(buffer);
  391. #else
  392. //newRtPredict();
  393. newSvmPredict();
  394. #endif
  395. return 0;
  396. }
  397. //from: http://blog.csdn.net/yangtrees/article/details/7458466

学习OpenCV——SVM 手写数字检测相关推荐

  1. python手写数字识别教学_python实现基于SVM手写数字识别功能

    本文实例为大家分享了SVM手写数字识别功能的具体代码,供大家参考,具体内容如下 1.SVM手写数字识别 识别步骤: (1)样本图像的准备. (2)图像尺寸标准化:将图像大小都标准化为8*8大小. (3 ...

  2. 【学习日记】手写数字识别及神经网络基本模型

    2021.10.7 [学习日记]手写数字识别及神经网络基本模型 1 概述 张量(tensor)是数字的容器,是矩阵向任意维度的推广,其维度称为轴(axis).深度学习的本质是对张量做各种运算处理,其分 ...

  3. 书接上文,基于藏文手写数字数据开发构建yolov5n轻量级藏文手写数字检测识别系统

    在上一篇文章中: <python基于轻量级CNN模型开发构建手写藏文数字识别系统> 开发实现了轻量级的藏文手写数字识别系统,这里主要是想基于前文的数据,整合目标检测模型来进一步挖掘藏文手写 ...

  4. 深度学习数字仪表盘识别_【深度学习系列】手写数字识别实战

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

  5. 深度学习项目实战——手写数字识别项目

    摘要 本文将介绍的有关于的paddle的实战的相关的问题,并分析相关的代码的阅读和解释.并扩展有关于的python的有关的语言.介绍了深度学习步骤: 1. 数据处理:读取数据 和 预处理操作 2. 模 ...

  6. CNN学习MNIST实现手写数字识别

    CNN的实现 我们之前已经实现了卷积层和池化层,现在来组合这些层,搭建进行手写数字识别的CNN. # 初始化权重 self.params = {'W1': weight_init_std * np.r ...

  7. 深度学习入门实践学习——手写数字识别(百度飞桨平台)——上篇

    一.项目平台 百度飞桨 二.项目框架 1.数据处理: 2.模型设计:网络结构,损失函数: 3.训练配置:优化器,资源配置: 4.训练过程: 5.保存加载. 三.手写数字识别任务 1.构建神经网络流程: ...

  8. 【第一个深度学习模型应用-手写数字识别】

    基于BP神经网络的手写数字识别报告 基于BP神经网络的手写数字识别报告 一.任务描述 二.数据集来源 三.方法 3.1 数据集处理方法 3.2.模型结构设计 3.3.模型算法 四.实验 4.1.实验环 ...

  9. svm手写数字识别_KNN 算法实战篇如何识别手写数字

    上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数据集是一个用于图像处理的数据集,这些数据描绘了 [0, 9] 的数字,我们可以用KNN 算 ...

  10. 基于opencv的手写数字字符识别

    摘要 本程序主要参照论文,<基于OpenCV的脱机手写字符识别技术>实现了,对于手写阿拉伯数字的识别工作.识别工作分为三大步骤:预处理,特征提取,分类识别.预处理过程主要找到图像的ROI部 ...

最新文章

  1. 轻松获取LAMP,LNMP环境编译参数配置
  2. Tomcat 源码阅读记录(1)
  3. solidworks入门
  4. c#语言编写汉诺塔游戏,c#实现汉诺塔问题示例
  5. pads最新版本是多少_电路EDA软件究竟有多少?
  6. Xshell5 提示要继续使用此程序,您必须应用最新的更新或使用新版本
  7. .NET 6新特性试用 | TryGetNonEnumeratedCount
  8. C# -- 多线程向同一文件写入
  9. C#4.0 命名参数可选参数
  10. hrbp 牵着鼻子走_防止被下属牵着鼻子走的四个经典方法,学会了,下属就好管了...
  11. 怎么关闭vivo系统自检_手机系统越来越卡,把握这几个优化设置,让手机流畅起来...
  12. Microsoft Office 2010 中的 Office 检测到此文件有问题
  13. github 从0开始的基本操作到fork和pr项目
  14. 信号完整性分析的基础概念
  15. redis主从配置及主从切换
  16. 应用场景:征信和权属管理
  17. 怎么在MAC系统下查看系统详细信息?新手快来看!
  18. 2 pygraphviz在windows10 64位下的安装问题(反斜杠的血案)
  19. 【开源】java做游戏之QQ连连看java单机高仿版(算是目前最高仿的了)
  20. 驱动开发:运用VAD隐藏R3内存思路

热门文章

  1. [C++ primer]优化内存分配
  2. Windows 7 64位下使用ADB驱动
  3. WildPacket AiroPeek EtherPeek OmniPeek
  4. 第一章:WTL的5个W
  5. Spring 通过XML配置装配Bean
  6. MyBatis中解决字段名与实体类属性名不相同的冲突
  7. spring mvc 的上传图片是怎么实现的?
  8. C语言图形库简单对比及EGE库的安装小手册
  9. ubuntu14上安装ros教程
  10. Android ListView 指定显示最后一行