三、基于SVM算法实现手写数字识别

作为一个工科生,而非数学专业的学生,我们研究一个算法,是要将它用于实际领域的。下面给出基于OpenCV3.0的SVM算法手写数字识别程序源码(参考http://blog.csdn.net/firefight/article/details/6452188)程序略有改动。

本部分将基于OpenCV实现简单的数字识别,待识别图像如下图所示,通过以下几个步骤实现图像中的数字的自动识别。

1.使用OpenCV训练手写数字识别分类器;

2.图像预处理及图像分割;

3.应用分类器进行识别。

3.1使用OpenCV训练手写数字识别分类器

所谓学习分类器就是根据训练样本,选取模型训练产生数字分类器,这里采用上文提到的SVM算法。

训练集使用MNIST,这个MNIST数据库是一个手写数字的数据库,它提供了六万的训练集和一万的测试集。它的图片是被规范处理过的,是一张被放在中间部位的28px*28px的灰度图。总共包含4个文件,每一个文件头部几个字节都记录着这些图片的信息,然后才是储存的图片信息,关于文件信息的具体描述可以参考下面这个网站:https://www.jianshu.com/p/4195577585e6

下面是利用OpenCV 3.2.0的SVM相关API学习MNIST样本库产生样本函数的主要代码:(值得注意的是MNIST库中的图像是黑底白字的)

svm.h头文件

#pragma once
#include <stdio.h>
#include <tchar.h>
#include<opencv/cv.h>
#include<opencv/highgui.h>#include <windows.h>
#include <stdlib.h>
#include <iostream>
using namespace std;
using namespace cv;class NumTrainData
{
public:NumTrainData(){memset(data, 0, sizeof(data));//Sets buffers to a specified character. Init the dataresult = -1;}
public:float data[64];int result;
};extern vector<NumTrainData> buffer;int ReadTrainData(int maxCount);
void newSvmStudy(vector<NumTrainData>& trainData);
char JpgPredict(Mat src);

svm.cpp文件

#include "svm.h"#include "opencv2/opencv.hpp"using namespace cv;
using namespace std;
using namespace cv::ml;#define SHOW_PROCESS 0
#define ON_STUDY 0
int featureLen = 64;void swapBuffer(char *buf)//0123->3210
{char temp;temp = *(buf);*buf = *(buf + 3);*(buf + 3) = temp;temp = *(buf + 1);*(buf + 1) = *(buf + 2);*(buf + 2) = temp;
}//获取ROI区域
void GetROI(Mat& src, Mat& dst)
{int left, right, top, bottom;left = src.cols;right = 0;top = src.rows;bottom = 0;//右下角为原点//Get valid areafor (int i = 0; i < src.rows; i++){for (int j = 0; j < src.cols; j++){if (src.at<uchar>(i, j) > 0)//获取src中i,j点的像素值,为灰度图像,值为0-255{if (j < left) left = j;if (j > right) right = j;if (i < top) top = i;if (i > bottom) bottom = i;}}}//将原点置于含有像素点的方框的左上角//Point center;//center.x=(left+right)/2;//center.y=(top+bottom)/2;int width = right - left;int height = bottom - top;int len = (width < height) ? height : width;//create a squredst = Mat::zeros(len, len, CV_8UC1);//Copy valid data to squre centerRect dstRect((len - width) / 2, (len - height) / 2, width, height);Rect srcRect(left, top, width, height);Mat dstROI = dst(dstRect);Mat srcROI = src(srcRect);srcROI.copyTo(dstROI);}int ReadTrainData(int maxCount)
{//Open image and label fileconst char fileName[] = "res//train-images.idx3-ubyte";//图像信息,以二进制方式存储  28*28const char LabelFileName[] = "res//train-labels.idx1-ubyte";//标签信息,以二进制方式存储//ofstream是从内存到硬盘,ifstream是从硬盘到内存,读取标准样本库ifstream lab_ifs(LabelFileName, ios_base::binary);ifstream ifs(fileName, ios_base::binary);if (ifs.fail() == true)//读取文件失败return -1;if (lab_ifs.fail() == true)//读取文件失败return -1;//Read train data number and image rows/closchar magicNum[4], ccount[4], crows[4], ccols[4];ifs.read(magicNum, sizeof(magicNum));//Read block of dataifs.read(ccount, sizeof(ccount));ifs.read(crows, sizeof(crows));ifs.read(ccols, sizeof(ccols));int count, rows, cols;swapBuffer(ccount);//Copies bytes between buffers.swapBuffer(crows);swapBuffer(ccols);memcpy(&count, ccount, sizeof(count));//Copies bytes between buffers.memcpy(&rows, crows, sizeof(rows));memcpy(&cols, ccols, sizeof(cols));//Just skip label headerlab_ifs.read(magicNum, sizeof(magicNum));lab_ifs.read(ccount, sizeof(ccount));//Create source and show image matrixMat src = Mat::zeros(rows, cols, CV_8UC1);//28*28 piex single channel imageMat temp = Mat::zeros(8, 8, CV_8UC1);Mat img, dst;char label = 0;Scalar templateColor(255, 0, 255);NumTrainData rtd;//int loop=1000;int total = 0;while (!ifs.eof())//Indicates if the end of a stream has been reached.{if (total >= count)//total train data numberbreak;total++;//cout << total << endl;//Read labellab_ifs.read(&label, 1);//读取标签,1个字节label = label + '0';//转换为ASCII码中的罗马数字//Read source dataifs.read((char*)src.data, rows*cols);//读取训练图像数据;每个像素被转成了0-255,0代表着白色,255代表着黑色。GetROI(src, dst);#if(SHOW_PROCESS)//Too small to watchimg = Mat::zeros(dst.rows * 10, dst.cols * 10, CV_8UC1);resize(dst, img, img.size());stringstream ss;ss << "Number" << label;string text = ss.str();putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, template);#endifrtd.result = label;resize(dst, temp, temp.size());//将dst缩放成一个8*8的temp矩阵//tehreshold(temp,temp,10,1,CT_THRESH_BINARY);for (int i = 0; i < 8; i++){for (int j = 0; j < 8; j++){rtd.data[i * 8 + j] = temp.at<uchar>(i, j);}}buffer.push_back(rtd);//if(waitKey(0)==27)//ESC to quit//break;maxCount--;if (maxCount == 0){//cout << "maxcount=" << maxCount << endl;system("pause");break;}}//buffer中存储了maxcount个8*8的矩阵和它所具有的标签ifs.close();lab_ifs.close();return 0;
}void newSvmStudy(vector<NumTrainData>& trainData)
{int testCount = trainData.size();//60000Mat m = Mat::zeros(1, featureLen, CV_32FC1);Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);Mat res = Mat::zeros(testCount, 1, CV_32SC1);for (int i = 0; i < testCount; i++){NumTrainData td = trainData.at(i);memcpy(m.data, td.data, featureLen * sizeof(float));normalize(m, m);memcpy(data.data + i*featureLen * sizeof(float), m.data, featureLen * sizeof(float));res.at<int>(i, 0) = td.result;//res.at<unsigned int>(i, 0) = td.result;//存储标签}START RT TRAINNING/////设置SVM参数Ptr<SVM> svm = SVM::create();svm->setType(SVM::C_SVC);//用于多类分类svm->setKernel(SVM::RBF);//采用高斯核函数svm->setTermCriteria(cv::TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));svm->setDegree(10.0);//高斯核的参数设置svm->setGamma(8.0);svm->setCoef0(1.0);svm->setC(10.0);svm->setNu(0.5);svm->setP(0.1);//训练Ptr<TrainData> tData = TrainData::create(data, ROW_SAMPLE, res);svm->train(tData);svm->save("res\\SVM_DATA.xml");}//预测数据
char JpgPredict(Mat src)
{Ptr<SVM> svm = Algorithm::load<ml::SVM>("res\\SVM_DATA.xml");svm->load("res\\SVM_DATA.xml");threshold(src, src, 230, 250, CV_THRESH_BINARY);Mat temp = Mat::zeros(8, 8, CV_8UC1);Mat m = Mat::zeros(1, featureLen, CV_32FC1);Mat element = getStructuringElement(MORPH_RECT, Size(2, 2));dilate(src, src, element);imshow("1", src);waitKey(30);resize(src, temp, temp.size());for (int i = 0; i < 8; i++){for (int j = 0; j < 8; j++){m.at<float>(0, j + i * 8) = temp.at<uchar>(i, j);}}normalize(m, m);// 该函数归一化输入数组使它的范数或者数值范围在一定的范围内。char ret = (char)svm->predict(m);//如果值为true而且是一个2类问题则返回判决函数值,否则返回类标签return ret;}

3.2 图像预处理及图像分割

前面通过学习产生了分类器,但我们输入图像中的数字并不能直接作为测试输入。图像中的数字笔画有时并不规整,还可能相互重叠。因为本文例子为了简化用的是屏幕截图,所以位置形变校正,色彩亮度校正等等都省去了,但仍需要一些简单处理。下面先对输入图像进行一把简单的预处理,主要目的是将图像转成二值图,这样便于我们下一步分割和识别。这样做还有个好处,就是把其余的噪声也顺带去掉了。

接下来,就可以对图像进行分割了。由于我们的分类器只能对数字一个一个地识别,所以首先要把每个数字分割出来。基本思想是先用findContours()函数把基本轮廓找出来,然后通过简单验证以确认是否为数字的轮廓。对于那些通过验证的轮廓,接下去会用boundingRect()找出它们的包围盒。

Process.h文件

#pragma once
#include "svm.h"
#include "opencv2/opencv.hpp"class Coordinate     //坐标类
{
public:double x, y;    //轮廓位置int order;      //轮廓向量contours中的第几个bool operator<(Coordinate &m)   //运算符重载,在sort()排序函数中使用{if (x < m.x)return true;elsereturn false;}
};void ImageProcess(Mat &srcImage);
void ImageFindRectangle(Mat &srcImage);

Process.cpp文件

#include "Process.h"using namespace cv;
using namespace std;Coordinate con[100] = { 0 }; //存放分割好的矩阵的中心坐标
vector<vector<Point>> contours;//定义一个存放边缘矩阵的容器
vector<Vec4i> hierarchy;  //定义一个存放树节点的前后关系的容器
Rect rect[100];            //定义一个存放分割好图像的矩阵,注意数据溢出关系
int i = 0;//全局变量void ImageFindRectangle(Mat &srcImage)
{//使用contours迭代器遍历每一个轮廓,找到并画出包围这个轮廓的最小矩阵vector<vector<Point>>::iterator It;for (It = contours.begin(); It < contours.end(); It++){//画出可包围数字的最小矩形Point2f vertex[4];rect[i] = boundingRect(*It);  //计算轮廓的垂直边界最小矩形,矩形是与图像上下边界平行的//矩形左上角的点vertex[0] = rect[i].tl();//矩形左下角的点vertex[1].x = (float)rect[i].tl().x, vertex[1].y = (float)rect[i].br().y;//矩形右下角的点vertex[2] = rect[i].br();//矩形右上方的点vertex[3].x = (float)rect[i].br().x, vertex[3].y = (float)rect[i].tl().y;for (int j = 0; j < 4; j++)line(srcImage, vertex[j], vertex[(j + 1) % 4], Scalar(0, 0, 255), 1);con[i].x = (vertex[0].x + vertex[1].x + vertex[2].x + vertex[3].x) / 4.0;//根据中心点判断图图像的位置con[i].y = (vertex[0].y + vertex[1].y + vertex[2].y + vertex[3].y) / 4.0;con[i].order = i;i++;}sort(con, con + i);  //将con按升序排列
}void ImageProcess(Mat &srcImage)
{Mat Image = Mat::zeros(srcImage.size(), CV_8U);Mat grayImage = Mat::zeros(srcImage.size(), CV_8U);//图像预处理cvtColor(srcImage, srcImage, COLOR_BGR2GRAY);   //转化为灰度图像threshold(srcImage, srcImage, 230, 255, CV_THRESH_BINARY);//阈值化//寻找图像边缘findContours(srcImage, contours, hierarchy, CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE);//寻找图像边缘;函数用法参数见笔记Mat dstImage = Mat::zeros(Image.size(), CV_8U);drawContours(dstImage, contours, -1, Scalar(255, 0, 255));//在dstImage图像中画出边缘//进行分割ImageFindRectangle(dstImage);//存储分割矩阵Mat num[11];for (int j = 0; j < i; j++){int k;k = con[j].order;srcImage(rect[k]).copyTo(num[j]);}cout << "i=" << i << endl;vector<char> res;for (int j = 0; j < i; j++){res.push_back(JpgPredict(num[j]));//cout << JpgPredict(num[j]) << endl;}cout << "Predicted number is:";for (const auto&number : res){cout <<number;// system("pause");}}

3.3 应用分类器进行识别

Main.cpp函数

#include "svm.h"
#include "Process.h"#include <fstream>
#include <vector>#include <opencv2/opencv.hpp>using namespace cv;
using namespace std;vector<NumTrainData> buffer;#define ON_STUDY 0
#define ON_PROCESS 1int main(void)
{
#if ON_STUDYint maxCount = 30000;ReadTrainData(maxCount);newSvmStudy(buffer);
#endif
#if ON_PROCESSMat img = imread("Sample3.jpg");ImageProcess(img);waitKey(0);
#endifreturn 0;
}

识别结果如下:

结果检测,SVM算法可以较好的识别手写数字,但是在编写代码的过程中发现一个问题,那就是这个算法对“1”数字的识别精度非常差,可能10张图中只能正确识别一次,不知道有没有大神能够给出一些建议?

上一篇:基于OpenCV的 SVM算法实现数字识别(三)---SMO求解

基于OpenCV的 SVM算法实现数字识别(四)---代码实现相关推荐

  1. 基于CNN的MINIST手写数字识别项目代码以及原理详解

    文章目录 项目简介 项目下载地址 项目开发软件环境 项目开发硬件环境 前言 一.数据加载的作用 二.Pytorch进行数据加载所需工具 2.1 Dataset 2.2 Dataloader 2.3 T ...

  2. [机器学习]基于OpenCV实现最简单的数字识别

    http://blog.csdn.net/jinzhuojun/article/details/8579416 本文将基于OpenCV实现简单的数字识别.这里以游戏Angry Birds为例,通过以下 ...

  3. 【图像识别】基于卷积神经网络CNN手写数字识别matlab代码

    1 简介 针对传统手写数字的随机性,无规律性等问题,为了提高手写数字识别的检测准确性,本文在研究手写数字区域特点的基础上,提出了一种新的手写数字识别检测方法.首先,对采集的手写数字图像进行预处理,由于 ...

  4. 基于opencv实现的手写数字识别

    一.使用模板匹配算法 match.py: import os import Function root_dir = "digits/train2" file7_7 = open(& ...

  5. Python基于深度学习的手写数字识别

    Python基于深度学习的手写数字识别 1.代码的功能和运行方法 2. 网络设计 3.训练方法 4.实验结果分析 5.结论 1.代码的功能和运行方法 代码可以实现任意数字0-9的识别,只需要将图片载入 ...

  6. 基于深度学习的手写数字识别算法Python实现

    摘 要 深度学习是传统机器学习下的一个分支,得益于近些年来计算机硬件计算能力质的飞跃,使得深度学习成为了当下热门之一.手写数字识别更是深度学习入门的经典案例,学习和理解其背后的原理对于深度学习的理解有 ...

  7. java图片降噪_Java基于opencv实现图像数字识别(四)—图像降噪

    Java基于opencv实现图像数字识别(四)-图像降噪 我们每一步的工作都是基于前一步的,我们先把我们前面的几个函数封装成一个工具类,以后我们所有的函数都基于这个工具类 这个工具类呢,就一个成员变量 ...

  8. OpenCV基于LeNet-5和连接组件分析的数字识别的实例(附完整代码)

    OpenCV基于LeNet-5和连接组件分析的数字识别的实例 OpenCV基于LeNet-5和连接组件分析的数字识别的实例 OpenCV基于LeNet-5和连接组件分析的数字识别的实例 #includ ...

  9. mser python车牌识别_基于MSER与SVM算法的车牌定位识别方法

    基于 MSER 与 SVM 算法的车牌定位识别方法 胡成伟 ; 袁明辉 [期刊名称] <软件> [年 ( 卷 ), 期] 2020(041)002 [摘要] 针对实际车牌识别系统中车牌位置 ...

最新文章

  1. K8S - Kubernetes简介
  2. Android activity属性
  3. 主成分分析中特征值分解与SVD(奇异值分解)的比较及其相关R语言的实现
  4. 2017-11-14【Python】爬虫练习
  5. PHP no input file specified 三种解决方法
  6. 【干货】大数据驱动的因果建模在滴滴的应用实践
  7. bash配置文件的修改
  8. Markdown编辑器初步使用
  9. 认知之经济学:经济是如何运行的
  10. E盾网络验证企业版个人版离线版易语言源码加密对接好的自绘界面1
  11. postgresql点云las_三维点云目标提取总结【转】
  12. 3Dmax专用快捷键大全(保姆式手把手教)
  13. 书单 | 做数字化转型,离不开这10本书!
  14. 基于STM32的倾斜仪设计(二)—— 硬件设计(2)
  15. 家庭智能插座一Homekit智能
  16. 将SVG文件转换为XML文件
  17. matlab离群值处理,数据平滑和离群值检测
  18. 主流开源流媒体服务器有哪些?
  19. 静态页面-HTML5+CSS大作业——传统节日--中秋节(2页)
  20. web前端-国际化-自动翻译(免费)

热门文章

  1. 牛客专项练习之设计模式
  2. 【第一个深度学习模型应用-手写数字识别】
  3. SAP的VBAK、VBAP和VBEP表
  4. 鬼火(irrlicht)的复燃
  5. java找出和最接近指定值_如何找到数组元素与特定值最接近的和?
  6. 搜狐合并Chinaren.com (转)
  7. Cisco Packet Tracer 6.0下载安装及汉化包使用方法无积分版
  8. xbox服务器中断,Xbox Live服务出现重大中断 目前问题已基本解决
  9. MacOS下使用UClient
  10. 校园心理网站html模板,校园心理微电影剧本