一、meanshift

均值漂移就是把指定的样本点沿着密度上升的方向移向高密度区域。这里可以用矢量加法的几何意义来理解。参考博文Mean Shift 聚类算法
meanshift为
Mr(x)=1k∑xi∈Sr(x)(xi−x)M_r(x)=\frac{1}{k}\sum\limits_{x_i\in S_r(x)}(x_i-x)Mr​(x)=k1​xi​∈Sr​(x)∑​(xi​−x)
其中Sr(x)={y:∥y−x∥<=r}S_r(x)=\{y:\|y-x\|<=r\}Sr​(x)={y:∥y−x∥<=r},kkk是Sr(x)S_r(x)Sr​(x)中点的个数。
更新x=x+Mr(x)更新x=x+M_r(x)更新x=x+Mr​(x)

实现上图的python代码:

from sklearn.datasets import make_blobsX1,y1=make_blobs(n_samples=200, n_features=2, centers=[[1.2, 1.2]],cluster_std=[[.1]], random_state=9)
plt.scatter(X1[:,0],X1[:,1],c=y1)
def meanshift(point,X,r,eps):pointNeigh=X[np.linalg.norm(X-point,axis=1)<=r]shift=np.sum(pointNeigh-point,axis=0)/len(pointNeigh)points=[point]while np.linalg.norm(shift)>eps:point=point+shiftpointNeigh=X[np.linalg.norm(X-point,axis=1)<=r]shift=np.sum(pointNeigh-point,axis=0)/len(pointNeigh)points.append(point)return pointspoints=meanshift(np.array([1,1]),X1,0.1,0.000001)
points=np.array(points)
plt.figure(figsize=(10,6))
plt.scatter(X1[:,0],X1[:,1],c=y1)
plt.plot(points[:,0],points[:,1],'r<--',markersize=8)
plt.annotate(r'$start$', xy = (1, 1), xytext = (1, 0.9),arrowprops = {'headwidth': 10, # 箭头头部的宽度'headlength': 5, # 箭头头部的长度'width': 4, # 箭头尾部的宽度'facecolor': 'r', # 箭头的颜色'shrink': 0.1, # 从箭尾到标注文本内容开始两端空隙长度},family='Times New Roman',  # 标注文本字体为Times New Romanfontsize=18,  # 文本大小为18fontweight='bold',  # 文本为粗体color='green',  # 文本颜色为红色# ha = 'center' # 水平居中
)
plt.annotate(r'$end$', xy = (points[-1][0],points[-1][1] ), xytext = (points[-1][0], points[-1][1]-0.1),arrowprops = {'headwidth': 10, # 箭头头部的宽度'headlength': 5, # 箭头头部的长度'width': 4, # 箭头尾部的宽度'facecolor': 'r', # 箭头的颜色'shrink': 0.1, # 从箭尾到标注文本内容开始两端空隙长度},family='Times New Roman',  # 标注文本字体为Times New Romanfontsize=18,  # 文本大小为18fontweight='bold',  # 文本为粗体color='green',  # 文本颜色为红色# ha = 'center' # 水平居中)

二、meanshift聚类

1.算法流程

需要给定的参数
bandwidth----带宽
Mindist—漂移均值收敛的阈值
center_distance----合并簇的阈值

第一阶段—聚类

需要初始化的集合
创建一个空的中心点集centers,用于存放各个簇所对应的中心点
创建一个空的集合results,用于存放各个簇所包含的点
1.将数据集X的点都标记为未访问unvisited;
2.从数据集X中取出一个点,记为point,判断它是否属于unvisited,如果属于,将其从unvisited删除,并进行第3步,否则重新从X取点;
3.创建一个空的集合result_point,用来存放point对应的簇中所包含的点;
4.从X中找到位于以point为中心,带宽为bandwidth之内的点,用neighbor表示;
5.判断neighbor是否为空集,则返回第2步,否则将neighbor中全部点加入到result_point中,并将这些点从unvisited中删除;
6.计算point在neighbor上的漂移均值meanshift;
7.判断shift是否大于给定的阈值Min-dist,如果大于,更新点point=point+shift并返回第4步,否则将point加入到centers,result_point加入到results中,再返回第2步//

第二阶段–合并

由第一阶段得到centers和相应的results。由于centers中有一些中心点之间的距离可能很小,我们需要将其所对应的簇合并成一个簇,并更新中心点。

第三阶段–分组

由于X中的点可能位于多个簇(result_point)中,我们需要确定其到底属于哪一个簇。
统计每个点在各个簇中所出现点次数,次数最高的簇就是该点最终所属的簇。

2.python代码

import math
def euclidean_dist(self,pointA,pointB):""""""if pointA.shape==pointB.shape:##pointA和pointB是两个点return np.linalg.norm(pointA-pointB)else:##pointA和pointB中有一个是点集return np.linalg.norm(pointA-pointB,axis=1)def gaussian_kernel(dist,bandwidth):"""dist---欧式距离bandwidth---带宽"""weight=(1/(np.sqrt(2*math.pi)*bandwidth))*np.exp(-(dist**2)/(2*np.power(bandwidth,2)))return weight
def compute_shift(pointNeigh,point,bandwidth,kernel):if kernel==False:shift=np.sum(pointNeigh-point,axis=0)/len(pointNeigh)else:dists=np.linalg.norm(pointNeigh-point,axis=1)point_weight=gaussian_kernel(dists,bandwidth)point_weight=point_weight.reshape(len(point_weight),1)shift=np.sum((pointNeigh-point)*point_weight,axis=0)/np.sum(point_weight) return shiftdef Clustering(X,bandwidth,MinDist,kernel=False):"""init_result---各个簇所包含的点的索引,init_centers--中心点"""unvisited=list(np.arange(len(X))) #未访问点的索引init_result=[]  #用于存放结果init_centers=[]for i in range(len(X)):point=X[i]if i in unvisited:unvisited.remove(i) #删除以访问点c_i=[]indexs=np.where(np.linalg.norm(X-point,axis=1)<=bandwidth)[0]pointNeigh=X[indexs]   #点point点bandwidth内搭点c_i.extend(indexs)  #把这些点加入point为中心点簇中for j  in list(indexs):if j in unvisited:unvisited.remove(j)shift=compute_shift(pointNeigh,point,bandwidth,kernel)while np.linalg.norm(shift)>MinDist: #判断shift的大小point=point+shiftindexs=np.where(np.linalg.norm(X-point,axis=1)<=bandwidth)[0]if len(indexs)==0:breakpointNeigh=X[indexs]c_i.extend(indexs)for k  in list(indexs):if k in unvisited:unvisited.remove(k)shift=compute_shift(pointNeigh,point,bandwidth,kernel)init_centers.append(point)init_result.append(c_i)return init_result,init_centersdef merge(init_centers,init_result,center_distance):final_centers=[init_centers[0]]final_result=[init_result[0]]k=len(init_centers)i=1stats=Truewhile i<k:for j in range(len(final_centers)):if np.linalg.norm(init_centers[i]-final_centers[j])<=center_distance:final_centers[j]=(init_centers[i]+final_centers[j])/2final_result[j].extend(init_result[i])stats=Falsebreakif stats==True:final_centers.append(init_centers[i])final_result.append(init_result[i])i+=1stats=Truereturn final_result,final_centers   def groupPoint(X,final_result,final_centers):result_table=pd.DataFrame(np.zeros((len(X),len(final_centers))),index=range(len(X)),columns=range(len(final_centers)))for i in range(len(final_centers)):clusterI_index=final_result[i]for j in range(len(clusterI_index)):result_table.iloc[clusterI_index[j],i]+=1result_id=np.argmax(result_table.values,axis=1)   return result_id
def plot(X,result_id,final_centers):for i in range(len(final_centers)):plt.scatter(X[result_id==i][:,0],X[result_id==i][:,1])

三、测试

1.数据集一

from sklearn.datasets import make_blobs
X,y=make_blobs(n_samples=400, n_features=2, centers=[[1.2, 1.2],[-1,-1]],cluster_std=[0.2,0.3], random_state=9)
plt.scatter(X[:,0],X[:,1])

init_result,init_centers=Clustering(X,0.4,0.00001,kernel=True)
final_result,final_centers=merge(init_centers,init_result,1.2)
result_id=groupPoint(X,final_result,final_centers)
plot(X,result_id,final_centers)

2.数据集二

X_d=np.array(data)
init_result,init_centers=Clustering(X_d,4,0.00001,kernel=True)
final_result,final_centers=merge(init_centers,init_result,1.2)
result_id=groupPoint(X_d,final_result,final_centers)
plot(X_d,result_id,final_centers)

3. 数据集三


import matplotlib.pyplot as plt
from sklearn.datasets import make_moons x,y = make_moons(n_samples=1500, shuffle=True,noise=0.06, random_state=None)
plt.scatter(x[:,0], x[:,1], c=y, s=7)
plt.show()

init_result,init_centers=Clustering(x,0.5,0.00001,kernel=True)
final_result,final_centers=merge(init_centers,init_result,1.)
result_id=groupPoint(x,final_result,final_centers)
plot(x,result_id,final_centers)


经过反复的调center_distance这个参数,都没有达到理想的聚类结果。

三、带有核函数的meanshift聚类

带有核函数的meanshift
m(x)=∑s∈Sg(∥s−xh∥2)(s−x)∑s∈Sg(∥s−xh∥2)m(x)=\frac{\sum\limits_{s\in S}g(\|\frac{s-x}{h}\|^2)(s-x)}{\sum\limits_{s\in S}g(\|\frac{s-x}{h}\|^2)}m(x)=s∈S∑​g(∥hs−x​∥2)s∈S∑​g(∥hs−x​∥2)(s−x)​
更新中心坐标:
x=m(x)+xx=m(x)+xx=m(x)+x

四、疑点

如何能把用于合并簇的阈值参数取消掉。

Meanshift均值漂移聚类算法相关推荐

  1. 【主色提取】模糊C均值(FCM )聚类算法和彩色图像快速模糊C均值( CIQFCM )聚类算法

    系列文章目录 第一章 主色提取入门之FCM 和 CIQFCM 目录 系列文章目录 前言 一.FCM 聚类算法 1 基本思想 ​编辑 2 FCM 的缺陷 二.CIQFCM 聚类算法 1 集群空间映射 1 ...

  2. K均值(K-means)聚类算法原理与代码详解

    0. 算法原理: 上述过程简单描述: a: 初始数据 b: 选择质点 c: 根据质点划分 d: 求均值,更新质心点 e: 划分 f: 更新质心点 1. 代码实现: # K means 教程# 0. 引 ...

  3. 均值漂移(mean shift )聚类算法Matlab实现详解

    Mean shift 算法是基于核密度估计的爬山算法,可用于聚类.图像分割.跟踪等,其在声呐图像数据处理也有广泛的应用,笔者在网上找了一遍也没有找到关于Mean shift的matlab实现代码,找到 ...

  4. [Python从零到壹] 十三.机器学习之聚类算法四万字总结全网首发(K-Means、BIRCH、树状聚类、MeanShift)

    欢迎大家来到"Python从零到壹",在这里我将分享约200篇Python系列文章,带大家一起去学习和玩耍,看看Python这个有趣的世界.所有文章都将结合案例.代码和作者的经验讲 ...

  5. a*算法的优缺点_五种聚类算法一览与python实现

    大家晚上好,我是阿涛. 今天的主题是聚类算法,小结一下,也算是炒冷饭了,好久不用真忘了. 小目录: 1.K-means聚类2.Mean-Shift聚类3.Dbscan聚类4.层次聚类5.GMM_EM聚 ...

  6. 机器学习7—聚类算法之K-means算法

    K-均值算法(K-means) 前言 聚类算法模型 常见的聚类算法 一.K-means算法描述 二.示例说明K-means算法流程 三.K-means算法中Kmean()函数说明 四.K-means算 ...

  7. 【聚类算法】常见的六大聚类算法

    转自:https://blog.csdn.net/Katherine_hsr/article/details/79382249 感谢博主小姐姐 算法步骤: (1) 首先我们选择一些类/组,并随机初始化 ...

  8. 聚类算法的缺点_常用聚类算法

    一.K-Means 算法步骤: (1) 首先我们选择一些类/组,并随机初始化它们各自的中心点.中心点是与每个数据点向量长度相同的位置.这需要我们提前预知类的数量(即中心点的数量). (2) 计算每个数 ...

  9. dbscan算法中 参数的意义_常用聚类算法

    一.K-Means 算法步骤: (1) 首先我们选择一些类/组,并随机初始化它们各自的中心点.中心点是与每个数据点向量长度相同的位置.这需要我们提前预知类的数量(即中心点的数量). (2) 计算每个数 ...

  10. 常见的六大聚类算法:转自:https://blog.csdn.net/Katherine_hsr/article/details/79382249

    1.K-Means(K均值)聚类 算法步骤: (1) 首先我们选择一些类/组,并随机初始化它们各自的中心点.中心点是与每个数据点向量长度相同的位置.这需要我们提前预知类的数量(即中心点的数量). (2 ...

最新文章

  1. 检察机关计算机涉密案例,三原县检察院:扎实做好新时代检察机关保密工作
  2. 为什么PUE只说明了数据中心能效的一部分?
  3. androidsdktools安装_如何命令行安装Android SDK Build Tools(构建工具)?
  4. Pip安装加速-解决pip下载速度慢的问题
  5. 体制内工作了十几年,跳出来学嵌入式合适吗?
  6. Android studio http 代理设置
  7. bootstrap插件(对话框)bootbox参数和自定义弹出框宽度设置
  8. 30 年间,软件开发行业为何 Bug 纷飞?
  9. dev代码拷贝中文乱码的解决方案
  10. c语言名著摘抄——语法及实例
  11. 程序包androidx.support.annotation不存在/import android.support.annotation.NonNull;报错
  12. oracle 审计变换表空间_Oracle审计日志和审计策略数据表迁移到新表空间
  13. 无需Root实现Android手机屏幕流畅投影到电脑进行演示(附软件下载)
  14. 更新至2021各省份上传服务器地址(航信、百旺、税务UK)
  15. 牛客练习赛73 B.遥远的记忆(并查集)
  16. draw.io 插入 word
  17. 项目管理中的成本绩效方法
  18. Cosmos 与 PolkaDot 互解
  19. 这份1658页的Java面试核心突击讲,成功让我上岸阿里
  20. 参会指南!POW'ER 2020上海峰会完整议程周边活动

热门文章

  1. java 遍历json串_Java遍历json字符串取值的实例
  2. Ubuntu安装Jenkins
  3. Tableau入门(一):条形图、堆积图、直方图绘制
  4. 蛇形天线设计和分析(转)
  5. 牛客网Python笔试技巧、单行多行输入方法以及代码调试技巧
  6. 自定义更改虚拟机中Ubuntu的ip地址
  7. 惠普103a微信打印服务器,惠普发布微信打印小程序,丰富云打印解决方案
  8. DWL文件能改成DWG文件打开吗?
  9. S7-1500PLC仿真
  10. 图形学卡通人物绘制以及交互操作