MeanShift算法原理及其python自定义实现

  • MeanShift算法原理
  • MeanShift python实现
    • 实现思路:
    • 代码:
    • 运行结果:

MeanShift算法原理

Meanshift是聚类中的一种经典方法,思想简单,用途广泛
Meanshift基于这样的事实,一个类的中心处 点的空间密度 是最大的,因此给定一个点,只要沿着密度方向,由稀疏指向稠密就可以找到这个点所在类的中心点。

Meanshift的核心思想是: 给定一个数据点,在其周围一定的Region of interest内,计算这个Region的质心,由原来的点指向这个计算出来的质心的向量被称为Mean Shift vector,如下图中黄色向量表示的那样。
接下来,将原来Region中心点的坐标置为质心的坐标(这个坐标是计算出来的,并不一定恰好落在原来的数据点上),在以质心坐标为中心的Region中继续计算新的质心
直到Mean Shift vector的大小小于设定阈值的时候停止迭代

每一轮迭代中都对每一个点进行上面的操作,等到所有的点都收敛于有限的几个中心时,算法结束。
该算法具有很快的收敛速度。
数学推导参考:http://www.cnblogs.com/liqizhou/archive/2012/05/12/2497220.html
代码实现还可参考:https://blog.csdn.net/jinshengtao/article/details/30258833

MeanShift python实现

实现思路:

  1. 构建距离度量函数
  2. 构建Gaussian概率密度函数,以实现局部Region操作
  3. 构建MeanShift类
    (1) 点移动函数:对输入的一个点,计算在其Gaussian局部范围的点的影响下质心移动的新位置
    (2) 聚类号分配函数:对所有点移动后的结果进行归类
    (3) 入口函数:一些循环控制等

Tips:显然每个点的第一次移动对这个点的类的确定是至关重要的,尤其是那些在类边缘处类别定义比较模糊的位置的点。因为马太效应,在以后的移动中,这个点被质心吸引的力会更大

代码:

'''
#Implement mean-shift algorithm only using basic python
#Author:Leo Ma
#For csmath2019 assignment3,ZheJiang University
#Date:2019.04.23
'''
import numpy as np
import random
DISTANCE_THRESHOLD = 1e-4
CLUSTER_THRESHOLD = 1e-1#define distance metric
def distance(a,b):return np.linalg.norm(np.array(a)-np.array(b))#distance=(x-u)**2
def Gaussian_kernal(distance,sigma):return (1/(sigma*np.sqrt(2*np.pi)))*np.exp(-0.5*distance/(sigma**2))#MeanShift类
class MeanShift(object):def __init__(self,kernal = Gaussian_kernal):self.kernal = kernal##计算center_point点移动后的坐标def shift_points(self,center_point,whole_points,Gaussian_sigma):shifting_px = 0.0shifting_py = 0.0sum_weight = 0.0for each_point in whole_points:#遍历每一个点dis = distance(center_point,each_point)#计算当前点与中心点的距离Gaussian_weight = self.kernal(dis,Gaussian_sigma)#计算当前点距离中心点的高斯权重#所有向量相加shifting_px += Gaussian_weight * each_point[0]shifting_py += Gaussian_weight * each_point[1]sum_weight += Gaussian_weight#归一化shifting_px /= sum_weightshifting_py /= sum_weightreturn [shifting_px,shifting_py]#根据shift之后的点坐标shifting_points获得聚类iddef cluster_points(self,shifting_points):clusterID_points = []#用于存放每一个点的类别号cluster_id=0#聚类号初始化为0cluster_centers = []#聚类中心点for i,each_point in enumerate(shifting_points):#遍历处理每一个点if i==0:#如果是处理的第一个点clusterID_points.append(cluster_id)#将这个点归为初始化的聚类号(0)cluster_centers.append(each_point)#将这个点看作聚类中心点cluster_id+=1#聚类号加1else:#处理的不是第一个点的情况for each_center in cluster_centers:#遍历每一个聚类中心点dis = distance(each_center,each_point)#计算当前点与该聚类中心点的距离if dis < CLUSTER_THRESHOLD:#如果距离小于聚类阈值clusterID_points.append(cluster_centers.index(each_center))#就将当前处理的点归为当前中心点同类(聚类号赋值)if(len(clusterID_points)<i+1):#如果上面那个for,所有的聚类中心点都没能收纳一个点,说明是时候开拓一个新类了clusterID_points.append(cluster_id)#把当前点置为一个新类,因为此时的cluster_idx以前谁都没用过cluster_centers.append(each_point)#将这个点作为这个这个新聚类的中心点cluster_id+=1#聚类号加1以备后用return clusterID_points#whole_points:输入的所有点#Gaussian_sigma:Gaussian核的sigmadef fit(self,whole_points,Gaussian_sigma):shifting_points = np.array(whole_points)need_shifting_flag = [True] * np.shape(whole_points)[0]#每一个点初始都标记为需要shiftingwhile True:distance_max = 0.0#每一轮迭代都对每一个点进行处理for i in range(0,np.shape(whole_points)[0]):if not need_shifting_flag[i]:#如果这个点已经被标记为不需要继续shifting,就continuecontinueshifting_point_init = shifting_points[i].copy()#将初始的第i个点的坐标备份一下#shifting_points[i]由第i个点的坐标更新为第i个点移动后的坐标shifting_points[i] = self.shift_points(shifting_points[i],whole_points,Gaussian_sigma)#计算第i个点移动的距离dis = distance(shifting_point_init,shifting_points[i])#如果该点移动的距离小于停止阈值,标记need_shifting_flag[i]为False,下一轮迭代对该点不做处理need_shifting_flag[i] = dis > DISTANCE_THRESHOLD#本轮迭代中最大的距离存储到distance_max中distance_max = max(distance_max,dis)#如果在一轮迭代中,所有点移动的最大距离都小于停止阈值,就停止迭代if(distance_max < DISTANCE_THRESHOLD):break#根据shift之后的点坐标shift_points获得聚类idcluster_class_id = self.cluster_points(shifting_points.tolist())return shifting_points,cluster_class_idfrom sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt #按照均匀分布随机产生n个颜色,每个颜色都由R、G、B三个分量表示
def colors(n):ret = []for i in range(n):ret.append((random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)))return retdef main():centers = [[0, 1], [-1, 2], [1, 2], [-2.5, 2.5], [2.5,2.5], [-4,1], [4,1], [-3,-1], [3,-1], [-2,-3], [2,-3], [0,-4]]#设置一些中心点X, _ = make_blobs(n_samples=300, centers=centers, cluster_std=0.3)#产生以这些中心点为中心,一定标准差的n个samplesmean_shifter = MeanShift()shifted_points, mean_shift_result = mean_shifter.fit(X, Gaussian_sigma=0.3)#Gaussian核设置为0.5,对X进行mean_shiftnp.set_printoptions(precision=3)print('input: {}'.format(X))print('assined clusters: {}'.format(mean_shift_result))color = colors(np.unique(mean_shift_result).size)for i in range(len(mean_shift_result)):plt.scatter(X[i, 0], X[i, 1], color = color[mean_shift_result[i]])plt.scatter(shifted_points[i,0],shifted_points[i,1], color = 'r')plt.xlabel("2018.06.13")plt.savefig("result_meanshift.png")plt.show()if __name__ == '__main__':main()

运行结果:

MeanShift算法原理及其python自定义实现相关推荐

  1. LM(Levenberg–Marquardt)算法原理及其python自定义实现

    LM算法原理及其python自定义实现 LM(Levenberg–Marquardt)算法原理 LM算法python实现 实现步骤: 代码: 运行结果: LM(Levenberg–Marquardt) ...

  2. 手把手教你EMD算法原理与Python实现(更新)

    Rose今天主要介绍一下EMD算法原理与Python实现.关于EMD算法之前介绍过<EMD算法之Hilbert-Huang Transform原理详解和案例分析>, SSVEP信号中含有自 ...

  3. python kmeans聚类 对二维坐标点聚类_Kmeans均值聚类算法原理以及Python如何实现

    第一步.随机生成质心 由于这是一个无监督学习的算法,因此我们首先在一个二维的坐标轴下随机给定一堆点,并随即给定两个质心,我们这个算法的目的就是将这一堆点根据它们自身的坐标特征分为两类,因此选取了两个质 ...

  4. 匈牙利算法原理与Python实现

    匈牙利算法原理与Python实现 今天学习一个新的算法-匈牙利算法,用于聚类结果分析,先用图表示我当前遇到的问题: 这两列值是我用不同算法得到的聚类结果,从肉眼可以看出第一列聚类为0的结果在第二列中对 ...

  5. Apriori 算法原理以及python实现详解

    Apriori 算法原理以及python实现 ​ Apriori算法是第一个关联规则挖掘算法,也是最经典的算法.它利用逐层搜索的迭代方法找出数据库中项集的关系,以形成规则,其过程由连接(类矩阵运算)与 ...

  6. PageRank算法原理与Python实现

    本文转载自https://blog.csdn.net/ten_sory/article/details/80927738 PageRank算法原理与Python实现 PageRank算法,即网页排名算 ...

  7. 朴素贝叶斯算法原理以及python实现

    朴素贝叶斯 一.朴素贝叶斯概述 二.概率论知识 三.朴素贝叶斯算法原理 四.参数估计方法 五.示例分析 六.拉普拉斯平滑修正 七.算法优缺点 八.python实现 8.1 sklearn贝叶斯 8.2 ...

  8. KNN算法原理与python实现

    文章目录 KNN算法原理 KNN算法介绍 KNN算法模型 距离度量 k值的选择 分类的决策规则 KNN算法python实现 手写数字识别 sklearn代码实现 参考文献 KNN算法原理 KNN算法介 ...

  9. EMD算法原理与python实现

    目录 简介 EMD算法原理 python实现EMD案例 本教程为脑机学习者Rose发表于公众号:脑机接口社区 .QQ交流群:903290195 简介 SSVEP信号中含有自发脑电和大量外界干扰信号,属 ...

  10. 统计学习方法笔记(一)-k近邻算法原理及python实现

    k近邻法 k近邻算法 算法原理 距离度量 距离度量python实现 k近邻算法实现 案例地址 k近邻算法 kkk近邻法(kkk-NN)是一种基本分类和回归方法. 算法原理 输入:训练集 T={(x1, ...

最新文章

  1. python与excel做数据可视化-python做可视化数据分析,究竟怎么样?
  2. ABP从入门到精通(2):aspnet-zero-core 使用MySql数据库
  3. art-template入门(四)之调试
  4. I/0口输入输出实验 将P1口的某一位作为输入使用,连接一个按键,当按键按下时使发光二极管亮,否则发光二极管熄灭
  5. Linux文件的复制、删除和移动命
  6. ios6.0,程序为横屏,出现闪退
  7. 蚂蚁森林:不存在网友反馈的“没有造林”的情况 干旱造成梭梭矮小
  8. python文件拆分_python – 在几个文件中拆分views.py.
  9. 计算机实战项目之 [含论文+开题报告+源码等]SSM图书馆预约占座系统[包运行成功]
  10. wincc7.4安装记录
  11. SAP ABAP开发入门-徐春波-专题视频课程
  12. imageAI基本使用
  13. 今日头条推荐算法相关博客集合
  14. windows10下F1-F11快捷键及window+Dor+E快捷键打开关闭控制
  15. 从MDK分散加载文件学习STM32启动流程
  16. 下载tomcat最新版本
  17. CSDN学院专属推荐--从Python小白走向Python工程师你只需要它!
  18. 医疗信息化与医院评审
  19. 2022-4-14 基于单片机的汽车灯
  20. (转)协议森林10 魔鬼细节 (TCP滑窗管理)

热门文章

  1. iPhone--卡贴是什么
  2. 影子系统、还原精灵、冰点还原优缺点比较
  3. 穿越Java - 基础篇 第三章 面向对象介绍 | 第4节 成员变量和局部变量
  4. pivotal公司简介
  5. wjh2005:GitHub 上有哪些完整的 iOS-App 源码值得参考?
  6. 阿里巴巴国际站—产品运营工作台操作指南
  7. WorldFirst万里汇推出港币和离岸人民币账户!
  8. matlab 去除水印,初試 Matlab 之去除水印
  9. 聚宽-彼得·林奇的成功投资策略
  10. [网络安全提高篇] 一一六.恶意代码同源分析及BinDiff软件基础用法