学习scikit-learn中的knn使用

并自己实现一个封装

学习scikit-learn中的knn使用
ps:代码块加标题会让字变红

python 首先引入库

在这里插入代码片

#在这个库里面所有的算法都是以面向对象的形式进行包装的,使用时要先进行实例化

from sklearn.neighbors import KNeighborsClassifier
import numpy as np

测试数据

Raw_data_X = [[3.4,2.3],[3.1,1.7],[1.3,3.3],[3.6,4.6],[2.2,2.8],[7.4,4.6],[5.7,3.5],[9.17,2.5],[7.8,3.4],[8,0.8]]
Raw_data_y = [0,0,0,0,0,1,1,1,1,1]

使用np数组加快计算过程

X_train = np.array(Raw_data_X)
y_train = np.array(Raw_data_y)

初始化模型

knn_classfier = KNeighborsClassifier(n_neighbors=6)#初始化一个值

fit得到’模型’

knn_classfier.fit(X_train,y_train)

测试点x

x = np.array([8,3.3])#2*1

x的shape为2行1列向量

我们输入时要把其转成矩阵,就这个例子来说就是转成1*2

X_predict = x.reshape(1,-1)#1*2的矩阵


调整好数据后进行预测(这个例子中只传入一个x,如何传入多个x呢)
预测结果为

y_predict =knn_classfier.predict(X_predict)
print(y_predict[0])

python 对于Knn来说,训练集就是模型




自己实现一个封装
这是一个函数实现

import numpy as np
from collections import Counterdef Knn_classify(k,X_train,y_train,x):# shape[0]指的就是行数,也就是X_train的数据点的个数,k显然必须小于总体样本个数assert 1 <= k <= X_train.shape[0]# 数据点的个数必须与数据标签个数相同assert X_train.shape[0] == y_train.shape[0]#x的维数必须和数据集中的X_train的维数保持一致assert x.shape[0] == X_train.shape[1]distances = [np.sqrt(np.sum((x_train - x) ** 2)) for x_train in X_train]nearest = np.argsort(distances)topK_y = [y_train[i] for i in nearest[:k]]votes = Counter(topK_y)return votes.most_common(1)[0][0]

我们需要用面向对象的思想去重构我们的代码,让它更符合sklearn的的风格
KNN.
首先引入相关包

import numpy as np
from collections import Counter

定义类KNNClassifier(这个只能输入单个点)

class KNNClassifier:

初始化

    def __init__(self,k):"""初始化KNN分类器,需要传入K的值"""assert k>=1,"k必须合法"self.k = kself._X_train = None#将我们训练的数据私有化,加 _self._y_train = None#私有化

fit训练分类器

 def fit(self,X_train, y_train):"""根据训练数据集X_train和y_train训练分类器"""# shape[0]指的就是行数,也就是X_train的数据点的个数,k显然必须小于总体样本个数assert 1 <= self.k <= X_train.shape[0],'K必须在一个合理的范围'# 数据点的个数必须与数据标签个数相同assert X_train.shape[0] == y_train.shape[0],'数据点个数必须与数据标签个数相同'self._X_train = X_trainself._y_train = y_trainreturn self

预测
给定待遇测数据集X_predict,返回表示X_predict的结果向量

    def predict(self,X_predict):"""给定待遇测数据集X_predict,返回表示X_predict的结果向量"""#传入的数据不为空assert self._X_train != None and self._y_train != None,'传入数据不为空,要先fit'#数据的维数必须一致assert X_predict.shape[1] == self._X_train.shape[1],'the feature number of X_predict must be equal to X_train'y_predict = [self._predict(x) for  x in X_predict]

给定单个待测数据,返回x的预测值结果

   def _predict(self,x):"""给定单个待测数据x,返回x的预测结果值"""# x的维数必须和数据集中的X_train的维数保持一致assert x.shape[0] == self.X_train.shape[1]distances = [np.sqrt(np.sum((x_train - x) ** 2)) for x_train in self.X_train]nearest = np.argsort(distances)topK_y = [self.y_train[i] for i in nearest[:self.k]]votes = Counter(topK_y)return votes.most_common(1)[0][0]

显示属性

 def __repr__(self):"""显示属性"""return "KNN(k=%d)" % self.k

测试一下

from knn.KNN import KNNClassifier
knn_clf = KNNClassifier(6)
knn_clf.fit(X_train,y_train)


预测

y_predict = knn_classfier.predict(X_predict)
print(y_predict)


PS:10个点测试
需要修改的代码:
数据x
第一块

x = [[3.4,2.3],[3.1,1.7],[1.3,3.3],[3.6,4.6],[2.2,2.8],[7.4,4.6],[5.7,3.5],[9.17,2.5],[7.8,3.4],[8,0.8]]
x = np.array(x)

第二块
将reshape成10*2的矩阵

X_predict = x.reshape(len(x),-1)#1*2的矩阵

Python机器学习:KNN算法02scikit-learn中的机器学习算法封装相关推荐

  1. 机器学习算法实践-SVM中的SMO算法

    前言 前两篇关于SVM的文章分别总结了SVM基本原理和核函数以及软间隔原理,本文我们就针对前面推导出的SVM对偶问题的一种高效的优化方法-序列最小优化算法(Sequential Minimal Opt ...

  2. 算法 从 数中选出_算法可以选出胜出的nba幻想选秀吗

    算法 从 数中选出 Note from Towards Data Science's editors: While we allow independent authors to publish ar ...

  3. 【老生谈算法】matlab实现车牌识别中值滤波算法——车牌识别中值滤波算法

    基于Matlab的车牌识别中值滤波算法的研究与实现 1.原文下载: 本算法原文如下,有需要的朋友可以点击进行下载 序号 原文(点击下载) 本项目原文 [老生谈算法]基于Matlab的车牌识别中值滤波算 ...

  4. 国密局公开SM2和SM3算法或预示中国商密算法将走向开放

    临近2010年年底的时候,在国密局的网站上公布了基于椭圆曲线ECC的SM2公开密钥国密算法和SM3杂凑算法.加上原来的SM1商密对称算法,中国定义的算法终于开始成熟并且以一个大方的姿态展示出来了. 此 ...

  5. 在算法研究过程中如何进行算法创新

    创新一直是一个令人纠结的话题,研究生毕业设计多数需要算法的创新,而博士生毕业更需要大量的创新才行.这里,我们就团队这几年来的工作经验,谈谈如何进行合理的算法创新. 一.创新角度 通常,我们使用一个算法 ...

  6. matlab中存档算法代码,Matlab中的FCM算法代码及中文详解

    Matlab中的FCM算法代码及中文详解 转自:http://xiaozu.renren.com/xiaozu/106512/336681453 function [center, U, obj_fc ...

  7. lru算法实现 redis_Redis中的lru算法实现

    lru是什么 lru(least recently used)是一种缓存置换算法.即在缓存有限的情况下,如果有新的数据需要加载进缓存,则需要将最不可能被继续访问的缓存剔除掉.因为缓存是否可能被访问到没 ...

  8. java 寻路算法_游戏中的寻路算法解析

    游戏角色的自动寻路,已经是游戏中一个历史比较悠久的领域,较为成熟也有很多种实现.这里摘录一句后面所提的参考资料中的描述:"业内AI开发者中有一句话:"寻路已不是问题."我 ...

  9. matlab中值滤波法算法,基于MATLAB中值滤波算法的优化与实现

    总第238期2014年第4期 舰 船 电 子 工 程 Ship Electronic Engineering Vol.34No.437 基于 MATLAB中值滤波算法的优化与实现* 赵建春 刘力源 ( ...

  10. python分类算法_python数据挖掘中的分类算法有哪些?

    一直以来,对于机器学习领域,Python都是人们津津乐道的话题,大家知道我们在用python学习机器技术时候,用到的方法内容和一般情况下的是一样的吗?想必,了解过的小伙伴一定是知道有哪些的,不知道的小 ...

最新文章

  1. 深度学习平台你知道多少?
  2. make: warning:  Clock skew detected.  Your build may be incomplete.
  3. 大型网站技术架构:核心原理与案例分析 mobi_大数据技术经典学习路线
  4. (转)Struts2快速入门
  5. eclipse 连接 mysql
  6. AACL2022会议征稿
  7. Atom : 一些有意思的插件
  8. 固态硬盘受损或数据删除,怎么办?详解各种恢复SSD数据方法
  9. 红旗河工程,南水北调西线工程,藏水入疆工程三合一
  10. 流量星球:实操!利用“拼多多砍价群”日吸100+精准女粉引流技术
  11. 阿里巴巴矢量字体库更改设置
  12. 12123 上传照片到文件服务器失败,12123软件上传不了照片怎么回事(教你最合理的上传方法)...
  13. HDU 2037 (贪心或记忆化搜索)
  14. Ubuntu 16.04 显示器分辨率低
  15. vue中监听enter键触发事件
  16. 小明开了一家糖果店。他别出心裁:把水果糖包成4颗一包和7颗一包的两种。糖果不能拆包卖
  17. 非递归式查找树形数据
  18. C++:图片数字水印-基于OpenCV+LSB
  19. 一. pandas入门介绍(一)
  20. 聚焦“云XR如何赋能元宇宙”,3DCAT实时云渲染首届行业生态合作交流会成功举办

热门文章

  1. 【codevs2304】【BZOJ1875】HH去散步,第一次的矩阵加速DP
  2. redis命令执行流程分析
  3. VS系列IDE(2005、2008等)下使用cppunit的方法及使用示例
  4. Stata和Matlab联合处理金融数据
  5. 计算机十进制例子,verilog给你举个最简单的例子:以十进制计算为例:14
  6. python付费课程推荐知乎_新手小白学习Python,有什么课程推荐吗?
  7. Android:进度条加载
  8. GITHUB来获得UE4源代码
  9. Android音视频之AudioRecord录音(一)
  10. 大数据架构师学习方向---加油。