0. 简介

GMM和Kmeans一样也属于聚类,其算法训练流程也十分相似,Kmeans可认为是“硬聚类”,GMM是“软聚类”。

给定数据集X,Kmeans算法流程是这样的----- a 初始化:随机初始k个中心(即k个点,记为μ);b 矫正数据归属:计算X中每个点与k个中心的距离,并将其归为相距最近的那个中心;c 矫正中心:计算每个中心(共k个)所有点的均值,并将其更新为中心值;d 完成整体训练:循环b和c,直到聚类到“足够好”。

GMM算法流程和Kmeans基本一致,区别在于:a 除了初始化k个中心(μ)外,每个中心还对应一个协方差矩阵(Σ)和混合概率(π),其中μ代表高斯分布的中心,Σ代表高斯分布形状,π代表高斯函数值的大小;b 矫正数据归属,GMM中每个数据点并不完全归属某个中心,而是归属每个中心,只是归属的概率不同;c 矫正中心,每个中心矫正更新时考虑数据集X中的所有点,而非某一部分数据点。

以下使用鸢尾花数据集按照a~d的流程解析GMM;导入鸢尾花数据集如下;

from sklearn import datasets
import numpy as npiris = datasets.load_iris()
X = iris.data
N, D = X.shape
display(X.shape, X[:10])
(150, 4)
array([[5.1, 3.5, 1.4, 0.2],[4.9, 3. , 1.4, 0.2],[4.7, 3.2, 1.3, 0.2],[4.6, 3.1, 1.5, 0.2],[5. , 3.6, 1.4, 0.2],[5.4, 3.9, 1.7, 0.4],[4.6, 3.4, 1.4, 0.3],[5. , 3.4, 1.5, 0.2],[4.4, 2.9, 1.4, 0.2],[4.9, 3.1, 1.5, 0.1]])

即鸢尾花数据集是一个150行4列的矩阵。

1. 初始化

定义聚类数量为3类,每一类都初始化一个中心μ、一个协方差矩阵Σ和混合概率π;

mus = X[np.random.choice(X.shape[0], 3, replace=False)]
covs = [np.identity(4) for i in range(3)]
pis = [1/3] * 3

2. 矫正数据归属

普通高斯概率函数(只有一个中心)如下,其中D是数据维度,此处D=4;

表示数据点x归属该中心(μ、Σ)的概率,代码如下;

def gaussian(X, mu, cov):diff = X - mureturn 1 / ((2 * np.pi) ** (D / 2) * np.linalg.det(cov) ** 0.5) * np.exp(-0.5 * np.dot(np.dot(diff, np.linalg.inv(cov)), diff))

当用混合高斯函数(即多个中心)时,表示一个数据点n归属中心k(μ_k、Σ_k)的概率函数如下,其中k表示第k个中心,此处总共有3个中心,即K=3;

将上式定义为γ_z_nk,即

代码如下:

gammas = []
for mu_, cov_, pi_ in zip(mus, covs, pis):# loop each centergamma_ = [[pi_ * gaussian(x_, mu_, cov_)] for x_ in X]# loop each pointgammas.append(gamma_)
gammas = np.array(gammas)
gamma_total = gammas.sum(0)
gammas /= gamma_total

3. 矫正中心

根据2.中的gammas值,更新μ、Σ和π值,公式如下,其中N表示数据总个数,此处N=150;

代码如下;

mus, covs, pis = [], [], []
for gamma_ in gammas:#loop each centergamma_sum = gamma_.sum()pi_ = gamma_sum / Nmu_ = (gamma_ * X).sum(0) / gamma_sumcov_ = []for x_, gamma_i in zip(X, gamma_):diff = (x_ - mu_).reshape(-1, 1)cov_.append(gamma_i * np.dot(diff, diff.T))cov_ = np.sum(cov_, axis=0) / gamma_sumpis.append(pi_)mus.append(mu_)covs.append(cov_)

4. 完成整体训练

将2.~3.作为一个循环单元,写成一个函数;

def train_step(X, mus, covs, pis):gammas = []for mu_, cov_, pi_ in zip(mus, covs, pis):# loop each centergamma_ = [[pi_ * gaussian(x_, mu_, cov_)] for x_ in X]# loop each pointgammas.append(gamma_)gammas = np.array(gammas)gamma_total = gammas.sum(0)gammas /= gamma_totalmus, covs, pis = [], [], []for gamma_ in gammas:#loop each centergamma_sum = gamma_.sum()pi_ = gamma_sum / Nmu_ = (gamma_ * X).sum(0) / gamma_sumcov_ = []for x_, gamma_i in zip(X, gamma_):diff = (x_ - mu_).reshape(-1, 1)cov_.append(gamma_i * np.dot(diff, diff.T))cov_ = np.sum(cov_, axis=0) / gamma_sumpis.append(pi_)mus.append(mu_)covs.append(cov_)return mus, covs, pis

训练50次;

for _ in range(50):mus, covs, pis = train_step(X, mus, covs, pis)

训练完成后,会得到3个中心,可计算每个点归属这三个中心的概率(即γ_z_nk,第n个点归属第k个中心的概率),并将其归属于概率最大的那个中心;因为数据集是4维,无法可视化,仅选择前两维度进行可视化展示如下;

整个训练过程的动态图如下;

完整代码如下;

from sklearn import datasets
import numpy as np
import matplotlib.pyplot as pltiris = datasets.load_iris()
X = iris.data
N, D = X.shapemus = X[np.random.choice(X.shape[0], 3, replace=False)]
covs = [np.identity(4) for i in range(3)]
pis = [1/3] * 3def gaussian(X, mu, cov):diff = X - mureturn 1 / ((2 * np.pi) ** (D / 2) * np.linalg.det(cov) ** 0.5) * np.exp(-0.5 * np.dot(np.dot(diff, np.linalg.inv(cov)), diff))def get_likelihood(gamma_total):return np.log(gamma_total).sum()def train_step(X, mus, covs, pis):gammas = []for mu_, cov_, pi_ in zip(mus, covs, pis):# loop each centergamma_ = [[pi_ * gaussian(x_, mu_, cov_)] for x_ in X]# loop each pointgammas.append(gamma_)gammas = np.array(gammas)gamma_total = gammas.sum(0)gammas /= gamma_totalmus, covs, pis = [], [], []for gamma_ in gammas:#loop each centergamma_sum = gamma_.sum()pi_ = gamma_sum / Nmu_ = (gamma_ * X).sum(0) / gamma_sumcov_ = []for x_, gamma_i in zip(X, gamma_):diff = (x_ - mu_).reshape(-1, 1)cov_.append(gamma_i * np.dot(diff, diff.T))cov_ = np.sum(cov_, axis=0) / gamma_sumpis.append(pi_)mus.append(mu_)covs.append(cov_)return mus, covs, pis, gamma_totallog_LL = []
for _ in range(50):mus, covs, pis, gamma_total = train_step(X, mus, covs, pis)log_LL.append(get_likelihood(gamma_total))
plt.plot(log_LL)
plt.grid()

高斯混合模型聚类_GMM: Gaussian Mixed Model(高斯混合模型)相关推荐

  1. 【机器学习之高斯混合模型(Gaussian Mixed Model,GMM) 】

    文章目录 前言 一.高斯混合模型(Gaussian Mixed Model,GMM) 是什么? 二.详解GMM 1.初步原理 2.EM算法 3.深读原理 3.GMM(高斯混合模型)和K-means的比 ...

  2. 高斯混合模型--GMM(Gaussian Mixture Model)

    参考:http://blog.sina.com.cn/s/blog_54d460e40101ec00.html 概率指事件随机发生的机率,对于均匀分布函数,概率密度等于一段区间(事件的取值范围)的概率 ...

  3. GMM(Gaussian mixture model, 高斯混合模型)

    GMM全称是Gaussian mixture model (高斯混合模型).与k-means算法类似,GMM也是一种常见的聚类算法,它与k-means区别主要在于,GMM是一种"软聚类&qu ...

  4. ML之GMM:Gaussian Mixture Model高斯混合模型相关论文、算法步骤相关配图

    Gaussian Mixture Model高斯混合模型相关概念及配图 目录 GMM相关论文 GMM算法步骤相关配图 GMM相关论文 更新-- GMM算法步骤相关配图

  5. 一维(多维)高斯模型(One(Multi)-dimensional Gaussian Model) 高斯混合模型GMM(Gaussian Mixture Model)

    一维高斯模型(One-dimensional Gaussian Model) 若随机变量X服从一个数学期望为,标准方差为的高斯分布,记为: x~N(,). 则概率密度函数为: 高斯分布的期望值决定了其 ...

  6. 机器学习算法(二十九):高斯混合模型(Gaussian Mixed Model,GMM)

    目录 1 混合模型(Mixture Model) 2 高斯模型 2.1 单高斯模型 2.2 高斯混合模型 3 模型参数学习 3.1 单高斯模型 3.2 高斯混合模型 4 高斯混合模型与K均值算法对比 ...

  7. 十一、高斯混合模型(Gaussian Mixed Model, GMM)

    1 高斯模型 1.1 单高斯模型 当样本数据 X X X 是一维数据时, X X X 服从高斯分布是指其概率密度函数(Probability Density Function)可以用下面的式子表示: ...

  8. 前景背景分离方法(二)高斯混合模型法GMM(Gaussian Mixture Model)

    int main() {VideoCapture capture("D:/videos/shadow/use3.MPG");if( !capture.isOpened() ){co ...

  9. GMM高斯混合模型聚类的EM估计过程matlab仿真

    目录 1.算法概述 2.仿真效果 3.MATLAB源码 1.算法概述 高斯混合模型(Gaussian Mixed Model)指的是多个高斯分布函数的线性组合,理论上GMM可以拟合出任意类型的分布,通 ...

最新文章

  1. htop 比top更好用的top
  2. “A and B ...”,谓语动词必以复数形式呈现?
  3. 悬挑脚手架卸载钢丝绳要求_安全不可忽视!脚手架搭设彩色图集,动画展示施工全过程,抠细节...
  4. Linux free指令查看内存使用情况
  5. 【dlib opencv - detector landmark】 ubuntu上针对dlib-hog和opencv haar人脸检测与landmar-68在不同平台上运行时间实验结果汇总
  6. pycharm如何更改python项目环境_PyCharm如何导入python项目,并配置虚拟环境
  7. 读后感《习惯的力量》
  8. python钉钉机器人发送excel附件_Python自动化办公|如何在钉钉上自动发送定制消息或通知给同事...
  9. 飞秋常见文件解决方案
  10. C#使用DevExpress中的chartcontrol
  11. 2016年考研数学一真题pdf ​​​
  12. mvn命令启动Spring boot项目
  13. Vue搭脚手架及创建项目
  14. java 视频提取音频 | Java工具类
  15. Android之数据统计TalkingData集成
  16. VMware 虚拟机 Nat 模式无法上网
  17. 利用Python+xarray+cartopy+matplotlib 实现遥感地形图制图绘制 —— xarray 学习文档01
  18. java中to date_Java Date toInstant()用法及代码示例
  19. 计算机通信网络扫描版,2015计算机通信与网络作业.pdf
  20. 好记性不如烂笔头(一)——局域网可以Ping通,但Socket无法连接

热门文章

  1. 会做饭的机器人曰记_颜真卿《麻姑仙坛记》:苍劲古朴,体态沉雄,气象宏大...
  2. 解决去除“请输入有效值。两个最接近的有效值分别为1和2“提示
  3. [react] react组件间的通信有哪些?
  4. [react] react中什么是受控组件?
  5. React开发(125):ant design学习指南之form中的hasFeedback
  6. [html] 跨标签页的通讯方式有哪些
  7. [css] 你有使用过哪些栅格系统?都有什么区别呢?
  8. [js] 请使用js实现一个秒表计时器的程序
  9. 前端学习(2766):生命周期函数
  10. 工作215:打印出父子组件的this