写前思考

目前的几个问题,如何找到簇的中心点。
如果要找簇的中心点,只需要求出那个簇所在地点的均值,然后将其赋值给新的中心点就可以了。
如何增加或者减少k的值(如何选择k值)
这还是一个没有解决的问题,现在我要去百度下嘿嘿。可以是由sse方法,找到数据变化的点
没错,现在可以写程序了。

什么是K-means

k均值聚类算法(k-means clustering algorithm)是一种迭代求解的聚类分析算法,其步骤是,预将数据分为K组,则随机选取K个对象作为初始的聚类中心,然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。每分配一个样本,聚类的聚类中心会根据聚类中现有的对象被重新计算。这个过程将不断重复直到满足某个终止条件。终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。

分析k-means的算法流程

  1. 预将数据分为K组,则随机选取K个对象作为初始的聚类中心;
  2. 计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。
  3. 重新计算各个聚类的中心,并将其作为新的聚类中心。
  4. 这个过程(步骤1-3)将不断重复直到满足某个终止条件。终止条件可以是没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。

主要函数的编写

首先,随机产生k个聚类中心,这里我们可以使用np的函数。
但是,我们要先得出这个数据集的边界,也就是x,y的最大最小值。
步骤: 首先读取数据集,然后使用np.max方法求最大最小值。然后生成一个(0,1)的矩阵。
之后使用生成的矩阵乘增量+初始值的方式。可能这个方式有点原始,有好的办法之后,会更新。

def create_centroids(k,data_set):min_x,max_x=min(data_set[:,0]),max(data_set[:,0])min_y,max_y=min(data_set[:,1]),max(data_set[:,1])centroids =np.random.random((k,2))centroids_x=min_x+centroids[:,0]*(max_x-min_x)centroids_y=min_y+centroids[:,1]*(max_y-min_y)centroids=[]for i in range(len(centroids_x)):centroids.append([centroids_x[i],centroids_y[i]])return np.array(centroids)

下面这个函数是划分样本到最近的聚类中心:

def findClosestCentroids(data,centroids):cluster_indexs=[]for i  in range(len(data)):diff=data[i]-centroids#这样减法会生成K行,这样一下子就能算出来 一个样本到k个聚类中心的坐标差# 下面求欧氏距离dist=0for j in range(len(diff[0])):dist+=diff[:,j]**2 # 求x,和y的平方和min_index=np.argmin(dist) #然后找到距离最小的值 并将其标注为簇编号cluster_indexs.append(min_index)return np.array(cluster_indexs) # 返回对应索引的簇编号

使用上述函数,我们就可以计算出,每个样本距离最近的聚类中心,接下来我们需要找出这个簇的真实中心,怎么找呢?
我们只需要找这个簇中样本的平均坐标就行了,同时np.mean()方法可以提供很好的帮助。
此外可以使用一点小技巧划分数据集。Datas[clustering==i]先举个例子。
如果数据集为datas=[[1,2],[1,3],[1,4]],那么datas[True ,False,True]=[[1,2],[1,4]] 所以我们可以用i对比整个簇编号列表。就会得到一个布尔型的列表,如果 和i相同 就会显示True 否则就是False 。这样就一步取出了所有的属于i的样本,然后在使用numpy.mean()方法求均值。

def computMeans(Datas, clustering):centroids = []# print(np.unique(clustering))for i in range(len(np.unique(clustering))):  # np.unique计算聚类个数u_k = np.mean(Datas[clustering==i], axis=0)  # 求每列的平均值centroids.append(u_k)return np.array(centroids)

然后就完成了步骤1-3的主要函数的编写,下面只要重复这些步骤,就可以求出来。
下面开始编辑主函数

def K_means(k,data_set):# 随机生成k矩阵centroids =create_centroids(k,data_set)# 使用生成算法生成聚类中心centroids_list=[]# 用于记录中心点的移动轨迹for i in range(30):# 训练三十遍clustering=findClosestCentroids(data_set,centroids)centroids_new=computMeans(data_set,clustering)centroids_list.append(centroids1)# 说明有点,没有被任何样本选中,说明中心点已经多了#比如说,我们自动生成了4个点,但是样本选择离自己最近的点的时候,发现没有一个样本选择该点,所以计算簇的中心点的时候,会有一个点是空的。所以返回的矩阵会少一行。这样两个矩阵就不一样了。但是这也说明,三个点就足够了。所以就不用继续迭代了if centroids.shape!=centroids_new.shape:print("出现未选中的中心点")return centroids_list,cluster,False#返回中心点的移动轨迹和最终的簇和退出状态if np.max(centroids-centroids_new)==0: # 如果两个矩阵没有差别了,就可以退出循环breakcentroids=centroids_newcluster=findClosestCentroids(data_set,centroids)# 找到最终的簇return centroids_list,cluster,True #返回中心点的移动轨迹和最终的簇和退出的状态

现在,主函数已经结束了,但是我们还不知道最优的k,这就是我开头说的第二个难题:
我们可以使用误差平方和( sum of squared errors)SSE。

手肘法求k值

手肘法的核心思想:随着聚类数k的增大,样本划分会更加精细,每个簇的聚合程度会逐渐提高,那么误差平方和SSE自然会逐渐变小。并且,当k小于真实聚类数时,由于k的增大会大幅增加每个簇的聚合程度,故SSE的下降幅度会很大,而当k到达真实聚类数时,再增加k所得到的聚合程度回报会迅速变小,所以SSE的下降幅度会骤减,然后随着k值的继续增大而趋于平缓,也就是说SSE和k的关系图是一个手肘的形状,而这个肘部对应的k值就是数据的真实聚类数。当然,这也是该方法被称为手肘法的原因。
S S E = ∑ i = 1 k ∑ p ∈ C i ∣ p − m i ∣ 2 SSE=\sum_{i=1}^{k} \sum_{p\in C_i} |p-m_i|^2 SSE=i=1∑k​p∈Ci​∑​∣p−mi​∣2其中,Ci是第i个簇,p是Ci中的样本点,mi是Ci的质心(Ci中所有样本的均值),SSE是所有样本的聚类误差,代表了聚类效果的好坏。将其转化为代码,则如下:

def calculate_sse(centroids,clustering,data_set):sum=0for i in range(data_set.shape[0]):diff= data_set[i]-centroids[clustering[i]]sum+=diff[0]**2+diff[1]**2return sum

我们再看下sse和k值的变化关系图:

如上图所示,k=2 是手肘的肘部。我们可以看出来,但是怎么让计算机能看懂呢?我又想了一个土办法:让i+1的sse值除i的sse值,这样会得出一个小于1的值,比值越小,变化越巨大。
然后我们在判断前后两个比值的差,差最大的就是肘部。我们默认会设置8个分类。所以可以很好的划分。
如果差值一最大,但是他在差值列表中的索引为0,但是肘部的k值为2。所以索引和k值得差为2。

def find_k_index(lis_k):ratios=[ lis_k[i+1]/lis_k[i] for i in range(len(lis_k)-1)]diff=[ ratios[i+1]-ratios[i] for i in range(len(ratios)-1)]return np.argmax(diff)+2

.接下来,我们要穷尽k值,进行聚类,然后找出效果最好的,并将图画出来。

def Confirm_K_value(data_set):# 计算最优的k值#在一个范围内随机生成矩阵,lis_k=[]for k in range(1,MAX_CLUSTER):centroids_list,cluster,state=K_means(k,data_set)if not state:# 如果发现有空值,那么直接终止增加Kbreaklis_k.append(calculate_sse(centroids_list[-1],cluster,data_set))print("SSE值",lis_k)max_=find_k_index(lis_k)print("共发现{}个簇".format(max_))centroids_list,cluster,state=K_means(max_,data_set)if not state:print("erro")return# 展示数据show(centroids_list,cluster,data_set)

下面描述下画图函数,代码略有繁琐,后期优化。

def show(centroids_list,cluster,data_set):#簇中心的位置变化情况,样本对应的簇编号,样本集合centroid_x = []centroid_y = []for centroid in centroids_list:centroid_x.append(centroid[:,0])centroid_y.append(centroid[:,1])plt.plot(centroid_x, centroid_y, 'r*--',c="blue", markersize=14)# 接下来是画点lis_x=[[] for i in range(np.unique(cluster).shape[0])]lis_y=[[] for i in range(np.unique(cluster).shape[0])]for i in range(len(data_set)):lis_x[cluster[i]].append(data_set[i][0])lis_y[cluster[i]].append(data_set[i][1])colors=['red','brown','orange','green','cyan','purple','pink','blue','#FFA07A','#20B2AA','#87CEFA','#9ACD32']for i in range(np.unique(cluster).shape[0]):plt.scatter(lis_x[i],lis_y[i],alpha=0.5,c=colors[i])#画一个散点图plt.show()

【注这个数据集在这个文件同目录的data目录下】,下载后需要放到相同位置。数据集

源代码:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
df= pd.read_csv('./data/两坨散点.csv')
data_set=np.array(df)def findClosestCentroids(data,centroids):cluster_indexs=[]for i  in range(len(data)):diff=data[i]-centroids# 下面求欧氏距离dist=0for j in range(len(diff[0])):dist+=diff[:,j]**2min_index=np.argmin(dist)cluster_indexs.append(min_index)return np.array(cluster_indexs)# print(np.unique(clustering))# 使用unique可以进行去重复操作。
# 根据聚类重新计算中心点函数:
def computMeans(Datas, clustering):centroids = []# print(np.unique(clustering))for i in range(len(np.unique(clustering))):  # np.unique计算聚类个数u_k = np.mean(Datas[clustering==i], axis=0)  # 求每列的平均值centroids.append(u_k)return np.array(centroids)def create_centroids(k,data_set):min_x,max_x=min(data_set[:,0]),max(data_set[:,0])min_y,max_y=min(data_set[:,1]),max(data_set[:,1])centroids =np.random.random((k,2))centroids_x=min_x+centroids[:,0]*(max_x-min_x)centroids_y=min_y+centroids[:,1]*(max_y-min_y)centroids=[]for i in range(len(centroids_x)):centroids.append([centroids_x[i],centroids_y[i]])return np.array(centroids)def calculate_sse(centroids,clustering,data_set):sum=0for i in range(data_set.shape[0]):diff= data_set[i]-centroids[clustering[i]]sum+=diff[0]**2+diff[1]**2return sum
def K_means(k,data_set):# 随机生成k矩阵centroids =create_centroids(k,data_set)centroids_list=[]for i in range(30):clustering=findClosestCentroids(data_set,centroids)centroids1=computMeans(data_set,clustering)centroids_list.append(centroids1)if centroids.shape!=centroids1.shape:print("出现未选中样本")# 说明 有点,没有被任何样本选中,说明中心点已经多了return centroids_list,[],False #返回中心点的移动轨迹和最终的簇和退出的状态if np.max(centroids-centroids1)==0:# print("find")breakcentroids=centroids1# print(centroids)cluster=findClosestCentroids(data_set,centroids)return centroids_list,cluster,True #返回中心点的移动轨迹和最终的簇和退出的状态def show(centroids_list,cluster,data_set):centroid_x = []centroid_y = []for centroid in centroids_list:centroid_x.append(centroid[:,0])centroid_y.append(centroid[:,1])plt.plot(centroid_x, centroid_y, 'r*--',c="blue", markersize=14)lis_x=[[] for i in range(np.unique(cluster).shape[0])]lis_y=[[] for i in range(np.unique(cluster).shape[0])]for i in range(len(data_set)):lis_x[cluster[i]].append(data_set[i][0])lis_y[cluster[i]].append(data_set[i][1])colors=['red','brown','orange','green','cyan','purple','pink','blue','#FFA07A','#20B2AA','#87CEFA','#9ACD32']for i in range(np.unique(cluster).shape[0]):plt.scatter(lis_x[i],lis_y[i],alpha=0.5,c=colors[i])#画一个散点图plt.show()
MAX_CLUSTER=8
def find_k_index(lis_k):ratios=[ lis_k[i+1]/lis_k[i] for i in range(len(lis_k)-1)]diff=[ ratios[i+1]-ratios[i] for i in range(len(ratios)-1)]return np.argmax(diff)+2
def Confirm_K_value(data_set):# 计算最优的k值#在一个范围内随机生成矩阵,lis_k=[]for k in range(1,MAX_CLUSTER):centroids_list,cluster,state=K_means(k,data_set)if not state:breaklis_k.append(calculate_sse(centroids_list[-1],cluster,data_set))print("SSE值",lis_k)max_=find_k_index(lis_k)print("共发现{}个簇".format(max_))centroids_list,cluster,state=K_means(max_,data_set)if not state:print("erro")return# 展示数据show(centroids_list,cluster,data_set)Confirm_K_value(data_set)

测试结果

  1. 两个簇情况下:
    SSE值 [897.1674886919908, 156.58714828390853, 119.03323146797621, 93.86925045546373, 81.41218770686388]
    共发现2个簇

  2. 三个簇情况下:

    SSE值 [2668.4175360883137, 1139.838296827762, 238.7100927716138, 206.82792773131195, 192.94401994160415]
    共发现3个簇

  3. 五个簇情况下:

    SSE值 [54790.63847837179, 19023.567985249738, 10365.750809115409, 6078.364568727137, 3892.1958462961006, 3695.49840412411, 3418.4365266913137]
    共发现5个簇
    【存在问题】本次实现的聚类算法,在五个簇的数据集中表现十分不稳定,有待改进。

K-means学习笔记及简易代码实现相关推荐

  1. 吴恩达《机器学习》学习笔记十一——神经网络代码

    吴恩达<机器学习>学习笔记十一--神经网络代码 数据准备 神经网络结构与代价函数· 初始化设置 反向传播算法 训练网络与验证 课程链接:https://www.bilibili.com/v ...

  2. 【学习笔记】低代码平台(LCAP:Low-Code Application Platform)

    学习笔记:低代码平台(LCAP:Low-Code Application Platform) [概念] 开发者写很少的代码,通过低代码平台提供的界面.逻辑.对象.流程等可视化编排工具来完成大量的开发工 ...

  3. Unity学习笔记1 简易2D横版RPG游戏制作(一)

    这个教程是参考一个YouTube上面的教程做的,原作者的教程做得比较简单,我先参考着做一遍,毕竟我也只是个初学者,还没办法完全自制哈哈.不过我之前也看过一个2D平台游戏的系列教程了,以后会整合起来,做 ...

  4. Unity学习笔记2 简易2D横版RPG游戏制作(二)

    十二.敌人受攻击时的闪烁和Player的生命值的修正 上一篇中,我们利用Controller2D中的IEnumerator TakenDamage接口,使得我们的Player受到攻击时会进行闪烁,我们 ...

  5. 0037 Java学习笔记-多线程-同步代码块、同步方法、同步锁

    什么是同步 在上一篇0036 Java学习笔记-多线程-创建线程的三种方式示例代码中,实现Runnable创建多条线程,输出中的结果中会有错误,比如一张票卖了两次,有的票没卖的情况,因为线程对象被多条 ...

  6. 【学习笔记】结合代码理解设计模式 —— 代理模式(静态代理、动态代理、延伸)

    文章目录 什么是代理模式 一. 代理模式简介 二. 静态代理模式 三. 动态代理模式 万能模版 前言:笔记基于狂神设计模式视频.<大话设计模式>观后而写 (最近一直在更新之前的刷题博客,今 ...

  7. C++ Qt学习笔记 (1) 简易计算器设计

    最近开始学习c++ qt, 按照教材上的例程设计一个简易的桌面计算器: Qt是一个基于C++语言的跨平台应用程序和UI开发框架,主要包含一个类库,和跨平台开发及国际化的工具,最初由挪威的Trollte ...

  8. fmri学习笔记|SPM 代码 循环

    目录 SPM 用代码跑操作(上) 第一步生成脚本 第二步 看一下生成的脚本结构 temp_smooth_job temp_smooth SPM 用代码跑操作(下) 下面记录怎么修改输入input和fo ...

  9. 卷起来了,写了一套计算机视觉学习笔记(20G/代码/PPT/视频)

    AI 显然是最近几年非常火的一个新技术方向,从几年前大家认识到 AI 的能力,到现在产业里已经在普遍的探讨 AI 如何落地了. 我们可以预言未来在很多的领域,很多的行业,AI 都会在里边起到重要的作用 ...

最新文章

  1. ios获取新数据要不要关_iOS开发之数据读写
  2. 检查网站是否可以正常访问
  3. 微软对学生免费提供Visual Studio等开发软件(包括中国)
  4. Spark性能相关参数配置详解
  5. python学习-模块和包
  6. 综合知识计算机类编制,天津事业编综合知识是什么
  7. 【2015沈阳区域赛F=HDU5514】Frogs(圆上n个青蛙跳统计跳劲哪些点---欧拉函数求和+思维)
  8. 尼康单反AF自动对焦模式与AF区域模式详解
  9. windows10查看桌面壁纸路径
  10. 【xlwings api语言参考】Worksheet.Cells 属性
  11. 电商网络推广是干什么,电商网络营销做什么
  12. 微信小程序开发—背景图片全屏(无白条)
  13. Linux 复制、粘贴快捷键
  14. 学习笔记之——DCDC降压芯片基本原理及选型主要参数介绍
  15. 8寸ndows平板,三款8英寸Windows平板对比体验
  16. DBSCAN聚类分析在图像分割的应用
  17. phpstorm常见问题
  18. 简单的golang游戏服务器框架《railgun》的文档目录索引
  19. linux类似vc的软件,linux下能否使用VC之类的软件?
  20. 使用itune备份iPhone

热门文章

  1. Java并发编程78讲--29 第29讲:HahMap 为什么是线程不安全的?
  2. adb 命令行获取android数据库文件并在可视化工具下显示
  3. window的tar.gz 到ubuntu上tar -zxf 文件名乱码
  4. kinmall:哪些国家适合证券型通证发行(STO)?
  5. 何为分度翩翩的程序员?
  6. 简单易懂的单纯形法理解
  7. 东软mysql期末题库_东软大三上学期实训笔记-mysql篇Day5完结篇
  8. 全美猎头公司排名 2005
  9. 织梦cms/dedecms清理冗余废弃未引用图片方法
  10. 《黑马头条》SpringBoot+SpringCloud+ Nacos等企业级微服务架构项目