关于KNN的介绍可以参考: http://blog.csdn.net/fengbingchun/article/details/78464169

这里给出KNN的C++实现,用于分类。训练数据和测试数据均来自MNIST,关于MNIST的介绍可以参考: http://blog.csdn.net/fengbingchun/article/details/49611549  , 从MNIST中提取的40幅图像,0,1,2,3四类各20张,每类的前10幅来自于训练样本,用于训练,后10幅来自测试样本,用于测试,如下图:

实现代码如下:

knn.hpp:

#ifndef FBC_NN_KNN_HPP_
#define FBC_NN_KNN_HPP_#include <memory>
#include <vector>namespace ANN {template<typename T>
class KNN {
public:KNN() = default;void set_k(int k);int set_train_samples(const std::vector<std::vector<T>>& samples, const std::vector<T>& labels);int predict(const std::vector<T>& sample, T& result) const;private:int k = 3;int feature_length = 0;int samples_number = 0;std::unique_ptr<T[]> samples;std::unique_ptr<T[]> labels;
};} // namespace ANN#endif // FBC_NN_KNN_HPP_

knn.cpp:

#include "knn.hpp"
#include <limits>
#include <algorithm>
#include <functional>
#include "common.hpp"namespace ANN {template<typename T>
void KNN<T>::set_k(int k)
{this->k = k;
}template<typename T>
int KNN<T>::set_train_samples(const std::vector<std::vector<T>>& samples, const std::vector<T>& labels)
{CHECK(samples.size() == labels.size());this->samples_number = samples.size();if (this->k > this->samples_number) this->k = this->samples_number;this->feature_length = samples[0].size();this->samples.reset(new T[this->feature_length * this->samples_number]);this->labels.reset(new T[this->samples_number]);T* p = this->samples.get();for (int i = 0; i < this->samples_number; ++i) {T* q = p + i * this->feature_length;for (int j = 0; j < this->feature_length; ++j) {q[j] = samples[i][j];}this->labels.get()[i] = labels[i];}
}template<typename T>
int KNN<T>::predict(const std::vector<T>& sample, T& result) const
{if (sample.size() != this->feature_length) {fprintf(stderr, "their feature length dismatch: %d, %d", sample.size(), this->feature_length);return -1;}typedef std::pair<T, T> value;std::vector<value> info;for (int i = 0; i < this->k + 1; ++i) {info.push_back(std::make_pair(std::numeric_limits<T>::max(), (T)-1.));}for (int i = 0; i < this->samples_number; ++i) {T s{ 0. };const T* p = this->samples.get() + i * this->feature_length;for (int j = 0; j < this->feature_length; ++j) {s += (p[j] - sample[j]) * (p[j] - sample[j]);}info[this->k] = std::make_pair(s, this->labels.get()[i]);std::stable_sort(info.begin(), info.end(), [](const std::pair<T, T>& p1, const std::pair<T, T>& p2) {return p1.first < p2.first; });}std::vector<T> vec(this->k);for (int i = 0; i < this->k; ++i) {vec[i] = info[i].second;}std::sort(vec.begin(), vec.end(), std::greater<T>());vec.erase(std::unique(vec.begin(), vec.end()), vec.end());std::vector<std::pair<T, int>> ret;for (int i = 0; i < vec.size(); ++i) {ret.push_back(std::make_pair(vec[i], 0));}for (int i = 0; i < this->k; ++i) {for (int j = 0; j < ret.size(); ++j) {if (info[i].second == ret[j].first) {++ret[j].second;break;}}}int max = -1, index = -1;for (int i = 0; i < ret.size(); ++i) {if (ret[i].second > max) {max = ret[i].second;index = i;}}result = ret[index].first;return 0;
}template class KNN<float>;
template class KNN<double>;} // namespace ANN

测试代码如下:

#include "funset.hpp"
#include <iostream>
#include "perceptron.hpp"
#include "BP.hpp""
#include "CNN.hpp"
#include "linear_regression.hpp"
#include "naive_bayes_classifier.hpp"
#include "logistic_regression.hpp"
#include "common.hpp"
#include "knn.hpp"
#include <opencv2/opencv.hpp>// =========================== KNN(K-Nearest Neighbor) ======================
int test_knn_classifier_predict()
{const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" };const int K{ 3 };cv::Mat tmp = cv::imread(image_path + "0_1.jpg", 0);const int train_samples_number{ 40 }, predict_samples_number{ 40 };const int every_class_number{ 10 };cv::Mat train_data(train_samples_number, tmp.rows * tmp.cols, CV_32FC1);cv::Mat train_labels(train_samples_number, 1, CV_32FC1);float* p = (float*)train_labels.data;for (int i = 0; i < 4; ++i) {std::for_each(p + i * every_class_number, p + (i + 1)*every_class_number, [i](float& v){v = (float)i; });}// train datafor (int i = 0; i < 4; ++i) {static const std::vector<std::string> digit{ "0_", "1_", "2_", "3_" };static const std::string suffix{ ".jpg" };for (int j = 1; j <= every_class_number; ++j) {std::string image_name = image_path + digit[i] + std::to_string(j) + suffix;cv::Mat image = cv::imread(image_name, 0);CHECK(!image.empty() && image.isContinuous());image.convertTo(image, CV_32FC1);image = image.reshape(0, 1);tmp = train_data.rowRange(i * every_class_number + j - 1, i * every_class_number + j);image.copyTo(tmp);}}ANN::KNN<float> knn;knn.set_k(K);std::vector<std::vector<float>> samples(train_samples_number);std::vector<float> labels(train_samples_number);const int feature_length{ tmp.rows * tmp.cols };for (int i = 0; i < train_samples_number; ++i) {samples[i].resize(feature_length);const float* p1 = train_data.ptr<float>(i);float* p2 = samples[i].data();memcpy(p2, p1, feature_length * sizeof(float));}const float* p1 = (const float*)train_labels.data;float* p2 = labels.data();memcpy(p2, p1, train_samples_number * sizeof(float));knn.set_train_samples(samples, labels);// predict dattacv::Mat predict_data(predict_samples_number, tmp.rows * tmp.cols, CV_32FC1);for (int i = 0; i < 4; ++i) {static const std::vector<std::string> digit{ "0_", "1_", "2_", "3_" };static const std::string suffix{ ".jpg" };for (int j = 11; j <= every_class_number + 10; ++j) {std::string image_name = image_path + digit[i] + std::to_string(j) + suffix;cv::Mat image = cv::imread(image_name, 0);CHECK(!image.empty() && image.isContinuous());image.convertTo(image, CV_32FC1);image = image.reshape(0, 1);tmp = predict_data.rowRange(i * every_class_number + j - 10 - 1, i * every_class_number + j - 10);image.copyTo(tmp);}}cv::Mat predict_labels(predict_samples_number, 1, CV_32FC1);p = (float*)predict_labels.data;for (int i = 0; i < 4; ++i) {std::for_each(p + i * every_class_number, p + (i + 1)*every_class_number, [i](float& v){v = (float)i; });}std::vector<float> sample(feature_length);int count{ 0 };for (int i = 0; i < predict_samples_number; ++i) {float value1 = ((float*)predict_labels.data)[i];float value2;memcpy(sample.data(), predict_data.ptr<float>(i), feature_length * sizeof(float));CHECK(knn.predict(sample, value2) == 0);fprintf(stdout, "expected value: %f, actual value: %f\n", value1, value2);if (int(value1) == int(value2)) ++count;}fprintf(stdout, "when K = %d, accuracy: %f\n", K, count * 1.f / predict_samples_number);return 0;
}

执行结果如下:与OpenCV中KNN结果相似。

GitHub:  https://github.com/fengbingchun/NN_Test

K-最近邻法(KNN) C++实现相关推荐

  1. python机器学习案例系列教程——K最近邻算法(KNN)、kd树

    全栈工程师开发手册 (作者:栾鹏) python数据挖掘系列教程 K最近邻简介 K最近邻属于一种估值或分类算法,他的解释很容易. 我们假设一个人的优秀成为设定为1.2.3.4.5.6.7.8.9.10 ...

  2. K近邻法(KNN)原理小结

    K近邻法(k-nearst neighbors,KNN)是一种很基本的机器学习方法了,在我们平常的生活中也会不自主的应用.比如,我们判断一个人的人品,只需要观察他来往最密切的几个人的人品好坏就可以得出 ...

  3. 3. k 近邻法 k-NN

    1. k k k-NN k k k-NN 是一种基本的监督学习方法,它和感知机有些不同.具体地,它没有一个明确策略,也就是没有损失函数,因此它没有一个显式的学习过程. 1.1 模型概述 k k k-N ...

  4. 斯坦福CS231n项目实战(一):k最近邻(kNN)分类算法

    我的网站:红色石头的机器学习之路 我的CSDN:红色石头的专栏 我的知乎:红色石头 我的微博:RedstoneWill的微博 我的GitHub:RedstoneWill的GitHub 我的微信公众号: ...

  5. scikit-learn K近邻法类库使用小结

    1. scikit-learn 中KNN相关的类库概述 在scikit-learn 中,与近邻法这一大类相关的类库都在sklearn.neighbors包之中.KNN分类树的类是KNeighborsC ...

  6. 基于KD树的K近邻算法(KNN)算法

    文章目录 KNN 简介 KNN 三要素 距离度量 k值的选择 分类决策规则 KNN 实现 1,构造kd树 2,搜索最近邻 3,预测 用kd树完成最近邻搜索 K近邻算法(KNN)算法,是一种基本的分类与 ...

  7. 统计学习方法笔记(李航)———第三章(k近邻法)

    k 近邻法 (k-NN) 是一种基于实例的学习方法,无法转化为对参数空间的搜索问题(参数最优化 问题).它的特点是对特征空间进行搜索.除了k近邻法,本章还对以下几个问题进行较深入的讨 论: 切比雪夫距 ...

  8. wifi室内定位讲解——K邻近法

    摘要 对于室内复杂环境来说, 适用于室外定位的 GPS 系统和蜂窝移动网络在室内中的定位精度明显恶化, 无法满足室内用户精确定位的需求.因此, 研究一种适用于室内复杂环境的高精度.环境自适应性强的定位 ...

  9. K 近邻法(K-Nearest Neighbor, K-NN)

    文章目录 1. k近邻算法 2. k近邻模型 2.1 模型 2.2 距离度量 2.2.1 距离计算代码 Python 2.3 kkk 值的选择 2.4 分类决策规则 3. 实现方法, kd树 3.1 ...

  10. [机器学习-sklearn] KNN(k近邻法)学习与总结

    KNN 学习与总结 引言 一,KNN 原理 二,KNN算法介绍 三, KNN 算法三要素 1 距离度量 2. K 值的选择 四, KNN特点 KNN算法的优势和劣势 KNN算法优点 KNN算法缺点 五 ...

最新文章

  1. latex 表格单元格上下左右居中_Excel文字对齐技巧:学会这6种方式,快速整理规范表格...
  2. 电量检测芯片BQ27510使用心得
  3. bzoj3482,jzoj3238-超时空旅行hiperprostor【最短路,凸包,斜率优化】
  4. 1.2-Nginx编译安装
  5. Hive笔记之JOIN的左外链接和右外链接
  6. java 获得文件的行数据_Java 读取文件指定行数据
  7. C语言转义字符的使用
  8. 张鑫 css,元素有高度 但是css设置背景色不显示
  9. C++类中的访问权限问题---public/protected/private
  10. red hat linux yum,Red Hat Enterprise Linux(RHEL)中yum的repo文件详解
  11. rstudio创建矩阵_R中的矩阵
  12. linux和windows下,C/C++的sleep函数
  13. Jersey框架入门学习
  14. 半导体物理 第七章 金属半导体接触整流理论
  15. 晶振外匹配电容应该怎样选取
  16. 今日话题:微信再次更新搜索框,公众号会更加火爆?
  17. python判断回文序列_怎么用python3代码检查回文序列?
  18. Java数据结构第三个-链表-单链表
  19. 基于注意力机制的机器翻译——经典论文解读与代码实现
  20. 【前端】HTML Manual-HTML入门手册

热门文章

  1. PyCharm导入numpy包遇到的问题
  2. 力扣(LeetCode)刷题,简单+中等题(第34期)
  3. 自然语言处理:网购商品评论情感判定
  4. 【radar】毫米波雷达相关数据集(检测、跟踪、里程计、SLAM、定位、场景识别)总结(1)
  5. uniapp 分享缩略图过大怎么办_女性胸外扩怎么办|3步带你完成改变
  6. 使用 sched_setaffinity 将线程绑到CPU核上运行
  7. 关于ceph源码 backtrace 打印函数调用栈
  8. HDU1402(FFT入门)
  9. 地图收敛心得170405
  10. linux 安装输入法