作者 | 叶庭云

来源 | 修炼Python

头图 | 下载于视觉中国

KNN算法简介

KNN(K-Nearest Neighbor)最邻近分类算法是数据挖掘分类(classification)技术中常用算法之一,其指导思想是"近朱者赤,近墨者黑",即由你的邻居来推断出你的类别。

KNN最邻近分类算法的实现原理:为了判断未知样本的类别,以所有已知类别的样本作为参照,计算未知样本与所有已知样本的距离,从中选取与未知样本距离最近的 K 个已知样本,再根据少数服从多数的投票法则(majority-voting),将未知样本与 K 个最邻近样本中所属类别占比较多的归为一类。

KNN算法的核心思想:寻找最近的k个数据,来预测新数据的分类

KNN算法的关键:

  • 样本的所有特征都要做可比较的量化,若是样本特征中存在非数值的类型,必须采取手段将其量化为数值。例如样本特征中包含颜色,可通过将颜色转换为灰度值来实现距离计算。

  • 样本特征要做归一化处理,样本有多个参数,每一个参数都有自己的定义域和取值范围,他们对距离计算的影响不一样,如取值较大的影响力会盖过取值较小的参数。所以样本参数必须做一些 scale 处理,最简单的方式就是所有特征的数值都采取归一化处置。

  • 需要一个距离函数以计算两个样本之间的距离 通常使用的距离函数有:欧氏距离、曼哈顿距离、切比雪夫距离等,一般选欧氏距离作为距离度量,但是这是只适用于连续变量。在文本分类这种非连续变量情况下,汉明距离可以用来作为度量。通常情况下,如果运用一些特殊的算法来计算度量的话,K近邻分类精度可显著提高,如运用大边缘最近邻法或者近邻成分分析法。

以计算二维空间中的A(x1,y1)、B(x2,y2)两点之间的距离为例,常用的欧氏距离的计算方法如下图所示:

KNN算法的优点:

  • 简单,易于理解,易于实现,无需估计参数,无需训练;

  • 适合对稀有事件进行分类;

  • 特别适合于多分类问题(multi-modal,对象具有多个类别标签), KNN比 SVM 的表现要好。

KNN算法的缺点:

  • 只适合小数据集:正是因为这个算法太简单,每次预测新数据都需要使用全部的数据集,所以如果数据集太大,就会消耗非常长的时间,占用非常大的存储空间。

  • 数据不平衡效果不好:如果数据集中的数据不平衡,有的类别数据特别多,有的类别数据特别少,那么这种方法就会失效了,因为特别多的数据最后在投票的时候会更有竞争优势。

  • 需要做数据标准化:由于使用距离来进行计算,如果数据量纲不同,数值较大的字段影响就会变大,所以需要对数据进行标准化,比如都转换到 0-1 的区间。

  • 不适合特征维度太多的数据:由于我们只能处理小数据集,如果数据的维度太多,那么样本在每个维度上的分布就很少。比如我们只有三个样本,每个样本只有一个维度,这比每个样本有三个维度特征要明显很多。

关于 K 的选取:

  • K 值的选取会影响到模型的效果。当 K 越小的时候容易过拟合,因为结果的判断与某一个点强相关。而 K 越大的时候容易欠拟合,因为要考虑所有样本的情况,那就等于什么都不考虑。

  • 对于 K 的取值,一种显而易见的办法就是从 1 开始不断地尝试,查看准确率。随着 K 的增加,一般情况下准确率会先变大后变小,然后选取效果最好的那个 K 值就好了。当然,关于 K 最好使用奇数,因为偶数在投票的时候就困难了,如果两个类别的投票数量是一样的,那就没办法抉择了,只能随机选一个。

  • 所以选取一个合适的 K 值也是 KNN 算法在实现时候的一个难点,需要根据经验和效果去进行尝试。

鸢尾花数据分类

以经典的鸢尾花数据分类为例,熟悉 KNN 算法基本原理。使用 sklearn 自带的鸢尾花数据集,这个数据集里面有 150 条数据,共有 3 个类别,即 Setosa 鸢尾花、Versicolour 鸢尾花和 Virginica 鸢尾花,每个类别有 50 条数据,每条数据有 4 个维度,分别记录了鸢尾花的花萼长度、花萼宽度、花瓣长度和花瓣宽度。

导入需要的依赖库

from sklearn import datasets   # sklearn自带的数据集
from sklearn.neighbors import KNeighborsClassifier   # sklearn模块的KNN类
import numpy as np    # 矩阵运算库numpy# 设置随机种子,不设置的话默认是按系统时间作为参数
# 设置后可以保证我们每次产生的随机数是一样的,便于测试
np.random.seed(6)

加载数据

iris = datasets.load_iris()
iris_x = iris.data      # 数据部分
iris_y = iris.target    # 类别部分
print(iris_x)
print(iris_y)

结果如下:这个数据集里面有 150 条数据,共有 3 个类别,即 Setosa 鸢尾花、Versicolour 鸢尾花和 Virginica 鸢尾花,每个类别有 50 条数据,每条数据有 4 个维度,分别记录了鸢尾花的花萼长度、花萼宽度、花瓣长度和花瓣宽度。

KNN预测分类

# permutation 接收一个数作为参数(这里为数据集长度150) 产生一个0-149乱序一维数组
randomarr= np.random.permutation(len(iris_x))
# 随机从150条数据中选125条作为训练集,25条作为测试集
iris_x_train = iris_x[randomarr[:-25]] # 训练集数据
iris_y_train = iris_y[randomarr[:-25]] # 训练集标签
iris_x_test = iris_x[randomarr[-25:]]  # 测试集数据
iris_y_test = iris_y[randomarr[-25:]]  # 测试集标签
# 定义一个KNN分类器对象
knn = KNeighborsClassifier()
# 调用该对象的训练方法,主要接收两个参数:训练数据集及其类别标签
knn.fit(iris_x_train, iris_y_train)
# 调用预测方法,主要接收一个参数:测试数据集
iris_y_predict = knn.predict(iris_x_test)
# 计算各测试样本预测的概率值 这里我们没有用概率值,但是在实际工作中可能会参考概率值来进行最后结果的筛选,而不是直接使用给出的预测标签
probility = knn.predict_proba(iris_x_test)
# 计算与最后一个测试样本距离最近的5个点,返回的是这些样本的距离和序号组成的数组
neighborpoint = knn.kneighbors([iris_x_test[-1]], 5)
print(neighborpoint)
print('------------------------------------------------------------------')
# 调用该对象的打分方法,计算出准确率
score = knn.score(iris_x_test, iris_y_test, sample_weight=None)
# 输出测试的结果
print('iris_y_predict = ')
print(iris_y_predict)
print('------------------------------------------------------------------')
# 输出原始测试数据集的正确标签,以方便对比
print('iris_y_test = ')
print(iris_y_test)
print('------------------------------------------------------------------')
# 输出准确率计算结果
print('Accuracy:', score)

结果如下:

经过上面的一个动手尝试,我们成功地实践了 KNN 算法,并使用它对鸢尾花数据进行了分类计算,预测准确率在 90% 以上。

手写KNN算法实现思路

为了加深对 KNN 算法的理解,我们手动实现,而不用 sklearn 模块的 KNN 类。

要自己动手用 Python 实现 KNN 算法,主要有以下三个步骤:

  • 算距离:给定待分类样本,计算它与已分类样本中的每个样本的距离;

  • 找邻居:圈定与待分类样本距离最近的 K 个已分类样本,作为待分类样本的近邻;

  • 做分类:根据这 K 个近邻中的大部分样本所属的类别来决定待分类样本该属于哪个分类;

预测城市空气质量

数据来源:http://www.tianqihoubao.com/aqi/chengdu-201901.html

以预测城市空气质量为例,对已获取的几个城市 2019 年的空气质量数据进行处理,划分测试集、训练集。

读取数据集

def read_dataset(filename1, filename2, trainingSet, testSet):with open(filename1, 'r') as csvfile:lines = csv.reader(csvfile)  # 读取所有的行dataset1 = list(lines)       # 转化成列表for x in range(len(dataset1)):  # 每一行数据for y in range(8):dataset1[x][y] = float(dataset1[x][y])   # 8个参数转换为浮点数testSet.append(dataset1[x])    # 生成测试集with open(filename2, 'r') as csvfile:lines = csv.reader(csvfile)  # 读取所有的行dataset2 = list(lines)       # 转化成列表for x in range(len(dataset2)):   # 每一行数据for y in range(8):dataset2[x][y] = float(dataset2[x][y])  # 8个参数转换为浮点数trainingSet.append(dataset2[x])  # 生成训练集

计算欧氏距离

def calculateDistance(testdata, traindata, length):  # 计算距离distance = 0  # length表示维度 数据共有几维for x in range(length):distance += pow((int(testdata[x]) - int(traindata[x])), 2)return round(math.sqrt(distance), 3)    # 保留3位小数

找 K 个相邻最近的邻居

def getNeighbors(self, trainingSet, test_instance, k):  # 返回最近的k个边距distances = []length = len(test_instance)# 对训练集的每一个数计算其到测试集的实际距离for x in range(len(trainingSet)):dist = self.calculateDistance(test_instance, trainingSet[x], length)print('训练集:{} --- 距离:{}'.format(trainingSet[x], dist))distances.append((trainingSet[x], dist))distances.sort(key=operator.itemgetter(1))  # 按距离从小到大排列# print(distances)neighbors = []# 排序完成后取距离最小的前k个for x in range(k):neighbors.append(distances[x][0])print(neighbors)return neighbors

计算比例最大的分类

def getResponse(neighbors):   # 根据少数服从多数,决定归类到哪一类class_votes = {}for x in range(len(neighbors)):response = neighbors[x][-1]  # 统计每一个分类的多少  空气质量的数字标识if response in class_votes:class_votes[response] += 1else:class_votes[response] = 1print(class_votes.items())sortedVotes = sorted(class_votes.items(), key=operator.itemgetter(1), reverse=True)  # 按分类大小排序  降序return sortedVotes[0][0]    # 分类最大的  少数服从多数   为预测结果

预测准确率计算

def getAccuracy(test_set, predictions):correct = 0for x in range(len(test_set)):# predictions预测的与testset实际的比对  计算预测的准确率if test_set[x][-1] == predictions[x]:correct += 1else:# 查看错误预测print(test_set[x], predictions[x])print('有{}个预测正确,共有{}个测试数据'.format(correct, len(test_set)))return (correct / (len(test_set))) * 100.0

run函数调用

def run(self):training_set = []    # 训练集test_set = []        # 测试集self.read_dataset('./train_4/test.txt', './train_4/train.txt', training_set, test_set)  # 数据划分print('Train set: ' + str(len(training_set)))print('Test set: ' + str(len(test_set)))# generate predictionspredictions = []k = 7  # 取最近的7个数据for x in range(len(test_set)):  # 对所有的测试集进行测试neighbors = self.getNeighbors(training_set, test_set[x], k)  # 找到8个最近的邻居result = self.getResponse(neighbors)  # 找这7个邻居归类到哪一类predictions.append(result)accuracy = self.getAccuracy(test_set, predictions)print('预测准确度为:  {:.2f}%'.format(accuracy))   # 保留2位小数

运行效果如下:

测试集上预测准确率在 95% 以上。可以通过增加训练集城市空气质量数据量,调节找邻居的数量k,提高预测准确率。

福 利

CSDN给大家发压岁钱啦!

2月4日到2月11日每天上午11点

价值198元的芒果TV年卡,价值99元的CSDN月卡现金红包,CSDN电子书月卡等奖品大放送!百分百中奖

电脑端点击链接参与:

https://t.csdnimg.cn/gAkN

更多精彩推荐
☞Python 分析热卖年货,今年春节大家都在送啥?☞爬了20W+条猫咪交易数据,它不愧是人类团宠☞英超引入 AI 球探,寻找下一个足球巨星
☞三年投 1000 亿,达摩院何以仗剑走天涯?☞2021年浅谈多任务学习
点分享点收藏点点赞点在看

机器学习 KNN算法实践相关推荐

  1. 机器学习KNN算法实践:预测城市空气质量

    出品:Python数据之道 作者:叶庭云 整理:Lemon 机器学习KNN算法实践 预测城市空气质量 「Python数据之道」导读: 之前在公众号上分享过 "图解KNN算法" 的内 ...

  2. 机器学习经典算法实践_服务机器学习算法的系统设计-不同环境下管道的最佳实践

    机器学习经典算法实践 "Eureka"! While working on a persistently difficult-to-solve problem, you disco ...

  3. 机器学习knn算法学习笔记使用sklearn库 ,莺尾花实例

    ** 机器学习knn算法学习笔记使用sklearn库 ,莺尾花实例. 具体knn算法是怎样的我这里就不再详细论述.在这里我注意总结我使用knn算法进行一个分类的分析 ** 分析过程 1.前期准备 引入 ...

  4. 课程设计(毕业设计)—基于机器学习KNN算法手写数字识别系统—计算机专业课程设计(毕业设计)

    机器学习KNN算法手写数字识别系统 下载本文手写数字识别系统完整的代码和课设报告的链接(或者可以联系博主koukou(壹壹23七2五六98),获取源码和报告):https://download.csd ...

  5. 机器学习 —— KNN算法简单入门

    机器学习 -- KNN算法简单入门 第1关:手动实现简单kNN算法 1 KNN算法简介 1.1 kNN 算法的算法流程 1.2 kNN 算法的优缺点 1.3 编程要求+参数解释 2. 代码实现 3. ...

  6. 机器学习——KNN算法

    机器学习--KNN算法 文章目录 机器学习--KNN算法 前言 一.KNN原理基础 二.sklearn的基本建模流程 三.KNN算法调优:选取最优的K值 四.KNN中距离的相关讨论 1. KNN使用的 ...

  7. 经典实战案例:用机器学习 KNN 算法实现手写数字识别 | 原力计划

    作者 | 奶糖猫 来源 | CSDN 博客,责编 | 夕颜 头图 | CSDN 下载自视觉中国 出品 | CSDN(ID:CSDNnews) 算法简介 手写数字识别是KNN算法一个特别经典的实例,其数 ...

  8. 开根号的笔算算法图解_机器学习KNN算法之手写数字识别

    1.算法简介 手写数字识别是KNN算法一个特别经典的实例,其数据源获取方式有两种,一种是来自MNIST数据集,另一种是从UCI欧文大学机器学习存储库中下载,本文基于后者讲解该例. 基本思想就是利用KN ...

  9. python手写字母识别_机器学习--kNN算法识别手写字母

    本文主要是用kNN算法对字母图片进行特征提取,分类识别.内容如下: kNN算法及相关Python模块介绍 对字母图片进行特征提取 kNN算法实现 kNN算法分析 一.kNN算法介绍 K近邻(kNN,k ...

最新文章

  1. PHP 10条有用的建议
  2. arrays中copyof复制两个数组_数组,及二维数组
  3. php计算器使用方法,php--计算器的算法实现(-)
  4. php后台地址检测,[thinkphp] 隐藏后台地址
  5. LIVE555再学习 -- testRTSPClient 源码分析
  6. LeetCode移掉k位数字(贪心算法)python
  7. 【数据结构与算法】之深入解析“K个逆序对数组”的求解思路与算法示例
  8. 怎么查询局域网内全部电脑IP和mac地址..
  9. 算法训练营 重编码_编码训练营后如何找到工作
  10. linux下的进程间通信-管道及共享内存
  11. LeetCode 496. 下一个更大元素 I
  12. 远程服务器搭建建站助手,windows + 管理助手建站指南
  13. it行业 平均年龄_IT行业一线员工现状调查报告
  14. matlab图像增强实验总结,图像处理实验报告
  15. git bash返回上一级目录
  16. pycharm PEP8规范(python)
  17. 什么是TOR 官方文档
  18. 支付宝小程序模板开发,一整套流程
  19. 细细品味C#——重构的艺术
  20. 【百度】 快速精准搜索

热门文章

  1. 5.7-基于Binlog+Position的复制搭建
  2. 开源交换需新框架 技术团队也待整合
  3. echo使用说明,参数详解
  4. SeaJS基本开发原则
  5. Android 对象型数据库 db4o
  6. Api 函数: GetCursorPos 与转换
  7. mysql常见日期查询问题
  8. 功能很全的图书馆管理系统
  9. 图解CAN总线数据的组成和帧格式
  10. php获取WdatePicker值,WdatePicker日历控件使用方法