介绍

在我遇到的所有机器学习算法中,KNN是最容易上手的。尽管它很简单,但事实上它其实在某些任务中非常有效(正如你将在本文中看到的那样)。

甚至它可以做的更好?它可以用于分类和回归问题!然而,它其实更擅长用于分类问题。我很少看到KNN在任何回归任务上实现。我在这里的目的是说明并强调,当目标变量本质上是连续的时,KNN是如何有效的运作的。

在本文中,我们将首先了解KNN算法背后的思维,研究计算点与点之间距离的不同方法,然后最终在Big Mart Sales数据集上用Python实现该算法。让我们动起来吧

1.用简单的例子来理解KNN背后的逻辑

让我们从一个简单的例子开始。请考虑下表 - 它包含10人的身高,年龄和体重(目标)值。如你所见,缺少ID11的重量值。我们需要根据他们的身高和年龄来预测这个人的体重。

注意:此表中的数据不代表实际值。它仅用作一个例子来解释这个概念。

为了更清楚地了解这一点,下面是上表中高度与年龄的关系图:

在上图中,y轴表示人的身高(以英尺为单位),x轴表示年龄(以年为单位)。这些点是根据ID值进行编号。黄点(ID 11)是我们的测试点。

如果我要求你根据图来确定ID11的重量,你的答案会是什么?你可能会说,因为ID11 更接近第 5点和第1点,所以它必须具有与这些ID类似的重量,可能在72-77千克之间(表中ID1和ID5的权重)。这实际上是有道理的,但你认为算法会如何预测这些值呢?让我们在下边进行试验讨论。

2. KNN算法是怎样工作的

如上所述,KNN可用于分类和回归问题。该算法使用“ 特征相似性 ”来预测任何新数据点的值。这意味着新的点将根据其与训练集中的点的接近程度而进行分配。从我们的例子中,我们知道ID11的高度和年龄类似于ID1和ID5,因此重量也大致相同。

如果这是一个分类问题,我们会采用该模式作为最终预测。在这种情况下,我们有两个重量值--72和77.猜猜最终值是如何计算的?是取两个重量的平均值来作为最终的预测值。

以下是该算法的逐步说明:首先,计算新的点与每个训练点之间的距离。选择最接近的k个数据点(基于距离)。在我们演示的例子中,如果k的值为3,则将选择点1,5,6。我们将在本文后面进一步探索选择正确的k值的方法。这些数据点的平均值是新点的最终预测值。在这里,我们的ID11的重量为 =(77 + 72 + 60)/ 3 = 69.66千克。

在接下来的几节中,我们将详细讨论这三个步骤中的每一个。

3.点与点之间距离的计算方法

所述第一步骤是计算新点和每个训练点之间的距离。计算该距离有多种方法,其中最常见的方法是 - 欧几里德,曼哈顿(用于连续)和汉明距离(用于分类)。欧几里德距离:欧几里德距离计算为新点(x)和现有点(y)之间的差的平方和的平方根。

曼哈顿距离:这是实际向量之间的距离,使用它们的绝对差值之和表示。汉明距离:用于分类变量。如果值(x)和值(y)相同,则距离D将等于0。否则D = 1。

一旦一个新的观测值与我们训练集中的点之间的距离被测量出来,下一步就是要选择最近的点。要考虑的点的数量由k的值定义。

4.如何选择k因子

第二个步骤是选择k值。这决定了我们在为任何新的观察值赋值时所要考虑到的邻居的数量。

在我们的示例中,k值 = 3,最近的点是ID1,ID5和ID6。

ID11的重量预测将是:

ID11 =(77 + 72 + 60)/ 3 ID11 = 69.66千克

如果k的值 = 5的话,那么距离最近的点将是ID1,ID4,ID5,ID6,ID10。

那么ID11的预测将是:

ID 11 =(77 + 59 + 72 + 60 + 58)/ 5 ID 11 = 65.2千克

我们注意到,基于k值,最终结果将趋于变化。那我们怎样才能找出k的最优值呢?让我们根据我们的训练集和验证集的误差计算来决定它(毕竟,最小化误差是我们的最终目标!)。

请看下面的图表,了解不同k值的训练误差和验证误差。

对于非常低的k值(假设k = 1),模型过度拟合训练数据,这导致验证集上的高错误率。另一方面,对于k的高值,该模型在训练集和验证集上都表现不佳。如果仔细观察,验证误差曲线在k = 9的值处达到最小值。那么该k值就是是模型的最佳K值(对于不同的数据集,它将有所不同)。该曲线称为“ 肘形曲线 ”(因为它具有类似肘部的形状),通常用于确定k值。

你还可以使用网格搜索技术来查找最佳k值。我们将在下一节中实现这一点。

5.处理数据集(Python代码)

到目前为止,你应该清楚的了解这个算法。我们现在将继续在数据集上实现该算法。我使用Big Mart销售数据集来进行代码实现,你可以从此链接下载它,邀请码为b543。

1.阅读文件

import pandas as pd df = pd.read_csv('train.csv') df.head()

2.计算缺失值

df.isnull().sum() #输入Item_weight和Outlet_size中缺少的值 mean = df['Item_Weight'].mean() #imputing item_weight with mean df['Item_Weight'].fillna(mean, inplace =True) mode = df['Outlet_Size'].mode() #imputing outlet size with mode df['Outlet_Size'].fillna(mode[0], inplace =True)

3.处理分类变量并删除id列

df.drop(['Item_Identifier', 'Outlet_Identifier'], axis=1, inplace=True) df = pd.get_dummies(df)

4.创建训练集和测试集

from sklearn.model_selection import train_test_split train , test = train_test_split(df, test_size = 0.3) x_train = train.drop('Item_Outlet_Sales', axis=1) y_train = train['Item_Outlet_Sales'] x_test = test.drop('Item_Outlet_Sales', axis = 1) y_test = test['Item_Outlet_Sales']

5.预处理 - 扩展功能

from sklearn.preprocessing import MinMaxScaler scaler = MinMaxScaler(feature_range=(0, 1)) x_train_scaled = scaler.fit_transform(x_train) x_train = pd.DataFrame(x_train_scaled) x_test_scaled = scaler.fit_transform(x_test) x_test = pd.DataFrame(x_test_scaled)

6.查看不同K值的错误率

#导入所需要的包 from sklearn import neighbors from sklearn.metrics import mean_squared_error from math import sqrt import matplotlib.pyplot as plt %matplotlib inline rmse_val = [] #存储不同K值的RMSE值 for K in range(20): K = K+1 model = neighbors.KNeighborsRegressor(n_neighbors = K) model.fit(x_train, y_train) #合适的模型 pred=model.predict(x_test) #对测试集进行测试 error = sqrt(mean_squared_error(y_test,pred)) #计算RMSE值 rmse_val.append(error) #存储RMSE值 print('RMSE value for k= ' , K , 'is:', error)

输出:

RMSE value for k = 1 is: 1579.8352322344945 RMSE value for k = 2 is: 1362.7748806138618 RMSE value for k = 3 is: 1278.868577489459 RMSE value for k = 4 is: 1249.338516122638 RMSE value for k = 5 is: 1235.4514224035129 RMSE value for k = 6 is: 1233.2711649472913 RMSE value for k = 7 is: 1219.0633086651026 RMSE value for k = 8 is: 1222.244674933665 RMSE value for k = 9 is: 1219.5895059285074 RMSE value for k = 10 is: 1225.106137547365 RMSE value for k = 11 is: 1229.540283771085 RMSE value for k = 12 is: 1239.1504407152086 RMSE value for k = 13 is: 1242.3726040709887 RMSE value for k = 14 is: 1251.505810196545 RMSE value for k = 15 is: 1253.190119191363 RMSE value for k = 16 is: 1258.802262564038 RMSE value for k = 17 is: 1260.884931441893 RMSE value for k = 18 is: 1265.5133661294733 RMSE value for k = 19 is: 1269.619416217394 RMSE value for k = 20 is: 1272.10881411344

#根据K值绘制RMSE值 curve = pd.DataFrame(rmse_val) #elbow curve curve.plot()

正如我们所讨论的,当我们取k = 1时,我们得到一个非常高的RMSE值。随着我们增加k值,RMSE值不断减小。在k = 7时,RMSE约为1219.06,并且随着K值在进一步增加,RMSE值会迅速上升。我们可以有把握地说,在这种情况下,k = 7会给我们带来最好的结果。

这些是使用我们的训练数据集进行的预测。现在让我们预测测试数据集的值并进行提交。

7.对测试数据集的预测

#阅读测试和提交文件 test = pd.read_csv('test.csv') submission = pd.read_csv('SampleSubmission.csv') submission['Item_Identifier'] = test['Item_Identifier'] submission['Outlet_Identifier'] = test['Outlet_Identifier'] #预处理测试数据集 test.drop(['Item_Identifier', 'Outlet_Identifier'], axis=1, inplace=True) test['Item_Weight'].fillna(mean, inplace =True) test = pd.get_dummies(test) test_scaled = scaler.fit_transform(test) test = pd.DataFrame(test_scaled) #预测测试集并创建提交文件 predict = model.predict(test) submission['Item_Outlet_Sales'] = predict submission.to_csv('submit_file.csv',index=False)

在提交此文件后,我得到的RMSE为1279.5159651297。

8.实现GridsearchCV

为了确定k的值,每次绘制肘部曲线是一个繁琐且繁琐的过程。你只需使用gridsearch即可简单的找到最佳值。

from sklearn.model_selection import GridSearchCV params = {'n_neighbors':[2,3,4,5,6,7,8,9]} knn = neighbors.KNeighborsRegressor() model = GridSearchCV(knn, params, cv=5) model.fit(x_train,y_train) model.best_params_

输出:

{'n_neighbors': 7}

6.结束语和其他资源

在本文中,我们介绍了KNN算法的工作原理及其在Python中的实现。它是最基本但最有效的机器学习技术之一。并且在本文中,我们是直接调用了Sklearn库中的KNN模型,如果你想更仔细的研究一下KNN的话,我建议你可以手敲一下有关KNN的源代码。

本文作者介绍了如何使用KNN算法去进行完成回归任务,大家如果感兴趣的话,可以跟着本文敲一遍代码,进行练习,毕竟看10篇文章也不如去敲一遍代码,毕竟看文章看看也就过去了,如果敲一遍代码的话,就会加深自己的印象,如果想深入的去了解KNN算法的话,可以自己去研究一下KNN的源代码,然后敲一遍,我们后边也会放出有关KNN源代码的文章,当然其他算法的文章我们也会发布,请大家到时候多多捧场。

A Practical Introduction to K-Nearest Neighbors Algorithm for Regression (with Python code)

k近邻回归算法python_K近邻算法用作回归的使用介绍(使用Python代码)相关推荐

  1. A* 算法原理以及在二维环境地图中的应用 -- Python 代码实现

    上节学习了 Dijkstra 路径规划规划算法,虽然能够找到最短的路径,但是其遍历的搜索过程具有盲目性,因此效率比较低,计算量非常大.而实际中电子地图的结点数量是非常庞大的,Dijkstra 算法在有 ...

  2. MATLAB算法实战应用案例精讲-【回归算法】偏最小二乘回归(PLS)(附MATLAB、R语言和Python代码)

    目录 前言 算法原理 建模方法 建模步骤 算法特点 算法步骤

  3. KMeans算法实现步骤介绍及Python代码

    文章目录 一.KMeans算法的步骤 二.KMeans实现过程中需要注意的地方 1.初始聚类中心的确定 2. 常用的距离度量 3. 聚类效果的衡量 SSE 4.迭代结束条件 5.空簇的处理 三.结果展 ...

  4. 【图像识别算法】像素级提取图像关键特征、内容 --python代码

    像素级提取图像关键特征算法-rgb 关键词:python像素级处理图像,python提取图片关键特征. 基于knn的图像识别技术主要涉及到以下概念: 色彩成像原理 [图像原理]rgb数字图片概念 计算 ...

  5. 逻辑回归 - sklearn (LR、LRCV、MLP、RLR)- Python代码实现

    目录 LR(LogisticRegression) - 线性回归 LRCV(LogisticRegressionCV )- 逻辑回归 MLP(MLPRegressor) - 人工神经网络 RLR(Ra ...

  6. python svr回归_SVR入门介绍(Python代码)

    一. SVR原理简述 在前面的文章中详细讨论过关于线性回归的公式推导,线性回归传送站.线性回归的基本模型为: ,从某方面说这和超平面的的表达式: 有很大的相似性.但SVR认为只要 与 不要偏离太大即算 ...

  7. python必备基础代码-机器学习算法基础(使用Python代码)

    介绍 谷歌的自动驾驶汽车和机器人受到了很多媒体的关注,但该公司真正的未来是在机器学习领域,这种技术能使计算机变得更聪明,更个性化.-Eric Schmidt(Google董事长) 我们可能生活在人类历 ...

  8. 机器学习算法基础(使用Python代码)

    介绍 谷歌的自动驾驶汽车和机器人受到了很多媒体的关注,但该公司真正的未来是在机器学习领域,这种技术能使计算机变得更聪明,更个性化.-Eric Schmidt(Google董事长) 我们可能生活在人类历 ...

  9. 机器学习算法基础之使用python代码

    介绍 谷歌的自动驾驶汽车和机器人受到了很多媒体的关注,但该公司真正的未来是在机器学习领域,这种技术能使计算机变得更聪明,更个性化.-Eric Schmidt(Google董事长) 我们可能生活在人类历 ...

最新文章

  1. 离开百度三年多,吴恩达纽交所敲钟,身价再增20亿
  2. 学计算机所需要的英语单词,学计算机最少要懂的英语单词是什么
  3. 计算机基础课程在线教学授课方案,《计算机基础》课程分层次模块化教学实施方案...
  4. 常用的几种卷积神经网络介绍
  5. ★(在人生的过程中我的65047777
  6. 一个题目涉及到的50个Sql语句
  7. NeHe OpenGL教程 第四十四课:3D光晕
  8. FreeFileSync 免费文件同步软件 实时自动备份重要资料
  9. 凤凰os linux双系统,逍遥安卓模拟器双系统无缝连接完胜Remix、凤凰OS
  10. c语言回溯法解决倒桥本分数式,回溯法 经典题目 八皇后 桥本分数
  11. ubuntu中文智能拼音输入法配置
  12. MapReduce 内部实现机制,你真的懂吗?
  13. starUML for MAC 破解方法
  14. 微信在线EXCEL自动统计人数
  15. 互联网大鱼吃小鱼背后:十亿消费者的推崇
  16. 【独家】硅谷创业公司在中国常跌的五个坑|禾赛科技CEO李一帆柏林亚太周主题演讲...
  17. 5分钟告诉你如何成为一名黑客?从萌新成为大佬,只需掌握这5点(思维、编程语言、网络安全、入侵实操、法律)
  18. 基于微信小程序的学生选课管理系统 小程序 uniapp
  19. Python 使用office365邮箱自动发送邮件
  20. 创业第12天,百度竞价助手3.0正式发布,上传各大软件下载站

热门文章

  1. 波卡链Substrate (6)Babe协议二“分配slot机制”
  2. 区块链BaaS云服务(12)易居(中国) 房地产 EBaaS(Estate Blockchain as a Service)
  3. 【数据库复习】第二章关系数据库
  4. buu 信息化时代的步伐
  5. [HOW TO]-ubuntu20.10环境上安装workpress
  6. Linux Kernel中断下半部分实现的三种方式
  7. [gic]-gicv3/gicv4的feature总结
  8. 如何编写一个测试HIDL接口的vts(gtest)的testcase
  9. 密码学基础知识(五)序列密码
  10. matlab中大figure怎样修改,操作Matlab的Figure窗口(一)