什么是KD树

要说KD树,我们得先说一下什么是KNN算法。
KNN是k-NearestNeighbor的简称,原理很简单:当你有一堆已经标注好的数据时,你知道哪些是正类,哪些是负类。当新拿到一个没有标注的数据时,你想知道它是哪一类的。只要找到它的邻居(离它距离短)的点是什么类别的,所谓近朱者赤近墨者黑,KNN就是采用了类似的方法。

如上图,当有新的点不知道是哪一类时,只要看看离它最近的几个点是什么类别,我们就判断它是什么类别。

举个例子:我们将k取3(就是每次看看新来的数据点的三个住的最近的邻居),那么我们将所有数据点和新来的数据点计算一次距离,然后排序,取前三个数据点,让它们举手表决。两票及以上的类别我们就认为是新的数据点的类别。

很简单也很好的想法,但是,我们要注意到当测试集数据比较大时,由于每次未标注的数据点都要和全部的已标注的数据点进行一次距离计算,然后排序。可以说时间开销非常大。我们在此基础上,想到了一种存储点与点之间关系的算法来通过空间换时间。
有一篇博文写KD树还不错
点击此处查看
举个例子:有一个二维的数据集: T={(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}

通过你已经学习的KD树的算法,按照依次选择维度,取各维中位数,是否得出和下面一样的KD树?


异常检测

我们的数据来自于KDD Cup 1999 Data 点我下载数据

数据格式如下图

数据的含义如下:

我们这次实验针对正常和DDOS攻击两种情况进行检测。
取特征范围为(1,9)U(22,31)的特征中的数值型特征,最终得到16维的特征向量。将数据随机化处理后按照7:3的比例分成训练集和测试集。下面是我们组做好的训练集和测试集
点我下载

处理完的训练集和测试集如下图:

下面是具体实现:

# coding=utf8
import math
import time
import matplotlib.pyplot as plt
import numpy as npold_settings = np.seterr(all='ignore')#定义节点类型
class KD_node:def __init__(self, point=None, split=None, left=None, right=None):''':param point: 数据点的特征向量:param split: 切分的维度:param left: 左儿子:param right: 右儿子'''self.point = pointself.split = splitself.left = leftself.right = right

计算方差,以利用方差大小进行判断在哪一维进行切分

def computeVariance(arrayList):''':param arrayList: 所有数据某一维的向量:return: 返回'''for ele in arrayList:ele = float(ele)LEN = float(len(arrayList))array = np.array(arrayList)sum1 = float(array.sum())array2 = np.dot(array, array.T)sum2 = float(array2.sum())mean = sum1 / LENvariance = sum2 / LEN - mean ** 2return variance

建树

def createKDTree(root, data_list):''':param root: 输入一个根节点,以此建树:param data_list: 数据列表:return: 返回根节点'''LEN = len(data_list)if LEN == 0:return# 数据点的维度dimension = len(data_list[0]) - 1 #去掉了最后一维标签维# 方差max_var = 0# 最后选择的划分域split = 0for i in range(dimension):ll = []for t in data_list:ll.append(t[i])var = computeVariance(ll) #计算出在这一维的方差大小if var > max_var:max_var = varsplit = i# 根据划分域的数据对数据点进行排序data_list = list(data_list)data_list.sort(key=lambda x: x[split]) #按照在切分维度上的大小进行排序data_list = np.array(data_list)# 选择下标为len / 2的点作为分割点point = data_list[LEN / 2]root = KD_node(point, split)root.left = createKDTree(root.left, data_list[0:(LEN / 2)])#递归的对切分到左儿子和右儿子的数据再建树root.right = createKDTree(root.right, data_list[(LEN / 2 + 1):LEN])return root
def computeDist(pt1, pt2):''':param pt1: 特征向量1:param pt2: 特征向量2:return: 两个向量的欧氏距离'''sum_dis = 0.0for i in range(len(pt1)):sum_dis += (pt1[i] - pt2[i]) ** 2#实现的欧氏距离计算,效率很低的版本,可以改成numpy的向量运算return math.sqrt(sum_dis)
def findNN(root, query):''':param root: 建立好的KD树的树根:param query: 查询数据:return: 与这个数据最近的前三个节点'''# 初始化为root的节点NN = root.pointmin_dist = computeDist(query, NN)nodeList = []temp_root = rootdist_list = [temp_root.point, None, None] #用来存储前三个节点##二分查找建立路径while temp_root:nodeList.append(temp_root) #对向下走的路径进行压栈处理dd = computeDist(query, temp_root.point) #计算当前最近节点和查询点的距离大小if min_dist > dd:NN = temp_root.pointmin_dist = dd# 当前节点的划分域temp_split = temp_root.splitif query[temp_split] <= temp_root.point[temp_split]:temp_root = temp_root.leftelse:temp_root = temp_root.right##回溯查找while nodeList:back_point = nodeList.pop()back_split = back_point.splitif dist_list[1] is None:dist_list[2] = dist_list[1]dist_list[1] = back_point.pointelif dist_list[2] is None:dist_list[2] = back_point.pointif abs(query[back_split] - back_point.point[back_split]) < min_dist: #当查询点和回溯点的距离小于当前最小距离时,另一个区域有希望存在更近的节点#如果大于这个距离,可以理解为假设在二维空间上,直角三角形的直角边已经不满足要求了,那么斜边也一定不满足要求if query[back_split] < back_point.point[back_split]:temp_root = back_point.rightelse:temp_root = back_point.leftif temp_root:nodeList.append(temp_root)curDist = computeDist(query, temp_root.point)if min_dist > curDist:min_dist = curDistdist_list[2] = dist_list[1]dist_list[1] = dist_list[0]dist_list[0] = temp_root.pointelif dist_list[1] is None or curDist < computeDist(dist_list[1], query):dist_list[2] = dist_list[1]dist_list[1] = temp_root.pointelif dist_list[2] is None or curDist < computeDist(dist_list[1], query):dist_list[2] = temp_root.pointreturn dist_list

进行判断

def judge_if_normal(dist_list):''':param dist_list: 利用findNN查找出的最近三个节点进行投票表决:return: '''normal_times = 0except_times = 0for i in dist_list:if abs(i[-1] - 0.0) < 1e-7: #浮点数的比较normal_times += 1else:except_times += 1if normal_times > except_times: #判断是normalreturn Trueelse:return False

数据预处理

def pre_data(path):f = open(path)lines = f.readlines()f.close()lstall = []for line in lines:lstn = []lst = line.split(",")u = 0y = 0for i in range(0, 9):if lst[i].isdigit():lstn.append(float(lst[i]))u += 1else:passfor j in range(21, 31):try:lstn.append(float(lst[j]))y += 1except:passif lst[len(lst) - 1] == "smurf.\n" or lst[len(lst) - 1] == "teardrop.\n":lstn.append(int("1"))else:lstn.append(int("0"))lstall.append(lstn)nplst = np.array(lstall, dtype=np.float16)return nplst

下面就是个人的测试代码了,大概运行了40分钟才全跑完

def my_test(all_train_data, all_test_data, train_data_num):train_data = all_train_data[:train_data_num]train_time_start = time.time()root = KD_node()root = createKDTree(root, train_data)train_time_end = time.time()train_time = train_time_end - train_time_startright = 0error = 0test_time_start = time.time()for i in range(len(all_test_data)):if judge_if_normal(findNN(root, all_test_data[i])) is True and abs(all_test_data[i][-1] - 0.0) < 1e-7:right += 1elif judge_if_normal(findNN(root, all_test_data[i])) is False and abs(all_test_data[i][-1] - 1.0) < 1e-7:right += 1else:error += 1test_time_end = time.time()test_time = test_time_end - test_time_startright_ratio = float(right) / (right + error)return right_ratio, train_time, test_timedef draw(train_num_list=[10, 100, 1000, 10000], train_data=[], test_data=[]):train_time_list = []test_time_list = []right_ratio_list = []for i in train_num_list:print 'start run ' + i.__str__()temp = my_test(train_data, test_data, i)right_ratio_list.append(temp[0])train_time_list.append(temp[1])test_time_list.append(temp[2])plt.title('train data num from ' + train_num_list[0].__str__() + ' to ' + train_num_list[:-1].__str__())plt.subplot(311)plt.plot(train_num_list, right_ratio_list, c='b')plt.xlabel('train data num')plt.ylabel('right ratio')plt.grid(True)plt.subplot(312)plt.plot(train_num_list, train_time_list, c='r')plt.xlabel('train data num')plt.ylabel('time of train data (s)')plt.grid(True)plt.subplot(313)plt.plot(train_num_list, test_time_list, c='g')plt.xlabel('train data num')plt.ylabel('time of test data (s)')plt.grid(True)plt.show()data = pre_data('KDD-test\ddos+normal_70.txt')
data2 = pre_data('KDD-test\ddos+normal_30.txt')
'''
建议开始将测试数据调小点,因为时间很长,下面这是全部训练集和全部测试集,共花费了40分钟才跑完。我是第六代i7 6700HQ+16G内存+1070+win10
'''
draw(train_num_list=[10, 100, 500, 1000, 2000, 3000, 5000, 10000, 15000, 20000, 50000, 100000, 265300],train_data=data[:], test_data=data2[:])

跑完的效果如图所示:

正确率最终达到95%以上。开始出现的波动我们怀疑是数据在开始没有达到良好的随机效果。
训练时间与训练数据量明显成线性关系
测试时间确实和理论一致,是Nlog(M)的时间复杂度。应该说这种时间复杂度的降低是我们使用KD树而不是原版的KNN最重要的地方。在原来的KNN算法下,假设训练集大小为M,测试集大小为N,则查询时间复杂度可以达到O(MN),但是我们降低到O(Nlog(M)),还是挺合算的。

本次实验可以优化的地方很多,但是时间匆忙,没有做更深的扩展。欢迎大家提出更多建议。

转载于:https://www.cnblogs.com/chuxiuhong/p/5982580.html

利用KD树进行异常检测相关推荐

  1. 利用kd树进行平面内最近点搜索

    要求: 对平面内每一个点都找到离它最近的点.通常想法是knn,那么由于是对n个点求,直接算的话就是n平方. 我们利用kd树来搜索,那么复杂度就变为了nlogn,甚至说由于a与b距离最小等价于b与a最小 ...

  2. 在浏览器中进行深度学习:TensorFlow.js (十二)异常检测算法

    2019独角兽企业重金招聘Python工程师标准>>> 异常检测是机器学习领域常见的应用场景,例如金融领域里的信用卡欺诈,企业安全领域里的非法入侵,IT运维里预测设备的维护时间点等. ...

  3. 机器学习算法(二十五):KD树详解及KD树最近邻算法

    目录 1 KD树 1.1 什么是KD树 1.2 KD树的构建 1.3 KD树的插入 1.4 KD树的删除 1.5 KD树的最近邻搜索算法 1.5.1 举例:查询点(2.1,3.1) 1.5.2 举例: ...

  4. KD树详解及KD树最近邻算法

    参考:http://blog.csdn.net/app_12062011/article/details/51986805 http://www.cnblogs.com/snake-hand/arch ...

  5. 异常检测(Anomaly detection)方法小结

    文章目录 一.基于规则处理 二.基于统计学处理 三.基于机器学习处理 3.1 无监督方法 3.2 半监督方法 3.3 有监督方法 四.数据类型 4.1. 高维数据 4.2. 时间序列数据 4.3. 文 ...

  6. 浅谈KNN算法原理及python程序简单实现、KD树、球树

    最近比较空闲,打算利用这一段时间理一下机器学习的一些常见的算法.第一个是KNN算法: KNN 1.原理: KNN,K-NearestNeighbor---K最近邻 K最近邻,就是K个最近的邻居的意思, ...

  7. k-d树+bbf算法的介绍与实现

    最近还是一直在研究SIFT算法,而SIFT特征点匹配是一个比较经典的问题,使用暴力匹配的话确实可以得到结果,但是运行速度较慢.我的计算机处理是i5的二代系列,匹配两张各检测有2000+个SIFT特征点 ...

  8. 基于深度模型的日志序列异常检测

    基于深度模型的日志序列异常检测 ​ 日志异常检测的核心是利用人工智能算法自动分析系统日志来发现并定位故障.根据送入检测模型的数据格式,日志异常检测算法模型分为序列模型和频率模型,其中序列模型又可以分为 ...

  9. 从K近邻算法、距离度量谈到KD树、SIFT+BBF算法

    原文出自:http://blog.csdn.net/v_JULY_v/article/details/8203674 前言 前两日,在微博上说:"到今天为止,我至少亏欠了3篇文章待写:1.K ...

最新文章

  1. python——文件和数据格式化
  2. 化整为零,一步一步教你搭建Prometheus监控报警系统
  3. RedrawWindow, UpdateWindow,InvalidateRect 用法
  4. 夫曼编码译码系统课程设计实验报告(含源代码c++_c语言),哈夫曼编码译码系统课程设计实验报告(含源代码C++_C语言)[1]...
  5. 小汤学编程之JavaEE学习day08——Maven
  6. 01 安装ansible
  7. 微信发ascii_微信公众平台开发(107) 分享到朋友圈和发送给好友
  8. 转 在SQL Server中创建用户角色及授权(使用SQL语句)
  9. Ra-08系列开发板入门教程,标准LoRaWAN对接私有服务器。
  10. android 视频播放器 加载字幕,Android 实现视频字幕Subtitle和横竖屏切换示例
  11. 新商用密码产品认证梳理——参考资料篇
  12. Kotlin基础学习 17
  13. 谷歌 不支持 activeX插件
  14. DLP4500 投影图案(一)
  15. Magento2 入门指南(新手必读)
  16. C++ 哈希表及unordered_set + unordered_map容器
  17. 【OpenCV-Python】教程:3-1 颜色空间转换与抠图
  18. Web应用接入Github登录
  19. 自建邮件系统的三大优势
  20. Excel表格如何转换成pdf格式

热门文章

  1. 改变层级_3DMAX基础,可编辑多边形层级介绍及概念
  2. linux命令:sosreport
  3. list.action.php,教你利用 PHP 实现高性能微服务部署
  4. 浏览器字体大小设置_全新内核 Edge 浏览器来了,这回或许能成为你的真 · 默认浏览器...
  5. Docker使用中遇到的问题
  6. dubbo 无法访问消费端_Dubbo最佳实践,我整理了以下9点
  7. 安卓设置菊花动画_Android仿ios加载loading菊花图效果
  8. js date转string_JS之你到底是什么类型?
  9. 赞!Google 资深软件工程师 LeetCode 刷题笔记首次公开
  10. java中数据结构的应用_Java集合入门 (二)常用数据结构和应用场景-数组