一、Mean Shift算法概述

Mean Shift算法,又称为均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后来由Yizong Cheng对其进行扩充,主要提出了两点的改进:

  • 定义了核函数;
  • 增加了权重系数。

核函数的定义使得偏移值对偏移向量的贡献随之样本与被偏移点的距离的不同而不同。权重系数使得不同样本的权重不同。Mean Shift算法在聚类,图像平滑、分割以及视频跟踪等方面有广泛的应用。

二、Mean Shift算法的核心原理

2.1、核函数

在Mean Shift算法中引入核函数的目的是使得随着样本与被偏移点的距离不同,其偏移量对均值偏移向量的贡献也不同。核函数是机器学习中常用的一种方式。核函数的定义如下所示:

X\mathbf{X}表示一个dd维的欧式空间,xx是该空间中的一个点x={x1,x2,x3⋯,xd}x=\left \{ x_1,x_2,x_3\cdots ,x_d \right \},其中,xx的模∥x∥2=xxT\left \| x \right \|^2=xx^T,R\mathbf{R}表示实数域,如果一个函数K:X→RK:\mathbf{X}\rightarrow \mathbf{R}存在一个剖面函数k:[0,∞]→Rk:\left [ 0,\infty \right ]\rightarrow \mathbf{R},即

K(x)=k(∥x∥2)

K\left ( x \right )=k\left ( \left \| x \right \|^2 \right )
并且满足:
(1)、kk是非负的
(2)、kk是非增的
(3)、kk是分段连续的
那么,函数K(x)K\left ( x \right )就称为核函数。

常用的核函数有高斯核函数。高斯核函数如下所示:

N(x)=12π−−√he−x22h2

N\left ( x \right )=\frac{1}{\sqrt{2\pi }h}e^{-\frac{x^2}{2h^2}}

其中,hh称为带宽(bandwidth),不同带宽的核函数如下图所示:

上图的画图脚本如下所示:

'''
Date:201604026
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
import mathdef cal_Gaussian(x, h=1):molecule = x * xdenominator = 2 * h * hleft = 1 / (math.sqrt(2 * math.pi) * h)return left * math.exp(-molecule / denominator)x = []for i in xrange(-40,40):x.append(i * 0.5);score_1 = []
score_2 = []
score_3 = []
score_4 = []for i in x:score_1.append(cal_Gaussian(i,1))score_2.append(cal_Gaussian(i,2))score_3.append(cal_Gaussian(i,3))score_4.append(cal_Gaussian(i,4))plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()

2.2、Mean Shift算法的核心思想

2.2.1、基本原理

对于Mean Shift算法,是一个迭代的步骤,即先算出当前点的偏移均值,将该点移动到此偏移均值,然后以此为新的起始点,继续移动,直到满足最终的条件。此过程可由下图的过程进行说明(图片来自参考文献3):

  • 步骤1:在指定的区域内计算偏移均值(如下图的黄色的圈)

  • 步骤2:移动该点到偏移均值点处

  • 步骤3: 重复上述的过程(计算新的偏移均值,移动)

  • 步骤4:满足了最终的条件,即退出

从上述过程可以看出,在Mean Shift算法中,最关键的就是计算每个点的偏移均值,然后根据新计算的偏移均值更新点的位置。

2.2.2、基本的Mean Shift向量形式

对于给定的dd维空间RdR^d中的nn个样本点xi,i=1,⋯,nx_i, i=1,\cdots , n,则对于xx点,其Mean Shift向量的基本形式为:

Mh(x)=1k∑xi∈Sh(xi−x)

M_h\left ( x \right )=\frac{1}{k}\sum_{x_i\in S_h}\left ( x_i-x \right )

其中,ShS_h指的是一个半径为hh的高维球区域,如上图中的蓝色的圆形区域。ShS_h的定义为:

Sh(x)=(y∣(y−x)(y−x)T⩽h2)

S_h\left ( x \right )=\left ( y\mid \left ( y-x \right )\left ( y-x \right )^T\leqslant h^2 \right )

这样的一种基本的Mean Shift形式存在一个问题:在ShS_h的区域内,每一个点对xx的贡献是一样的。而实际上,这种贡献与xx到每一个点之间的距离是相关的。同时,对于每一个样本,其重要程度也是不一样的。

2.2.3、改进的Mean Shift向量形式

基于以上的考虑,对基本的Mean Shift向量形式中增加核函数和样本权重,得到如下的改进的Mean Shift向量形式:

Mh(x)=∑ni=1GH(xi−x)w(xi)(xi−x)∑ni=1GH(xi−x)w(xi)

M_h\left ( x \right )=\frac{\sum_{i=1}^{n}G_H\left ( x_i-x \right )w\left ( x_i \right )\left ( x_i-x \right )}{\sum_{i=1}^{n}G_H\left ( x_i-x \right )w\left ( x_i \right )}

其中:

GH(xi−x)=|H|−12G(H−12(xi−x))

G_H\left ( x_i-x \right )=\left | H \right |^{-\frac{1}{2}}G\left ( H^{-\frac{1}{2}}\left ( x_i-x \right ) \right )

G(x)G\left ( x\right )是一个单位的核函数。HH是一个正定的对称d×dd\times d矩阵,称为带宽矩阵,其是一个对角阵。w(xi)⩾0w\left ( x_i \right )\geqslant 0是每一个样本的权重。对角阵HH的形式为:

H=⎛⎝⎜⎜⎜⎜⎜h210⋮00h22⋮0⋯⋯⋯00⋮h2d⎞⎠⎟⎟⎟⎟⎟d×d

H=\begin{pmatrix} h_1^2 & 0 & \cdots & 0\\ 0 & h_2^2 & \cdots & 0\\ \vdots & \vdots & & \vdots \\ 0 & 0 & \cdots & h_d^2 \end{pmatrix}_{d\times d}

上述的Mean Shift向量可以改写成:

Mh(x)=∑ni=1G(xi−xhi)w(xi)(xi−x)∑ni=1G(xi−xhi)w(xi)

M_h\left ( x \right )=\frac{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h_i} \right )w\left ( x_i \right )\left ( x_i-x \right )}{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h_i} \right )w\left ( x_i \right )}

Mean Shift向量Mh(x)M_h\left ( x \right )是归一化的概率密度梯度。

2.3、Mean Shift算法的解释

在Mean Shift算法中,实际上是利用了概率密度,求得概率密度的局部最优解。

2.3.1、概率密度梯度

对一个概率密度函数f(x)f\left ( x \right ),已知dd维空间中nn个采样点xi,i=1,⋯,nx_i,i=1,\cdots ,n,f(x)f\left ( x \right )的核函数估计(也称为Parzen窗估计)为:

f^(x)=∑ni=1K(xi−xh)w(xi)hd∑ni=1w(xi)

\hat{f}\left ( x \right )=\frac{\sum_{i=1}^{n}K\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )}{h^d\sum_{i=1}^{n}w\left ( x_i \right )}
其中
w(xi)⩾0w\left ( x_i \right )\geqslant 0是一个赋给采样点xix_i的权重
K(x)K\left ( x \right )是一个核函数

概率密度函数f(x)f\left ( x \right )的梯度▽f(x)\bigtriangledown f\left ( x \right )的估计为

▽f^(x)=2∑ni=1(x−xi)k′(∥∥xi−xh∥∥2)w(xi)hd+2∑ni=1w(xi)

\bigtriangledown \hat{f}\left ( x \right )=\frac{2\sum_{i=1}^{n}\left ( x-x_i \right ){k}'\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )}{h^{d+2}\sum_{i=1}^{n}w\left ( x_i \right )}

令g(x)=−k′(x)g\left ( x \right )=-{k}'\left ( x \right ),G(x)=g(∥x∥2)G\left ( x \right )=g\left ( \left \| x \right \|^2 \right ),则有:

▽f^(x)=2∑ni=1(xi−x)G(∥∥xi−xh∥∥2)w(xi)hd+2∑ni=1w(xi)=2h2⎡⎣⎢∑ni=1G(xi−xh)w(xi)hd∑ni=1w(xi)⎤⎦⎥⎡⎣⎢∑ni=1(xi−x)G(∥∥xi−xh∥∥2)w(xi)∑ni=1G(xi−xh)w(xi)⎤⎦⎥

\begin{align*} \bigtriangledown \hat{f}\left ( x \right ) &= \frac{2\sum_{i=1}^{n}\left ( x_i-x \right )G\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )}{h^{d+2}\sum_{i=1}^{n}w\left ( x_i \right )}\\ &= \frac{2}{h^2}\left [ \frac{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )}{h^d\sum_{i=1}^{n}w\left ( x_i \right )} \right ]\left [ \frac{\sum_{i=1}^{n}\left ( x_i-x \right )G\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )}{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )} \right ] \end{align*}

其中,第二个方括号中的就是Mean Shift向量,其与概率密度梯度成正比。

2.3.2、Mean Shift向量的修正

Mh(x)=∑ni=1G(∥∥xi−xh∥∥2)w(xi)xi∑ni=1G(xi−xh)w(xi)−x

M_h\left ( x \right )=\frac{\sum_{i=1}^{n}G\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )x_i}{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )}-x

记:mh(x)=∑ni=1G(∥∥xi−xh∥∥2)w(xi)xi∑ni=1G(xi−xh)w(xi)m_h\left ( x \right )=\frac{\sum_{i=1}^{n}G\left ( \left \| \frac{x_i-x}{h} \right \|^2 \right )w\left ( x_i \right )x_i}{\sum_{i=1}^{n}G\left ( \frac{x_i-x}{h} \right )w\left ( x_i \right )},则上式变成:

Mh(x)=mh(x)+x

M_h\left ( x \right )=m_h\left ( x \right )+x

这与梯度上升的过程一致。

2.4、Mean Shift算法流程

Mean Shift算法的算法流程如下:

  • 计算mh(x)m_h\left ( x \right )
  • 令x=mh(x)x=m_h\left ( x \right )
  • 如果∥mh(x)−x∥<ε\left \| m_h\left ( x \right )-x \right \|,结束循环,否则,重复上述步骤

三、实验

3.1、实验数据

实验数据如下图所示(来自参考文献1):

画图的代码如下:

'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as pltf = open("data")
x = []
y = []
for line in f.readlines():lines = line.strip().split("\t")if len(lines) == 2:x.append(float(lines[0]))y.append(float(lines[1]))
f.close()  plt.plot(x, y, 'b.', label="original data")
plt.title('Mean Shift')
plt.legend(loc="upper right")
plt.show()

3.2、实验的源码

#!/bin/python
#coding:UTF-8
'''
Date:20160426
@author: zhaozhiyong
'''import math
import sys
import numpy as npMIN_DISTANCE = 0.000001#mini errordef load_data(path, feature_num=2):f = open(path)data = []for line in f.readlines():lines = line.strip().split("\t")data_tmp = []if len(lines) != feature_num:continuefor i in xrange(feature_num):data_tmp.append(float(lines[i]))data.append(data_tmp)f.close()return datadef gaussian_kernel(distance, bandwidth):m = np.shape(distance)[0]right = np.mat(np.zeros((m, 1)))for i in xrange(m):right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)right[i, 0] = np.exp(right[i, 0])left = 1 / (bandwidth * math.sqrt(2 * math.pi))gaussian_val = left * rightreturn gaussian_valdef shift_point(point, points, kernel_bandwidth):points = np.mat(points)m,n = np.shape(points)#计算距离point_distances = np.mat(np.zeros((m,1)))for i in xrange(m):point_distances[i, 0] = np.sqrt((point - points[i]) * (point - points[i]).T)#计算高斯核      point_weights = gaussian_kernel(point_distances, kernel_bandwidth)#计算分母all = 0.0for i in xrange(m):all += point_weights[i, 0]#均值偏移point_shifted = point_weights.T * points / allreturn point_shifteddef euclidean_dist(pointA, pointB):#计算pointA和pointB之间的欧式距离total = (pointA - pointB) * (pointA - pointB).Treturn math.sqrt(total)def distance_to_group(point, group):min_distance = 10000.0for pt in group:dist = euclidean_dist(point, pt)if dist < min_distance:min_distance = distreturn min_distancedef group_points(mean_shift_points):group_assignment = []m,n = np.shape(mean_shift_points)index = 0index_dict = {}for i in xrange(m):item = []for j in xrange(n):item.append(str(("%5.2f" % mean_shift_points[i, j])))item_1 = "_".join(item)print item_1if item_1 not in index_dict:index_dict[item_1] = indexindex += 1for i in xrange(m):item = []for j in xrange(n):item.append(str(("%5.2f" % mean_shift_points[i, j])))item_1 = "_".join(item)group_assignment.append(index_dict[item_1])return group_assignmentdef train_mean_shift(points, kenel_bandwidth=2):#shift_points = np.array(points)mean_shift_points = np.mat(points)max_min_dist = 1iter = 0m, n = np.shape(mean_shift_points)need_shift = [True] * m#cal the mean shift vectorwhile max_min_dist > MIN_DISTANCE:max_min_dist = 0iter += 1print "iter : " + str(iter)for i in range(0, m):#判断每一个样本点是否需要计算偏置均值if not need_shift[i]:continuep_new = mean_shift_points[i]p_new_start = p_newp_new = shift_point(p_new, points, kenel_bandwidth)dist = euclidean_dist(p_new, p_new_start)if dist > max_min_dist:#record the max in all pointsmax_min_dist = distif dist < MIN_DISTANCE:#no need to moveneed_shift[i] = Falsemean_shift_points[i] = p_new#计算最终的groupgroup = group_points(mean_shift_points)return np.mat(points), mean_shift_points, groupif __name__ == "__main__":#导入数据集path = "./data"data = load_data(path, 2)#训练,h=2points, shift_points, cluster = train_mean_shift(data, 2)for i in xrange(len(cluster)):print "%5.2f,%5.2f\t%5.2f,%5.2f\t%i" % (points[i,0], points[i, 1], shift_points[i, 0], shift_points[i, 1], cluster[i])

3.3、实验的结果

经过Mean Shift算法聚类后的数据如下所示:

'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as pltf = open("data_mean")
cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
center_x = []
center_y = []
center_dict = {}for line in f.readlines():lines = line.strip().split("\t")if len(lines) == 3:label = int(lines[2])if label == 0:data_1 = lines[0].strip().split(",")cluster_x_0.append(float(data_1[0]))cluster_y_0.append(float(data_1[1]))if label not in center_dict:center_dict[label] = 1data_2 = lines[1].strip().split(",")center_x.append(float(data_2[0]))center_y.append(float(data_2[1]))elif label == 1:data_1 = lines[0].strip().split(",")cluster_x_1.append(float(data_1[0]))cluster_y_1.append(float(data_1[1]))if label not in center_dict:center_dict[label] = 1data_2 = lines[1].strip().split(",")center_x.append(float(data_2[0]))center_y.append(float(data_2[1]))else:data_1 = lines[0].strip().split(",")cluster_x_2.append(float(data_1[0]))cluster_y_2.append(float(data_1[1]))if label not in center_dict:center_dict[label] = 1data_2 = lines[1].strip().split(",")center_x.append(float(data_2[0]))center_y.append(float(data_2[1]))
f.close()plt.plot(cluster_x_0, cluster_y_0, 'b.', label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1, 'g.', label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2, 'k.', label="cluster_2")
plt.plot(center_x, center_y, 'r+', label="mean point")
plt.title('Mean Shift 2')
#plt.legend(loc="best")
plt.show()

参考文献

  1. Mean Shift Clustering

  2. Meanshift,聚类算法

  3. meanshift算法简介

简单易学的机器学习算法——Mean Shift聚类算法相关推荐

  1. 机器学习(十)Mean Shift 聚类算法

    Mean Shift 聚类算法 原文地址:http://blog.csdn.net/hjimce/article/details/45718593  作者:hjimce 一.mean shift 算法 ...

  2. 简单易学的机器学习算法——梯度提升决策树GBDT

    梯度提升决策树(Gradient Boosting Decision Tree,GBDT)算法是近年来被提及比较多的一个算法,这主要得益于其算法的性能,以及该算法在各类数据挖掘以及机器学习比赛中的卓越 ...

  3. 简单易学的机器学习算法——Metropolis-Hastings算法

    在简单易学的机器学习算法--马尔可夫链蒙特卡罗方法MCMC中简单介绍了马尔可夫链蒙特卡罗MCMC方法的基本原理,介绍了Metropolis采样算法的基本过程,这一部分,主要介绍Metropolis-H ...

  4. 《菜菜的机器学习sklearn课堂》聚类算法Kmeans

    聚类算法 聚类算法 无监督学习与聚类算法 sklearn中的聚类算法 KMeans KMeans是如何工作的 簇内误差平方和的定义和解惑 sklearn.cluster.KMeans 重要参数 n_c ...

  5. python机器学习案例系列教程——聚类算法总结

    全栈工程师开发手册 (作者:栾鹏) python教程全解 一.什么是聚类? 聚类(Clustering):聚类是一个人们日常生活的常见行为,即所谓"物以类聚,人以群分",核心的思想 ...

  6. 【火炉炼AI】机器学习023-使用层次聚类算法构建模型

    [火炉炼AI]机器学习023-使用层次聚类算法构建模型 (本文所使用的Python库和版本号: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotli ...

  7. 机器学习笔记--聚类算法 k-means--31省市消费水平聚类

    参考文章:https://blog.csdn.net/rankiy/article/details/99843363 1.数据集 数据介绍: 现有1999年全国31个省份城镇居民家庭平均每月全年消费性 ...

  8. 【聚类算法】常见聚类算法总结

    转自:https://blog.csdn.net/u010062386/article/details/82499777 感谢博主 1.常见算法 1.原型聚类 "原型"是指样本空间 ...

  9. ML之Clustering之普聚类算法:普聚类算法的相关论文、主要思路、关键步骤、代码实现等相关配图之详细攻略

    ML之Clustering之普聚类算法:普聚类算法的相关论文.主要思路.关键步骤.代码实现等相关配图之详细攻略 目录 普聚类算法的相关论文 普聚类算法的主要思路 普聚类算法的关键步骤 普聚类算法的代码 ...

  10. k中心点聚类算法伪代码_聚类算法之——K-Means、Canopy、Mini Batch K-Means

    K-Means||算法 K-Means||算法是为了解决K-Means++算法缺点而产生的一种算法: 主要思路是改变每次遍历时候的取样规则,并非按照K-Means++算法每次遍历只获取一个样本,而是每 ...

最新文章

  1. 各种光学仪器成像技术(上)
  2. Java归去来第2集:利用Eclipse创建Maven Web项目
  3. 利用循环打印杨辉三角形
  4. Markdown语法整理
  5. (Spring)依赖注入
  6. Wannafly summer camp
  7. CORS 跨域-哪些操作受到同源限制
  8. python错误怎么处理_python报的错误怎么处理
  9. 在浏览器控制台输出内容 console.log(string);
  10. jdbc 批量insert_JDBC相关知识解答
  11. 公众号文章折叠点击后展开案例_(案例)蜂窝纸板在包装中的应用内衬
  12. Redis入门到入土教程_1
  13. mysql 敏感词_过滤敏感词方式
  14. DB2 SQLCODE 异常大全编辑(三)
  15. tensorflow-ckpt2npy
  16. it行业se是_CS、IT、SE到底有什么区别?
  17. 20211101bugku_re_mountain_climbing
  18. Every Document Owns Its Structure: Inductive Text Classification via GNN (TextING)
  19. 网络工程师就业前景、职业规划和工资待遇
  20. day1-python基础1

热门文章

  1. JS使用递归遍历json对象进行操作
  2. ESXI安装网卡或HBA卡驱动
  3. 技术研发团队管理计划方案书
  4. Quartz定时任务框架(二):Trigger触发器详解
  5. Linux安装/升级/卸载pip3
  6. 51单片机串口输出某些汉字乱码修复补丁(支持所有keil版本)
  7. 微信公众平台 使用JS-SDK实现拍照上传功能
  8. 进销存管理系统——商品管理
  9. java jbutton数组_java-JButton需要显示图像数组
  10. 第四章 维纳滤波原理及自适应算法