import numpy as npfrom math import sqrtfrom collections import Counterfrom .metrics import accuracy_score

class KNNClassifier:

    def __init__(self, k):        """初始化kNN分类器"""        assert k >= 1, "k must be valid"        self.k = k        self._X_train = None        self._y_train = None

    def fit(self, X_train, y_train):        """根据训练数据集X_train和y_train训练kNN分类器"""        assert X_train.shape[0] == y_train.shape[0], \            "the size of X_train must be equal to the size of y_train"        assert self.k <= X_train.shape[0], \            "the size of X_train must be at least k."

        self._X_train = X_train        self._y_train = y_train        return self

    def predict(self, X_predict):        """给定待预测数据集X_predict,返回表示X_predict的结果向量"""        assert self._X_train is not None and self._y_train is not None, \                "must fit before predict!"        assert X_predict.shape[1] == self._X_train.shape[1], \                "the feature number of X_predict must be equal to X_train"

        y_predict = [self._predict(x) for x in X_predict]        return np.array(y_predict)

    def _predict(self, x):        """给定单个待预测数据x,返回x的预测结果值"""        assert x.shape[0] == self._X_train.shape[1], \            "the feature number of x must be equal to X_train"

        distances = [sqrt(np.sum((x_train - x) ** 2))                     for x_train in self._X_train]        nearest = np.argsort(distances)

        topK_y = [self._y_train[i] for i in nearest[:self.k]]        votes = Counter(topK_y)

        return votes.most_common(1)[0][0]

    def score(self, X_test, y_test):        """根据测试数据集 X_test 和 y_test 确定当前模型的准确度"""

        y_predict = self.predict(X_test)        return accuracy_score(y_test, y_predict)

    def __repr__(self):        return "KNN(k=%d)" % self.k

转载于:https://www.cnblogs.com/heguoxiu/p/10135546.html

KNNClassifier相关推荐

  1. python分类算法_用Python实现KNN分类算法

    本文实例为大家分享了Python KNN分类算法的具体代码,供大家参考,具体内容如下 KNN分类算法应该算得上是机器学习中最简单的分类算法了,所谓KNN即为K-NearestNeighbor(K个最邻 ...

  2. tensorflow 迁移学习_基于 TensorFlow.js 1.5 的迁移学习图像分类器

    在黑胡桃社区的体验案例中,有一个"人工智能教练",它其实是一个自定义的图像分类器.使用 TensorFlow.js 这个强大而灵活的 Javascript 机器学习库可以很轻松地构 ...

  3. 菜鸟进阶: C++实现KNN文本分类算法

    作者:finallyliuyu(转载请注明原作者和出处) (代码暂不发布源码下载版,以后会发布) KNN文本分类算法又称为(k nearest neighhor).它是一种基于事例的学习方法,也称懒惰 ...

  4. 计算机视觉编程——图像内容分类

    文章目录 图像内容分类 1 K近邻分类法(KNN) 1.1 一个简单的二维示例 1.2 用稠密SIFT作为图像特征 1.3 图像分类:手势识别 2 贝叶斯分类器 3 支持向量机 图像内容分类 1 K近 ...

  5. 一个不错的机器视觉库 SimpleCV: a kinder, gentler machine vision library

    Computer Vision platform using Python. WHAT IS IT? SimpleCV is an open source framework for building ...

  6. Python计算机视觉:第八章 图像类容分类

    第八章 图像类容分类 8.1 K最近邻 8.1.1 一个简单的二维例子 8.1.2 图像稠密(dense)sift特征) 8.1.3 图像分类--手势识别 8.1 session 和登录失败 8.1. ...

  7. [云炬python3玩转机器学习]4-3 训练数据集,测试数据集

    03 测试我们的算法 import numpy as np import matplotlib.pyplot as plt from sklearn import datasets iris = da ...

  8. MATLAB高光谱图像构建KNN图

    在高光谱图像的特征提取过程中,采用非线性降维的方式对高光谱图像降维的过程中,采用图自编码器来对数据进行降维,需要将利用高光谱图像的结构信息和内容信息,则需要将高光谱图像数据构造为一个图结构,图结构的构 ...

  9. Python机器学习:KNN算法04f分类准确度

    引入相关包 import numpy as np import matplotlib.pyplot as plt import matplotlib from sklearn import datas ...

最新文章

  1. 线性基+树上倍增 ---- BZOJ4568[线性基+树上倍增]
  2. R语言使用knitr生成机器学习模型全流程步骤示例:knitr与自动化结果报告、knitr常用参数
  3. 错误: libstdc++.so.6: cannot open shared object file: No such file or directory
  4. python 面向对象教程:访问限制
  5. java对象的包装_java中常见对象——基本包装类
  6. c++利用windows api遍历指定文件夹及其子文件夹中的文件
  7. 路由cpu负载过高检查
  8. 【黑马程序员】————预处理指令2-文件编译
  9. LINUX下载编译sqlite-jdbc(支持mips64el/loongarch64的jar包下载)
  10. linux 文本编辑器Vim/Vi详细介绍
  11. 聚合直播源码原生播放器php分享,原生聚合直播搭建源码
  12. android 添加文件打开方式
  13. Unity3D 性能优化
  14. sam文件获取与解密
  15. vc linux编译环境变量,CodeLite集成开发环境使用VC编译器开发.doc
  16. [转载]MIT人工智能实验室:如何做研究
  17. 布衣联盟XP SP2之国兴奥运版
  18. VMware:在部分链上无法执行所调用的函数,请打开父虚拟磁
  19. javaScript面试高频技术点(多为原生基础+框架集合)
  20. 版本发布 | IvorySQL Release - 2.2

热门文章

  1. L3-020 至多删三个字符 [DP]
  2. JAVA 类加载 随记
  3. java多线程总结一:线程的两种创建方式及优劣比较
  4. 13个不可不知的ASP.NET MVC扩展点
  5. Eigen(3)矩阵Matrix及其简单操作
  6. 蓝桥杯 2011年第二届C语言初赛试题(3)
  7. 在c++中qsort()排序函数的使用qsort函数应用大全
  8. uml 时序图_UML各种图总结:
  9. 少儿编程100讲轻松学python(七)-pycharm怎么删除项目
  10. ajax和for循环谁难,关于“for”循环中jquery $ .ajax的问题