具体描述见《统计学习方法》第三章。

  1 //
  2 //  main.cpp
  3 //  kNN
  4 //
  5 //  Created by feng on 15/10/24.
  6 //  Copyright © 2015年 ttcn. All rights reserved.
  7 //
  8
  9 #include <iostream>
 10 #include <vector>
 11 #include <algorithm>
 12 #include <cmath>
 13 using namespace std;
 14
 15 template<typename T>
 16 struct KdTree {
 17     // ctor
 18     KdTree():parent(nullptr), leftChild(nullptr), rightChild(nullptr) {}
 19
 20     // KdTree是否为空
 21     bool isEmpty() { return root.empty(); }
 22
 23     // KdTree是否为叶子节点
 24     bool isLeaf() { return !root.empty() && !leftChild && !rightChild;}
 25
 26     // KdTree是否为根节点
 27     bool isRoot() { return !isEmpty() && !parent;}
 28
 29     // 判断KdTree是否为根节点的左儿子
 30     bool isLeft() { return parent->leftChild->root == root; }
 31
 32     // 判断KdTree是否为根节点的右儿子
 33     bool isRight() { return parent->rightChild->root == root; }
 34
 35     // 存放根节点的数据
 36     vector<T> root;
 37
 38     // 父节点
 39     KdTree<T> *parent;
 40
 41     // 左儿子
 42     KdTree<T> *leftChild;
 43
 44     // 右儿子
 45     KdTree<T> *rightChild;
 46 };
 47
 48
 49 /**
 50  *  矩阵转置
 51  *
 52  *  @param matrix 原矩阵
 53  *
 54  *  @return 原矩阵的转置矩阵
 55  */
 56 template<typename T>
 57 vector<vector<T>> transpose(const vector<vector<T>> &matrix) {
 58     size_t rows = matrix.size();
 59     size_t cols = matrix[0].size();
 60     vector<vector<T>> trans(cols, vector<T>(rows, 0));
 61     for (size_t i = 0; i < cols; ++i) {
 62         for (size_t j = 0; j < rows; ++j) {
 63             trans[i][j] = matrix[j][i];
 64         }
 65     }
 66
 67     return trans;
 68 }
 69
 70 /**
 71  *  找中位数
 72  *
 73  *  @param vec 数组
 74  *
 75  *  @return 数组中的中位数
 76  */
 77 template<typename T>
 78 T findMiddleValue(vector<T> vec) {
 79     sort(vec.begin(), vec.end());
 80     size_t pos = vec.size() / 2;
 81     return vec[pos];
 82 }
 83
 84 /**
 85  *  递归构造KdTree
 86  *
 87  *  @param tree  KdTree根节点
 88  *  @param data  数据矩阵
 89  *  @param depth 当前节点深度
 90  *
 91  *  @return void
 92  */
 93 template<typename T>
 94 void buildKdTree(KdTree<T> *tree, vector<vector<T>> &data, size_t depth) {
 95     // 输入数据个数
 96     size_t samplesNum = data.size();
 97
 98     if (samplesNum == 0) {
 99         return;
100     }
101
102     if (samplesNum == 1) {
103         tree->root = data[0];
104         return;
105     }
106
107     // 每一个输入数据的维度,属性个数
108     size_t k = data[0].size();
109     vector<vector<T>> transData = transpose(data);
110
111     // 找到当前切分点
112     size_t splitAttributeIndex = depth % k;
113     vector<T> splitAttributes = transData[splitAttributeIndex];
114     T splitValue = findMiddleValue(splitAttributes);
115
116     vector<vector<T>> leftSubSet;
117     vector<vector<T>> rightSubset;
118
119     for (size_t i = 0; i < samplesNum; ++i) {
120         if (splitAttributes[i] == splitValue && tree->isEmpty()) {
121             tree->root = data[i];
122         } else if (splitAttributes[i] < splitValue) {
123             leftSubSet.push_back(data[i]);
124         } else {
125             rightSubset.push_back(data[i]);
126         }
127     }
128
129     tree->leftChild = new KdTree<T>;
130     tree->leftChild->parent = tree;
131     tree->rightChild = new KdTree<T>;
132     tree->rightChild->parent = tree;
133     buildKdTree(tree->leftChild, leftSubSet, depth + 1);
134     buildKdTree(tree->rightChild, rightSubset, depth + 1);
135 }
136
137 /**
138  *  递归打印KdTree
139  *
140  *  @param tree  KdTree
141  *  @param depth 当前深度
142  *
143  *  @return void
144  */
145 template<typename T>
146 void printKdTree(const KdTree<T> *tree, size_t depth) {
147     for (size_t i = 0; i < depth; ++i) {
148         cout << "\t";
149     }
150
151     for (size_t i = 0; i < tree->root.size(); ++i) {
152         cout << tree->root[i] << " ";
153     }
154     cout << endl;
155
156     if (tree->leftChild == nullptr && tree->rightChild == nullptr) {
157         return;
158     } else {
159         if (tree->leftChild) {
160             for (int i = 0; i < depth + 1; ++i) {
161                 cout << "\t";
162             }
163             cout << "left : ";
164             printKdTree(tree->leftChild, depth + 1);
165         }
166
167         cout << endl;
168
169         if (tree->rightChild) {
170             for (size_t i = 0; i < depth + 1; ++i) {
171                 cout << "\t";
172             }
173             cout << "right : ";
174             printKdTree(tree->rightChild, depth + 1);
175         }
176         cout << endl;
177     }
178 }
179
180 /**
181  *  节点之间的欧氏距离
182  *
183  *  @param p1 节点1
184  *  @param p2 节点2
185  *
186  *  @return 节点之间的欧式距离
187  */
188 template<typename T>
189 T calDistance(const vector<T> &p1, const vector<T> &p2) {
190     T res = 0;
191     for (size_t i = 0; i < p1.size(); ++i) {
192         res += pow(p1[i] - p2[i], 2);
193     }
194
195     return res;
196 }
197
198 /**
199  *  搜索目标节点的最近邻
200  *
201  *  @param tree KdTree
202  *  @param goal 待分类的节点
203  *
204  *  @return 最近邻节点
205  */
206 template <typename T>
207 vector<T> searchNearestNeighbor(KdTree<T> *tree, const vector<T> &goal ) {
208     // 节点数属性个数
209     size_t k = tree->root.size();
210     // 划分的索引
211     size_t d = 0;
212     KdTree<T> *currentTree = tree;
213     vector<T> currentNearest = currentTree->root;
214     // 找到目标节点的最叶节点
215     while (!currentTree->isLeaf()) {
216         size_t index = d % k;
217         if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index]) {
218             currentTree = currentTree->leftChild;
219         } else {
220             currentTree = currentTree->rightChild;
221         }
222
223         ++d;
224     }
225     currentNearest = currentTree->root;
226     T currentDistance = calDistance(goal, currentTree->root);
227
228     KdTree<T> *searchDistrict;
229     if (currentTree->isLeft()) {
230         if (!(currentTree->parent->rightChild)) {
231             searchDistrict = currentTree;
232         } else {
233             searchDistrict = currentTree->parent->rightChild;
234         }
235     } else {
236         searchDistrict = currentTree->parent->leftChild;
237     }
238
239     while (!(searchDistrict->parent)) {
240         T districtDistance = abs(goal[(d + 1) % k] - searchDistrict->parent->root[(d + 1) % k]);
241
242         if (districtDistance < currentDistance) {
243             T parentDistance = calDistance(goal, searchDistrict->parent->root);
244
245             if (parentDistance < currentDistance) {
246                 currentDistance = parentDistance;
247                 currentTree = searchDistrict->parent;
248                 currentNearest = currentTree->root;
249             }
250
251             if (!searchDistrict->isEmpty()) {
252                 T rootDistance = calDistance(goal, searchDistrict->root);
253                 if (rootDistance < currentDistance) {
254                     currentDistance = rootDistance;
255                     currentTree = searchDistrict;
256                     currentNearest = currentTree->root;
257                 }
258             }
259
260             if (!(searchDistrict->leftChild)) {
261                 T leftDistance = calDistance(goal, searchDistrict->leftChild->root);
262                 if (leftDistance < currentDistance) {
263                     currentDistance = leftDistance;
264                     currentTree = searchDistrict;
265                     currentNearest = currentTree->root;
266                 }
267             }
268
269             if (!(searchDistrict->rightChild)) {
270                 T rightDistance = calDistance(goal, searchDistrict->rightChild->root);
271                 if (rightDistance < currentDistance) {
272                     currentDistance = rightDistance;
273                     currentTree = searchDistrict;
274                     currentNearest = currentTree->root;
275                 }
276             }
277
278         }
279
280         if (!(searchDistrict->parent->parent)) {
281             searchDistrict = searchDistrict->parent->isLeft()? searchDistrict->parent->parent->rightChild : searchDistrict->parent->parent->leftChild;
282         } else {
283             searchDistrict = searchDistrict->parent;
284         }
285         ++d;
286     }
287
288     return currentNearest;
289 }
290
291 int main(int argc, const char * argv[]) {
292     vector<vector<double>> trainDataSet{{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
293     KdTree<double> *kdTree = new KdTree<double>;
294     buildKdTree(kdTree, trainDataSet, 0);
295     printKdTree(kdTree, 0);
296
297     vector<double> goal{3, 4.5};
298     vector<double> nearestNeighbor = searchNearestNeighbor(kdTree, goal);
299
300     for (auto i : nearestNeighbor) {
301         cout << i << " ";
302     }
303     cout << endl;
304
305     return 0;
306 }

转载于:https://www.cnblogs.com/skycore/p/4908873.html

[Machine Learning]kNN代码实现(Kd tree)相关推荐

  1. Machine Learning Summary

    Machine Learning Summary General Idea No Free Lunch Theorem (no "best") CV for complex par ...

  2. Unity in Machine Learning

    这里写目录标题 小尝试 自定义object 的属性 C# 访问修饰符 C# NameSpace Object运动 例1 例2 Python检测 + C# 第1种 第2种 Object 运动方式 小尝试 ...

  3. Machine Learning In Action 第二章学习笔记: kNN算法

    本文主要记录<Machine Learning In Action>中第二章的内容.书中以两个具体实例来介绍kNN(k nearest neighbors),分别是: 约会对象预测 手写数 ...

  4. 特征工程的宝典-《Feature Engineering for Machine Learning》翻译及代码实现

    由O'Reilly Media,Inc.出版的<Feature Engineering for Machine Learning>(国内译作<精通特征工程>)一书,可以说是特征 ...

  5. KNN算法与Kd树(转载+代码详细解释)

    最近邻法和k-近邻法 下面图片中只有三种豆,有三个豆是未知的种类,如何判定他们的种类? 提供一种思路,即:未知的豆离哪种豆最近就认为未知豆和该豆是同一种类.由此,我们引出最近邻算法的定义:为了判定未知 ...

  6. kd tree python_Python实现KNN与KDTree

    KNN算法: KNN的基本思想以及数据预处理等步骤就不介绍了,网上挑了两个写的比较完整有源码的博客. 利用KNN约会分类 KNN项目实战--改进约会网站的配对效果 KNN 代码 ''' Functio ...

  7. Machine Learning | (7) Scikit-learn的分类器算法-决策树(Decision Tree)

    Machine Learning | 机器学习简介 Machine Learning | (1) Scikit-learn与特征工程 Machine Learning | (2) sklearn数据集 ...

  8. 机器学习(Machine Learning)、深度学习(Deep Learning)、NLP面试中常考到的知识点和代码实现

    网址:https://github.com/NLP-LOVE/ML-NLP 此项目是机器学习(Machine Learning).深度学习(Deep Learning).NLP面试中常考到的知识点和代 ...

  9. 【Machine Learning 学习笔记】Stochastic Dual Coordinate Ascent for SVM 代码实现

    [Machine Learning 学习笔记]Stochastic Dual Coordinate Ascent for SVM 代码实现 通过本篇博客记录一下Stochastic Dual Coor ...

最新文章

  1. Spring Cloud(三)服务提供者 Eureka + 服务消费者(rest + Ribbon)
  2. mysql 查询调试_使用MySQL慢速查询日志进行调试
  3. php 编译安装 png.h,PHP编译安装时常见错误解决办法【大全】
  4. 获取节点及元素的代码
  5. python turtle画彩虹-Python基础实例——绘制彩虹(turtle库的应用)
  6. 一种无限循环轮播图的实现原理
  7. python 读取excel图片_如何用Python读取Excel中图片?
  8. python3.7适用的opencv_通过python3.7.3使用openCV截图一个区域
  9. python如何快速登记凭证_如何高效地翻凭证?
  10. 开放接口的安全验证方案(AES+RSA)
  11. invest模型的python安装方法,两种方法
  12. python 图像检索_深度学习图像检索
  13. EBS INV:事务处理
  14. 快手短视频直播间怎么提高人气热度,直播间冷启动是什么?
  15. html5限制拖拽区域怎么实现,html5怎么实现拖拽
  16. 连接HC-05与HC-06
  17. SSH密钥登录系统报错Permissions 0644 for ‘xxx‘ re too open
  18. SD卡驱动学习(一)
  19. 基于JSP的在线调查问卷系统
  20. 微信h5网页点击链接跳转到默认浏览器是怎么弄的?gdtool一键实现该方案

热门文章

  1. 1280*720P和1920*1080P的视频在25帧30帧50帧60帧时的参数
  2. 滴滴专车——司机提现流程
  3. 转载:iPhone 6 Plus 屏幕宽度问题 375 vs 414
  4. Winform开发框架之通用人员信息管理实现代码介绍
  5. NI Measurement Studio 打包问题的解决(原创)
  6. php register_shutdown_function
  7. Microsoft Visual Studio Learning Pack 自动生成流程图插件(转)
  8. RMAN 系列(二) ---- RMAN 设置和配置
  9. Spring4新特性——集成Bean Validation 1.1(JSR-349)到SpringMVC
  10. 优秀的Java程序员必须了解GC的工作原理