机器学习算法与Python实践这个系列主要是参考《机器学习实战》这本书。因为自己想学习Python,然后也想对一些机器学习算法加深下了解,所以就想通过Python来实现几个比较常用的机器学习算法。恰好遇见这本同样定位的书籍,所以就参考这本书的过程来学习了。

机器学习中有两类的大问题,一个是分类,一个是聚类。分类是根据一些给定的已知类别标号的样本,训练某种学习机器,使它能够对未知类别的样本进行分类。这属于supervised learning(监督学习)。而聚类指事先并不知道任何样本的类别标号,希望通过某种算法来把一组未知类别的样本划分成若干类别,这在机器学习中被称作 unsupervised learning (无监督学习)。在本文中,我们关注其中一个比较简单的聚类算法:k-means算法。

一、k-means算法

通常,人们根据样本间的某种距离或者相似性来定义聚类,即把相似的(或距离近的)样本聚为同一类,而把不相似的(或距离远的)样本归在其他类。

我们以一个二维的例子来说明下聚类的目的。如下图左所示,假设我们的n个样本点分布在图中所示的二维空间。从数据点的大致形状可以看出它们大致聚为三个cluster,其中两个紧凑一些,剩下那个松散一些。我们的目的是为这些数据分组,以便能区分出属于不同的簇的数据,如果按照分组给它们标上不同的颜色,就是像下图右边的图那样:

如果人可以看到像上图那样的数据分布,就可以轻松进行聚类。但我们怎么教会计算机按照我们的思维去做同样的事情呢?这里就介绍个集简单和经典于一身的k-means算法。

k-means算法是一种很常见的聚类算法,它的基本思想是:通过迭代寻找k个聚类的一种划分方案,使得用这k个聚类的均值来代表相应各类样本时所得的总体误差最小。

k-means算法的基础是最小误差平方和准则。其代价函数是:

式中,μc(i)表示第i个聚类的均值。我们希望代价函数最小,直观的来说,各类内的样本越相似,其与该类均值间的误差平方越小,对所有类所得到的误差平方求和,即可验证分为k类时,各聚类是否是最优的。

上式的代价函数无法用解析的方法最小化,只能有迭代的方法。k-means算法是将样本聚类成 k个簇(cluster),其中k是用户给定的,其求解过程非常直观简单,具体算法描述如下:

1、随机选取 k个聚类质心点

2、重复下面过程直到收敛  {

对于每一个样例 i,计算其应该属于的类:

对于每一个类 j,重新计算该类的质心:

}

下图展示了对n个样本点进行K-means聚类的效果,这里k取2。

其伪代码如下:

********************************************************************

创建k个点作为初始的质心点(随机选择)

当任意一个点的簇分配结果发生改变时

对数据集中的每一个数据点

对每一个质心

计算质心与数据点的距离

将数据点分配到距离最近的簇

对每一个簇,计算簇中所有点的均值,并将均值作为质心

********************************************************************

二、Python实现

我使用的Python是2.7.5版本的。附加的库有Numpy和Matplotlib。具体的安装和配置见前面的博文。在代码中已经有了比较详细的注释了。不知道有没有错误的地方,如果有,还望大家指正(每次的运行结果都有可能不同)。里面我写了个可视化结果的函数,但只能在二维的数据上面使用。直接贴代码:

kmeans.py

[python] view plaincopy
  1. #################################################
  2. # kmeans: k-means cluster
  3. # Author : zouxy
  4. # Date   : 2013-12-25
  5. # HomePage : http://blog.csdn.net/zouxy09
  6. # Email  : zouxy09@qq.com
  7. #################################################
  8. from numpy import *
  9. import time
  10. import matplotlib.pyplot as plt
  11. # calculate Euclidean distance
  12. def euclDistance(vector1, vector2):
  13. return sqrt(sum(power(vector2 - vector1, 2)))
  14. # init centroids with random samples
  15. def initCentroids(dataSet, k):
  16. numSamples, dim = dataSet.shape
  17. centroids = zeros((k, dim))
  18. for i in range(k):
  19. index = int(random.uniform(0, numSamples))
  20. centroids[i, :] = dataSet[index, :]
  21. return centroids
  22. # k-means cluster
  23. def kmeans(dataSet, k):
  24. numSamples = dataSet.shape[0]
  25. # first column stores which cluster this sample belongs to,
  26. # second column stores the error between this sample and its centroid
  27. clusterAssment = mat(zeros((numSamples, 2)))
  28. clusterChanged = True
  29. ## step 1: init centroids
  30. centroids = initCentroids(dataSet, k)
  31. while clusterChanged:
  32. clusterChanged = False
  33. ## for each sample
  34. for i in xrange(numSamples):
  35. minDist  = 100000.0
  36. minIndex = 0
  37. ## for each centroid
  38. ## step 2: find the centroid who is closest
  39. for j in range(k):
  40. distance = euclDistance(centroids[j, :], dataSet[i, :])
  41. if distance < minDist:
  42. minDist  = distance
  43. minIndex = j
  44. ## step 3: update its cluster
  45. if clusterAssment[i, 0] != minIndex:
  46. clusterChanged = True
  47. clusterAssment[i, :] = minIndex, minDist**2
  48. ## step 4: update centroids
  49. for j in range(k):
  50. pointsInCluster = dataSet[nonzero(clusterAssment[:, 0].A == j)[0]]
  51. centroids[j, :] = mean(pointsInCluster, axis = 0)
  52. print 'Congratulations, cluster complete!'
  53. return centroids, clusterAssment
  54. # show your cluster only available with 2-D data
  55. def showCluster(dataSet, k, centroids, clusterAssment):
  56. numSamples, dim = dataSet.shape
  57. if dim != 2:
  58. print "Sorry! I can not draw because the dimension of your data is not 2!"
  59. return 1
  60. mark = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr', '<r', 'pr']
  61. if k > len(mark):
  62. print "Sorry! Your k is too large! please contact Zouxy"
  63. return 1
  64. # draw all samples
  65. for i in xrange(numSamples):
  66. markIndex = int(clusterAssment[i, 0])
  67. plt.plot(dataSet[i, 0], dataSet[i, 1], mark[markIndex])
  68. mark = ['Dr', 'Db', 'Dg', 'Dk', '^b', '+b', 'sb', 'db', '<b', 'pb']
  69. # draw the centroids
  70. for i in range(k):
  71. plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize = 12)
  72. plt.show()

三、测试结果

测试数据是二维的,共80个样本。有4个类。如下:

testSet.txt

[python] view plaincopy
  1. 1.658985    4.285136
  2. -3.453687   3.424321
  3. 4.838138    -1.151539
  4. -5.379713   -3.362104
  5. 0.972564    2.924086
  6. -3.567919   1.531611
  7. 0.450614    -3.302219
  8. -3.487105   -1.724432
  9. 2.668759    1.594842
  10. -3.156485   3.191137
  11. 3.165506    -3.999838
  12. -2.786837   -3.099354
  13. 4.208187    2.984927
  14. -2.123337   2.943366
  15. 0.704199    -0.479481
  16. -0.392370   -3.963704
  17. 2.831667    1.574018
  18. -0.790153   3.343144
  19. 2.943496    -3.357075
  20. -3.195883   -2.283926
  21. 2.336445    2.875106
  22. -1.786345   2.554248
  23. 2.190101    -1.906020
  24. -3.403367   -2.778288
  25. 1.778124    3.880832
  26. -1.688346   2.230267
  27. 2.592976    -2.054368
  28. -4.007257   -3.207066
  29. 2.257734    3.387564
  30. -2.679011   0.785119
  31. 0.939512    -4.023563
  32. -3.674424   -2.261084
  33. 2.046259    2.735279
  34. -3.189470   1.780269
  35. 4.372646    -0.822248
  36. -2.579316   -3.497576
  37. 1.889034    5.190400
  38. -0.798747   2.185588
  39. 2.836520    -2.658556
  40. -3.837877   -3.253815
  41. 2.096701    3.886007
  42. -2.709034   2.923887
  43. 3.367037    -3.184789
  44. -2.121479   -4.232586
  45. 2.329546    3.179764
  46. -3.284816   3.273099
  47. 3.091414    -3.815232
  48. -3.762093   -2.432191
  49. 3.542056    2.778832
  50. -1.736822   4.241041
  51. 2.127073    -2.983680
  52. -4.323818   -3.938116
  53. 3.792121    5.135768
  54. -4.786473   3.358547
  55. 2.624081    -3.260715
  56. -4.009299   -2.978115
  57. 2.493525    1.963710
  58. -2.513661   2.642162
  59. 1.864375    -3.176309
  60. -3.171184   -3.572452
  61. 2.894220    2.489128
  62. -2.562539   2.884438
  63. 3.491078    -3.947487
  64. -2.565729   -2.012114
  65. 3.332948    3.983102
  66. -1.616805   3.573188
  67. 2.280615    -2.559444
  68. -2.651229   -3.103198
  69. 2.321395    3.154987
  70. -1.685703   2.939697
  71. 3.031012    -3.620252
  72. -4.599622   -2.185829
  73. 4.196223    1.126677
  74. -2.133863   3.093686
  75. 4.668892    -2.562705
  76. -2.793241   -2.149706
  77. 2.884105    3.043438
  78. -2.967647   2.848696
  79. 4.479332    -1.764772
  80. -4.905566   -2.911070

测试代码:

test_kmeans.py

[python] view plaincopy
  1. #################################################
  2. # kmeans: k-means cluster
  3. # Author : zouxy
  4. # Date   : 2013-12-25
  5. # HomePage : http://blog.csdn.net/zouxy09
  6. # Email  : zouxy09@qq.com
  7. #################################################
  8. from numpy import *
  9. import time
  10. import matplotlib.pyplot as plt
  11. ## step 1: load data
  12. print "step 1: load data..."
  13. dataSet = []
  14. fileIn = open('E:/Python/Machine Learning in Action/testSet.txt')
  15. for line in fileIn.readlines():
  16. lineArr = line.strip().split('\t')
  17. dataSet.append([float(lineArr[0]), float(lineArr[1])])
  18. ## step 2: clustering...
  19. print "step 2: clustering..."
  20. dataSet = mat(dataSet)
  21. k = 4
  22. centroids, clusterAssment = kmeans(dataSet, k)
  23. ## step 3: show the result
  24. print "step 3: show the result..."
  25. showCluster(dataSet, k, centroids, clusterAssment)

运行的前后结果是:

不同的类用不同的颜色来表示,其中的大菱形是对应类的均值质心点。

四、算法分析

k-means算法比较简单,但也有几个比较大的缺点:

(1)k值的选择是用户指定的,不同的k得到的结果会有挺大的不同,如下图所示,左边是k=3的结果,这个就太稀疏了,蓝色的那个簇其实是可以再划分成两个簇的。而右图是k=5的结果,可以看到红色菱形和蓝色菱形这两个簇应该是可以合并成一个簇的:

(2)对k个初始质心的选择比较敏感,容易陷入局部最小值。例如,我们上面的算法运行的时候,有可能会得到不同的结果,如下面这两种情况。K-means也是收敛了,只是收敛到了局部最小值:

(3)存在局限性,如下面这种非球状的数据分布就搞不定了:

(4)数据库比较大的时候,收敛会比较慢。

k-means老早就出现在江湖了。所以以上的这些不足也被世人的目光敏锐的捕捉到,并融入世人的智慧进行了某种程度上的改良。例如问题(1)对k的选择可以先用一些算法分析数据的分布,如重心和密度等,然后选择合适的k。而对问题(2),有人提出了另一个成为二分k均值(bisecting k-means)算法,它对初始的k个质心的选择就不太敏感,这个算法我们下一个博文再分析和实现。

五、参考文献

[1] K-means聚类算法

[2] 漫谈 Clustering (1): k-means

机器学习K均值聚类 python相关推荐

  1. 机器学习-*-K均值聚类及代码实现

    KMeans聚类 在聚类算法中,最出名的应该就是k均值聚类(KMeans)了,几乎所有的数据挖掘/机器学习书籍都会介绍它,有些初学者还会将其与KNN等混淆.k均值是一种聚类算法,属于无监督学习的一种, ...

  2. 机器学习--K均值聚类

    机器学习--聚类 一.无监督学习 二.KMeans聚类 2.1 概览 2.2 理论介绍 2.2.1 模型 2.2.2 策略 2.2.3 算法 2.3 案例讲解 2.4 Python实现 2.4.1 导 ...

  3. 机器学习-K均值聚类(python3代码实现)

    K均值聚类 哈尔滨工程大学-537 算法原理: K均值是发现给定数据集的 k k k个簇的算法.簇个数k" role="presentation" style=" ...

  4. 机器学习--k均值聚类(K-means)

    数据分析入门与实战  公众号: weic2c 1.摘要 分类作为一种监督学习方法,要求必须事先明确知道各个类别的信息,并且断言所有待分类项都有一个类别与之对应.但是很多时候上述条件得不到满足,尤其是在 ...

  5. 机器学习 K均值聚类(K-means) 鸢尾花数据集

    聚类的目标是使聚类后的各个簇,具有簇内聚合,簇间分离的特点. 如何度量簇之间,簇内样本之间的差异度?常用距离计算,最常用的是"闵可夫斯基距离"(Minkowski distance ...

  6. 机器学习-k均值聚类算法-k_means原理14

    非监督学习

  7. python机器学习库sklearn——k均值聚类

    全栈工程师开发手册 (作者:栾鹏) python数据挖掘系列教程 k均值聚类的相关的知识内容可以参考 http://blog.csdn.net/luanpeng825485697/article/de ...

  8. 机器学习算法与Python实践之k均值聚类(k-means)

    机器学习算法与Python实践之(五)k均值聚类(k-means) zouxy09@qq.com http://blog.csdn.net/zouxy09 机器学习算法与Python实践这个系列主要是 ...

  9. python机器学习案例系列教程——k均值聚类、k中心点聚类

    全栈工程师开发手册 (作者:栾鹏) python数据挖掘系列教程 上一篇我们学习了层次聚类.层次聚类只是迭代的把最相近的两个聚类匹配起来.并没有给出能给出多少的分组.今天我们来研究一个K均值聚类.就是 ...

最新文章

  1. C# 读取 Access
  2. Linux总线驱动-02: struct bus_type 结构体
  3. python换中包_在Linux中替换已安装的python包中的源代码
  4. java postconstruct_spring框架中@PostConstruct的实现原理
  5. oracle中的sga和pga
  6. nginx 非socket代理配置
  7. C# 中 for和foreach 性能比较,提高编程性能
  8. insert into user mysql root_跳过授权表登录后使用insert into创建root权限用户
  9. java多线程中出现的异常分别有哪些_java多线程试题
  10. python-pandas 高级功能(通过学习kaggle案例总结)
  11. 【转-整理】Struts2中package,action,result,method配置详解
  12. 12/27复习有感--整环的整除性
  13. 文献阅读——金属伪影减少MAR问题
  14. mysql 100个标题_100个超强吸引人的标题100个吸引人的标题
  15. 通过js实现页面布局
  16. 微服务电商实战(十一)搭建vue项目对接注册登陆接口,解决跨域问题,使用七牛云实现头像上传
  17. data.frame中的NA值怎么替换成0 把na变为0 把na变为想要的数 改变na 是否为na is.na()是否为null is.null() is.null() 删除去掉NA的行
  18. 为什么要标准化用户故事格式?
  19. 细说2类骗人的物联卡,满满的都是套路!
  20. 【UCIe】UCIe 软件配置

热门文章

  1. nginx 502错误原因和解决办法总结
  2. 502 Bad Gateway 错误的可能原因
  3. PB反编译大师,PB反编译升级版本
  4. 专访Pebble智能手表创始人:Pebble是这样诞生的
  5. 使用 SpringBoot + Mybatis 实现的仿思否风格论坛
  6. 空间三角形_教师招聘试讲-小学数学 三角形内角和 教案
  7. mysql 打开mdf文件怎么打开_mdf 数据库文件怎么打开
  8. java sdk 安装_Java SDK下载安装环境配置+AndroidKiller配置教程
  9. 管理者应该掌握的八项基本技能
  10. 【转】CAN总线终端电阻的作用?为什么是120Ω?为什么是0.25W?*****