K近邻法也即KNN(k-NearestNeighbor),是一种基本的分类和回归方法,与感知机不一样,感知机是二分类,KNN可以多分类。

算法核心:如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。

知识点梳理:

a、 模型:
y=argmaxcj∑xi∈Nk(x)I(yi=Cj)i=1,2,…,N,j=1,2,…,Ky = arg\space max_{c_j} \sum_{x_i\in N_k(x)}\space I(y_i = C_j) \space\space i = 1,2,…,N,\space j = 1,2,…,Ky=arg maxcj​​xi​∈Nk​(x)∑​ I(yi​=Cj​)  i=1,2,…,N, j=1,2,…,K

b、 策略:k值的选择需要在训练集通过交叉验证获得对应数据集的最优k值

c、 分类决策规则:k个邻居多数表决法,最近质心分类算法

d、 距离度量方式: 闵可夫斯基距离 “minkowski”也称作LpL_pLp​距离
Lp=[∑i=1n∣xi−yi∣p]1/pL_p = [\sum_{i=1}^n|x_i - y_i|^p]^{1/p}Lp​=[i=1∑n​∣xi​−yi​∣p]1/p

  • p=1p = 1p=1时称为曼哈顿距离 “manhattan”
  • p=2p = 2p=2时称为欧式距离 “euclidean”
  • p=无穷大p = 无穷大p=无穷大 时称为切比雪夫距离“chebyshev”

e、构造kd树:当训练数据量大,维度高时需要提高搜索效率,kd树就是提高高维数据搜索效率的方式之一。kd树是二叉树,是一种对k维空间中实例点进行存储以便对其进行快速检索的树形数据结构。kd树表示对k维空间的一个划分(partition),构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,构成一系列的k维超矩形区域。kd树的每个结点对应于一个k维超矩形区域。

f、搜索kd树,kd树的搜索复杂度ο(logN)\omicron(logN)ο(logN),注意并不是k邻近法的复杂度.

示例,二维空间中p取不同值时,与原点的LPL_PLP​距离为1(Lp=1L_p = 1Lp​=1)的图形如下:

KNN 优点

  • 理论成熟,思想简单,既可以用来做分类也可以用来做回归
  • 可用于非线性分类
  • 训练时间复杂度比支持向量机之类的算法低,仅为O(n)
  • 和朴素贝叶斯之类的算法比,对数据没有假设,准确度高,对异常点不敏感
  • 由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合
  • 该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分

KNN 缺点

  • 计算量大,尤其是特征数非常多的时候
  • 样本不平衡的时候,对稀有类别的预测准确率低
  • KD树,球树之类的模型建立需要大量的内存
  • 使用懒散学习方法,基本上不学习,导致预测时速度比起逻辑回归之类的算法慢

优缺点来自于博客

1、加载数据,MINIST数据集

import numpy as np
import os# 训练集
with open('./minist_data/train-images.idx3-ubyte') as f:loaded = np.fromfile(file = f, dtype = np.uint8)X_train = loaded[16:].reshape((60000, 784))
#     X_train = X_train.astype(np.int32)
print('X_train:',X_train.shape) # (60000, 784)with open('./minist_data/train-labels.idx1-ubyte') as f:loaded = np.fromfile(file = f, dtype = np.uint8)y_train = loaded[8:]
#     y_train = y_train.astype(np.int32)
print('y_train:',y_train.shape) # (60000,)# 测试集
with open('./minist_data/t10k-images.idx3-ubyte') as f:loaded = np.fromfile(file=f, dtype=np.uint8)X_test = loaded[16:].reshape((10000, 784))
#     X_test = X_test.astype(np.int32)
print('X_test:',X_test.shape) # (10000, 784)with open('./minist_data/t10k-labels.idx1-ubyte') as f:loaded = np.fromfile(file=f, dtype=np.uint8)y_test = loaded[8:].reshape((10000))
#     y_test = y_test.astype(np.int32)
print('y_test:',y_test.shape) # (10000,)
X_train: (60000, 784)
y_train: (60000,)
X_test: (10000, 784)
y_test: (10000,)

1.1、构造数字0和数字1的二分类数据

# 数字0 和数字1二分类
X_train1 = X_train[y_train == 0]
X_train2 = X_train[y_train == 1]
X_train12 = np.vstack([X_train1,X_train2])
print (X_train12.shape)y12 = np.ones((X_train12.shape[0],1))
y12[:X_train1.shape[0],:] = 0all_data = np.hstack([X_train12,y12])
print (all_data.shape)
print ('*'*30)# 随机打乱数据,选取前80%作为训练集,后面20%作为测试集
np.random.shuffle(all_data)
XX_train = all_data[:int(0.8*all_data.shape[0]),:-1]
yy_train = all_data[:int(0.8*all_data.shape[0]),-1]
XX_test = all_data[int(0.8*all_data.shape[0]):,:-1]
yy_test = all_data[int(0.8*all_data.shape[0]):,-1]
print (XX_train.shape)
print (yy_train.shape)
print (XX_test.shape)
print (yy_test.shape)
print (yy_test[:10])
(12665, 784)
(12665, 785)
******************************
(10132, 784)
(10132,)
(2533, 784)
(2533,)
[1. 1. 1. 1. 0. 1. 1. 0. 0. 1.]

1.2、每一类数字内部实例之间距离分布

  • 可以看出数字1的距离与其他数字之间的距离有很大的区分度,其他数字的区分度不明显
  • 距离度量方式采用的是欧氏距离np.sqrt(np.sum(np.square(x1- x2)))
# 同一类点的距离分布
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(40,6),dpi=80)def cal_distance(x1,x2):
#     x1 = np.array(x1.tolist())
#     x2 = np.array(x2.tolist())return np.sqrt(np.sum(np.square(x1- x2)))label_list = range(10)
color_list = ['black','blue','yellow','green','red','chocolate','gold','orange','purple','seagreen']
for i in range(10):X_train_temp = X_train[y_train == i][:100]distance0 = []for n in range(len(X_train_temp)):for m in range(n+1,len(X_train_temp)):distance0.append(cal_distance(X_train_temp[n],X_train_temp[m]))print ('distance%d:'%i,np.sum(distance0)/len(distance0))plt.plot(distance0[:100],color = color_list[i], label='distance%d'%label_list[i])plt.legend(loc='best')
plt.show();
distance0: 140.9302377413214
distance1: 98.50867850966716
distance2: 138.59233471066668
distance3: 138.3567534182396
distance4: 128.89089543524952
distance5: 132.80861456562158
distance6: 129.8494792770674
distance7: 123.30475748459737
distance8: 137.4896308228642
distance9: 127.77494860942429

1.3、不同数字类别之间的距离

  • 下面的实例是以0与其他10个数字之间的距离,与数字1之间的距离最小,还是比较有区分度。与其他数字没有明显的区分度
  • 数字1-9都测试过与数字0呈现的规律一样
# 不同类之间距离分布
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(40,6),dpi=80)def cal_distance(x1,x2):return np.sqrt(np.sum(np.square(x1- x2)))
#     return np.sum(np.abs(x1-x2))label_list = range(10)
color_list = ['black','blue','yellow','green','red','chocolate','gold','orange','purple','seagreen']
X_train_temp1 = X_train[y_train == 0][:100]
for i in range(10):X_train_temp = X_train[y_train == i][:100]distance0 = []for n in range(len(X_train_temp)):for m in range(n+1,len(X_train_temp)):distance0.append(cal_distance(X_train_temp1[n],X_train_temp[m]))print ('distance%d:'%i,np.sum(distance0)/len(distance0))plt.plot(distance0[:100],color = color_list[i], label='distance%d'%label_list[i])plt.legend(loc='best')
plt.show();
distance0: 140.9302377413214
distance1: 129.9145429967284
distance2: 142.60880316936058
distance3: 143.29173677133426
distance4: 139.96523346878226
distance5: 140.76444705989658
distance6: 140.5789216539252
distance7: 137.41035418359536
distance8: 143.8551541797874
distance9: 140.54349920061563

1.4、由图可以看出数字0与数字1整体基本不相交,容易出现把数字1误分类成数字0,但是不容易把数字0误分类成数字1,欧式距离比曼哈顿距离更有区分度

import matplotlib.pyplot as plt
%matplotlib inlineimport time
start = time.time()
c0 = X_train1[:100]
c1 = X_train2[:100]def cal_distance(x1,x2):return np.sqrt(np.sum(np.square(x1- x2)))
#     return np.sum(np.abs(x1-x2))import pandas as pd
distance0 = []
for n in range(len(c0)):for m in range(n+1,len(c0)):distance0.append(cal_distance(c0[n],c0[m]))
number = len(distance0)
print (number)
print ('distance0:',np.sum(distance0)/number)distance1 = []
for n in range(len(c1)):for m in range(n+1,len(c1)):distance1.append(cal_distance(c1[n],c1[m]))
number = len(distance1)
print ('distance1:',np.sum(distance1)/number)distance01 = []
for n in range(len(c1)):for m in range(n+1,len(c1)):distance01.append(cal_distance(c0[n],c1[m]))
number = len(distance01)
print ('distance01:',np.sum(distance01)/number)distance10 = []
for n in range(len(c1)):for m in range(n+1,len(c1)):distance10.append(cal_distance(c1[n],c0[m]))
print ('distance10:',np.sum(distance10)/number)
print (number)end = time.time()
# print ('Run time:',end-start)plt.figure(figsize=(40,6),dpi=80)
plt.plot(distance0[1:501],color = 'red', label="distance0")
plt.plot(distance1[1:501],color = 'blue', label="distance1")
plt.plot(distance01[:500],color = 'green', label="distance01")
plt.plot(distance10[:500],color = 'yellow', label="distance10")
plt.legend(loc='best')
plt.show();
4950
distance0: 140.9302377413214
distance1: 98.50867850966716
distance01: 129.9145429967284
distance10: 127.82344364057488
4950

2、模型预测

2.1、构造KNN分类函数

  • 数字0和数字1的二分类精度非常高为1,从上面1.4的距离展示也能看出来,数字0与数字1距离有明显的的区分度
  • 对10个数字整体分类,精度非常之低为0.295,从上面的1.2和1.3能看出来,除了数字1其他数字之间的距离并没有明显的区分度,看了预测结果很多是预测成了数字1
def nearest_classifier(X_train,y_train,x,k):distance_list = []for i in range(len(X_train)):distance_list.append(np.sqrt(np.sum(np.square(x- X_train[i]))))topKList = np.argsort(np.array(distance_list))[:k]labelList = [0] * 10for index in topKList:labelList[int(y_train[index])] += 1  return labelList.index(max(labelList))def test(X_train, y_train, X_test, y_test, k):print('start test')errorCnt = 0for i in range(len(X_test)):   if not i%(len(X_test)/10):print ('Run Percent:',float(i/len(X_test)))       x = X_test[i]y = nearest_classifier(X_train, y_train, x, k)
#         print ('prd:',y)
#         print ('actual:',y_test[i])if y != y_test[i]: errorCnt += 1    return 1 - (errorCnt / len(X_test))
k = 25
acc1 = test(XX_train, yy_train, XX_test, yy_test, k)
acc2 = test(X_train, y_train, X_test[:200], y_test[:200], k)
print ('acc1:',acc1)
print ('acc2:',acc2)
start test
0
Run Percent: 0.0
0
start test
0
Run Percent: 0.0
20
Run Percent: 0.1
40
Run Percent: 0.2
60
Run Percent: 0.3
80
Run Percent: 0.4
100
Run Percent: 0.5
120
Run Percent: 0.6
140
Run Percent: 0.7
160
Run Percent: 0.8
180
Run Percent: 0.9
141
acc1: 1.0
acc2: 0.29500000000000004

2.2、用sklearn里面的模块KNeighborsClassifier进行二分类预测,预测精度挺高的

from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier()
model.fit(XX_train,yy_train)
acc = model.score(XX_test, yy_test)
print ('acc:',acc)
acc: 0.9996052112120016

2.3、用sklearn里面的模块KNeighborsClassifier对10个类别的数字进行预测,预测精度也挺高的,距离度量采用暴力搜索和kd_tree方式进行精度都一样

from sklearn.neighbors import KNeighborsClassifier
# model = KNeighborsClassifier(algorithm='brute')
model = KNeighborsClassifier(algorithm='kd_tree')model.fit(X_train,y_train)
acc = model.score(X_test, y_test)
print ('acc:',acc)
acc: 0.9688
2.1里面的结果与1.2-1.4的距离展示相符,但是与2.3的结果相差甚大。 是什么原因呢?
XX_train = XX_train.astype(np.int32)
XX_test = XX_test.astype(np.int32)
X_train = X_train.astype(np.int32)
X_test = X_test.astype(np.int32)k = 25
acc1 = test(XX_train, yy_train, XX_test, yy_test, k)
acc2 = test(X_train, y_train, X_test[:200], y_test[:200], k)
print ('acc1:',acc1)
print ('acc2:',acc2)
start test
Run Percent: 0.0
start test
Run Percent: 0.0
Run Percent: 0.1
Run Percent: 0.2
Run Percent: 0.3
Run Percent: 0.4
Run Percent: 0.5
Run Percent: 0.6
Run Percent: 0.7
Run Percent: 0.8
Run Percent: 0.9
acc1: 0.9996052112120016
acc2: 0.97
2.1里面的结果与1.2-1.4的距离展示相符,但是与2.3的结果相差甚大。原因是载入的数据是uint8无符号整数,0 至 255,在做减法和平方时造成溢出。上面更改了数据类型后重跑了下两种预测,与直接使用KNeighborsClassifier一样。

3、sklearn 中KNN相关的类库

此部分内容主要来自博客。

sklearn.neighbors包中有:

  • KNN分类树的类是KNeighborsClassifier
  • KNN回归树的类是KNeighborsRegressor
  • 限定半径最近邻分类树的类RadiusNeighborsClassifier
  • 限定半径最近邻回归树的类RadiusNeighborsRegressor
  • 最近质心分类算法NearestCentroid

这些算法中,KNN分类和回归的类参数完全一样。限定半径最近邻法分类和回归的类的主要参数也和KNN基本一样

  • kneighbors_graph类返回用KNN时和每个样本最近的K个训练集样本的位置
  • radius_neighbors_graph返回用限定半径最近邻法时和每个样本在限定半径内的训练集样本的位置

sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, weights=’uniform’,algorithm=’auto’, leaf_size=30, p=2, metric=’minkowski’, metric_params=None,n_jobs=None, **kwargs)

参数 说明
n_neighbors K值的选择与样本分布有关,一般选择一个较小的K值,可以通过交叉验证来选择一个比较优的K值
weights,主要用于标识每个样本的近邻样本的权重 如果是KNN,就是K个近邻样本的权重,如果是限定半径最近邻,就是在距离在半径以内的近邻样本的权重。选择默认的"uniform",意味着所有最近邻样本权重都一样,在做预测时一视同仁。如果是"distance",则权重和距离成反比例,即距离预测目标更近的近邻具有更高的权重,这样在预测类别或者做回归时,更近的近邻所占的影响因子会更加大。
algorithm,KNN和限定半径最近邻法使用的算法 一共有4种可选输入,‘brute’对应第一种蛮力实现,‘kd_tree’对应第二种KD树实现,‘ball_tree’对应第三种的球树实现, ‘auto’则会在上面三种算法中做权衡
leaf_size,停止建子树的叶子节点阈值 这个值控制了使用KD树或者球树时, 停止建子树的叶子节点数量的阈值。这个值越小,则生成的KD树或者球树就越大,层数越深,建树时间越长,反之,则生成的KD树或者球树会小,层数较浅,建树时间较短。默认是30
p,距离度量附属参数 p是使用距离度量参数 metric 附属参数,只用于闵可夫斯基距离和带权重闵可夫斯基距离中p值的选择
metric,距离度量 闵可夫斯基距离 “minkowski”(默认参数)
metric_params,距离度量其他附属参数 一般都用不上,主要是用于带权重闵可夫斯基距离的权重
n_jobs,并行处理任务数 主要用于多核CPU时的并行处理,加快建立KNN树和预测搜索的速度。一般用默认的-1就可以了,即所有的CPU核都参与计算。

上面的例子直接使用KNeighborsClassifier进行预测,简单快速

from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier(n_neighbors=25, weights='uniform',algorithm='auto', leaf_size=30, p=2, metric='minkowski')
model.fit(X_train,y_train)
acc = model.score(X_test, y_test)
print ('acc:',acc)
acc: 0.9609

博客中的例子

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.datasets.samples_generator import make_classification
# X为样本特征,Y为样本类别输出, 共1000个样本,每个样本2个特征,输出有3个类别,没有冗余特征,每个类别一个簇
X, Y = make_classification(n_samples=1000, n_features=2, n_redundant=0,n_clusters_per_class=1, n_classes=3,random_state=3)
plt.figure(figsize=(8,6),dpi=80)
plt.scatter(X[:, 0], X[:, 1], marker='o', c=Y)
plt.show()

from sklearn.neighbors import KNeighborsClassifier
clf = KNeighborsClassifier(n_neighbors = 15 , weights='distance')
clf.fit(X, Y)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',metric_params=None, n_jobs=1, n_neighbors=15, p=2,weights='distance')
from matplotlib.colors import ListedColormap
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])#确认训练集的边界
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
#生成随机数据来做测试集,然后作预测
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),np.arange(y_min, y_max, 0.02))Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
print ('lenth X:',len(X))
acc1 = clf.score(X, Y)
print ('acc1:',acc1)# 画出测试集数据
Z = Z.reshape(xx.shape)
plt.figure(figsize=(8,6),dpi=80)
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)# 也画出所有的训练集数据
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=cmap_bold)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("3-Class classification (k = 15, weights = 'distance')" )
plt.show();
lenth X: 1000
acc1: 1.0

为什么有错误的分类,准确度还是1?

03_KNN_统计学习方法相关推荐

  1. 《机器学习与数据科学(基于R的统计学习方法)》——2.11 R中的SQL等价表述...

    本节书摘来异步社区<机器学习与数据科学(基于R的统计学习方法)>一书中的第2章,第2.11节,作者:[美]Daniel D. Gutierrez(古铁雷斯),更多章节内容可以访问云栖社区& ...

  2. 统计学习方法:朴素贝叶斯

    作者:桂. 时间:2017-04-20  18:31:37 链接:http://www.cnblogs.com/xingshansi/p/6740308.html 前言 本文为<统计学习方法&g ...

  3. 【组队学习】【32期】统计学习方法习题实战

    统计学习方法习题实战 航路开辟者:胡锐锋.王维嘉.王瀚翀.王茸茸.毛鹏志 领航员:张璇 航海士:胡锐锋.王维嘉.王瀚翀.王茸茸.毛鹏志.王天富.范佳慧 基本信息 开源内容:https://github ...

  4. 4000字超干货!《统计学习方法》啃书指南(1)

    事半功倍的啃书姿势内容分以下六个部分: 一.我的学习历程(不喜欢听啰嗦的可以从第二部分开始) 二.学习<统计学习方法>遇到的困难 三. 在学习过程中碰过的壁 四.最后解决问题的方法 五.学 ...

  5. 开源!北大研究生把《统计学习方法》书中全部算法都实现了!

    来源:开源最前线(ID:OpenSourceTop) 一个好的开源项目分享给大家. <统计学习方法>可以说是机器学习的入门宝典,许多机器学习培训班.互联网企业的面试.笔试题目,很多都参考这 ...

  6. 第1章统计学习方法概论之1.1统计学习

    1统计学习(也称统计机器学习) 1.1统计学习定义: 统计学习(statistical learning)是关于计算机基于数据构建概率统计模型并运用模型对数据进行预测与分析的一门学科. 统计学习就是计 ...

  7. 【统计学习方法】线性可分支持向量机对鸢尾花(iris)数据集进行二分类

    本文摘要 · 理论来源:[统计学习方法]第七章 SVM · 技术支持:pandas(读csv).numpy.sklearn.svm.svm思想.matplotlib.pyplot(绘图) · 代码目的 ...

  8. 【统计学习方法】朴素贝叶斯对鸢尾花(iris)数据集进行训练预测

    本文摘要 · 理论来源:[统计学习方法]第三四章 朴素贝叶斯 · 技术支持:pandas(读csv).numpy.sklearn.naive_bayes.GaussianNB(高斯朴素贝叶斯模型).s ...

  9. 【统计学习方法】K近邻对鸢尾花(iris)数据集进行多分类

    本文摘要 · 理论来源:[统计学习方法]第三章 K近邻 · 技术支持:pandas(读csv).collections.Counter(统计).numpy.sklearn.neighbors.KNei ...

最新文章

  1. 习题2.4 递增的整数序列链表的插入 (15 分)
  2. ubuntu adb 调试手机
  3. Opencv ---像素坐标转世界坐标(已知外参)
  4. 大型网站电商网站架构案例和技术架构的示例
  5. 一个程序员的水平能差到什么程度
  6. 最近在练习爬虫,分享一些简单入门的知识
  7. 第一章节 ASP.NET Web应用程序基础(一)
  8. 如何用MyEclipse在Resin中调试Web应用程序
  9. Linq的Distinct太不给力了
  10. hd Aruba wifi / honor
  11. python docx 图片_python-docx设置图片大小和对齐方式
  12. android 股票行情走势图K线控件 KLineView
  13. 计算机网络第一章概论
  14. django之 报错(1146, “Table ‘demo2.web‘ doesn‘t exist“)
  15. 【C语言】定义一个函数,求长方体的体积
  16. 硬实力 | 观成科技加密流量人工智能安全检测类产品荣获新技术新产品证书
  17. VSCode删除多余空行快捷方法
  18. ROS学习笔记publisher的编程实现c++详解
  19. 未来世界,已经不在遥远
  20. java后台jd_2017春季_京东_Java后端研发岗面经

热门文章

  1. iOS开发--AVFoundation自定义相机
  2. leetcode Sudoku java
  3. android学习日记13--数据存储之SharedPreference
  4. visual stdio 工程 宏
  5. plsql 常用函数
  6. 仿新浪新闻中异步替换关键字
  7. 修改HUDSON_HOME
  8. 搜索引擎广告计费系统如何防恶意点击
  9. 【elasticsearch】es一直重启,报错日志是分片无法分配
  10. 如何扩容LVM逻辑卷