每一行代表一个手写字体图像,最大值为16,大小64,然后最后一列为该图片的标签值。

import numpy as np from sklearn import svm import matplotlib.colors import matplotlib.pyplot as plt from PIL import Image from sklearn.metrics import accuracy_score import os from sklearn.model_selection import train_test_split from sklearn.model_selection import GridSearchCV from time import time def show_acc(a, b, tip): acc = a.ravel() == b.ravel() print('%s acc :%.2f%%' % (tip, 100*np.mean(acc))) def save_image(image, i): # 由于optdigits数据集的像素最大是16,所以这里对其reshape image *= 16.9 # 图像取反为了好观察 image = 255 - image # 转化为图像的uint8格式 a = image.astype(np.uint8) output_path = './/handwriting' if not os.path.exists(output_path): os.mkdir(output_path) Image.fromarray(a).save(output_path + ('//%d.jpg' % i)) if __name__ == '__main__': # 开始加载训练数据集 data = np.loadtxt('optdigits.tra', dtype=np.float, delimiter=',') # 最后一列得到的是该手写字体图片的label x, y = np.split(data, (-1,), axis=1) # 64x64大小 images = x.reshape(-1, 8, 8) y = y.ravel().astype(np.int) # 加载测试数据集 data_test = np.loadtxt('optdigits.tes', dtype=np.float, delimiter=',') x_test, y_test = np.split(data_test, (-1,), axis=1) images_test = x_test.reshape(-1, 8, 8) y_test = y_test.ravel().astype(np.int) plt.figure(figsize=(15, 15), facecolor='w') for index, image in enumerate(images[:16]): plt.subplot(4, 8, index+1) plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('trian image:%i' %y[index]) for index, image in enumerate(images_test[:16]): plt.subplot(4, 8, index+17) plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') save_image(image.copy(), index) plt.title('test image:%i' %y[index]) plt.tight_layout(1.5) plt.show() params = {'C':np.logspace(0, 3, 7), 'gamma':np.logspace(-5, 0, 11)} model = svm.SVC(C=10, kernel='rbf', gamma=0.001) print('==============start training=================') start = time() model.fit(x, y) end = time() train_time = end - start print('train time:%dseconds' % train_time) y_hat = model.predict(x) show_acc(y, y_hat, 'trian data') y_hat_test = model.predict(x_test) print('y_hat:\n', y_hat) print('y_test:\n', y_test) show_acc(y_test, y_hat_test, 'valiation data') # 测试集里面错分的数据 # 测试集里面和预测值不同的图像 err_images = images_test[y_test != y_hat_test] # 预测里面和测试不同的预测值 err_y_hat = y_hat_test[y_test != y_hat_test] # 测试里面和预测不同的测试值 err_y = y_test[y_test != y_hat_test] print('err_y_hat:\n', err_y_hat) print('err_y:\n', err_y) plt.figure(figsize=(15, 15), facecolor='w') for index, image in enumerate(err_images): if index >= 30: break plt.subplot(5, 6, index+1) plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('error:%i, the real:%i' % (err_y_hat[index], err_y[index])) plt.tight_layout(4) plt.show()

接着我们更换训练方法,修改程序:

# model = svm.SVC(C=10, kernel='rbf', gamma=0.001) model = GridSearchCV(svm.SVC(kernel='rbf'), param_grid=params, cv=3)

训练时间要长很多,但准确率并没有提升。。。。

接着我们使用经典的MNIST数据集来做实验:

import numpy as np from sklearn import svm import matplotlib.colors import matplotlib.pyplot as plt from PIL import Image from sklearn.metrics import accuracy_score import pandas as pd import os import csv from sklearn.model_selection import train_test_split from sklearn.model_selection import GridSearchCV from sklearn.ensemble import RandomForestClassifier from time import time from pprint import pprint import warnings def show_acc(a, b, tip): acc = a.ravel() == b.ravel() print('%s acc :%.2f%%' % (tip, 100*np.mean(acc))) def save_image(image, i): # 图像取反为了好观察 image = 255 - image # 转化为图像的uint8格式 a = image.astype(np.uint8) output_path = './/handwriting' if not os.path.exists(output_path): os.mkdir(output_path) Image.fromarray(a).save(output_path + ('//%d.jpg' % i)) def save_model(model): data_test_hat = model.predict(data_test) with open('Prediction.csv', 'wt') as f: writer = csv.writer(f) writer.writerow(['ImageId', 'Label']) for i, d in enumerate(data_test_hat): writer.writerow([i, d]) if __name__ == '__main__': warnings.filterwarnings('ignore') classifier_type = 'RF' print('loading train data......') start = time() data = pd.read_csv('MNIST.train.csv', header=0, dtype=np.int) print('loading finishing......') # 读取标签值 y = data['label'].values x = data.values[:, 1:] print('the images numbers:%d, the pixs of images:%d' % (x.shape)) # reshape成28x28的格式,还原成原始的图像格式 images = x.reshape(-1, 28, 28) y = y.ravel() print(images) print(y) print('loading test data......') start = time() data_test = pd.read_csv('MNIST.test.csv', header=0, dtype=np.int) data_test = data_test.values images_test_result = data_test.reshape(-1, 28, 28) print('data-test:\t', data_test) print('images-test-result:\t', images_test_result) print('loading finishing......') np.random.seed(0) x, x_test, y, y_test = train_test_split(x, y, train_size=0.8, random_state=1) images = x.reshape(-1, 28, 28) images_test = x_test.reshape(-1, 28, 28) print('x-shape:\t', x.shape) print('x-test-shape:\t', x_test.shape) # 显示我们使用的部分训练数据和测试数据 plt.figure(figsize=(15, 9), facecolor='w') for index, image in enumerate(images[:16]): plt.subplot(4, 8, index+1) plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('train data:%d' % (y[index])) for index, image in enumerate(images_test_result[:16]): plt.subplot(4, 8, index+17) plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') save_image(image.copy(), index) plt.title('test data') plt.tight_layout() plt.show() if classifier_type == 'SVM': model = svm.SVC(C=3000, kernel='rbf', gamma=1e-10) print('让我们荡起小浆,开始训练吧.............') t_start = time() model.fit(x, y) t_end = time() print('train time:%.3f' % (t_end - t_start)) print('小船到岸,清下水......') # print('最优分类器:', model.best_estimator_) # print('最优参数:\t', model.best_params_) # print('model.cv_results_ = \n', model.cv_results_) t = time() y_hat = model.predict(x) t = time() - t print('SVM训练集准确率:%.3f%%, 耗时:%.3f' %(accuracy_score(y, y_hat), t)) t = time() y_hat_test = model.predict(x_test) t = time() - t print('SVM测试集准确率:%.3f%%, 耗时:%.3f' %(accuracy_score(y_test, y_hat_test), t)) save_model(model) elif classifier_type == 'RF': rfc = RandomForestClassifier(100, criterion='gini', min_samples_split=2, min_impurity_split=1e-10, bootstrap=True, oob_score=True) print('让我们再次荡起小浆,开始训练吧.............') t = time() rfc.fit(x, y) print('train time:%.3f' % (time() - t)) print('OOB准确率:%.3f%%' %(rfc.oob_score_*100)) print('小船到岸,清下水......') t = time() y_hat = rfc.predict(x) t = time() - t print('SVM训练集准确率:%.3f%%, 耗时:%.3f' %(accuracy_score(y, y_hat), t)) t = time() y_hat_test = rfc.predict(x_test) t = time() - t print('SVM测试集准确率:%.3f%%, 耗时:%.3f' %(accuracy_score(y_test, y_hat_test), t)) save_model(rfc) err = (y_test != y_hat_test) err_images = images_test[err] err_y_hat = y_hat_test[err] err_y = y_test[err] print('err_y_hat:\n', err_y_hat) print('err_y:\n', err_y) plt.figure(figsize=(15, 15), facecolor='w') for index, image in enumerate(err_images): if index >= 20: break plt.subplot(4, 5, index+1) plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('err:%i, real:%i' % (err_y_hat[index], err_y[index])) plt.suptitle('Digital Handwriting recognition:Classifier--%s' % classifier_type, fontsize=15) plt.tight_layout(rect=(0, 0, 1, 0.94)) plt.show()

相对来说,SVM和随机森林算法效果都已经不错,但随机森林表现的要好一点,分析可能是SVM还需要调参。

matlab基于SVM的手写字体识别,机器学习SVM--基于手写字体识别相关推荐

  1. android 手写字体识别,一种基于Android系统的手写数学公式识别及生成MathML的方法...

    专利名称:一种基于Android系统的手写数学公式识别及生成MathML的方法 技术领域: 本发明属于模式识别技术领域,涉及数学公式中字符间的空间结构分析,具体涉及一种基于Android系统的手写数学 ...

  2. matlab手写字母识别,一种基于MATLAB的手写字母的神经网络识别方法

    文章编号 :1009 - 671X(2001) 10 - 0028 - 03 一种基于 MATLAB 的手写字母的神经网络识别方法 邓铭辉 ,孙 枫 ,张 志(哈尔滨工程大学 自动化学院 ,黑龙江 哈 ...

  3. 基于深度学习的手写数字实现及超简单的英文字母识别

    本文章大致分为5个板块,分别是MNIST数据库,深度学习神经网络的构建,图像预处理,图像识别,简单的英文字母识别展示. 1.MNIST数据库 总所周知,MNIST数据库是专门用于为手写数字识别系统提供 ...

  4. 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

    机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

  5. 课程设计(毕业设计)—基于机器学习KNN算法手写数字识别系统—计算机专业课程设计(毕业设计)

    机器学习KNN算法手写数字识别系统 下载本文手写数字识别系统完整的代码和课设报告的链接(或者可以联系博主koukou(壹壹23七2五六98),获取源码和报告):https://download.csd ...

  6. 残疾人手语交流辅助系统手语识别与翻译基于数据手套的虚拟手的实现

    残疾人手语交流辅助系统手语识别与翻译&&基于数据手套的虚拟手的实现 1. 特征提取 原始数据含有背景等大量无用成分,通过骨架技术,直接获得手的运动信息,减小问题复杂度.硬件和软件使用的 ...

  7. 【opencv机器学习】基于SVM和神经网络的车牌识别

    基于SVM和神经网络的车牌识别 深入理解OpenCV:实用计算机视觉项目解析 本文用来学习的项目来自书籍<实用计算机视觉项目解析>第5章Number Plate Recognition 提 ...

  8. 《视觉SLAM进阶:从零开始手写VIO》第三讲 基于优化的IMU预积分与视觉信息融合 作业

    <视觉SLAM进阶:从零开始手写VIO>第三讲 基于优化的IMU预积分与视觉信息融合 作业 文章目录 <视觉SLAM进阶:从零开始手写VIO>第三讲 基于优化的IMU预积分与视 ...

  9. python识别手写文字_Python3实现简单可学习的手写体识别(实例讲解)

    1.前言 版本:Python3.6.1 + PyQt5 + SQL Server 2012 以前一直觉得,机器学习.手写体识别这种程序都是很高大上很难的,直到偶然看到了这个视频,听了老师讲的思路后,瞬 ...

最新文章

  1. 怎样学好网络(1)-正确的定位
  2. SLF4J log4j 学习笔记一
  3. C编译器、链接器、加载器详解
  4. 为什么要看源码、如何看源码,高手进阶必看
  5. linux系统里常用的抓图工具,Linux系统下屏幕截图常用方法
  6. Python---时间函数
  7. java buffalo_随你怎么玩!Buffalo 网络硬盘新潮流
  8. 引用到网站绝对路径Server.MapPath(~/myfile.mdb)
  9. web导入excel数据
  10. 使用ASP.NET MVC Futures 中的异步Action 【转】
  11. Java 加密扩展(JCE)框架 之 Cipher 加密与解密
  12. Linux电源管理-Operating Performance Points(OPP)
  13. NYOJ题目66-分数拆分
  14. DL_C1_week_2_2(Logistic Regression)
  15. vmware-tools for LFS
  16. 软件测试——selenium环境搭建及自动化测试
  17. 深信服AC1000路由部署模式怎么配置线路负载均衡
  18. java标签用setbounds_setBounds的用法
  19. matlab打开界面模糊,matlab模糊逻辑(一)
  20. 亚马逊抛出“下一代贸易链”整合解决方案:中国跨境电商如何借道转型?

热门文章

  1. sanic入门(一)
  2. Cordova 环境搭建+打包Android APK
  3. python学习第二天
  4. 【技术干货】GD32VF103C-START 入门
  5. 转载:西门子 Step7 V5.5 SP4中文版
  6. 二进制文件、文本文件
  7. Oracle 函数使用:CURSOR游标简单案例
  8. SCRT连不上本地虚拟机的linux解决方法
  9. 宝塔面板能打开, 但wordpress 网站不能打开,提示建立数据库连接时出错--解决办法
  10. Latex 数学符号显示为文本模式 数学模式转为文本模式