手写字体optdigits识别:

每一行代表一个手写字体图像,最大值为16,大小64,然后最后一列为该图片的标签值。

import numpy as np
from sklearn import svm
import matplotlib.colors
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import accuracy_score
import os
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from time import timedef show_acc(a, b, tip):acc = a.ravel() == b.ravel()print('%s acc :%.2f%%' % (tip, 100*np.mean(acc)))def save_image(image, i):# 由于optdigits数据集的像素最大是16,所以这里对其reshapeimage *= 16.9# 图像取反为了好观察image = 255 - image# 转化为图像的uint8格式a = image.astype(np.uint8)output_path = './/handwriting'if not os.path.exists(output_path):os.mkdir(output_path)Image.fromarray(a).save(output_path + ('//%d.jpg' % i))if __name__ == '__main__':# 开始加载训练数据集data = np.loadtxt('optdigits.tra', dtype=np.float, delimiter=',')# 最后一列得到的是该手写字体图片的labelx, y = np.split(data, (-1,), axis=1)# 64x64大小images = x.reshape(-1, 8, 8)y = y.ravel().astype(np.int)# 加载测试数据集data_test = np.loadtxt('optdigits.tes', dtype=np.float, delimiter=',')x_test, y_test = np.split(data_test, (-1,), axis=1)images_test = x_test.reshape(-1, 8, 8)y_test = y_test.ravel().astype(np.int)plt.figure(figsize=(15, 15), facecolor='w')for index, image in enumerate(images[:16]):plt.subplot(4, 8, index+1)plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')plt.title('trian image:%i' %y[index])for index, image in enumerate(images_test[:16]):plt.subplot(4, 8, index+17)plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')save_image(image.copy(), index)plt.title('test image:%i' %y[index])plt.tight_layout(1.5)plt.show()params = {'C':np.logspace(0, 3, 7), 'gamma':np.logspace(-5, 0, 11)}model = svm.SVC(C=10, kernel='rbf', gamma=0.001)print('==============start training=================')start = time()model.fit(x, y)end = time()train_time = end - startprint('train time:%dseconds' % train_time)y_hat = model.predict(x)show_acc(y, y_hat, 'trian data')y_hat_test = model.predict(x_test)print('y_hat:\n', y_hat)print('y_test:\n', y_test)show_acc(y_test, y_hat_test, 'valiation data')# 测试集里面错分的数据# 测试集里面和预测值不同的图像err_images = images_test[y_test != y_hat_test]# 预测里面和测试不同的预测值err_y_hat = y_hat_test[y_test != y_hat_test]# 测试里面和预测不同的测试值err_y = y_test[y_test != y_hat_test]print('err_y_hat:\n', err_y_hat)print('err_y:\n', err_y)plt.figure(figsize=(15, 15), facecolor='w')for index, image in enumerate(err_images):if index >= 30:breakplt.subplot(5, 6, index+1)plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')plt.title('error:%i, the real:%i' % (err_y_hat[index], err_y[index]))plt.tight_layout(4)plt.show()

接着我们更换训练方法,修改程序:
# model = svm.SVC(C=10, kernel='rbf', gamma=0.001)model = GridSearchCV(svm.SVC(kernel='rbf'), param_grid=params, cv=3)
训练时间要长很多,但准确率并没有提升。。。。

接着我们使用经典的MNIST数据集来做实验:

import numpy as np
from sklearn import svm
import matplotlib.colors
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import accuracy_score
import pandas as pd
import os
import csv
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from time import time
from pprint import pprint
import warningsdef show_acc(a, b, tip):acc = a.ravel() == b.ravel()print('%s acc :%.2f%%' % (tip, 100*np.mean(acc)))def save_image(image, i):# 图像取反为了好观察image = 255 - image# 转化为图像的uint8格式a = image.astype(np.uint8)output_path = './/handwriting'if not os.path.exists(output_path):os.mkdir(output_path)Image.fromarray(a).save(output_path + ('//%d.jpg' % i))def save_model(model):data_test_hat = model.predict(data_test)with open('Prediction.csv', 'wt') as f:writer = csv.writer(f)writer.writerow(['ImageId', 'Label'])for i, d in enumerate(data_test_hat):writer.writerow([i, d])if __name__ == '__main__':warnings.filterwarnings('ignore')classifier_type = 'RF'print('loading train data......')start = time()data = pd.read_csv('MNIST.train.csv', header=0, dtype=np.int)print('loading finishing......')# 读取标签值y = data['label'].valuesx = data.values[:, 1:]print('the images numbers:%d, the pixs of images:%d' % (x.shape))# reshape成28x28的格式,还原成原始的图像格式images = x.reshape(-1, 28, 28)y = y.ravel()print(images)print(y)print('loading test data......')start = time()data_test = pd.read_csv('MNIST.test.csv', header=0, dtype=np.int)data_test = data_test.valuesimages_test_result = data_test.reshape(-1, 28, 28)print('data-test:\t', data_test)print('images-test-result:\t', images_test_result)print('loading finishing......')np.random.seed(0)x, x_test, y, y_test = train_test_split(x, y, train_size=0.8, random_state=1)images = x.reshape(-1, 28, 28)images_test = x_test.reshape(-1, 28, 28)print('x-shape:\t', x.shape)print('x-test-shape:\t', x_test.shape)# 显示我们使用的部分训练数据和测试数据plt.figure(figsize=(15, 9), facecolor='w')for index, image in enumerate(images[:16]):plt.subplot(4, 8, index+1)plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')plt.title('train data:%d' % (y[index]))for index, image in enumerate(images_test_result[:16]):plt.subplot(4, 8, index+17)plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')save_image(image.copy(), index)plt.title('test data')plt.tight_layout()plt.show()if classifier_type == 'SVM':model = svm.SVC(C=3000, kernel='rbf', gamma=1e-10)print('让我们荡起小浆,开始训练吧.............')t_start = time()model.fit(x, y)t_end = time()print('train time:%.3f' % (t_end - t_start))print('小船到岸,清下水......')# print('最优分类器:', model.best_estimator_)# print('最优参数:\t', model.best_params_)# print('model.cv_results_ = \n', model.cv_results_)t = time()y_hat = model.predict(x)t = time() - tprint('SVM训练集准确率:%.3f%%, 耗时:%.3f' %(accuracy_score(y, y_hat), t))t = time()y_hat_test = model.predict(x_test)t = time() - tprint('SVM测试集准确率:%.3f%%, 耗时:%.3f' %(accuracy_score(y_test, y_hat_test), t))save_model(model)elif classifier_type == 'RF':rfc = RandomForestClassifier(100, criterion='gini', min_samples_split=2, min_impurity_split=1e-10, bootstrap=True, oob_score=True)print('让我们再次荡起小浆,开始训练吧.............')t = time()rfc.fit(x, y)print('train time:%.3f' % (time() - t))print('OOB准确率:%.3f%%' %(rfc.oob_score_*100))print('小船到岸,清下水......')t = time()y_hat = rfc.predict(x)t = time() - tprint('SVM训练集准确率:%.3f%%, 耗时:%.3f' %(accuracy_score(y, y_hat), t))t = time()y_hat_test = rfc.predict(x_test)t = time() - tprint('SVM测试集准确率:%.3f%%, 耗时:%.3f' %(accuracy_score(y_test, y_hat_test), t))save_model(rfc)err = (y_test != y_hat_test)err_images = images_test[err]err_y_hat = y_hat_test[err]err_y = y_test[err]print('err_y_hat:\n', err_y_hat)print('err_y:\n', err_y)plt.figure(figsize=(15, 15), facecolor='w')for index, image in enumerate(err_images):if index >= 20:breakplt.subplot(4, 5, index+1)plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')plt.title('err:%i, real:%i' % (err_y_hat[index], err_y[index]))plt.suptitle('Digital Handwriting recognition:Classifier--%s' % classifier_type, fontsize=15)plt.tight_layout(rect=(0, 0, 1, 0.94))plt.show()
相对来说,SVM和随机森林算法效果都已经不错,但随机森林表现的要好一点,分析可能是SVM还需要调参。

机器学习SVM--基于手写字体识别相关推荐

  1. 《MATLAB 神经网络43个案例分析》:第19章 基于SVM的手写字体识别

    <MATLAB 神经网络43个案例分析>:第19章 基于SVM的手写字体识别 1. 前言 2. MATLAB 仿真示例 3. 小结 1. 前言 <MATLAB 神经网络43个案例分析 ...

  2. matlab基于SVM的手写字体识别,机器学习SVM--基于手写字体识别

    每一行代表一个手写字体图像,最大值为16,大小64,然后最后一列为该图片的标签值. import numpy as np from sklearn import svm import matplotl ...

  3. 课程设计(毕业设计)—基于机器学习KNN算法手写数字识别系统—计算机专业课程设计(毕业设计)

    机器学习KNN算法手写数字识别系统 下载本文手写数字识别系统完整的代码和课设报告的链接(或者可以联系博主koukou(壹壹23七2五六98),获取源码和报告):https://download.csd ...

  4. numpy完成手写字体识别(机器学习作业02)

    numpy完成手写字体识别(机器学习02) 参考代码:mnielsen/neural-networks-and-deep-learning: 参考讲解:深度学习多分类任务的损失函数详解 - 知乎 (z ...

  5. 计算机视觉ch8 基于LeNet的手写字体识别

    文章目录 原理 LeNet的简单介绍 Minist数据集的特点 Python代码实现 原理 卷积神经网络参考:https://www.cnblogs.com/chensheng-zhou/p/6380 ...

  6. 基于Python神经网络的手写字体识别

    本文将分享实现手写字体识别的神经网络实现,代码中有详细注释以及我自己的一些体会,希望能帮助到大家 (≧∇≦)/ ############################################ ...

  7. 深度学习,实现手写字体识别(大数据人工智能公司)

    手写字体识别是指给定一系列的手写字体图片以及对应的标签,构建模型进行学习,目标是对于一张新的手写字体图片能够自动识别出对应的文字或数字.通过深度学习构建普通神经网络和卷积神经网络,处理手写字体数据.通 ...

  8. AI基础:KNN与K近邻距离度量说明、利用KNN手写字体识别分类实践

    KNN k近邻 文章目录 KNN算法 K近邻中近邻的距离度量 欧式距离 标准化欧式距离 曼哈顿距离 汉明距离 夹角余弦 杰卡德相似系数 皮尔逊系数 切比雪夫距离 闵可夫斯基距离 马氏距离 巴氏距离 各 ...

  9. 第六讲 Keras实现手写字体识别分类

    一 本节课程介绍 1.1 知识点 1.图像识别分类相关介绍: 2.Mnist手写数据集介绍: 3.标准化数据预处理: 4.实验手写字体识别 二 课程内容 2.1 图像识别分类基本介绍 计算机的图像识别 ...

最新文章

  1. 「UI 测试自动化selenium」汇总
  2. 白话Elasticsearch50-深入聚合数据分析之基于doc values正排索引的聚合内部原理
  3. Slack推安全企业加密管理可轻易用密钥控制数据
  4. EHCache 初步使用指南
  5. LTE Module User Documentation(翻译5)——Mobility Model with Buildings
  6. hdu2019——数列有序解题报告
  7. 四年级上册数学计算机笔记,四年级数学下册笔记整理
  8. mysql何时会走索引
  9. 机器学习第18篇 - Boruta特征变量筛选(2)
  10. php5.1文件包含,包含文件 - ThinkPHP 5.1 完全开发手册
  11. mysql 内置存储过程_mysql 内置存储过程
  12. 从零开始写个编译器吧 - 程序流控制
  13. 怎么把word转换ppt?
  14. 微信小程序中引入图标
  15. 寄生虫技术计算机软件怎么样,2019寄生虫软件-某寄生虫软件分析
  16. 小米全系列机型代码查询与 制作rom分区架构图示
  17. 华为盒子-悦MEC6108V9C-强刷固件-4.4.2版本
  18. UE4Possess切换控制Pawn
  19. UC 浏览器曝中间人攻击漏洞,官方:已修复,国内版不受影响
  20. 基于java流浪动物救助管理系统获取(java毕业设计)

热门文章

  1. uniapp写微信授权登录
  2. 推荐7款非常棒的将代码片段转换成图片的工具
  3. 微信公众号如何分享课件PPT?
  4. B站粉丝计数器!基于microByte
  5. 清除Chrome浏览器下默认浅黄色背景(保存密码时出现)
  6. sublime text 白色边框方框解决方法
  7. DeepFool运行代码中间问题
  8. 计算机系一班班会,计算机学院计算机类20级1班举行“爱在身边,温馨家园”主题班会...
  9. 加强统筹布局和顶层设计,以技能、平台、应用为三大着力点推动人工智能突破发展...
  10. 云流化助力虚拟展厅,更炫酷的展示方案