[Machine Learning]kNN代码实现(Kd tree)
具体描述见《统计学习方法》第三章。
1 // 2 // main.cpp 3 // kNN 4 // 5 // Created by feng on 15/10/24. 6 // Copyright © 2015年 ttcn. All rights reserved. 7 // 8 9 #include <iostream> 10 #include <vector> 11 #include <algorithm> 12 #include <cmath> 13 using namespace std; 14 15 template<typename T> 16 struct KdTree { 17 // ctor 18 KdTree():parent(nullptr), leftChild(nullptr), rightChild(nullptr) {} 19 20 // KdTree是否为空 21 bool isEmpty() { return root.empty(); } 22 23 // KdTree是否为叶子节点 24 bool isLeaf() { return !root.empty() && !leftChild && !rightChild;} 25 26 // KdTree是否为根节点 27 bool isRoot() { return !isEmpty() && !parent;} 28 29 // 判断KdTree是否为根节点的左儿子 30 bool isLeft() { return parent->leftChild->root == root; } 31 32 // 判断KdTree是否为根节点的右儿子 33 bool isRight() { return parent->rightChild->root == root; } 34 35 // 存放根节点的数据 36 vector<T> root; 37 38 // 父节点 39 KdTree<T> *parent; 40 41 // 左儿子 42 KdTree<T> *leftChild; 43 44 // 右儿子 45 KdTree<T> *rightChild; 46 }; 47 48 49 /** 50 * 矩阵转置 51 * 52 * @param matrix 原矩阵 53 * 54 * @return 原矩阵的转置矩阵 55 */ 56 template<typename T> 57 vector<vector<T>> transpose(const vector<vector<T>> &matrix) { 58 size_t rows = matrix.size(); 59 size_t cols = matrix[0].size(); 60 vector<vector<T>> trans(cols, vector<T>(rows, 0)); 61 for (size_t i = 0; i < cols; ++i) { 62 for (size_t j = 0; j < rows; ++j) { 63 trans[i][j] = matrix[j][i]; 64 } 65 } 66 67 return trans; 68 } 69 70 /** 71 * 找中位数 72 * 73 * @param vec 数组 74 * 75 * @return 数组中的中位数 76 */ 77 template<typename T> 78 T findMiddleValue(vector<T> vec) { 79 sort(vec.begin(), vec.end()); 80 size_t pos = vec.size() / 2; 81 return vec[pos]; 82 } 83 84 /** 85 * 递归构造KdTree 86 * 87 * @param tree KdTree根节点 88 * @param data 数据矩阵 89 * @param depth 当前节点深度 90 * 91 * @return void 92 */ 93 template<typename T> 94 void buildKdTree(KdTree<T> *tree, vector<vector<T>> &data, size_t depth) { 95 // 输入数据个数 96 size_t samplesNum = data.size(); 97 98 if (samplesNum == 0) { 99 return; 100 } 101 102 if (samplesNum == 1) { 103 tree->root = data[0]; 104 return; 105 } 106 107 // 每一个输入数据的维度,属性个数 108 size_t k = data[0].size(); 109 vector<vector<T>> transData = transpose(data); 110 111 // 找到当前切分点 112 size_t splitAttributeIndex = depth % k; 113 vector<T> splitAttributes = transData[splitAttributeIndex]; 114 T splitValue = findMiddleValue(splitAttributes); 115 116 vector<vector<T>> leftSubSet; 117 vector<vector<T>> rightSubset; 118 119 for (size_t i = 0; i < samplesNum; ++i) { 120 if (splitAttributes[i] == splitValue && tree->isEmpty()) { 121 tree->root = data[i]; 122 } else if (splitAttributes[i] < splitValue) { 123 leftSubSet.push_back(data[i]); 124 } else { 125 rightSubset.push_back(data[i]); 126 } 127 } 128 129 tree->leftChild = new KdTree<T>; 130 tree->leftChild->parent = tree; 131 tree->rightChild = new KdTree<T>; 132 tree->rightChild->parent = tree; 133 buildKdTree(tree->leftChild, leftSubSet, depth + 1); 134 buildKdTree(tree->rightChild, rightSubset, depth + 1); 135 } 136 137 /** 138 * 递归打印KdTree 139 * 140 * @param tree KdTree 141 * @param depth 当前深度 142 * 143 * @return void 144 */ 145 template<typename T> 146 void printKdTree(const KdTree<T> *tree, size_t depth) { 147 for (size_t i = 0; i < depth; ++i) { 148 cout << "\t"; 149 } 150 151 for (size_t i = 0; i < tree->root.size(); ++i) { 152 cout << tree->root[i] << " "; 153 } 154 cout << endl; 155 156 if (tree->leftChild == nullptr && tree->rightChild == nullptr) { 157 return; 158 } else { 159 if (tree->leftChild) { 160 for (int i = 0; i < depth + 1; ++i) { 161 cout << "\t"; 162 } 163 cout << "left : "; 164 printKdTree(tree->leftChild, depth + 1); 165 } 166 167 cout << endl; 168 169 if (tree->rightChild) { 170 for (size_t i = 0; i < depth + 1; ++i) { 171 cout << "\t"; 172 } 173 cout << "right : "; 174 printKdTree(tree->rightChild, depth + 1); 175 } 176 cout << endl; 177 } 178 } 179 180 /** 181 * 节点之间的欧氏距离 182 * 183 * @param p1 节点1 184 * @param p2 节点2 185 * 186 * @return 节点之间的欧式距离 187 */ 188 template<typename T> 189 T calDistance(const vector<T> &p1, const vector<T> &p2) { 190 T res = 0; 191 for (size_t i = 0; i < p1.size(); ++i) { 192 res += pow(p1[i] - p2[i], 2); 193 } 194 195 return res; 196 } 197 198 /** 199 * 搜索目标节点的最近邻 200 * 201 * @param tree KdTree 202 * @param goal 待分类的节点 203 * 204 * @return 最近邻节点 205 */ 206 template <typename T> 207 vector<T> searchNearestNeighbor(KdTree<T> *tree, const vector<T> &goal ) { 208 // 节点数属性个数 209 size_t k = tree->root.size(); 210 // 划分的索引 211 size_t d = 0; 212 KdTree<T> *currentTree = tree; 213 vector<T> currentNearest = currentTree->root; 214 // 找到目标节点的最叶节点 215 while (!currentTree->isLeaf()) { 216 size_t index = d % k; 217 if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index]) { 218 currentTree = currentTree->leftChild; 219 } else { 220 currentTree = currentTree->rightChild; 221 } 222 223 ++d; 224 } 225 currentNearest = currentTree->root; 226 T currentDistance = calDistance(goal, currentTree->root); 227 228 KdTree<T> *searchDistrict; 229 if (currentTree->isLeft()) { 230 if (!(currentTree->parent->rightChild)) { 231 searchDistrict = currentTree; 232 } else { 233 searchDistrict = currentTree->parent->rightChild; 234 } 235 } else { 236 searchDistrict = currentTree->parent->leftChild; 237 } 238 239 while (!(searchDistrict->parent)) { 240 T districtDistance = abs(goal[(d + 1) % k] - searchDistrict->parent->root[(d + 1) % k]); 241 242 if (districtDistance < currentDistance) { 243 T parentDistance = calDistance(goal, searchDistrict->parent->root); 244 245 if (parentDistance < currentDistance) { 246 currentDistance = parentDistance; 247 currentTree = searchDistrict->parent; 248 currentNearest = currentTree->root; 249 } 250 251 if (!searchDistrict->isEmpty()) { 252 T rootDistance = calDistance(goal, searchDistrict->root); 253 if (rootDistance < currentDistance) { 254 currentDistance = rootDistance; 255 currentTree = searchDistrict; 256 currentNearest = currentTree->root; 257 } 258 } 259 260 if (!(searchDistrict->leftChild)) { 261 T leftDistance = calDistance(goal, searchDistrict->leftChild->root); 262 if (leftDistance < currentDistance) { 263 currentDistance = leftDistance; 264 currentTree = searchDistrict; 265 currentNearest = currentTree->root; 266 } 267 } 268 269 if (!(searchDistrict->rightChild)) { 270 T rightDistance = calDistance(goal, searchDistrict->rightChild->root); 271 if (rightDistance < currentDistance) { 272 currentDistance = rightDistance; 273 currentTree = searchDistrict; 274 currentNearest = currentTree->root; 275 } 276 } 277 278 } 279 280 if (!(searchDistrict->parent->parent)) { 281 searchDistrict = searchDistrict->parent->isLeft()? searchDistrict->parent->parent->rightChild : searchDistrict->parent->parent->leftChild; 282 } else { 283 searchDistrict = searchDistrict->parent; 284 } 285 ++d; 286 } 287 288 return currentNearest; 289 } 290 291 int main(int argc, const char * argv[]) { 292 vector<vector<double>> trainDataSet{{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}}; 293 KdTree<double> *kdTree = new KdTree<double>; 294 buildKdTree(kdTree, trainDataSet, 0); 295 printKdTree(kdTree, 0); 296 297 vector<double> goal{3, 4.5}; 298 vector<double> nearestNeighbor = searchNearestNeighbor(kdTree, goal); 299 300 for (auto i : nearestNeighbor) { 301 cout << i << " "; 302 } 303 cout << endl; 304 305 return 0; 306 }
转载于:https://www.cnblogs.com/skycore/p/4908873.html
[Machine Learning]kNN代码实现(Kd tree)相关推荐
- Machine Learning Summary
Machine Learning Summary General Idea No Free Lunch Theorem (no "best") CV for complex par ...
- Unity in Machine Learning
这里写目录标题 小尝试 自定义object 的属性 C# 访问修饰符 C# NameSpace Object运动 例1 例2 Python检测 + C# 第1种 第2种 Object 运动方式 小尝试 ...
- Machine Learning In Action 第二章学习笔记: kNN算法
本文主要记录<Machine Learning In Action>中第二章的内容.书中以两个具体实例来介绍kNN(k nearest neighbors),分别是: 约会对象预测 手写数 ...
- 特征工程的宝典-《Feature Engineering for Machine Learning》翻译及代码实现
由O'Reilly Media,Inc.出版的<Feature Engineering for Machine Learning>(国内译作<精通特征工程>)一书,可以说是特征 ...
- KNN算法与Kd树(转载+代码详细解释)
最近邻法和k-近邻法 下面图片中只有三种豆,有三个豆是未知的种类,如何判定他们的种类? 提供一种思路,即:未知的豆离哪种豆最近就认为未知豆和该豆是同一种类.由此,我们引出最近邻算法的定义:为了判定未知 ...
- kd tree python_Python实现KNN与KDTree
KNN算法: KNN的基本思想以及数据预处理等步骤就不介绍了,网上挑了两个写的比较完整有源码的博客. 利用KNN约会分类 KNN项目实战--改进约会网站的配对效果 KNN 代码 ''' Functio ...
- Machine Learning | (7) Scikit-learn的分类器算法-决策树(Decision Tree)
Machine Learning | 机器学习简介 Machine Learning | (1) Scikit-learn与特征工程 Machine Learning | (2) sklearn数据集 ...
- 机器学习(Machine Learning)、深度学习(Deep Learning)、NLP面试中常考到的知识点和代码实现
网址:https://github.com/NLP-LOVE/ML-NLP 此项目是机器学习(Machine Learning).深度学习(Deep Learning).NLP面试中常考到的知识点和代 ...
- 【Machine Learning 学习笔记】Stochastic Dual Coordinate Ascent for SVM 代码实现
[Machine Learning 学习笔记]Stochastic Dual Coordinate Ascent for SVM 代码实现 通过本篇博客记录一下Stochastic Dual Coor ...
最新文章
- Spring Cloud(三)服务提供者 Eureka + 服务消费者(rest + Ribbon)
- mysql 查询调试_使用MySQL慢速查询日志进行调试
- php 编译安装 png.h,PHP编译安装时常见错误解决办法【大全】
- 获取节点及元素的代码
- python turtle画彩虹-Python基础实例——绘制彩虹(turtle库的应用)
- 一种无限循环轮播图的实现原理
- python 读取excel图片_如何用Python读取Excel中图片?
- python3.7适用的opencv_通过python3.7.3使用openCV截图一个区域
- python如何快速登记凭证_如何高效地翻凭证?
- 开放接口的安全验证方案(AES+RSA)
- invest模型的python安装方法,两种方法
- python 图像检索_深度学习图像检索
- EBS INV:事务处理
- 快手短视频直播间怎么提高人气热度,直播间冷启动是什么?
- html5限制拖拽区域怎么实现,html5怎么实现拖拽
- 连接HC-05与HC-06
- SSH密钥登录系统报错Permissions 0644 for ‘xxx‘ re too open
- SD卡驱动学习(一)
- 基于JSP的在线调查问卷系统
- 微信h5网页点击链接跳转到默认浏览器是怎么弄的?gdtool一键实现该方案
热门文章
- 1280*720P和1920*1080P的视频在25帧30帧50帧60帧时的参数
- 滴滴专车——司机提现流程
- 转载:iPhone 6 Plus 屏幕宽度问题 375 vs 414
- Winform开发框架之通用人员信息管理实现代码介绍
- NI Measurement Studio 打包问题的解决(原创)
- php register_shutdown_function
- Microsoft Visual Studio Learning Pack 自动生成流程图插件(转)
- RMAN 系列(二) ---- RMAN 设置和配置
- Spring4新特性——集成Bean Validation 1.1(JSR-349)到SpringMVC
- 优秀的Java程序员必须了解GC的工作原理