Python 基于sklearn - svm实现MNIST手写数字识别

一、数据集:MNIST

数据地址:http://yann.lecun.com/exdb/mnist/

训练数据:MNIST中的60000张图像,0-9的手写数字

测试数据:MNIST中的10000张图像,0-9的手写数字

注意:训练和测试代码直接使用了ubyte格式数据,即只对原数据进行了解压,没有先转换为png/jpg,但也附上png数据转换代码。

数据格式转换:从ubyte转换到png格式,存储格式:mnist_train>label>.png,代码如下:

提示:PIL不再支持新版本,要额外安装Pillow库

import numpy as np
import structfrom PIL import Image
import osdata_file = 'train-images.idx3-ubyte'
# It's 47040016B, but we should set to 47040000B
data_file_size = 47040016
data_file_size = str(data_file_size - 16) + 'B'data_buf = open(data_file, 'rb').read()magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', data_buf, 0)
datas = struct.unpack_from('>' + data_file_size, data_buf, struct.calcsize('>IIII'))
datas = np.array(datas).astype(np.uint8).reshape(numImages, 1, numRows, numColumns)label_file = 'train-labels.idx1-ubyte'# It's 60008B, but we should set to 60000B
label_file_size = 60008
label_file_size = str(label_file_size - 8) + 'B'label_buf = open(label_file, 'rb').read()magic, numLabels = struct.unpack_from('>II', label_buf, 0)
labels = struct.unpack_from('>' + label_file_size, label_buf, struct.calcsize('>II'))
labels = np.array(labels).astype(np.int64)datas_root = 'mnist_train'
if not os.path.exists(datas_root):os.mkdir(datas_root)for i in range(10):file_name = datas_root + os.sep + str(i)if not os.path.exists(file_name):os.mkdir(file_name)count = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
for ii in range(numLabels):img = Image.fromarray(datas[ii, 0, 0:28, 0:28])label = labels[ii]file_name = datas_root + os.sep + str(label) + os.sep + \str(label) + '_' + str(count[label]) + '.png'count[label] = count[label] + 1# file_name = datas_root + os.sep + str(label) + os.sep + \#             'mnist_train_' + str(ii) + '.png'img.save(file_name)data_file = 't10k-images.idx3-ubyte'
# It's 7840016B, but we should set to 7840000B
data_file_size = 7840016
data_file_size = str(data_file_size - 16) + 'B'data_buf = open(data_file, 'rb').read()magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', data_buf, 0)
datas = struct.unpack_from('>' + data_file_size, data_buf, struct.calcsize('>IIII'))
datas = np.array(datas).astype(np.uint8).reshape(numImages, 1, numRows, numColumns)label_file = 't10k-labels.idx1-ubyte'# It's 10008B, but we should set to 10000B
label_file_size = 10008
label_file_size = str(label_file_size - 8) + 'B'label_buf = open(label_file, 'rb').read()magic, numLabels = struct.unpack_from('>II', label_buf, 0)
labels = struct.unpack_from('>' + label_file_size, label_buf, struct.calcsize('>II'))
labels = np.array(labels).astype(np.int64)datas_root = 'mnist_test'
if not os.path.exists(datas_root):os.mkdir(datas_root)for i in range(10):file_name = datas_root + os.sep + str(i)if not os.path.exists(file_name):os.mkdir(file_name)count = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
for ii in range(numLabels):img = Image.fromarray(datas[ii, 0, 0:28, 0:28])label = labels[ii]file_name = datas_root + os.sep + str(label) + os.sep + \str(label) + '_' + str(count[label]) + '.png'count[label] = count[label] + 1# file_name = datas_root + os.sep + str(label) + os.sep + \#             'mnist_test_' + str(ii) + '.png'img.save(file_name)

转换后的数据如下图

二、训练模型

import numpy as np
import struct
import pickle
from sklearn import svm
###用于做数据预处理
from sklearn import preprocessing##读取数据集
def load_mnist_train(labels_path, images_path):with open(labels_path, 'rb') as lbpath:magic, n = struct.unpack('>II', lbpath.read(8))labels = np.fromfile(lbpath, dtype=np.uint8)with open(images_path, 'rb') as imgpath:magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)return images, labelsif __name__ == '__main__':##读取训练数据labels_path = "train-labels.idx1-ubyte"images_path = "train-images.idx3-ubyte"train_images, train_labels = load_mnist_train(labels_path, images_path)##标准化X = preprocessing.StandardScaler().fit_transform(train_images)X_train = X[0:60000]y_train = train_labels[0:60000]##定义并训练模型model_svc = svm.SVC()model_svc.fit(X_train, y_train)file = open("model.pickle", "wb")##保存模型pickle.dump(model_svc, file)file.close()

三、测试模型

import numpy as np
import struct
import pickle
###用于做数据预处理
from sklearn import preprocessingdef test(images_path, labels_path, modelPath):# 读取测试图像with open(labels_path, 'rb') as lbpath:magic, n = struct.unpack('>II', lbpath.read(8))test_labels = np.fromfile(lbpath, dtype=np.uint8)with open(images_path, 'rb') as imgpath:magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))test_images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(test_labels), 784)##读取模型file = open(modelPath, "rb")model_svc = pickle.load(file)file.close()##评分并预测x = preprocessing.StandardScaler().fit_transform(test_images)x_test = x[0:10000]y_test = test_labels[0:10000]num = model_svc.predict(x_test)for i in range(10000):print("Real:", y_test[i], "Predict:", num[i])print("Accuracy:", model_svc.score(x_test, y_test))return numif __name__ == '__main__':images_path = "t10k-images.idx3-ubyte"labels_path = "t10k-labels.idx1-ubyte"modelPath = "model.pickle"num = test(images_path, labels_path, modelPath)

四、参考资料

图片格式转换: MNIST数据集格式ubyte转png_haoji007的博客-CSDN博客_ubyte

模型训练及测试:图像处理基本库的学习笔记2--SVM,MATLAB,Tensorflow下分别对mnist数据集进行训练,并且进行预测 - 灰信网(软件开发博客聚合)

sklearn-svm模型参数设置:机器学习笔记(3)-sklearn支持向量机SVM - 简书

模型保存和调用: 基于sklearn的SVM模型保存与调用_hellosonny的博客-CSDN博客_svm保存模型

单个图片测试:基于svm机器学习的手写数字识别_Brinshy的博客-CSDN博客_基于svm的手写数字识别

Python SVM手写数字识别相关推荐

  1. python手写数字识别教学_python实现基于SVM手写数字识别功能

    本文实例为大家分享了SVM手写数字识别功能的具体代码,供大家参考,具体内容如下 1.SVM手写数字识别 识别步骤: (1)样本图像的准备. (2)图像尺寸标准化:将图像大小都标准化为8*8大小. (3 ...

  2. python手写汉字识别_用python实现手写数字识别

    前言 在之前的学习中,已经对神经网络的算法具体进行了学习和了解.现在,我们可以用python通过两种方法来实现手写数字的识别.这两种方法分别是多元逻辑回归和神经网络方法. 用多元逻辑回归手写数字识别 ...

  3. 基于python的手写数字识别实验报告_联机手写数字识别实验报告

    1 联机手写数字识别设计 一.设计论述 模式识别是六十年代初迅速发展起来的一门学科. 由于它研究的是如何用机 器来实现人 ( 及某些动物 ) 对事物的学习. 识别和判断能力, 因而受到了很多科技 领域 ...

  4. python实现手写数字识别(小白入门)

    手写数字识别(小白入门) 今早刚刚上了节实验课,关于逻辑回归,所以手有点刺挠就想发个博客,作为刚刚入门的小白,看到代码运行成功就有点小激动,这个实验没啥含金量,所以路过的大牛不要停留,我怕你们吐槽哈哈 ...

  5. 利用python卷积神经网络手写数字识别_卷积神经网络使用Python的手写数字识别

    为了使机器更智能,开发人员正在研究机器学习和深度学习技术.人类通过反复练习和重复执行任务来学习执行任务,从而记住了如何执行任务.然后,他大脑中的神经元会自动触发,它们可以快速执行所学的任务.深度学习与 ...

  6. 基于python的手写数字识别knn_KNN分类算法实现手写数字识别

    需求: 利用一个手写数字"先验数据"集,使用knn算法来实现对手写数字的自动识别: 先验数据(训练数据)集: ♦数据维度比较大,样本数比较多. ♦ 数据集包括数字0-9的手写体. ...

  7. svm手写数字识别_KNN 算法实战篇如何识别手写数字

    上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数据集是一个用于图像处理的数据集,这些数据描绘了 [0, 9] 的数字,我们可以用KNN 算 ...

  8. Python神经网络手写数字识别代码解释

    使用了数据集MNIST中的部分数据. 1.读取数据集内容 #打开文件并获取其中的内容 data_file=open("mnist_train.csv",'r') #open()函数 ...

  9. svm手写数字识别python_SVM算法识别手写体数字

    sklearn内部集成了一些手写体数字图片数据集,现在我们使用这些数据,用SVM支持向量机算法进行训练识别的练习.笔者习惯用pycharm,今天手痒,用一下Spyder编辑,顺便对比一下哪一个好用.废 ...

  10. python数字的鲁棒输入_请教关于python的手写数字识别神经网络问题~~~~

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 """network.py~~~~~~~~~~ A module to implement the stochastic g ...

最新文章

  1. PicoBlaze 8 位微控制器
  2. spark2.2读写操作hive和mysql数据库
  3. 9 Redis 持久化AOF
  4. 命令查看mysql端口映射_【转载】烂泥:如何利用telnet命令检测端口映射是否成功...
  5. Android仿人人客户端(v5.7.1)——采用ViewGroup做父容器,实现左侧滑动菜单(三)...
  6. Unity 下载安装Standard Assets
  7. 基于AT89S52单片机的GPS液晶显示定位系统
  8. 电动自行车新国标正式实施 二季度数码市场需求好转
  9. C语言加油站程序,C语言解决 加油站问题
  10. java判断闰年中闰月_编程序:计算某年某月有多少天(区分闰年和闰月)?怎么编?...
  11. Java-你知道String为什么不可变吗?
  12. 双矩阵对策MATLAB,带有模糊收益的双矩阵对策研究
  13. 通达OA-公共文件柜在线阅读Word 文档失败:Word 无法创建工作文件,请检查临时环境变量
  14. mysql查询表的列名_查看表所有列名SQL
  15. 特殊符号备用——三角形
  16. 【MySQL】Spring Boot项目基于Sharding-JDBC和MySQL主从复制实现读写分离(8千字详细教程)
  17. 我就问你,半路接手嵌入式项目棘手不?
  18. c#—MemoryStream读图片存入ImageList
  19. 抽丝剥茧,C#面向对象快速上手
  20. 利用sqlmap进行文件读写

热门文章

  1. 【音频隐写提取】MP3Stego下载、命令、使用方法
  2. Android 用代码获取基站号(cell)和小区号(lac)
  3. 钢绞线的弹性模量的计算方法_钢绞线弹性模量的理论计算及其影响因素分析
  4. NPU-电工电子技术第一章作业讲评
  5. 第一代微型计算机中没有只有汇编语言,[]汇编语言教程2微型计算机系统的概述.ppt...
  6. excel流程图分叉 合并_excel流程图怎么画
  7. 程序人生之七:我的 2010
  8. 计算机组成模型机的视频教学,3CPU 3设计模型机 罗克露计算机组成原理课件(绝对与网上视频教程同步).pdf...
  9. 重装系统启动盘制作介绍
  10. java实现zip文件压缩和解压