本文首发于微信公众号【DeepDriving】,公众号后台回复关键字【手写数字识别】可获取本文代码链接。

前言

手写数字识别是机器学习和深度学习中一个非常著名的入门级图像识别项目,很多人都是从这个项目开始进入图像识别领域的。虽然现在深度学习在图像识别领域已经风靡一时,取得了令人瞩目的成就,但是不可否认的是经典的机器学习方法依然是不过时并且有用武之地的。本文将带大家用经典的提取HOG特征+SVM分类方法来实现手写数字识别。

下载数据集

手写数字识别数据集采用的是MNIST数据集,该数据集可以从官方网站上下载:http://yann.lecun.com/exdb/mnist/,也可以从格物钛的网站上下载:https://gas.graviti.cn/dataset/data-decorators/MNIST。数据集包括以下4个压缩文件:

  • train-images-idx3-ubyte.gz: 训练集图像数据
  • train-labels-idx1-ubyte.gz: 训练集标签数据
  • t10k-images-idx3-ubyte.gz: 测试集图像数据
  • t10k-labels-idx1-ubyte.gz: 测试集标签数据

其中训练集包含60000个样本,测试集包含10000个样本,每个样本都是28x28的灰度图像。数据集下载好以后我们可以取几个样本进行可视化,看一下这些样本是什么样的。

提取HOG特征

与深度学习不同的是,在机器学习中我们需要手动去提取和处理特征,然后再将这些处理好的特征送入分类器进行训练或预测。HOG(Histogram of Oriented Gradient,方向梯度直方图)是一种在计算机视觉和图像处理中常用的特征描述子。在OpenCV中,我们可以调用cv2.HOGDescriptor()来创建一个HOGDescriptor类对象:

def CreateHOGDescriptor():winSize = (28, 28)blockSize = (14, 14)blockStride = (7, 7)cellSize = (7, 7)nbins = 9derivAperture = 1winSigma = -1.histogramNormType = 0L2HysThreshold = 0.2gammaCorrection = 1nlevels = 64signedGradient = Truehog = cv2.HOGDescriptor(winSize, blockSize, blockStride, cellSize, nbins, derivAperture,winSigma, histogramNormType, L2HysThreshold, gammaCorrection, nlevels, signedGradient)return hog

创建HOGDescriptor对象时有些参数需要进行设置:

  • winSize: 这里设置为样本图像的大小。

  • cellSize: 该值决定了提取的特征向量的大小,越小的cellSize值得到的特征向量越大。

  • blockSize: block主要用来解决光照变化问题,大的blockSize值可以使得算法对图像的局部变化不那么敏感,通常blockSize设置为2*cellSize。

  • blockStride: 确定相邻块之间的重叠度并控制对比归一化的程度,通常blockStride设置为 blockSize的1/2。

  • nbins: 设置梯度直方图中bin的数量,HOG论文的作者推荐值为9,这样可以以20度为增量捕获0~180度之间的梯度。

  • signedGradients: 梯度是有符号的还是无符号的。

创建HOGDescriptor对象后,就可以调用compute()方法来计算图像的HOG特征了。

训练SVM模型

在OpenCV中,我们可以直接调用机器学习库中的SVM_create()函数创建一个SVM分类器模型,创建模型的时候需要设置一些参数,比如分类模型类型、核函数类型、正则化系数等。

def InitSVM(C=12.5, gamma=0.50625):model = cv2.ml.SVM_create()model.setGamma(gamma)model.setC(C)model.setKernel(cv2.ml.SVM_RBF)model.setType(cv2.ml.SVM_C_SVC)return model

要选择合适的SVM超参数是比较难的,不过比较好的是,OpenCV为我们提供了一个trainAuto()函数,该函数会通过K折交叉验证来寻找最优的参数。模型创建好后,我们可以调用该函数对模型进行训练。

def TrainSVM(model, samples, responses, kFold=10):model.trainAuto(samples, cv2.ml.ROW_SAMPLE, responses, kFold)return model

因为需要进行K折交叉验证,所以调用trainAuto()函数训练模型所需要的时间比较长。如果不想花那么多时间训练模型,可以减少K折交叉验证的K值,或者直接不用该函数而是用train()函数来训练模型。

模型训练完以后,我们可以将HOG和SVM模型保存到XML文件中,以便后续使用。

svm_model.save('svm.xml')
hog.save('hog_descriptor.xml')

测试模型

模型训练好以后,我们可以在测试集上测试一下模型的准确率。首先提取测试集中每个图像的HOG特征,然后将特征送入SVM分类模型进行预测并统计出模型的准确率。

def EvaluateSVM(model, samples, labels):predictions = SVMPredict(model, samples)accuracy = (labels == predictions).mean()print('Accuracy: %.2f %%' % (accuracy*100))confusion = np.zeros((10, 10), np.int32)for i, j in zip(labels, predictions):confusion[int(i), int(j)] += 1print('confusion matrix:')print(confusion)

我训练的模型在测试集上的准确率为99.46%,得到的混淆矩阵如下:

confusion matrix:
[[ 978    0    0    0    0    0    1    1    0    0][   0 1132    1    0    0    0    1    0    1    0][   1    0 1027    0    0    0    0    4    0    0][   0    0    2 1006    0    1    0    0    1    0][   0    0    0    0  976    0    1    0    0    5][   1    0    0    2    0  888    1    0    0    0][   4    2    1    0    0    1  949    0    1    0][   0    2    3    0    0    0    0 1022    0    1][   2    0    0    1    0    1    0    1  967    2][   0    1    0    1    3    0    0    1    2 1001]]

使用模型

训练好一个模型后,我们当然希望能够把它应用到实际生活中来帮我们解决一些问题。既然训练了一个手写数字识别模型,那么我们就让它来识别一下手写的数字,看看效果到底怎么样。

首先,我们拿一张白纸写上一些数字,然后用图像处理的方法将纸上的每个数字的区域提取出来,再执行前文所述的提取HOG特征+SVM分类的识别流程。下面是我的一些测试结果:

从上图中可以看到,前面从0~8这几列的识别准确率还是比较高的,但是9这列数字全部识别成了7,可能是我写的数字“9”与训练集中的数字7更相似而与数字9的差异比较大吧,读者有兴趣的话可以试一下。

参考资料

  • https://towardsdatascience.com/mnist-handwritten-digits-classification-from-scratch-using-python-numpy-b08e401c4dab
  • https://learnopencv.com/handwritten-digits-classification-an-opencv-c-python-tutorial/
  • https://opencv24-python-tutorials.readthedocs.io/en/latest/py_tutorials/py_ml/py_svm/py_svm_opencv/py_svm_opencv.html

欢迎关注我的公众号【DeepDriving】,我会不定期分享计算机视觉、机器学习、深度学习、无人驾驶等领域的文章。

[附代码] 如何用HOG+SVM实现手写数字识别相关推荐

  1. python svm实现手写数字识别——直接可用

    python svm实现手写数字识别--直接可用 1.训练 1.1.训练数据集下载--已转化成csv文件 1.2 .训练源码 2.预测单张图片 2.1.待预测图像 2.2.预测源码 2.3.预测结果 ...

  2. 【ML-SVM案例学习】svm实现手写数字识别

    文章目录 前言 一.源码分步解析 1.引入库 2. 设置属性防止中文乱码 3.加载数字图片数据 4.获取样本数量,并将图片数据格式化 5.模型构建 6.测试数据部分实际值和预测值获取 7.进行图片展示 ...

  3. 【ML实验5】SVM(手写数字识别、核方法)

    实验代码获取 github repo 山东大学机器学习课程资源索引 实验目的 实验内容 这里并不是通过 KTT 条件转化,而是对偶问题和原问题为强对偶关系,可以通过 KTT 条件进行化简. 令 x = ...

  4. SVM进行手写数字识别

    使用了TensorFlow中的mnist数据集 from sklearn import svm import numpy as np from sklearn.metrics import class ...

  5. DL之NN/Average_Darkness/SVM:手写数字图片识别(本地数据集50000训练集+数据集加4倍)比较3种算法Average_Darkness、SVM、NN各自的准确率

    DL之NN/Average_Darkness/SVM:手写数字图片识别(本地数据集50000训练集+数据集加4倍)比较3种算法Average_Darkness.SVM.NN各自的准确率 目录 数据集下 ...

  6. 【项目实践】:KNN实现手写数字识别(附Python详细代码及注释)

    ↑ 点击上方[计算机视觉联盟]关注我们 本节使用KNN算法实现手写数字识别.KNN算法基本原理前边文章已经详细叙述,盟友们可以参考哦! 数据集介绍 有两个文件: (1)trainingDigits文件 ...

  7. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

  8. TensorFlow手写数字识别与一步一步实现卷积神经网络(附代码实战)

    编译 | fendouai 编辑 | 安可 [导读]:本篇文章将说明 TensorFlow 手写数字识别与一步一步实现卷积神经网络.欢迎大家点击上方蓝字关注我们的公众号:深度学习与计算机视觉. 手写数 ...

  9. 利用Tensorflow实现手写数字识别(附python代码)

    手写识别的应用场景有很多,智能手机.掌上电脑的信息工具的普及,手写文字输入,机器识别感应输出:还可以用来识别银行支票,如果准确率不够高,可能会引起严重的后果.当然,手写识别也是机器学习领域的一个Hel ...

最新文章

  1. codeforces 785D D. Anton and School - 2
  2. 用python画玫瑰花教程-使用Python画一朵玫瑰花
  3. 新一代视频AI服务 —— 阿里云智能视觉重磅发布
  4. 线性回归与梯度下降法
  5. SpringBoot-@ComponentScan、@Import
  6. Domain Socket本地进程间通信
  7. 简洁大气自适应后台登录模板单页源码
  8. 191202-GETJOB-捡历的写法
  9. python 条形图 负值_Python处理JSON数据并生成条形图
  10. python tushare获取股票数据并可视化_使用Python获取股票数据Tushare
  11. java中进行socket编程实现tcp、udp协议总结
  12. 溯源项目(全套源码)
  13. 8.5 向量应用(三)——知识补充和梳理(夹角、距离和平面束)
  14. 安装Cisco Packet Tracer
  15. 家庭账本应该怎样记简洁明了
  16. 华哥倒酒(二分答案)
  17. 第二章 Java流程控制 ① 笔记
  18. html图片缩放全部显示不全,100% width CSS 在缩小/放大窗口时候内容被截断或显示不全...
  19. python账号_基于Python打造账号共享浏览器功能
  20. JavaScript获取时间戳的坑

热门文章

  1. 受力分析软件_【硕士论文】供热管网管道支架载荷分析与优化设计
  2. 基于vue的日历H5
  3. 微信小程序-扫一扫 wx.scanCode() 扫码大变身
  4. [转]Openstack Havana Dashboard测试和使用
  5. netapp管理地址_NetApp ONTAP Simulator部署指南(1)
  6. 【Informatica Powercenter】关于log4j
  7. Mac进入和离开全屏模式
  8. 如何用php做商品价格计算表,利用ajax+php实现商品价格计算
  9. 中国公司商标在美国被抢注的对策与防范
  10. 群晖php pear,Synology 群晖DSM7.0 40850 beta 版本各机型固件下载链接