在 Python 的 sklearn 工具包中有 KNN 算法。KNN 既可以做分类器,也可以做回归。

如果是做分类,你需要引用:

from sklearn.neighbors import KNeighborsClassifier

如果是做回归,你需要引用:

from sklearn.neighbors import KNeighborsRegressor

如何在 sklearn 中创建 KNN 分类器:

我们使用构造函数 KNeighborsClassifier(n_neighbors=5, weights=‘uniform’,algorithm=‘auto’, leaf_size=30),这里有几个比较主要的参数:

创建完 KNN 分类器之后,我们就可以输入训练集对它进行训练,这里我们使用 fit() 函数,传入训练集中的样本特征矩阵和分类标识,会自动得到训练好的 KNN 分类器。然后可以使用 predict() 函数来对结果进行预测,这里传入测试集的特征矩阵可以得到测试集的预测分类结果。

如何用 KNN 对手写数字进行识别分类

手写数字数据集是个非常有名的用于图像识别的数据集。数字识别的过程就是将这些图片与分类结果 0-9 一一对应起来。

完整的手写数字数据集 MNIST 里面包括了 60000 个训练样本,以及 10000 个测试样本。如果你学习深度学习的话,MNIST 基本上是你接触的第一个数据集。

今天我们用 sklearn 自带的手写数字数据集做 KNN 分类,它只包括了 1797 幅数字图像,每幅图像大小是 8*8 像素。

训练分三个阶段:

1、加载数据集;本地导入,线上调取,自带数据集;在这里,我们使用自带数据集;

2、准备阶段:可视化数据描述:样本量多少,图像长啥样,输入输出特征;数据处理:缺失值处理,异常值处理、特征工程构造;

3、分类阶段:通过训练可以得到分类器,然后用测试集进行准确率的计算。

#加载库from sklearn.model_selection import train_test_splitfrom sklearn import preprocessingfrom sklearn.metrics import accuracy_scorefrom sklearn.datasets import load_digitsfrom sklearn.neighbors import KNeighborsClassifierfrom sklearn.svm import SVCfrom sklearn.naive_bayes import MultinomialNBfrom sklearn.tree import DecisionTreeClassifierimport matplotlib.pyplot as plt# 加载数据digits = load_digits()data = digits.data# 数据探索print(data.shape)# 查看第一幅图像print(digits.images[0])# 第一幅图像代表的数字含义print(digits.target[0])# 将第一幅图像显示出来plt.gray()plt.imshow(digits.images[0])plt.show()

运行结果:

我们对原始数据集中的第一幅进行数据可视化,可以看到图像是个 8*8 的像素矩阵,上面这幅图像是一个“0”,从训练集的分类标注中我们也可以看到分类标注为“0”。

sklearn 自带的手写数字数据集一共包括了 1797 个样本,每幅图像都是 8*8 像素的矩阵。因为并没有专门的测试集,所以我们需要对数据集做划分,划分成训练集和测试集。

因为 KNN 算法和距离定义相关,我们需要对数据进行规范化处理,采用 Z-Score 规范化:

# 分割数据,将 25% 的数据作为测试集,其余作为训练集(你也可以指定其他比例的数据作为训练集)train_x, test_x, train_y, test_y = train_test_split(data, digits.target, test_size=0.25, random_state=33)# 采用 Z-Score 规范化ss = preprocessing.StandardScaler()train_ss_x = ss.fit_transform(train_x)test_ss_x = ss.transform(test_x)

75%数据作为训练集;train_x作为训练集的输入特征值矩阵,train_y作为训练集的输出特征值;
对训练集与测试集中的输入特征值进行z评分归一化;(记住一定要对测试集输入特征进行同样的处理!以后谈论的变换默认都是对输入特征进行!)
fit_transform是fit和transform两个函数都执行一次。所以ss是进行了fit拟合的。只有在fit拟合之后,才能进行transform
在进行test的时候,我们已经在train的时候fit过了,所以直接transform即可。
另外,如果我们没有fit,直接进行transform会报错,因为需要先fit拟合,才可以进行transform。

然后我们构造一个 KNN 分类器 knn,把训练集的数据传入构造好的 knn,并通过测试集进行结果预测,与测试集的结果进行对比,得到 KNN 分类器准确率,

# 创建 KNN 分类器knn = KNeighborsClassifier() knn.fit(train_ss_x, train_y) predict_y = knn.predict(test_ss_x) print("KNN 准确率: %.4lf" % accuracy_score(predict_y, test_y))

knn.fit(训练集输入特征,训练集输出特征)
knn.predict(测试集输入特征)=模型输出值
accquary_score(knn.predict(测试集输入特征),测试集输出特征)

KNN 准确率: 0.9756

我们选用之前学过的几个模型,进行预测:

# 创建 SVM 分类器svm = SVC()svm.fit(train_ss_x, train_y)predict_y=svm.predict(test_ss_x)print('SVM 准确率: %0.4lf' % accuracy_score(predict_y, test_y))# 采用 Min-Max 规范化mm = preprocessing.MinMaxScaler()train_mm_x = mm.fit_transform(train_x)test_mm_x = mm.transform(test_x)# 创建 Naive Bayes 分类器mnb = MultinomialNB()mnb.fit(train_mm_x, train_y) predict_y = mnb.predict(test_mm_x) print(" 多项式朴素贝叶斯准确率: %.4lf" % accuracy_score(predict_y, test_y))# 创建 CART 决策树分类器dtc = DecisionTreeClassifier()dtc.fit(train_mm_x, train_y) predict_y = dtc.predict(test_mm_x) print("CART 决策树准确率: %.4lf" % accuracy_score(predict_y, test_y))

运行结果:
SVM 准确率: 0.9867
多项式朴素贝叶斯准确率: 0.8844
CART 决策树准确率: 0.8356

这里需要注意的是,我们在做多项式朴素贝叶斯分类的时候,传入的数据不能有负数。

因为 Z-Score 会将数值规范化为一个标准的正态分布,即均值为 0,方差为 1,数值会包含负数。因此我们需要采用 Min-Max 规范化,将数据规范化到 [0,1] 范围内。

数据预处理:无量纲化处理(线性:均值化与标准化;非线性),降维
当输入特征接近正态分布,使用Z评分归一化;
当输入特征呈现高度偏斜,而我们模型对输入特征的要求是正态分布时,选用Box-Cox变换;
Z评分归一化的特点:变换结果均值为0,方差为1;变换时对数值进行平移和缩放的同时,保留了密度图的总体形态
Box-Cox变换:变换时对数值进行平移和缩放的同时,改变了整体形态,产生了比原始图偏斜更少的密度图。
降维:特征缩减,比如PCA-主成分分析
特征工程:根据原有的特征,设计新的特征(实际应用需要反复验证)

倘若同样的数据集,改变k值,会得出如下结论:

knn默认k值为5 准确率:0.9756
knn的k值为200的准确率:0.8489
SVM分类准确率:0.9867
高斯朴素贝叶斯准确率:0.8111
多项式朴素贝叶斯分类器准确率:0.8844
CART决策树准确率:0.8400
K值的选取如果过大,正确率降低。
算法效率排行 SVM > KNN(k值在合适范围内) >多项式朴素贝叶斯 > CART > 高斯朴素贝叶斯

分别用 KNN、SVM、朴素贝叶斯和决策树做分类器,并统计了四个分类器的准确率。在数据量不大的情况下,使用 sklearn 还是方便的。

如果数据量很大,比如 MNIST 数据集中的 6 万个训练数据和 1 万个测试数据,那么采用深度学习 +GPU 运算的方式会更适合。

因为深度学习的特点就是需要大量并行的重复计算,GPU 最擅长的就是做大量的并行计算。

总结:

1、各模型对输入特征是有分布要求的;比如多项式朴素贝叶斯分类要求数据非负;最小二乘法模型要求数据是正态分布,满足四大假设;神经网络则对数据分布无要求。

2、特征变换是针对输入特征的;特征变化是数据处理的子集,决定模型的上限,而模型的好坏只是逼近这个上限

3、skearn适合数据量较小的训练,若是数据量过大,可以采用深度学习框架+GPU运算。实际运用中,可以使用集成学习(机器学习+深度学习框架)完成。

参考:

数据分析实战45讲

用商业案例学R语言数据挖掘--特征工程

数据预处理

c++ 正态分布如何根据x求y_knn实战:如何对手写数字进行识别?相关推荐

  1. knn实战:如何对手写数字进行识别?

    在 Python 的 sklearn 工具包中有 KNN 算法.KNN 既可以做分类器,也可以做回归. 如果是做分类,你需要引用: from sklearn.neighbors import KNei ...

  2. 【机器学习实战】利用KNN和其他分类器对手写数字进行识别

    一.在sklearn中创建KNN分类器 KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_si ...

  3. 【北京大学】13 TensorFlow1.x的项目实战之手写英文体识别OCR技术

    目录 1 项目介绍 1.1 项目功能 1.2 评估指标 2 数据集介绍 2.1 数据特征 3 数据的预处理 3.1 数据增强 3.2 倾斜矫正 3.3 去横线 3.4 文本区域定位 4 网络结构 5 ...

  4. 深度学习实战——利用卷积神经网络对手写数字二值图像分类(附代码)

    系列文章目录 深度学习实战--利用卷积神经网络对手写数字二值图像分类(附代码) 目录 系列文章目录 前言 一.案例需求 二.MATLAB算法实现 三.MATLAB源代码 参考文献 前言 本案例利用MA ...

  5. 基于Paddle的计算机视觉入门教程——第7讲 实战:手写数字识别

    B站教程地址 https://www.bilibili.com/video/BV18b4y1J7a6/ 任务介绍 手写数字识别是计算机视觉的一个经典项目,因为手写数字的随机性,使用传统的计算机视觉技术 ...

  6. pytorch实战案例-手写数字分类-卷积模型——深度AI科普团队

    文章目录 数据准备 导入需要的模块 使用GPU训练 将数据转换为tensor 导入训练集和测试集 数据加载器 数据展示 创建模型 将模型复制到GPU 损失函数 定义训练和测试函数 开始训练 源码已经上 ...

  7. pytorch实战案例-手写数字分类-全链接模型——深度AI科普团队

    文章目录 @[TOC] 数据准备 导入需要的模块 将数据转换为tensor 导入训练集和测试集 数据加载器 数据展示 创建模型 定义损失函数 定义优化函数 定义训练和测试函数 开始训练 源码已经上传: ...

  8. python手写数字识别实验报告_机器学习python实战之手写数字识别

    看了上一篇内容之后,相信对K近邻算法有了一个清晰的认识,今天的内容--手写数字识别是对上一篇内容的延续,这里也是为了自己能更熟练的掌握k-NN算法. 我们有大约2000个训练样本和1000个左右测试样 ...

  9. TensorFlow2 入门指南 | 04 分类问题实战之手写数字识别

    前言: 本专栏在保证内容完整性的基础上,力求简洁,旨在让初学者能够更快地.高效地入门TensorFlow2 深度学习框架.如果觉得本专栏对您有帮助的话,可以给一个小小的三连,各位的支持将是我创作的最大 ...

最新文章

  1. html中单双引号嵌套,[转]详细讲述asp中单引号与双引号(即引号多重嵌套)的用法...
  2. 取消学术型硕士,增扩博士,北京大学这个学院做出研究生培养结构调整
  3. 拨号、宽带接入“面面观”比较九种上网方式
  4. VC++获取屏幕大小第二篇 物理大小GetDeviceCaps 上
  5. PrimeFaces:在动态生成的对话框中打开外部页面
  6. php打印出函数的内容吗,PHP打印函数集合详解以及PHP打印函数对比详解(精)
  7. java swarm_科学网—Java_Swarm编程:遇到麻烦了...... - 高德华的博文
  8. 360压缩电脑版_震惊!360竟然出了一款这么良心的软件
  9. 扩展PHP内置的异常处理类
  10. 虚拟参考站(VRS)
  11. 小小技巧--BLOB视频加密
  12. Excel两行交换及两列交换,快速互换相邻表格数据的方法
  13. Flutter 项目实战 截图分享到微信|QQ|微博 十二
  14. uni-app省市区选择器
  15. JavaScript实现连缀
  16. 送给你的一份英语学习资料,请查收!
  17. 时序预测 | MATLAB实现ARIMA时间序列预测(GDP预测)
  18. 百度地图JavaScript API 学习之地址解析
  19. 云服务器php文件怎么运行,云服务器php文件怎么运行环境
  20. ffmpeg简易使用应用分享(m3u8下载与视频合并等)

热门文章

  1. 比 TensorFlow Lite 快 15.6 倍!业界首个移动 GPU BNN 加速引擎 PhoneBit 开源
  2. 百度大脑语音能力引擎论坛定档 11.28,邀你一同解码 AI 语音的奥秘
  3. 这款耳机性价比值得你看一下
  4. 前端开发大师修炼指南
  5. 刷爆抖音,评分9.7!这本Python书太酷了!程序员:太爱!
  6. 漫画:什么是希尔排序?
  7. 不得了!这个 AI 让企业家、技术人员、投资人同台“互怼”
  8. 让 Cloud Native 飞,Pick 干货,看这里、看这里!
  9. 阿里涉足零售 IoT 的猜想
  10. 分布式存储绝不简单 —— UCan下午茶-武汉站纪实