原理

大家自行百度吧,我懒得码字了

推荐一下原理原理https://blog.csdn.net/jinshengtao/article/details/30258833

代码

直接上代码了,看不懂,就参照一下原理

# author: wdq
# contact: 1920132572@qq.com
# datetime:2022/3/15 17:40
# software: PyCharm
import random
from collections import Counter
from typing import Listimport numpy as np
from numpy import ndarrayclass MeanShift:def __init__(self, nums: ndarray, band_width: float):""":param nums: 要划分的ndarray:param band_width: 窗口大小"""# 要划分的ndarrayself.__nums = nums# 窗口大小self.__band_width = band_width# 停止步长self.__stop_band_width = 10 ** -4 * self.__band_width# 访问数组self.__is_visited = [False] * self.__nums.shape[0]# 聚类中心self.__cluster_centers = []# 聚类self.__cluster = []def mean_shift(self) -> List[List[List[int]]]:# 判断是否所有点都被访问过while not self.__is_all_visited():my_member = []# 在没被访问的点随机选一个点start_point = random.choice([i for i in range(self.__nums.shape[0]) if not self.__is_visited[i]])my_mean = self.__nums[start_point]while True:# 得到到各点的距离,以及权重distance, gaussian = self.__get_shift(my_mean)# 找到在窗口的点in_the_area = self.__find__points(distance)# 保留当前的位置old_mean = my_mean.copy()# 得到新的位置my_mean = self.__get_new_mean(gaussian, in_the_area)# 将范围的点划到当次的聚类my_member.extend(in_the_area)# 更新当前的访问数组self.__update_visited(in_the_area)# 判断是否小于停止步长if self.__get_distance(old_mean, my_mean) < self.__stop_band_width:merge_width = None# 遍历当前聚类for i in range(len(self.__cluster_centers)):# 判断中心点离得太近if self.__get_distance(my_mean, self.__cluster_centers[i]) < self.__band_width / 2:merge_width = ibreak# 如果太近了就合并这2个聚类if merge_width is not None:# 合并中心点self.__cluster_centers[merge_width] = self.__get_new_center(my_mean,self.__cluster_centers[merge_width])# 合并聚类中的点self.__cluster[merge_width].extend(my_member)# 否则就添加一个聚类else:self.__cluster_centers.append(my_mean.tolist())self.__cluster.append(my_member)break# 返回分好类的结果return self.__get_result()def __is_all_visited(self) -> bool:""":return: 是否全部访问"""# 遍历访问数组for i in self.__is_visited:if not i:return Falsereturn Truedef __get_distance(self, start: any, end: any) -> float:""":param start: 起始点:param end: 终点:return: 两点之间的距离"""# 类型转换if type(start) != ndarray:start = np.array(start)if type(end) != ndarray:end = np.array(end)# 返回欧式距离return np.linalg.norm(start - end)def __get_shift(self, start: ndarray) -> (ndarray, ndarray):""":param start: 开始的点:return: 计算滑动的距离"""# 距离distance = np.zeros((self.__nums.shape[0], 1))# 权重gaussian = np.zeros((self.__nums.shape[0], 1))for i in range(distance.shape[0]):temp = self.__get_distance(start, self.__nums[i])gaussian[i] = self.__gaussian_kernel(temp, self.__band_width)distance[i] = tempreturn distance, gaussiandef __gaussian_kernel(self, distance: float, bandwidth: float) -> float:"""高斯核函数:param distance: 距离:param bandwidth: 窗口大小:return: 权重"""return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * (distance / bandwidth) ** 2)def __get_new_mean(self, gaussian: ndarray, in_the_area: List[int]) -> ndarray:""":param gaussian: 权重:param in_the_area: 在区域的点:return:"""# 权重weight = 0.# 在范围的点new_mean = np.array([self.__nums[i].tolist() for i in in_the_area])for i in range(len(in_the_area)):new_mean[i] = new_mean[i] * gaussian[in_the_area[i]]weight += gaussian[in_the_area[i]]# 对范围的点进行加权,并算出漂移到的点return np.sum(new_mean, axis=0) / weight if weight != 0 else np.sum(new_mean, axis=0)def __find__points(self, distance: ndarray) -> List[int]:""":param distance: 距离ndarray:return: 在窗口大小内的点"""return [i for i, j in enumerate(distance) if j < self.__band_width ** 2]def __update_visited(self, in_the_area: List[int]) -> None:"""更新访问过的点:param in_the_area: 在窗口大小内的点:return:"""for i in in_the_area:self.__is_visited[i] = Truedef __get_new_center(self, mymean: ndarray, old_center: List[int]) -> List[int]:"""合并中心点:param mymean: 现在的中心点:param old_center: 以前的中心点:return:"""return [(i + j) / 2 for i, j in zip(mymean.tolist(), old_center)]def __get_result(self) -> List[List[List[int]]]:"""将结果分好类并返回这段代码比较丑陋,将就看看,不看也行,我自己都不想看大致意思就是找这些点应该分到那个类:return:"""count = []result = [[] for i in range(len(self.__cluster))]# 计数,计出每个点到每个聚类的次数for i in self.__cluster:count.append(dict(Counter(i)))belong = []# 遍历找出每个点到到那个聚类的最大值,那我们就可以认为它在那个聚类for num in range(len(self.__nums)):# 最大次数的索引index = 0for i in range(1, len(count)):if count[i].get(num, 0) > count[index].get(num, 0):index = ibelong.append(index)# 分类for i in range(len(self.__nums)):result[belong[i]].append(self.__nums[i].tolist())# 把空的聚类移除return [i for i in result if i]

测试代码

import matplotlib
from matplotlib import pyplot as plt
from sklearn import datasetsfrom MeanShift import MeanShiftmatplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False
iris = datasets.load_iris()  # 引入数据集
# 分的类不好就重新分,多试一哈
mean_shift = MeanShift(nums=iris.data, band_width=1.34)  # 对于iris,窗口大小为1.34,别问为什么,别问,问就是好用
colors = ['red', 'green', 'blue', 'black', 'yellow']
a = mean_shift.mean_shift()
for i in range(len(a)):for j in a[i]:plt.scatter(j[0], j[1], c=colors[i])
plt.title("Mean-Shift")
plt.xlabel('萼片长度')
plt.ylabel('萼片宽度')
plt.show()"""___________.__                   __               _____                 _____  .__      ._.\__    ___/|  |__ _____    ____ |  | __  ______ _/ ____\___________    /     \ |__| ____| ||    |   |  |  \\__  \  /    \|  |/ / /  ___/ \   __\/  _ \_  __ \  /  \ /  \|  |/    \ ||    |   |   Y  \/ __ \|   |  \    <  \___ \   |  | (  <_> )  | \/ /    Y    \  |   |  \||____|   |___|  (____  /___|  /__|_ \/____  >  |__|  \____/|__|    \____|__  /__|___|  /_\/     \/     \/     \/     \/                               \/        \/\/      """

运行结果

标准答案

MeanShift算法

只用来学习,借鉴,错的话,欢迎批评和指导!

邮箱:cse.dqwu19@gzu.edu.cn

Python 实现MeanShift算法相关推荐

  1. 基于python的mean-shift算法

    一.Mean Shift算法概述 Mean Shift算法又称均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后来又由Yzong Cheng对其进行扩充,主要提出 ...

  2. OpenCV中MeanShift算法视频移动对象分析

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 MeanShift算法 Mean Shift是一种聚类算法,在数据 ...

  3. 基于Mean-shift算法跟踪对象

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 跟踪对象是计算机视觉领域的重要应用.这在监控系统.国防.自动驾驶汽 ...

  4. 传统目标跟踪——MeanShift算法

    目录 一.均值漂移(MeanShift) 二.流程 三.代码 3.1 meanshift+固定框的代码 3.2 优化:meanshift+鼠标选择 3.3 meanshift+自己实现函数 四.补充知 ...

  5. python多维向量聚类_机器学习:Python实现聚类算法(三)之总结

    考虑到学习知识的顺序及效率问题,所以后续的几种聚类方法不再详细讲解原理,也不再写python实现的源代码,只介绍下算法的基本思路,使大家对每种算法有个直观的印象,从而可以更好的理解函数中参数的意义及作 ...

  6. python数据结构与算法总结

    python常用的数据结构与算法就分享到此处,本月涉及数据结构与算法的内容有如下文章: <数据结构和算法对python意味着什么?> <顺序表数据结构在python中的应用> ...

  7. 数学推导+纯Python实现机器学习算法:GBDT

    Datawhale推荐 作者:louwill,Machine Learning Lab 时隔大半年,机器学习算法推导系列终于有时间继续更新了.在之前的14讲中,笔者将监督模型中主要的单模型算法基本都过 ...

  8. 以图搜图Python实现Hash算法

    以图搜图(一):Python实现dHash算法 http://yshblog.com/blog/43 以图搜图(二):Python实现pHash算法 http://yshblog.com/blog/4 ...

  9. em算法python代码_EM 算法求解高斯混合模型python实现

    注:本文是对<统计学习方法>EM算法的一个简单总结. 1. 什么是EM算法? 引用书上的话: 概率模型有时既含有观测变量,又含有隐变量或者潜在变量.如果概率模型的变量都是观测变量,可以直接 ...

最新文章

  1. 刚进来的小伙伴说Nginx只能做负载均衡,还是太年轻了
  2. linux下,在挂载设备之前,查看设备的文件系统类型
  3. python3操作MySQL:insert插入数据
  4. 嵌入式Linux裸机开发(六)——S5PV210时钟系统
  5. oracle10g数据库复制,oracle -10g 中Duplicate 复制数据库
  6. python bokeh 示例_Python bokeh.plotting.figure.arc()用法及代码示例
  7. Docker下部署wordpress
  8. IT不是技术,IT是一个世界
  9. mysql无法修改表字段
  10. 上海证券交易所云平台移动行情服务测试项目
  11. 视频教程-C语言-从汇编角度理解C语言的本质-C/C++
  12. XSSFWorkbook Excel导出导入
  13. Phobos Runtime Library
  14. 手握N段大厂实习经历的人生有多爽?
  15. 常见电路结构分析七:三相电的使用与接法
  16. 百度高层调整:沈抖领军智能云 打造第二增长曲线
  17. 2014年终总结 --量变到质变,是一个过程!
  18. Python处理ISO 8601日期时间
  19. es 简单实现增加,查询,分词 热词
  20. 前端——css 背景background

热门文章

  1. Excel一个单元格中输入度分秒转换成小数(如256.3246(读256度32分46秒))
  2. 渗透测试之红队项目日常渗透笔记
  3. android 热点 连接电脑上网,电脑没有网络,用手机数据线,就能让电脑快速上网...
  4. 基于51单片机的波形发生器proteus仿真数码管LCD12864显示
  5. 服务器系统导致无盘客户机usb失灵,无盘客户机无法启动/故障排查过程
  6. WIN10电脑手动抓蓝屏dump
  7. [C/C++语言基础] —函数
  8. Chrome-谷歌浏览器多开教程
  9. DEEP COMPRESSION(深度学习网络参数压缩)
  10. PyCharm免费使用,没有校园邮箱也可以(PyCharm学生认证邮箱失效同样可以)