KNN主要包括训练过程和分类过程。在训练过程上,需要将训练集存储起来。在分类过程中,将测试集和训练集中的每一张图片去比较,选取差别最小的那张图片。

如果数据集多,就把训练集分成两部分,一小部分作为验证集(假的测试集),剩下的都为训练集(一般来说是70%-90%,具体多少取决于需要调整的超参数的多少,如果超参数多,验证集占比就更大一点)。验证集的好处是用来调节超参数,如果数据集不多,使用交叉验证的方法来调节参数。但是交叉验证的代价比较高,K折交叉验证,K越大越好,但是代价也更高。

决策分类

明确K个邻居中所有数据类别的个数,将测试数据划分给个数最多的那一类。即由输入实例的 K 个最临近的训练实例中的多数类决定输入实例的类别。

常用决策规则:

多数表决法:多数表决法和我们日常生活中的投票表决是一样的,少数服从多数,是最常用的一种方法。

加权表决法:有些情况下会使用到加权表决法,比如投票的时候裁判投票的权重更大,而一般人的权重较小。所以在数据之间有权重的情况下,一般采用加权表决法。

优点:

所选择的邻居都是已经正确分类的对象

KNN算法本身比较简单,分类器不需要使用训练集进行训练,训练时间复杂度为0。本算法分类的复杂度与训练集中数据的个数成正比。

对于类域的交叉或重叠较多的待分类样本,KNN算法比其他方法跟合适。

缺点:

当样本分布不平衡时,很难做到正确分类

计算量较大,因为每次都要计算测试数据到全部数据的距离。

python代码实现:

import numpy as np

class kNearestNeighbor:

def init(self):

pass

def train(self, X, y):

self.Xtr = X

self.ytr = y

def predict(self, X, k=1):

num_test = X.shape[0]

Ypred = np.zeros(num_test, dtype = self.ytr.dtype)

for i in range(num_test):

distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)

closest_y = y_train[np.argsort(distances)[:k]]

u, indices = np.unique(closest_y, return_inverse=True)

Ypred[i] = u[np.argmax(np.bincount(indices))]

return Ypred

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

load_CIFAR_batch()和load_CIFAR10()是用来加载CIFAR-10数据集的

import pickle

def load_CIFAR_batch(filename):

“”" load single batch of cifar “”"

with open(filename, ‘rb’) as f:

datadict = pickle.load(f, encoding=‘latin1’)

X = datadict[‘data’]

Y = datadict[‘labels’]

X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype(“float”)

Y = np.array(Y)

return X, Y

1

2

3

4

5

6

7

8

9

10

import os

def load_CIFAR10(ROOT):

“”" load all of cifar “”"

xs = []

ys = []

for b in range(1,6):

f = os.path.join(ROOT, ‘data_batch_%d’ %(b))

X, Y = load_CIFAR_batch(f)

xs.append(X)

ys.append(Y)

Xtr = np.concatenate(xs) #使变成行向量

Ytr = np.concatenate(ys)

del X,Y

Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, ‘test_batch’))

return Xtr, Ytr, Xte, Yte

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

Xtr, Ytr, Xte, Yte = load_CIFAR10(‘cifar10’)

Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3)

Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3)

1

2

3

#由于数据集稍微有点大,在电脑上跑的很慢,所以取训练集5000个,测试集500个

num_training = 5000

num_test = 500

x_train = Xtr_rows[:num_training, :]

y_train = Ytr[:num_training]

x_test = Xte_rows[:num_test, :]

y_test = Yte[:num_test]

1

2

3

4

5

6

7

8

9

knn = kNearestNeighbor()

knn.train(x_train, y_train)

y_predict = knn.predict(x_test, k=7)

acc = np.mean(y_predict == y_test)

print(‘accuracy : %f’ %(acc))

1

2

3

4

5

accuracy : 0.302000

1

#k值取什么最后的效果会更好呢?可以使用交叉验证的方法,这里使用的是5折交叉验证

num_folds = 5

k_choices = [1, 3, 5, 8, 10, 12, 15, 20, 50, 100]

x_train_folds = np.array_split(x_train, num_folds)

y_train_folds = np.array_split(y_train, num_folds)

k_to_accuracies = {}

for k_val in k_choices:

print('k = ’ + str(k_val))

k_to_accuracies[k_val] = []

for i in range(num_folds):

x_train_cycle = np.concatenate([f for j,f in enumerate (x_train_folds) if j!=i])

y_train_cycle = np.concatenate([f for j,f in enumerate (y_train_folds) if j!=i])

x_val_cycle = x_train_folds[i]

y_val_cycle = y_train_folds[i]

knn = kNearestNeighbor()

knn.train(x_train_cycle, y_train_cycle)

y_val_pred = knn.predict(x_val_cycle, k_val)

num_correct = np.sum(y_val_cycle == y_val_pred)

k_to_accuracies[k_val].append(float(num_correct) / float(len(y_val_cycle)))

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

k = 1

k = 3

k = 5

k = 8

k = 10

k = 12

k = 15

k = 20

k = 50

k = 100

1

2

3

4

5

6

7

8

9

10

for k in sorted(k_to_accuracies):

for accuracy in k_to_accuracies[k]:

print(‘k = %d, accuracy = %f’ % (int(k), accuracy))

1

2

3

k = 1, accuracy = 0.098000

k = 1, accuracy = 0.148000

k = 1, accuracy = 0.205000

k = 1, accuracy = 0.233000

k = 1, accuracy = 0.308000

k = 3, accuracy = 0.089000

k = 3, accuracy = 0.142000

k = 3, accuracy = 0.215000

k = 3, accuracy = 0.251000

k = 3, accuracy = 0.296000

k = 5, accuracy = 0.096000

k = 5, accuracy = 0.176000

k = 5, accuracy = 0.240000

k = 5, accuracy = 0.284000

k = 5, accuracy = 0.309000

k = 8, accuracy = 0.100000

k = 8, accuracy = 0.175000

k = 8, accuracy = 0.263000

k = 8, accuracy = 0.289000

k = 8, accuracy = 0.310000

k = 10, accuracy = 0.099000

k = 10, accuracy = 0.174000

k = 10, accuracy = 0.264000

k = 10, accuracy = 0.318000

k = 10, accuracy = 0.313000

k = 12, accuracy = 0.100000

k = 12, accuracy = 0.192000

k = 12, accuracy = 0.261000

k = 12, accuracy = 0.316000

k = 12, accuracy = 0.318000

k = 15, accuracy = 0.087000

k = 15, accuracy = 0.197000

k = 15, accuracy = 0.255000

k = 15, accuracy = 0.322000

k = 15, accuracy = 0.321000

k = 20, accuracy = 0.089000

k = 20, accuracy = 0.225000

k = 20, accuracy = 0.270000

k = 20, accuracy = 0.319000

k = 20, accuracy = 0.306000

k = 50, accuracy = 0.079000

k = 50, accuracy = 0.248000

k = 50, accuracy = 0.278000

k = 50, accuracy = 0.287000

k = 50, accuracy = 0.293000

k = 100, accuracy = 0.075000

k = 100, accuracy = 0.246000

k = 100, accuracy = 0.275000

k = 100, accuracy = 0.284000

k = 100, accuracy = 0.277000

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

可视化交叉验证的结果

import matplotlib.pyplot as plt

plt.rcParams[‘figure.figsize’] = (10.0, 8.0)

plt.rcParams[‘image.interpolation’] = ‘nearest’

plt.rcParams[‘image.cmap’] = ‘gray’

1

2

3

4

5

for k in k_choices:

accuracies = k_to_accuracies[k]

plt.scatter([k] * len(accuracies), accuracies)

accuracies_mean = np.array([np.mean(v) for k,v in sorted(k_to_accuracies.items())])

accuracies_std = np.array([np.std(v) for k,v in sorted(k_to_accuracies.items())])

plt.errorbar(k_choices, accuracies_mean, yerr=accuracies_std)

plt.title(‘Cross-validation on k’)

plt.xlabel(‘k’)

plt.ylabel(‘Cross-validation accuracy’)

plt.show()

1

2

3

4

5

6

7

8

9

10

11

knn算法python理解与预测_理解KNN算法相关推荐

  1. mooc数据结构与算法python版期末考试_数据结构与算法Python版-中国大学mooc-试题题目及答案...

    数据结构与算法Python版-中国大学mooc-试题题目及答案 更多相关问题 婴儿出生一两天后就有笑的反应,这种笑的反应属于(). [判断题]填制原始凭证,汉字大写金额数字一律用正楷或草书书写,汉字大 ...

  2. 数据结构与算法python语言实现答案_数据结构与算法:Python语言实现 源代码 PPT 练习答案 源码.zip...

    1 60660-数据结构与算法:Python语言实现[练习答案]Solutions Manual.rar 943.25 KB 2018/11/1 12:03:34 2 __MACOSX 0 Bytes ...

  3. 卷积神经网络算法python实现车牌识别_车牌识别算法之CNN卷积神经网络

    原标题:车牌识别算法之CNN卷积神经网络 随着我国经济的发展,汽车,特别是小轿车的数量越来越多,智能交通管理系统应运而生.车牌智能自动识别作为智能交通管理系统中的重要组成部分,在智能交通管理中发挥着越 ...

  4. 数据结构与算法python版 期末考试_数据结构与算法Python版期末在线考试OJ部分

    1 二叉树路径(10分) 题目内容: 给定一个二叉查找树的节点插入顺序,请重新构建这个二叉查找树,并按从左至右顺序返回所有根节点至叶节点的路径 输入格式: 一行整数,以空格分隔 注:测试用例中不包含重 ...

  5. python xgb模型 预测_如何使用XGBoost模型进行时间序列预测

    字幕组双语原文:如何使用XGBoost模型进行时间序列预测 英语原文:How to Use XGBoost for Time Series Forecasting 翻译:雷锋字幕组(Shangru) ...

  6. mooc数据结构与算法python版期末测验_中国大学MOOC(慕课)_数据结构与算法Python版_测试题及答案...

    中国大学MOOC(慕课)_数据结构与算法Python版_测试题及答案 更多相关问题 采用fopen()函数打开文件,支持文件读取的参数有: [简答题]简单阐述高分子材料热-机械特征及成型加工的关系,并 ...

  7. mooc数据结构与算法python版期末测验_中国大学数据结构与算法Python版答案_MOOC慕课章节期末答案...

    中国大学数据结构与算法Python版答案_MOOC慕课章节期末答案 更多相关问题 java.lang 包的 Character 类的 isJavaIdentifierStart 方法的功能是用来判断某 ...

  8. 中文分词算法python代码_python实现中文分词FMM算法实例

    本文实例讲述了python实现中文分词FMM算法.分享给大家供大家参考.具体分析如下: FMM算法的最简单思想是使用贪心算法向前找n个,如果这n个组成的词在词典中出现,就ok,如果没有出现,那么找n- ...

  9. 人工免疫算法c语言实例,毕业设计_蚁群算法模拟系统的设计与实现.doc

    J I A N G S U U N I V E R S I T Y 本 科 毕 业 论 文 蚁群算法模拟系统的设计与实现 Ant Colony Simulation System Design and ...

最新文章

  1. [笔记][原创]Verilog HDL语法分析笔记
  2. 61二叉搜索树的第k个结点
  3. 清除nginx服务器网站缓存数据
  4. 腾讯2014年实习生招聘笔试面试经历
  5. 关于memcpy和 strcpy的区别 以及memset
  6. TX2安装CH341驱动 总结
  7. JavaScript详细解析
  8. 电脑桌面云便签怎么登录便签账号?
  9. java------jdkd的安装与配置变量环境
  10. 串口服务器是什么,有什么功能
  11. mysql时间自动填充_Mysql自动设置时间(自动获取时间,填充时间)
  12. 四种渠道打造网站高质量原创内容
  13. 自定义 SpringBoot Banner 图案
  14. 使用POI实现Excel导出导入 详细解释
  15. Xcode直接安装ipa
  16. 项目管理必备工具——甘特图
  17. Django Swagger文档库drf-spectacular
  18. word之插入LaTex公式
  19. 02 数据定义语言DDL
  20. Java毕设项目派大星水产商城mp4(java+VUE+Mybatis+Maven+Mysql)

热门文章

  1. mysql connector配置_mysql connector odbc配置注意事项
  2. SliceProceduralMesh的使用
  3. bbb sd6 无e2 修改
  4. azm335x 串口配置
  5. xLite连接asterisk提示sip408错误
  6. Windows Embedded CE 6.0开发初体验(六)平台定制
  7. Android usb 权限广播,[Android]USB开发
  8. 单体预聚合的目的是什么_高分子化学实验指导书-修改-2012
  9. LSGO软件技术团队2015~2016学年第九周(1026~1101)总结
  10. 【转】Postman系列一:Postman安装及使用过程中遇到的问题