欢迎直接到我的博客查看最近文章:www.pkudodo.com。更新会比较快,评论回复我也能比较快看见,排版也会更好一点。

原始blog链接: http://www.pkudodo.com/2018/11/19/1-2/

前言

《统计学习方法》一书在前几天正式看完,由于这本书在一定程度上对于初学者是有一些难度的,趁着热乎劲把自己走过的弯路都写出来,后人也能走得更顺畅一点。

以下是我的github地址,其中有《统计学习方法》书中所有涉及到的算法的实现,也是在写博客的同时编写的。在编写宗旨上是希望所有人看完书以后,按照程序的思路再配合书中的公式,能知道怎么讲所学的都应用上。(有点自恋地讲)程序中的注释可能比大部分的博客内容都多。希望大家能够多多捧场,如果可以的话,为我的github打颗星,也能让更多的人看到。

github:

GitHub|手写实现李航《统计学习方法》书中全部算法

相关博文:

  1. 统计学习方法|感知机原理剖析及实现
  2. 统计学习方法|K近邻原理剖析及实现
  3. 统计学习方法|朴素贝叶斯原理剖析及实现
  4. 统计学习方法|决策树原理剖析及实现
  5. 统计学习方法|逻辑斯蒂原理剖析及实现

正文

K近邻的直观理解

我们先看一张图:

在感知机一节中我们看过这张图,当时使用划分超平面的方式来划分数据。那么除了划线,还有别的方式来划分数据吗?

观察一下,黄点和蓝点代表了两种标签,比如每个蓝点都是一个合格的产品,黄点是劣质的产品。事实上在图中可以看到,相同标记的样本点通常是成团的形式聚在一起,因为合格的产品在属性上一定是相同或相似的(合格的产品在属性上不太可能会跑到不合格的一类中去)。

那么我们预测过程中,查看被预测的样本x是属于哪一堆来判断它是黄豆还是蓝豆是不是可行呢?

当然可以啦,K近邻就是一种基于该原理的算法。从名字里就可以看到,K近邻样本的预测上,是看被预测样本x离哪一团最近,那它就是属于哪一类的。

接下来我们正儿八经地来看一下K近邻,看下面这张图。

图里有两种标记,就叫它黄豆、绿豆、紫豆好了(这里也能看到,K近邻不像感知机只能划分两种类,K近邻是一种多类划分的模型)。当我们要预测一个样本x时,将x的特征转换为在图中的坐标点,分别计算和绿豆、黄豆、紫豆的举例,比如说距离分别为1, 1.5, 0.8。选择其中距离最小的紫豆作为样本x的预测类别。

那啥是样本x和一团豆的距离?

1.找到样本x最近的点,该点的类就是样本的预测类:这是一种方法,但是如果有噪音呢(同一块区域又有黄点又有绿点)?比如说x实际上是黄豆,但是它的位置在黄豆和绿豆的边界上(例如上图黄点和绿点的交叉处),很可能它最近的点是一个绿点,所以....不太好

2.与每一团的中心点进行距离计算:分别计算绿色、黄色、紫色的中心点,判断距离最小的类即为预测输出类。这样会不会有问题吗?我们看一下上图中绿色和紫色交叉的地方,很明显在这个交叉位置离绿色很近,与紫色中心点较远,但实际上紫色的。所以.....不太好

3.找到样本点周围的K个点,其中占数目最多的类即预测输出的类:克服了前两种方法的弊端,实际上这就是K近邻所使用的算法

感知机的数学角度(配合《统计学习方法》食用更佳)

算法剖析

K近邻并没有显式的学习过程,也就是不需要对训练集进行学习。预测过程中直接遍历预测点与所有点的距离,并找到最近的K个点即可。找到K个最近点后,使用多数表决(即投票)的方式确定预测点的类别。式3.1I(yi=ci)中的I为指示函数,当括号内条件为真时I(true)=1,I(false)=0。argmax表示令后式数值最大时的参数,例如argmax(-X^2 + 1)的结果为0,因为x=0时-X^2 + 1结果为1,为最大值。

式3.1表示对于每一个类Cj,进行I(yi=cj)进行求和,就是计算这K个点中有多少个标记为Cj的点,例如K=25,一共有四个类分别为C1、C2、C3、C4,25个点中他们的个数分别有10、5、1、9个,那么最多数目的类别C1就是样本点的预测类别。

距离度量

在多维空间中有很多种方式可以计算点与点之间的举例,通常采用欧氏距离作为K近邻的度量单位(大部分模型中欧氏距离都是一种不错的选择)。其实就是样本A、B中每一个特征都互相相减,再平方、再求和。与二维中两点之间距离计算方式相同,只是扩展到了多维。

曼哈顿与P=无穷可以不用深究,在本文中使用曼哈顿准确度极差(仅针对Mnist数据集使用K近邻的情况),这两种方式目前仅作了解即可。

K近邻算法缺点:

1、在预测样本类别时,待预测样本需要与训练集中所有样本计算距离,当训练集数量过高时(例如Mnsit训练集有60000个样本),每预测一个样本都要计算60000个距离,计算代价过高,尤其当测试集数目也较大时(Mnist测试集有10000个)。

1、K近邻在高维情况下时(高维在机器学习中并不少见),待预测样本需要与依次与所有样本求距离。向量维度过高时使得欧式距离的计算变得不太迅速了。本文在60000训练集的情况下,将10000个测试集缩减为200个,整个过程仍然需要308秒(曼哈顿距离为246秒,但准确度大幅下降)。

使用欧氏距离还是曼哈顿距离,性能上的差别相对来说不是很大,说明欧式距离并不是制约计算速度的主要方式。最主要的是训练集的大小,每次预测都需要与60000个样本进行比对,同时选出距离最近的K项。

为了解决这一问题,前人提出了KD树算法。

KD树

KD树将整个特征空间划分成多个区域,直观上来看,首先将整个空间分成A、B区域,待测样本判断在A区的时候,那B区过远,内部的点就不需要再判断了,大幅度减少需要比较的样本数量。

但遗憾的是作者暂时没有对KD树进行实现,仅仅在理论层面上讲解并不是我所希望的。所以各位同学请耐心等待,等我将其实现后,写出来的内容可能会更有深度一些。

贴代码:(建议去本文最上方的github链接下载,有书中所有算法的实现以及详细注释)

#coding=utf-8
#Author:Dodo
#Date:2018-11-16
#Email:lvtengchao@pku.edu.cn'''
数据集:Mnist
训练集数量:60000
测试集数量:10000(实际使用:200)
------------------------------
运行结果:(邻近k数量:25)
向量距离使用算法——欧式距离正确率:97%运行时长:308s
向量距离使用算法——曼哈顿距离正确率:14%运行时长:246s
'''import numpy as np
import timedef loadData(fileName):'''加载文件:param fileName:要加载的文件路径:return: 数据集和标签集'''print('start read file')#存放数据及标记dataArr = []; labelArr = []#读取文件fr = open(fileName)#遍历文件中的每一行for line in fr.readlines():#获取当前行,并按“,”切割成字段放入列表中#strip:去掉每行字符串首尾指定的字符(默认空格或换行符)#split:按照指定的字符将字符串切割成每个字段,返回列表形式curLine = line.strip().split(',')#将每行中除标记外的数据放入数据集中(curLine[0]为标记信息)#在放入的同时将原先字符串形式的数据转换为整型dataArr.append([int(num) for num in curLine[1:]])#将标记信息放入标记集中#放入的同时将标记转换为整型labelArr.append(int(curLine[0]))#返回数据集和标记return dataArr, labelArrdef calcDist(x1, x2):'''计算两个样本点向量之间的距离使用的是欧氏距离,即 样本点每个元素相减的平方  再求和  再开方欧式举例公式这里不方便写,可以百度或谷歌欧式距离(也称欧几里得距离):param x1:向量1:param x2:向量2:return:向量之间的欧式距离'''return np.sqrt(np.sum(np.square(x1 - x2)))#马哈顿距离计算公式# return np.sum(x1 - x2)def getClosest(trainDataMat, trainLabelMat, x, topK):'''预测样本x的标记。获取方式通过找到与样本x最近的topK个点,并查看它们的标签。查找里面占某类标签最多的那类标签(书中3.1 3.2节):param trainDataMat:训练集数据集:param trainLabelMat:训练集标签集:param x:要预测的样本x:param topK:选择参考最邻近样本的数目(样本数目的选择关系到正确率,详看3.2.3 K值的选择):return:预测的标记'''#建立一个存放向量x与每个训练集中样本距离的列表#列表的长度为训练集的长度,distList[i]表示x与训练集中第## i个样本的距离distList = [0] * len(trainLabelMat)#遍历训练集中所有的样本点,计算与x的距离for i in range(len(trainDataMat)):#获取训练集中当前样本的向量x1 = trainDataMat[i]#计算向量x与训练集样本x的距离curDist = calcDist(x1, x)#将距离放入对应的列表位置中distList[i] = curDist#对距离列表进行排序#argsort:函数将数组的值从小到大排序后,并按照其相对应的索引值输出#例如:#   >>> x = np.array([3, 1, 2])#   >>> np.argsort(x)#   array([1, 2, 0])#返回的是列表中从小到大的元素索引值,对于我们这种需要查找最小距离的情况来说很合适#array返回的是整个索引值列表,我们通过[:topK]取列表中前topL个放入list中。#----------------优化点-------------------#由于我们只取topK小的元素索引值,所以其实不需要对整个列表进行排序,而argsort是对整个#列表进行排序的,存在时间上的浪费。字典有现成的方法可以只排序top大或top小,可以自行查阅#对代码进行稍稍修改即可#这里没有对其进行优化主要原因是KNN的时间耗费大头在计算向量与向量之间的距离上,由于向量高维#所以计算时间需要很长,所以如果要提升时间,在这里优化的意义不大。(当然不是说就可以不优化了,#主要是我太懒了)topKList = np.argsort(np.array(distList))[:topK]        #升序排序#建立一个长度时的列表,用于选择数量最多的标记#3.2.4提到了分类决策使用的是投票表决,topK个标记每人有一票,在数组中每个标记代表的位置中投入#自己对应的地方,随后进行唱票选择最高票的标记labelList = [0] * 10#对topK个索引进行遍历for index in topKList:#trainLabelMat[index]:在训练集标签中寻找topK元素索引对应的标记#int(trainLabelMat[index]):将标记转换为int(实际上已经是int了,但是不int的话,报错)#labelList[int(trainLabelMat[index])]:找到标记在labelList中对应的位置#最后加1,表示投了一票labelList[int(trainLabelMat[index])] += 1#max(labelList):找到选票箱中票数最多的票数值#labelList.index(max(labelList)):再根据最大值在列表中找到该值对应的索引,等同于预测的标记return labelList.index(max(labelList))def test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, topK):'''测试正确率:param trainDataArr:训练集数据集:param trainLabelArr: 训练集标记:param testDataArr: 测试集数据集:param testLabelArr: 测试集标记:param topK: 选择多少个邻近点参考:return: 正确率'''print('start test')#将所有列表转换为矩阵形式,方便运算trainDataMat = np.mat(trainDataArr); trainLabelMat = np.mat(trainLabelArr).TtestDataMat = np.mat(testDataArr); testLabelMat = np.mat(testLabelArr).T#错误值技术errorCnt = 0#遍历测试集,对每个测试集样本进行测试#由于计算向量与向量之间的时间耗费太大,测试集有6000个样本,所以这里人为改成了#测试200个样本点,如果要全跑,将行注释取消,再下一行for注释即可,同时下面的print#和return也要相应的更换注释行# for i in range(len(testDataMat)):for i in range(200):# print('test %d:%d'%(i, len(trainDataArr)))print('test %d:%d' % (i, 200))#读取测试集当前测试样本的向量x = testDataMat[i]#获取预测的标记y = getClosest(trainDataMat, trainLabelMat, x, topK)#如果预测标记与实际标记不符,错误值计数加1if y != testLabelMat[i]: errorCnt += 1#返回正确率# return 1 - (errorCnt / len(testDataMat))return 1 - (errorCnt / 200)if __name__ == "__main__":start = time.time()#获取训练集trainDataArr, trainLabelArr = loadData('../Mnist/mnist_train.csv')#获取测试集testDataArr, testLabelArr = loadData('../Mnist/mnist_test.csv')#计算测试集正确率accur = test(trainDataArr, trainLabelArr, testDataArr, testLabelArr, 25)#打印正确率print('accur is:%d'%(accur * 100), '%')end = time.time()#显示花费时间
print('time span:', end - start)

统计学习方法|K近邻原理剖析及实现相关推荐

  1. 统计学习方法|支持向量机(SVM)原理剖析及实现

    欢迎直接到我的博客查看最近文章:www.pkudodo.com.更新会比较快,评论回复我也能比较快看见,排版也会更好一点. 原始blog链接: http://www.pkudodo.com/2018/ ...

  2. 统计学习方法——K近邻模型

    0. 写在前面 在这一讲的讨论班中,我们将要讨论一下K近邻模型.可能有人会说,K近邻模型有什么好写的,那分明就是一个最简单的机器学习模型,哦,不,连机器学习也算不上的算法吧.但是这里,我想提醒的是,我 ...

  3. 统计学习方法 | K 近邻法

    一.简介 1.直观理解 定义:是一种基本的分类与回归方法 主要思想:假定给定一个训练数据集,其中实例标签已定,当输入新的实例时,可以根据其最近的K个训练实例的标签,预测新实例对应的标注信息 分类问题: ...

  4. 复习03统计学习方法(K近邻KNN)---图片版

  5. 《统计学习方法》—— 感知机原理、推导以及python3代码实现(一)

    前言 感知机是<统计学习方法>介绍的第一个算法,它解决的也是最基本的问题,即,面对已经标记的数据,如何根据标记将它们区分开来. 本文将从感知机问题的来源.感知机推导以及感知机的python ...

  6. 数据挖掘算法一:K近邻原理及python实现

    1 原理 1.1 绪论 k近邻算法(k-nearest neighbor, k-NN)是一种基本分类与回归方法,多用于分类.其输入为实例的特征向量,对应于特征空间中的点:输出为实例的类别,可以取多类. ...

  7. 统计学习方法 --- 感知机模型原理及c++实现

    参考博客 Liam Q博客 和李航的<统计学习方法> 感知机学习旨在求出将训练数据集进行线性划分的分类超平面,为此,导入了基于误分类的损失函数,然后利用梯度下降法对损失函数进行极小化,从而 ...

  8. AI算法连载10:统计之k 近邻法

    导语:在人工智能AI如火如荼的大潮下,越来越多的工程师们意识到算法是AI的核心.而面对落地的应用,不懂算法的AI产品经理将是空谈,不仅无法与工程师沟通,更无法深刻理解应用的性能与方式.所以业界逐渐形成 ...

  9. 统计学习方法|逻辑斯蒂原理剖析及实现

    欢迎直接到我的博客查看最近文章:www.pkudodo.com.更新会比较快,评论回复我也能比较快看见,排版也会更好一点. 原始blog链接: http://www.pkudodo.com/2018/ ...

最新文章

  1. 知识图谱query与文本相似性如何处理
  2. 第三次冲刺12.16
  3. 新浪短链api java_php调用新浪短链接API的方法
  4. Latex论文排版技巧再总结
  5. 2019牛客暑期多校训练营(第七场)
  6. 信息学奥赛一本通——1000:入门测试题目
  7. docker安装eclipse che
  8. python最简单的画图代码
  9. 锐捷无线ap服务器怎么绑定mac,MacBook中的锐捷设置全攻略
  10. WebView文件下载
  11. keil编译器CODE、RO、RW、ZI的含义
  12. 3-5 单链表分段逆转
  13. Java编程那些事儿78——时间和日期处理
  14. 关闭计算机睡眠模式,电脑睡眠模式怎么关闭
  15. 图片切割 - 九宫格
  16. ubuntu 回到根目录,回到上一级 常用指令
  17. 信息学奥赛C++语言:蛋糕
  18. 写论文难,提纲该怎么写?
  19. pve远程连接 spcie_PVE相关 篇一:解决CX341a PVE 中报错PCIe Bus Error
  20. MongoDB学习笔记总结(含报错、问题、技巧)

热门文章

  1. LeetCode Maximal Square(最大子矩阵)
  2. 理解可变参数va_list、va_start、va_arg、va_end原理及使用方法
  3. UVa10000 - Longest Paths(为什么是WA)
  4. 题目1022:游船出租
  5. 基于Vue的WebApp项目开发(四)
  6. Razor @Html.Raw()的作用
  7. 利用MOG2背景模型提取运动目标的OpenCV代码
  8. linux下diff、patch制作补丁
  9. jquery easyui datagrid 分页 详解
  10. leetcode算法题--两句话中的不常见单词