文章目录

  • 前言
  • 1. K-Means算法的不足之二
  • 2. K-Means++算法
  • 3. 代码实现
  • 4. sklearn实现K-Means++
  • 结束语

前言

  K-Means算法的原理很简单:初始化kkk个聚类中心之后,不断计算样本与kkk个聚类中心的距离,离哪个聚类中心最近,相应的该样本就属于这个聚类中心所属的类别,然后重新计算聚类中心,直至其不再发生变化,具体步骤请参阅我的上篇博客。
  当然,此算法也有一定的局限和不足之处,想必其中一点已经很清楚了,那就是需要首先确定聚类中心的个数kkk,即需要先验知识。如果有一大批数据,我事先不知道其有几个类别,那kkk的大小该如何确定呢?这个问题我想放在下篇博客分享。
  本篇博客主要分享一下如何解决K-Means的另一个不足之处:聚类中心的初始化

1. K-Means算法的不足之二

  K-Means在初始化聚类中心时是在最小值和最大值之间随机取一个值作为其聚类中心,这样的随机取值会导致聚类中心可能选择的不好,最终对结果会产生很大的影响。经过测试,如果样本类别区分度较明显,按照K-Means初始化聚类中心,对结果的影响并不大;反之,如果样本的类别区分度不大,聚类结果会有较大的不同。下面用周志华老师的《机器学习》这本书里的西瓜数据集来说明。

  上述两张图可以说明,对于同一数据集,不同的初始聚类中心其产生的结果会有较大的不同。因此,K-Means++算法被提出了。

2. K-Means++算法

  K-Means++算法是K-Means算法的改进版,主要是为了选择出更优的初始聚类中心。其基本思路如下:

  • 在数据集中随机选择一个样本作为第一个初始聚类中心;
  • 选择出其余的聚类中心:
    • 计算数据集中的每一个样本与已经初始化的聚类中心之间的距离,并选择其中最短的距离,记为did_idi​;
    • 以概率选择距离最大的样本作为新的聚类中心,重复上述过程,直到kkk个聚类中心都被确定。
  • 对kkk个初始的聚类中心,利用K-Means算法计算出最终的聚类中心。

对“以概率选择距离最大的样本作为新的聚类中心”的理解:
  即初始的聚类中心之间的相互距离应尽可能的远。假如有3、5、15、10、23、5、15、10、23、5、15、10、2这五个样本的最小距离did_idi​,则其和sumsumsum为353535,然后乘以一个取值在[0,1)[0, 1)[0,1)范围的值,即概率,也可以称其为权重,然后这个结果不断减去样本的距离did_idi​,直到某一个距离使其小于等于0,这个距离对应的样本就是新的聚类中心。比如上述的例子,假设sumsumsum乘以0.50.50.5得到结果17.517.517.5,17.5−3=14.5>0,14.5−5=9.5>0,9.5−15=−5.5<017.5-3=14.5>0,14.5-5=9.5>0,9.5-15=-5.5<017.5−3=14.5>0,14.5−5=9.5>0,9.5−15=−5.5<0,则距离did_idi​为15的样本点距离最大,做为新的的聚类中心。

3. 代码实现

  代码和上篇博客里里面的代码基本一致,不同之处就是更换了初始化聚类中心的函数。

import numpy as np
import random
import matplotlib.pyplot as pltdef load_data(file_path):"""将txt里面的数据转换成矩阵:param file_path::return:"""data_list = []with open(file_path, 'r') as f:lines = f.readlines()for line in lines:data_row = []line = line.strip().split('\t')for x in line:data_row.append(float(x))data_list.append(data_row)data_arr = np.array(data_list)return data_arrdef o2_distance(vecA, vecB):"""计算向量vecA和向量vecB之间的欧氏距离的平方:param vecA: 向量vecA的坐标:param vecB: 向量vecB的坐标:return:"""# .T 对一个矩阵转置distance = np.dot((vecA - vecB), (vecA - vecB).T)return distancedef nearest_distance(data_arr, cluster_centers):min_distance = 100# 当前已经初始化的聚类中心之间的距离dim = np.shape(cluster_centers)[0]for i in range(dim):# 计算point与每个聚类中心之间的距离distance = o2_distance(data_arr, cluster_centers[i])# 选择最短距离if distance < min_distance:min_distance = distancereturn min_distancedef get_centroids(data_arr, k):dim_m, dim_n = np.shape(data_arr)cluster_centers = np.array(np.zeros(shape=(k, dim_n)))# 随机选择一个样本点为第一个聚类中心index = np.random.randint(0, dim_m)cluster_centers[0] = data_arr[index]# 初始化一个距离的序列distance = [0.0 for _ in range(dim_m)]for i in range(1, k):sum_all = 0for j in range(dim_m):# 对每一个样本找到最近的聚类中心点distance[j] = nearest_distance(data_arr[j], cluster_centers[0:i])# 将所有的最短距离相加sum_all += distance[j]# 取得sum_all之间的随机值sum_all *= random.random()# 以概率获得距离最远的样本点作为聚类中心for id, dist in enumerate(distance):sum_all -= distif sum_all > 0:continuecluster_centers[i] = data_arr[id]breakreturn cluster_centersdef kmeans(data_arr, k, centroids):"""聚类计算:param data_arr::param k::param centroids::return: sub_centroids [类别, 最小距离]"""# (样本个数, 特征维度)dim_m, dim_n = np.shape(data_arr)# 初始化每一个样本所属的类别sub_center = np.array(np.zeros(shape=(dim_m, 2)))# 更新标志flag = Truewhile flag:flag = Falsefor i in range(dim_m):# 设置样本与聚类中心之间的初始最小距离, 初始值为无穷min_distance = np.inf# 设置所属的初始类别min_index = 0for j in range(k):# 计算样本i和每个聚类中心之间的距离distance = o2_distance(data_arr[i], centroids[j])if distance < min_distance:min_distance = distancemin_index = jif sub_center[i, 0] != min_index:flag = Truesub_center[i] = np.array([min_index, min_distance])# 重新计算聚类中心for j in range(k):sum_all = np.array(np.zeros(shape=(1, dim_n)))# 每个类别中的样本个数counter = 0for i in range(dim_m):# 计算第j个类别if sub_center[i, 0] == j:sum_all += data_arr[i]counter += 1for t in range(dim_n):try:centroids[j, t] = sum_all[0, t] / counterexcept Exception as err:print('样本数为0')return sub_centerdef draw_picture(data_arr, sub_center, centroids):x = sub_center[:, 0]dots1 = data_arr[x == 0.0]dots2 = data_arr[x == 1.0]dots3 = data_arr[x == 2.0]dots4 = data_arr[x == 3.0]plt.figure()plt.scatter(dots1[:, 0], dots1[:, 1], marker='o',color='blue', alpha=0.7, label='dots1 samples')plt.scatter(dots2[:, 0], dots2[:, 1], marker='o',color='green', alpha=0.7, label='dots2 samples')plt.scatter(dots3[:, 0], dots3[:, 1], marker='o',color='red', alpha=0.7, label='dots3 samples')plt.scatter(dots4[:, 0], dots4[:, 1], marker='o',color='purple', alpha=0.7, label='dots4 samples')plt.scatter(centroids[:, 0], centroids[:, 1], marker='x',color='black', alpha=0.7, label='centroids')plt.savefig('./result_plus.png')plt.show()if __name__ == '__main__':k = 4file_path = './西瓜数据集.txt'data_arr = load_data(file_path)centroids = get_centroids(data_arr, k)sub_center = kmeans(data_arr, k, centroids)draw_picture(data_arr, sub_center, centroids)

  运行结果如图所示:

4. sklearn实现K-Means++

本代码主要使用sklearn.cluster.KMeans,部分参数已注释,更详细的说明,请参考官方文档。

from sklearn.cluster import KMeans
import numpy as np
import matplotlib.pyplot as pltdef load_data(file_path):"""将txt里面的数据转换成矩阵:param file_path::return:"""data_list = []with open(file_path, 'r') as f:lines = f.readlines()for line in lines:data_row = []line = line.strip().split('\t')for x in line:data_row.append(float(x))data_list.append(data_row)data_arr = np.array(data_list)return data_arrdef draw_picture(data_arr, cluster_centers, labels):dots1 = data_arr[labels == 0]dots2 = data_arr[labels == 1]dots3 = data_arr[labels == 2]dots4 = data_arr[labels == 3]plt.figure()plt.scatter(dots1[:, 0], dots1[:, 1], marker='o',color='blue', alpha=0.7, label='dots1 samples')plt.scatter(dots2[:, 0], dots2[:, 1], marker='o',color='green', alpha=0.7, label='dots2 samples')plt.scatter(dots3[:, 0], dots3[:, 1], marker='o',color='red', alpha=0.7, label='dots3 samples')plt.scatter(dots4[:, 0], dots4[:, 1], marker='o',color='purple', alpha=0.7, label='dots4 samples')plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], marker='x',color='black', alpha=0.7, label='centroids')plt.savefig('./result_plus_skl.png')plt.show()if __name__ == '__main__':file_path = './西瓜数据集.txt'data_arr = load_data(file_path)# n_clusters 聚类中心数目k, 默认为8# init 聚类中心的初始化方法, 默认为k-means++# n_init 算法执行的次数, 默认为10# max_iter 最大迭代次数, 默认300# tol 算法收敛的阈值, 默认0.0001# verbose 0不打印日志, 1打印日志# random_state 随机数生成器的种子# n_jobs 任务使用的CPU数量kmeans = KMeans(n_clusters=4, init='k-means++', n_init=10, max_iter=300,tol=1e4, verbose=0, random_state=1024, n_jobs=1)kmeans.fit(X=data_arr)# 聚类中心cluster_centers = kmeans.cluster_centers_# 各个样本的类别labels = kmeans.labels_draw_picture(data_arr, cluster_centers, labels)

  运行结果如图所示:

结束语

  本篇博客主要分享了K-Means++算法的原理及两种代码实现,能力有限,如有错误的地方,欢迎交流哦,O(∩_∩)O哈哈~

聚类之K-Means++算法相关推荐

  1. kmeans改进 matlab,基于距离函数的改进k―means 算法

    摘要:聚类算法在自然科学和和社会科学中都有很普遍的应用,而K-means算法是聚类算法中经典的划分方法之一.但如果数据集内相邻的簇之间离散度相差较大,或者是属性分布区间相差较大,则算法的聚类效果十分有 ...

  2. k means算法C语言伪代码,K均值算法(K-Means)

    1. K-Means算法步骤 算法步骤 收敛性定义,畸变函数(distortion function): 伪代码: 1) 创建k个点作为K个簇的起始质心(经常随机选择) 2) 当任意一个点的蔟分配结果 ...

  3. 机器学习——聚类之k近邻算法及python使用

    聚类算法之k近邻及python使用 什么是k近邻算法 k近邻算法流程 使用sklearn进行代码实现 数据集介绍 标准化 代码实现 写在开头,套用我的老师的一句话目前所有自然学科的前沿都是在研究数学, ...

  4. k均值聚类算法(K Means)及其实战案例

    算法说明 K均值聚类算法其实就是根据距离来看属性,近朱者赤近墨者黑.其中K表示要聚类的数量,就是说样本要被划分成几个类别.而均值则是因为需要求得每个类别的中心点,比如一维样本的中心点一般就是求这些样本 ...

  5. (学习笔记)聚类算法 - k均值算法(k-means)

    聚类之K均值算法 聚类介绍 k均值算法步骤 Python实现 参考文献 聚类介绍 聚类是一种经典的无监督学习方法. 聚类的目的是将数据集中的样本划分为若干个通常是不相交的子集,每个子集被称为簇,以此来 ...

  6. K means 图片压缩

    k-means的基本原理较为清晰,这里不多赘述,本次博客主要通过基础的k means算法进行图像的压缩处理. 原理分析 在彩色图像中,每个像素的大小为3字节(RGB),可以表示的颜色总数为256 * ...

  7. 基于改进的k最近邻算法的单体型重建问题An Improved KNN Algorithm for Haplotype Reconstruction Problem

    基于改进的k最近邻算法的单体型重建问题 An Improved KNN Algorithm for Haplotype Reconstruction Problem DOI: 10.12677/csa ...

  8. OpenCV的k - means聚类 -对图片进行颜色量化

    OpenCV的k - means聚类 目标 学习使用cv2.kmeans()数据聚类函数OpenCV 理解参数 输入参数 样品:它应该的np.float32数据类型,每个特性应该被放在一个单独的列. ...

  9. OpenCV官方文档 理解k - means聚类

    理解k - means聚类 目标 在这一章中,我们将了解k - means聚类的概念,它是如何工作等. 理论 我们将这个处理是常用的一个例子. t恤尺寸问题 考虑一个公司要发布一个新模型的t恤. 显然 ...

  10. 机器学习-聚类之K均值(K-means)算法原理及实战

    K-means算法 前言 机器学习方法主要分为监督学习和非监督学习两种.监督学习方法是在样本标签类别已知的情况下进行的,可以统计出各类样本的概率分布.特征空间分布区域等描述量,然后利用这些参数进行分类 ...

最新文章

  1. linux工程师前景_小猿圈预测2019年Linux云计算发展前景
  2. 为你揭示最危害程序员职业生涯的三大观念
  3. java 百分比相加_2019年Java面试题基础系列228道(5),快看看哪些你还不会?
  4. 数据之“星”冉冉升起,“星斗奖”申报正式启动!
  5. Java批量完成对文件夹下全部的css与js压缩,利用yuicompressor
  6. const 和readonly
  7. hapi返回xml格式 微信开发 node
  8. python 跳过计算错误_在python中计算分数时如何跳过被0除的错误?
  9. 1万吨猪肉只够吃1.5小时,中国人是有多爱猪肉?
  10. vue引用electron_前端跨平台桌面开发技术:Electron 快速起步
  11. Flex4中使用WCF
  12. 当VS2005卸载不干净时
  13. kotlin android获取按钮,Kotlin Android按钮
  14. jxl读数据库数据生成xls 并下载
  15. project甘特图导出图片_Project将任务及甘特图导出Excel的方法
  16. 联想微型计算机设置键盘开机,联想台式机怎么样设置键盘开机
  17. Linux操作系统原理
  18. 了解嵌入式软件开发周期
  19. arm板gdb调试移植
  20. 玩乐|杭州夏季纳凉好去处

热门文章

  1. 穆利堂推荐 新周刊,当下中国的12中孤单
  2. 计算机网卡更改mac地址,苹果MAC地址怎么改?MAC网卡物理地址修改的详细方法
  3. 中国中产阶级“被中产”?
  4. Latex中斜线表头的制作方法
  5. 微信web开发者工具-移动调试iphone端的调试
  6. 05-现代威胁环境下的10个SIEM用例
  7. Flume从入门实战到精通再到面试一文搞定
  8. php开发俄罗斯方块,HTML5+JS实现俄罗斯方块原理及具体步骤_html5教程技巧
  9. 《游戏学习》Java版俄罗斯方块小游戏源码实战
  10. 华为《悟空》刷屏:愿你如少年,永不知天高地厚