引言

kmeans 是一个简单实用的机器学习算法,今天专门介绍一下。这篇文章主要关注以下几点:

  • kmeans 算法的具体流程是啥?
  • 做一个代码实战,并用可视化的方法来展示kmeans的聚类效果。
  • 训练过程中如何选取最好的聚类结果?
  • 训练好后如何评价训练的的结果。

k-means算法流程

k-means算法的步骤:

  1. 随机选取k个初始中心点。
  2. 计算每个样本点和中心点的距离,各自划分为距离最近的中心点所对应的类别中。
  3. 重新确定中心点: 每个类别中的点的均值作为新的中心点。
  4. 重复步骤2,3 直到某个条件终止(一定的次数或者是每个簇中的点不在改变)

k-means算法存在的问题:

  1. 计算量和样本数量成正比
  2. 初始值对聚类结果影响很大
  3. 不同的初始值可能会有不同的聚类结果。
  4. 对不规则的数据效果不好。
  5. 分成的簇的数量需要人工确认。
  6. 对初始值的稳定性非常差。

kmeans 算法实战

生成数据

我们可以先自己手工构造一个一个的簇,作为聚类的数据集。 这里讲一下生成数据的方法。
要点如下:

  • 指定若干个二维的点作为质心。
  • 指定每个质心周围的点的个数以及离散程度,用标准差数值来表示。
  • 绘制散点图。
    具体的代码如下:
import  numpy as np
import os
%matplotlib inline
import matplotlib
import matplotlib.pyplot as pltfrom sklearn.datasets import make_blobs
# 这是中心点
blob_centers = np.array(
[[2.2,2.3],[-1.5, 2.3],[-2.8, 1.8],[-2.8, 2.8],[-2.8, 1.3]
]
)
# 这是离散程度, 越大则越离散。
blob_std = np.array([1, 0.3, 0.1, 0.1, 0.1])X,y = make_blobs(n_samples=5000, centers = blob_centers, cluster_std = blob_std, random_state=7)
print(X, y)

这里有个小思考: 例子中基本都是使用的二维点来进行展示的, 如果是多维的数据,如果可以图形化展示聚类算法的效果呢?

聚类算法进行聚类

# kmeans算法from sklearn.cluster import KMeans
# 需要人为指定簇的个数
k = 5
# random_stat 指定可以保证出来的结果都是一样的
kmeans = KMeans(n_clusters = k, random_state=42)
# 输入的X是一个二维的数组,每个元素都是一个点。
y_pred= kmeans.fit_predict(X)
# 得到的是每个点的簇的标签
print(y_pred)
print(kmeans.labels_)
# 可以得出簇的中心点
kmeans.cluster_centers_

可视化展示-画等高线

等高线可以看出每个簇所处的区域是啥,具体的函数代码如下:

def plot_data(X):plt.plot(X[:, 0], X[:, 1], 'k.', markersize=2)# 画中心点
def plot_centroids(centroids, weights=None, circle_color='w', cross_color='k'):if weights is not None:centroids = cnetroids[weights > weights.max() / 10]plt.scatter(centroids[:, 0], centroids[:, 1], marker= 'o', s = 30, linewidths=8, color=circle_color, zorder=10, alpha=0.9)plt.scatter(centroids[:, 0], centroids[:, 1],marker='x', s=50, linewidths=1,color=cross_color, zorder=11, alpha=1)def plot_decision_boundaries(clusterer, X, resolution=1000, show_centroids=True, show_xlabels=True, show_ylabels=True):# 这个是为了可以显示所有的点mins = X.min(axis=0)maxs = X.max(axis=0)print(mins)print(maxs)# 下面画一个网格xx, yy = np.meshgrid(np.linspace(mins[0], maxs[0], resolution),np.linspace(mins[1], maxs[1], resolution))# 这个ravel函数可以把二维数组拍平Z = clusterer.predict(np.c_[xx.ravel(), yy.ravel()])print(Z)Z = Z.reshape(xx.shape)plt.contourf(Z, extent=(mins[0], maxs[0], mins[1], maxs[1]), cmap="Pastel2")plt.contour(Z, extent=(mins[0], maxs[0], mins[1], maxs[1]), linewidths=1, colors='k')plot_data(X)if show_centroids:plot_centroids(clusterer.cluster_centers_)if show_xlabels:plt.xlabel("$x_1$", fontsize=14)else:plt.tick_params(labelbottom='off')if show_ylabels:plt.ylabel("$x_2$", fontsize=14, rotation=0)else:plt.tick_params(labelleft='off')plt.figure(figsize=(8,4))
plot_decision_boundaries(kmeans, X)
  • 也可以写一个画图函数,将不同簇的点用不同的颜色表示出来。代码如下:
def plot_color(clusterer, X):# 预测一下聚类效果y = clusterer.fit_predict(X)# 这里的c 表示颜色序列plt.scatter(X[:, 0], X[:, 1], c=y,s=2, cmap=plt.cm.Spectral)plot_color(kmeans, X)

kmeans聚类算法使用技巧

我们知道, kmeans聚类算法的结果,会受到初始值中心点的影响,我们怎么保证最后聚类出来的结果是比较合适的呢?
kmeans算法里有几个参数, 需要注意一下。

  • n_init, 例如这个值设定为10的话,会把模型跑10次,最后选取跑的最好的那一次,那么, 如何确定哪个值好呢,先留个疑问。
  • max_iter: 迭代次数, 这里设置最大的迭代次数。

kmeans聚类算法的评估方法

机器学习库中, 有个inertia_属性, 这个就是每个样本与对应质心的距离的平方和的累加值。当设定n_init参数的时候, 会选取多次聚类, inertia_值最小的聚类结果。

如何寻找最合适的k值

  • 首先,选择最小inertia值对应的k值不可靠,原因是,随着k的增加, 聚类的簇就会越来越多, 数据点到质心的距离也会很小, 最后的inertia的值也会变得非常小, 但是聚类效果并不是越来越好的。

  • 推荐方案一:画一个图,表示出k值与对应的inertia值的对应关系曲线,在曲线的下降率发生剧烈变化的时候对应的k值,可以认为是比较合适的k值,不过, 这个点只能作为参考,有可能不是最优的选择。

  • 推荐方案二:使用轮廓系数来选取k值。
    定义两个概念:

  1. a i a_i ai​: 计算样本i到同簇其他样本的平均距离 a i a_i ai​, 这个值越小, 说明越是应该被聚类到这个簇。
  2. b i b_i bi​: 计算样本i 到其他簇的所有样本的平均距离, 然后取一个最小值就是 b i b_i bi​, b i b_i bi​越大越好。

s ( i ) = b ( i ) − a ( i ) m a x ( a ( i ) , b ( i ) ) s(i)=\frac{b(i)-a(i)}{max(a(i),b(i))} s(i)=max(a(i),b(i))b(i)−a(i)​

我们可以计算所有 s ( i ) s(i) s(i)的均值,作为评价指标。
sklearn机器学习库中有一个专门计算函数。具体使用如下:

# kmeans_per_k表示聚类实例的列表。
# 注意实例取值要从1开始, 否则会出错的。
silhouette_scores = [silhouette_score(X,model.labels_) for model in kmeans_per_k[1:]]#绘制图表
plt.figure(figsize=(8,4))
plt.plot(range(2,10),silhouette_scores,'bo-')
plt.show()

这个指标也只能是作为一个参考,最合适的k值还是需要根据自己的需求来决定的。

总结

kmeans 是一个简单实用的算法, 不过缺点也是比较明显的, 对于形状不规则的点状图,kmeans 的聚类算法效果可能会不好,而且对于kmeans 的模型评估比较难, 不容易找到合适的k值。

【机器学习】聚类算法 kmeans相关推荐

  1. 《菜菜的机器学习sklearn课堂》聚类算法Kmeans

    聚类算法 聚类算法 无监督学习与聚类算法 sklearn中的聚类算法 KMeans KMeans是如何工作的 簇内误差平方和的定义和解惑 sklearn.cluster.KMeans 重要参数 n_c ...

  2. 机器学习-Sklearn-07(无监督学习聚类算法KMeans)

    机器学习-Sklearn-07(无监督学习聚类算法KMeans) 学习07 1 概述 1.1 无监督学习与聚类算法 聚类算法又叫做"无监督分类",其目的是将数据划分成有意义或有用的 ...

  3. 机器学习实战-61:K均值聚类算法(K-Means)

    K均值聚类算法(K-Means) 深度学习原理与实践(开源图书)-总目录,建议收藏,告别碎片阅读! 机器学习分为监督学习.无监督学习和半监督学习(强化学习).无监督学习最常应用的场景是聚类(clust ...

  4. 炼数成金数据分析课程---17、机器学习聚类算法(后面要重点看)

    炼数成金数据分析课程---17.机器学习聚类算法(后面要重点看) 一.总结 一句话总结: 大纲+实例快速学习法 主要讲解常用聚类算法(比如K-means等)的原理及python代码实现:后面学习聚类的 ...

  5. 用人话讲明白聚类算法kmeans

    文章目录 1.什么是聚类 2.K-Means步骤 3.K-Means的数学描述 4.初始中心点怎么确定 5.K值怎么确定 6.小结 1.什么是聚类 先来回顾一下本系列第一篇就讲到的机器学习的种类. 监 ...

  6. 基于机器学习聚类算法寻找美国职业篮球联赛NBA中的超级强队

    聚类算法 聚类算法是机器学习中经典的非监督学习算法之一,相比于分类算法,聚类不依赖预定义的样本标签,而是让算法通过对数据的学习从而找到其内部的规律,该算法对有相同特征的样本进行聚类,聚类的时候,我们并 ...

  7. 聚类算法——kmeans和meanshift

    聚类算法--kmeans和meanshift [转] 1. meanshift 转于http://www.cnblogs.com/liqizhou/archive/2012/05/12/2497220 ...

  8. matlab中的聚类算法,kmeans聚类算法matlab matlab 聚类算法silhouette

    怎样用matlab实现多维K-means聚类算法小编觉得一个好的周末应该是这样的:睡到中午醒来,在床上躺着玩两个小时手机,起床随便吃点东西,下午去超市买一大堆零食,五六点的时候去约小伙伴们吃火锅烧烤, ...

  9. K均值聚类算法(Kmeans)讲解及源码实现

    K均值聚类算法(Kmeans)讲解及源码实现 算法核心 K均值聚类的核心目标是将给定的数据集划分成K个簇,并给出每个数据对应的簇中心点.算法的具体步骤描述如下. 数据预处理,如归一化.离群点处理等. ...

  10. 离线轻量级大数据平台Spark之MLib机器学习库聚类算法KMeans实例

    1.KMeans算法 所谓聚类问题,就是给定一个元素集合D,其中每个元素具有n个可观察属性,使用某种算法将D划分成k个子集,要求每个子集内部的元素之间相异度尽可能低,而不同子集的元素相异度尽可能高.其 ...

最新文章

  1. Atomic Integer 原理分析-getAndAddInt
  2. 【C++深度剖析教程29】C++对象模型分析下
  3. php pcntl signal,php – 后续的pcntl_signal信号没有启动处理程序
  4. 面向对象三之对象的使用方法
  5. leetcode880.DecodedStringatIndex
  6. Mac笔记本安装Webstrom
  7. Win32 SDK - 打开文件对话框
  8. Python Tricks(五)—— 计算 list of lists 的长度(元素个数)
  9. GBin1分享:一个漂亮的jQuery页面内容导航插件 - Flexiable Nav
  10. 当当网价格系统架构分析
  11. 项目启动会发言稿(范文二)
  12. Cortex-M3 (NXP LPC1788)之开发环境搭建
  13. 固液分离机市场现状及未来发展趋势
  14. metapath2vec 异构网络表示学习
  15. ctfshow web入门 nodejs 334-341(更新中)
  16. 【Spring】IDEA中创建Spring项目
  17. Ubuntu下安装Luma qq
  18. 手机便签里的文字不小心点了个粘贴就消失了应该怎样复原呢?
  19. [STM32] Mac开发STM32之Makefile
  20. 打印unicode汉字编码字符串为乱码怎么办?

热门文章

  1. 冀教版三年级计算机技术教案,冀教版三年级信息技术教案
  2. OpenGL——二次曲面函数(球面-圆锥面-圆柱面)
  3. 理解UDDI(1):UDDI服务实施的体系架构
  4. android ios 逆向工程,iOS逆向工程(七):使用Theos逆向项目
  5. MVC与三层架构模型笔记
  6. [Luogu P2447] [BZOJ 1923] [SDOI2010]外星千足虫
  7. matlab图形网格线画虚线
  8. python使用while、for及循环嵌套实现直角三角形及正、倒金字塔
  9. 荣耀50和小米civi参数对比
  10. Python-计算md5值对图片去重