MeanShift

  • Mean Shift聚类算法
    • 背景介绍
    • 效果展示
    • 模型概览
      • 模型定义
      • 训练过程
    • 数据集
      • 数据集介绍
    • 训练
      • 01 - 读入数据
      • 02 - 求解meanshift向量
      • 03 - 聚类
      • 04 - 绘图
      • 05 - 主函数
      • 运行结果
    • 总结
    • 参考文献

Mean Shift聚类算法

Mean Shift算法是一种无参密度估计算法,Mean Shift算法在很多领域都有成功应用,例如图像平滑、图像分割、物体跟踪等,这些属于人工智能里面模式识别或计算机视觉的部分,另外也包括常规的聚类应用。

背景介绍

通过名字就可以看到该算法的核心,mean(均值),shift(偏移),Mean Shift算法所做的工作是找到数据概率密度最大的区域。以二维来说明可能更容易理解,下图中的很多的红点就是我们的样本特征点,meanshift会依次选中每一个点为圆心(质心),在选中一个点后,然后以半径R画一个圆,然后落在这个圆中的每一个点与圆心都会构成一个向量,把所有这些向量相加,我们会得到一个向量,就是下图中用黄色箭头表示的向量,这个向量就是meanshift向量。然后再以这个meanshift向量的终点为圆心,继续上述过程,又可以得到一个meanshift向量,然后不断地继续重复这样的过程,我们可以得到很多连续的meanshift向量,这些向量首尾相连,最终迭代到收敛,在某处停下来。最后的那个meanshift向量的终点就是最终得到的结果(最终质心),如下图:

效果展示

我们使用meanshift算法对商城消费者的数据进行聚类,下面的散点图为聚类结果,其中每个点的横坐标代表消费者的消费积分,纵坐标代表消费者的年收入,星形标记代表不同的聚类。

模型概览

模型定义

在商城统计的消费者数据中,与消费者有关的数据有4条,分别为性别,年龄,年收入,消费产生的积分。现对消费者进行聚类,我们选取年收入与消费积分数据作为样本特征点,在2维空间n个样本点中xi,i=1,⋯,nx_{i}, i=1, \cdots, nxi​,i=1,⋯,n,对于其中的一个样本的候选质心xxx,其中N(x)N\left(x\right)N(x)是围绕xxx 周围一个给定距离范围内的样本邻域, 通过计算其他所有样本点yyy与xxx的距离,小于区域半径rrr的点即表示落在区域内,N(xi)N\left(x_{i}\right)N(xi​)定义为下:

Nh(x)=(y∣(y−x)(y−x)T⩽r2)N_{h}(x)=\left(y \mid(y-x)(y-x)^{T} \leqslant r^{2}\right) Nh​(x)=(y∣(y−x)(y−x)T⩽r2)

再通过更新质心的候选位置,到达最终质心,这些侯选位置通常是所选定区域内点的均值,则对于xxx点,其Mean Shift向量的基本形式为:

Mh(x)=1k∑xi∈Nh(xi−x)M_{h}(x)=\frac{1}{k} \sum_{x_{i} \in N_{h}}\left(x_{i}-x\right) Mh​(x)=k1​xi​∈Nh​∑​(xi​−x)

∥mh(x)−x∥<ε\left\|m_{h}(x)-x\right\|<\varepsilon ∥mh​(x)−x∥<ε

MhM_{h}Mh​ 是均值偏移向量(mean shift vector), 该向量是所有质心中指向点密度增加最多的区域的偏移向量,kkk表示区域内的样本点数量。在不断迭代的过程中,质心不断更新,当更新后的质心与原质心变化小于一定阈值ε\varepsilonε时(此值在以下实现算法中定义为区域半径),发生收敛,结束循环。对其他样本点重复以上步骤,可求得每个样本的最终质心。有了质心即可根据样本点密度来进行分类了。

训练过程

通过以下几个步骤进行模型训练

  1. 在未被标记的数据点中依次选择作为起始中心点C。
  2. 以C为质心作半径为radius的圆,得到区域中出现的所有数据点,则设这些点同属于一个聚类A。同时在该聚类中记录数据点出现的频率次数加1。
  3. 以C为中心点,得到从C开始到区域内每个样本的向量,将这些向量相加,得到向量Shift。C沿着Shift的方向移动到Shift的终点。
  4. 重复步骤2、3,直到shift向量不在发生变化(迭代到收敛),记住此时的C。且这个迭代过程中遇到的点都归类到A。
  5. 如果收敛时当前类A的质心与其它已经存在的类A2质心的距离小于半径(也可设为其他阈值),那么把A2和A合并,数据点出现次数也对应合并。否则,把A作为新的类,并保存。
  6. 重复1、2、3、4、5直到所有的点都被标记为已访问。
  7. 分类:根据每个类对每个点的访问频率,取访问频率最大的那个类,作为当前点集的所属类。

数据集

数据集介绍

此数据集共200行,每行包含了消费者的相关信息与消费数据。其各维属性的意义如下:

属性名 解释 类型
Customer 消费者编号 连续值
Gender 消费者性别 离散值
Age 消费者年龄 离散值
Annual Income (k$) 消费者的年收入 离散值
Spending Score (1-100) 消费产生的积分 离散值

训练

首先我们引入必要的库:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

01 - 读入数据

param:

  • file_path:数据存储路径

return:

  • data:样本特征数据
def load_data(file_path):customers_data = pd.read_csv("meanshift_data/Mall_Customers.csv")# 读入样本数据data = customers_data.iloc[:, [3, 4]].values # 获取收入与消费积分数据return data

02 - 求解meanshift向量

param:

  • data:样本特征数据
  • radius:样本区域半径

return:

  • clusters:所有样本点聚类结果
def mean_shift(data, radius=2.5):clusters = []#存储每个样本所属的质心与出现频次for i in tqdm(range(len(data))):# 每个数据点都做为初始聚类的质心cluster_centroid = data[i] # 初始质心cluster_frequence = np.zeros(len(data)) # 初始每个数据点的聚类频率# 遍历数据点while True:temp_data = []#存储半径内的所有数据集for j in range(len(data)): # 每次都遍历所有元素v = data[j] # 获取第j个点if np.linalg.norm(v - cluster_centroid)<= radius:#判断样本是否在圆区域内temp_data.append(v)#  把半径内的所有数据集合起来cluster_frequence[i] += 1 # 在当前聚类中记录数据点出现的次数加1        old_centroid = cluster_centroid # 旧的质心          new_centroid = np.average(temp_data,axis=0)# 新的质心          cluster_centroid = new_centroid # 更新质心# 如果新旧质心一致,出现收敛,则结束if np.array_equal(new_centroid,old_centroid):# 判断是否出现重复聚类has_same_cluster = Falsefor cluster in clusters:                   # 两个质心小于半径,则为同一个聚类if np.linalg.norm(cluster['centroid'] - cluster_centroid)<= radius:has_same_cluster = True#合并,数据点出现次数也对应合并。cluster['frequency'] = cluster['frequency'] + cluster_frequence#出现重复的类,跳出,直接进行下一个样本点计算if has_same_cluster:break                #如果质心不同,保存质心,记录数据频次,并跳出,进行下一个样本点计算if not has_same_cluster:clusters.append({'centroid':cluster_centroid,'frequency':cluster_frequence})breakreturn clusters

03 - 聚类

根据样本点在所有聚类中出现的频率,取对其访问频率最大的那个类,作为当前点的所属类。

param:

  • data:样本特征数据
  • clusters:所有聚类结果

return:

  • index:最终聚类数量
def clustering(data, clusters):t = []index=[]for cluster in clusters:cluster['data'] = []t.append(cluster['frequency'])t = np.array(t)# 聚类for i in range(len(data)):column_frequency = t[:, i]cluster_index = np.where(column_frequency == np.max(column_frequency))[0][0]#得到频率最大的类索引index.append(cluster_index)#记录索引clusters[cluster_index]['data'].append(data[i])#将样本点添加到所属类return np.unique(index)

04 - 绘图

param:

  • clusters:所有聚类结果
  • index:最终聚类数量
def draw(index,clusters):fig = plt.figure(figsize=(20,7))axes = fig.add_subplot(111)colour=['magenta','cyan','pink','red','orange','green','blue']#创建一个颜色库# 画出每个类质心            for i in index:axes.scatter(clusters[i]['centroid'][0], clusters[i]['centroid'][1], marker='*', s=260, linewidths=3, color='black', label='centroid')# 画出样本聚类结果for i in tqdm(index):for j in range(len(clusters[i]['data'])):axes.scatter(clusters[i]['data'][j][0], clusters[i]['data'][j][1], color=colour[i],alpha=1)x_min, x_max = min(data[:,0])-10, max(data[:,0])+10 # 横轴坐标范围y_min, y_max = min(data[:,1])-1, max(data[:,1])+1 # 纵轴坐标范围plt.title("Mean Shift clustering") # 标题plt.show()

05 - 主函数

if __name__ == "__main__":data = load_data("meanshift_data/Mall_Customers.csv")print("1、导入数据:",len(data),"条")print("2、求解MeanShift向量")clusters=mean_shift(data, radius=18)   index = clustering(data, clusters)print("3、聚类数量为:",len(index))print("4、绘图")draw(index,clusters)

运行结果

总结

介绍了meanshift聚类算法的基本概念,算法原理以及具体实现过程。并借助消费者数据集,建立模型,实现了根据消费者的特征进行分类的模型训练过程。

参考文献

  1. http://www.scikitlearn.com.cn/0.21.3/22/#234-mean-shift
  2. Mean shift: A robust approach toward feature space analysis.” D. Comaniciu and P. Meer, IEEE Transactions on Pattern Analysis and Machine Intelligence (2002)

MeanShift- 案例实现(python)相关推荐

  1. arcgis python编程案例-ArcGIS Python编程案例-电子资料链接

    ArcGIS Python编程案例(1)-Python语言基础 https://www.jianshu.com/p/dd90816d019b ArcGIS Python编程案例(2)-使用ArcPy编 ...

  2. python字典导入mongodb_Python语言生成内嵌式字典(dict)-案例从python提取内嵌json写入mongodb...

    本文主要向大家介绍了Python语言生成内嵌式字典(dict)-案例从python提取内嵌json写入mongodb,通过具体的内容向大家展示,希望对大家学习Python语言有所帮助. 从mongo查 ...

  3. 案例:用python将中文翻译的和英文原文合成新的word文档

    案例:用python中文翻译的和英文原文合成新的word文档 一 问题的形成 需求的描述:一个英语翻译专业的研究生同学有一个期末作业.老师给了一个英文的文档,需要同学们翻译成中文.老师给的文档是格式如 ...

  4. 在我的新书里,尝试着用股票案例讲述Python爬虫大数据可视化等知识

    我的新书,<基于股票大数据分析的Python入门实战>,预计将于2019年底在清华出版社出版. 如果大家对大数据分析有兴趣,又想学习Python,这本书是一本不错的选择.从知识体系上来看, ...

  5. Computer:字符编码(ASCII编码/GBK编码/BASE64编码/UTF-8编码)的简介、案例应用(python中的编码格式及常见编码问题详解)之详细攻略

    Computer:字符编码(ASCII编码/GBK编码/BASE64编码/UTF-8编码)的简介.案例应用(python中的编码格式及常见编码问题详解)之详细攻略 目录 符串编码(ASCII编码/GB ...

  6. python数据建模案例源代码_一个完整的数据分析案例 | 用Python建立客户流失预测模型(含源数据+代码)...

    原标题:一个完整的数据分析案例 | 用Python建立客户流失预测模型(含源数据+代码) 来源:数据分析不是个事儿 作者:启方 原文: https://mp.weixin.qq.com/s/_20MN ...

  7. 用通俗易懂的方式讲解:主成分分析(PCA)算法及案例(Python 代码)

    文章目录 知识汇总 加入方式 一.引入问题 二.数据降维 三.PCA基本数学原理 3.1 内积与投影 3.2 基 3.3 基变换的矩阵表示 3.4 协方差矩阵及优化目标 3.5 方差 3.6 协方差 ...

  8. 软件测试案例|Python+Selenium+unittest完成对登录页面的自动化测试

    软件测试案例:Python+Selenium+unittest完成对登录页面的自动化测试 01.实验简介 本实验使用Python语言结合Selenium UI测试工具,利用unittest组织测试用例 ...

  9. 视频教程-跟着王进老师学开发Python篇:基础入门案例讲解-Python

    跟着王进老师学开发Python篇:基础入门案例讲解 教学风格独特,以学员视角出发设计课程,难易适度,重点突出,架构清晰,将实战经验融合到教学中.讲授技术同时传递方法.得到广大学员的高度认可. 王进 ¥ ...

  10. python小项目案例-拯救Python新手的几个项目实战

    原标题:拯救Python新手的几个项目实战 Python 做小游戏 实例一:24点游戏 项目名称:经典趣味24点游戏程序设计(python) 实例二:五子棋游戏 python学习关注我们企鹅qun: ...

最新文章

  1. 主机入侵防御系统(HIPS)分析
  2. 让炼丹师不再为数据集发愁,这家公司建了一个AI公开数据集平台
  3. 对缓存击穿的一点思考
  4. 1.4 计算机系统概述思维导图小结-常见问题和易混淆知识点(组成原理)
  5. 【C#】【Thread】上下文同步域SynchronizationAttribute
  6. YUIDoc example代码高亮错误、生成API文档目录不按源文件注释顺序
  7. iaas层次化结构--从业务需求到设计需求
  8. [转载] java synchronized静态同步方法与非静态同步方法,同步语句块
  9. 怎么知道自己适不适合学计算机专业,不知道自己到底适不适合学习计算机专业...
  10. Codeigniter的一些优秀实践
  11. 吴恩达教授机器学习课程笔记【三】广义线性模型(1)-指数族分布
  12. Agilent函数发生器编程(Agilent IO Suite使用)
  13. 网站跳出率(Bounce Rate)
  14. Git Branching
  15. Win10:电脑共享WIFI
  16. 20175212童皓桢 《Java程序设计》第十周学习总结
  17. qt几种常见的打包安装程序工具
  18. 程控电源CANoe上位机面板(CAN\ETH测试、RS232串口通信、编写设计思路)
  19. 牛客wannaflay挑战赛22 签到题 计数器 Wolf and Rabbit HDU 1222
  20. 等待末日, 一家6口在地窖里住9年? 被警方救出后, 前后太诡异了!

热门文章

  1. R语言建立和可视化混合效应模型mixed effect model
  2. Redis安装启动和配置文件
  3. Vue.js下载与使用
  4. 【遥感】时间分辨率:轨道周期 / 运行周期 / 重复周期 / 轨道重访周期 / 重访周期 概念
  5. 你到底能用Python做什么?下面是Python的三个主要应用程序。
  6. vsftp,lftp
  7. JavaScript的16进制转10进制
  8. linux 设备驱动(一)——字符设备驱动
  9. 基于STM32的ESP8266获取天气数据(HAL库)
  10. 李宏毅深度学习视频摘要