在学习EM时,介绍了使用EM算法求解高斯混合模型(GMM:Gaussian Mixture Model,http://blog.csdn.net/foreseerwang/article/details/75222522),从而进行聚类的过程,并与k-means算法进行了对比,可以看到GMM模型的优势。

但是,GMM模型仍存在一些问题,譬如:必须要先知道类别数K。此时,可以通过引入GMM模型参数先验的方式,自动确定K。这就需要用到本文提到的VBEM算法了,详见PRML第10.2节和MLaPP第21.6节均有详细论述。

此外,在实践中发现,某些初始化下,简单的K-means和GMM/EM无法获得准确的聚类结果,估计可能和数据均衡度有关。请见下图,小图1是生成的GMM模型原始数据,每个高斯过程的参数与之前文章一致,但三类数据的比例是随机产生的。可以看到,在这种情况下,三类的比例失调,有一类特别少。

小图2和小图3分别是kmeans算法和GMM/EM算法聚类的结果,可以看到,与原始数据不符。

在这种情况下, 同样可以看到VBEM算法的优势。第二行的三个图是VBEM的结果。按照PRML和MLaPP书中所说,VBEM的类别数K需要选取大于实际类别数的值,这里选取K=6,小图4是直接聚6类的结果,用作VBEM的初始值。小图5是经过100次迭代的结果,仍然还有4类。小图6是迭代稳定后的结果。可以看到:

1. 自动生成了3类数据;

2. 3类的分布与原始数据非常一致。

这就是VBEM。世界正在变得越来越清晰...

代码如下:

clear all;
close all;
rng(2);%% Parameters
N = 1000;                                   % 总数据量
D = 2;                                      % 数据维度
K = 3;                                      % 类别数目
Pz = rand([K,1]);                           % 随机生成各类比例
Pz = Pz/sum(Pz);% 数据初始化,与之前的EM聚类程序相同
mu = [1 2; -6 2; 7 1];
sigma=zeros(K,D,D);
sigma(1,:,:)=[2 -1.5; -1.5 2];
sigma(2,:,:)=[5 -2.; -2. 3];
sigma(3,:,:)=[1 0.1; 0.1 2];%% Data Generation and display
x = zeros(N,D);
PzCDF1 = 0;
figure(1); subplot(2,3,1); hold on;
figure(2); hold on;
for ii = 1:K,PzCDF2 = PzCDF1 + Pz(ii);PzIdx1 = round(PzCDF1*N);PzIdx2 = round(PzCDF2*N);x(PzIdx1+1:PzIdx2,:) = mvnrnd(mu(ii,:), squeeze(sigma(ii,:,:)), PzIdx2-PzIdx1);PzCDF1 = PzCDF2;figure(1); subplot(2,3,1); hold on;plot(x(PzIdx1+1:PzIdx2,1),x(PzIdx1+1:PzIdx2,2),'o');
end;
[~, tmpidx] = sort(rand(N,1));
x = x(tmpidx,:);                        % shuffle datafigure(1); subplot(2,3,1);
plot(mu(:,1),mu(:,2),'k*');
axis([-10,10,-4,8]);
title('1.Generated Data (original)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');figure(2);
plot(x(:,1),x(:,2),'o');
figure(2);
plot(mu(:,1),mu(:,2),'k*');
axis([-10,10,-4,8]);
title('Generated Data (original)');
xlabel('x1');
ylabel('x2');fprintf('\n$$ Data generation and display completed...\n');%% clustering: Matlab k-means
k_idx=kmeans(x,K);                  % 使用Matlab现有k-means算法
figure(1); subplot(2,3,2); hold on;
for ii=1:K,idx=(k_idx==ii);plot(x(idx,1),x(idx,2),'o');center = mean(x(idx,:));plot(center(1),center(2),'k*');
end;
axis([-10,10,-4,8]);
title('2.Clustering: Matlab k-means', 'fontsize', 20);
xlabel('x1');
ylabel('x2');fprintf('\n$$ K-means clustering completed...\n');%% clustering: EM
% Refer to pp.351, MLaPP
% Pw: weight
% mu: u of Gaussion distribution
% sigma: Covariance matrix of Gaussion distribution
% r(i,k): responsibility; rk: sum of r over i
% px: p(x|mu,sigma)% 上面的聚类结果作为EM算法的初始值
Pw=zeros(K,1);
for ii=1:K,idx=(k_idx==ii);Pw(ii)=sum(idx)*1.0/N;mu(ii,:)=mean(x(idx,:));sigma(ii,:,:)=cov(x(idx,1),x(idx,2));
end;px=zeros(N,K);
for jj=1:100, % 简单起见,直接循环,不做结束判断for ii=1:K,px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sigma(ii,:,:)));% 使用Matlab自带的mvnpdf,有时会出现sigma非正定的错误end;% E steptemp=px.*repmat(Pw',N,1);r=temp./repmat(sum(temp,2),1,K);% M steprk=sum(r);Pw=rk'/N;mu=r'*x./repmat(rk',1,D);for ii=1:Ksigma(ii,:,:)=x'*(repmat(r(:,ii),1,D).*x)/rk(ii)-mu(ii,:)'*mu(ii,:);end;
end;% display
[~,clst_idx]=max(px,[],2);
figure(1); subplot(2,3,3); hold on;
for ii=1:K,idx=(clst_idx==ii);plot(x(idx,1),x(idx,2),'o');center = mean(x(idx,:));sigma(ii,:,:)=cov(x(idx,1),x(idx,2));plot(center(1),center(2),'k*');
end;axis([-10,10,-4,8]);
title('3.Clustering: GMM/EM', 'fontsize', 20);
xlabel('x1');
ylabel('x2');fprintf('\n$$ Gaussian Mixture using EM completed...\n');%% Variational Bayes EM
% Refer to ch.10.2, PRML
% x: visible variable, N * D
% z: latent variable, N * K% z: Pz, Ppi, alp0, alpk
%    Pz = P(z|pi);                                          PRML(10.37)
%    Ppi = Dir(pi|alp0)                                     PRML(10.39)
% x: Px, Pz, Ppi, mu, lambda, m0, beta0, W0, nu0
%    Px = P(x|z, mu, lambda);        高斯分布               PRML(10.38)
%    P(mu, lambda) = P(mu|lambda)*P(lambda)                PRML(10.40)
%        = N(mu|m0, (beta0*lambda)^-1) * Wi(lambda|W0, nu0)% rho: N*K,定义参见PRML(10.46)
% r: N*K, responsibility; 归一化之后的rho,定义参见PRML(10.49)
% N_k: sum of r over n                    定义参见PRML(10.51)
% xbar_k:                                 定义参见PRML(10.52)
% S_k                                     定义参见PRML(10.53)K = 6;                  % 增加分类数,利用VBEM自动选择分类数
k_idx=kmeans(x,K);      % 使用Matlab自带的k-means聚类,结果作为VBEM的初始值figure(1); subplot(2,3,4); hold on;
for ii=1:K,idx=(k_idx==ii);plot(x(idx,1),x(idx,2),'o');center = mean(x(idx,:));plot(center(1),center(2),'k*');mu(ii,:) = mean(x(idx,:));sigma(ii,:,:)=cov(x(idx,1),x(idx,2));px(:,ii)=GaussPDF(x,mu(ii,:),squeeze(sigma(ii,:,:)));% 使用Matlab自带的mvnpdf,有时会出现sigma非正定的错误,特使用自编函数GaussPDF
end;
axis([-10,10,-4,8]);
title('4.Clustering: VBEM (initial)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');% 初始化,具体定义参见PRML式(10.40)
alp0 = 0.0001;          % alpha0,应<<1,以实现类别数自动筛选
m0 = 0;
beta0 = rand()+0.5;         % 拍脑袋初始化
W0 = squeeze(mean(sigma));
W0inv = pinv(W0);
nu0 = D*2;                  % 拍脑袋初始化S_k = zeros(K,D,D);
W_k = zeros(K,D,D);
E_mu_lmbd = zeros(N,K);     % 即PRML中式(10.64)的等号左侧r = px./repmat(sum(px,2),1,K);                  % N*K
N_k = ones(1,K)*(-100);
fprintf('\n');
for ii = 1:1000,% M-stepN_k_new = sum(r);                           % 1*K,式(11.51)N_k_new(N_k_new<N/1000.0)=1e-4;             % 避免出现特别小或为零的Nkif sum(abs(N_k_new-N_k))<0.001,              break;  % early stop,如果Nk基本没变化了,则停止迭代elseN_k = N_k_new;end;xbar_k = r'*x./repmat(N_k', 1, D);          % K*D,PRML式(10.52)for jj = 1:K,dx = x-repmat(xbar_k(jj,:), N, 1);      % N*DS_k(jj,:,:) = dx'*(dx.*repmat(r(:,jj),1,D))/N_k(jj); % D*D,PRML式(10.53)end;alp_k = alp0 + N_k;         % PRML式(10.58)beta_k = beta0 + N_k;       % PRML式(10.60)m_k = (beta0*m0 + repmat(N_k',1,D).*xbar_k)./...repmat(beta_k',1,D);    % K*D,PRML式(10.61)for jj = 1:K,dxm = xbar_k(jj,:)-m0;Wkinv = W0inv + N_k(jj)*squeeze(S_k(jj,:,:)) + ...dxm'*dxm*beta0*N_k(jj)/(beta0+N_k(jj));W_k(jj,:,:) = pinv(Wkinv);           % K*D*D,PRML式(10.62)end;nu_k = nu0 + N_k;                        % 1*K,PRML式(10.63)% E-step: 迭代计算ralp_tilde = sum(alp_k);E_ln_pi = psi(alp_k) - psi(alp_tilde);      % PRML式(10.66)E_ln_lambda = D*log(2)*ones(1,K);           for jj = 1:D,E_ln_lambda = E_ln_lambda + psi((nu_k+1-jj)/2); end;for jj = 1:K,E_ln_lambda(jj) = E_ln_lambda(jj) + ...log(det(squeeze(W_k(jj,:,:))));     % PRML式(10.65)dxm = x-repmat(m_k(jj,:),N,1);          % N*DDbeta = D/beta_k(jj);for nn = 1:N,E_mu_lmbd(nn,jj) = Dbeta+nu_k(jj)*(dxm(nn,:)*...squeeze(W_k(jj,:,:))*dxm(nn,:)');   % PRML式(10.64)end;end;rho = exp(repmat(E_ln_pi,N,1)+repmat(E_ln_lambda,N,1)/2-...E_mu_lmbd/2);                           % PRML式(10.46)r = rho./repmat(sum(rho,2),1,K);            % PRML式(10.49)if mod(ii,10)==0,fprintf('%3d loops finished.\n', ii);end;if ii == 100,[~,clst_idx]=max(r,[],2);figure(1); subplot(2,3,5); hold on;for kk=1:K,idx=(clst_idx==kk);if sum(idx)/N>0.01,plot(x(idx,1),x(idx,2),'o');center = mean(x(idx,:));plot(center(1),center(2),'k*');end;end;axis([-10,10,-4,8]);title('5.Clustering: VBEM (100 iter)', 'fontsize', 20);xlabel('x1');ylabel('x2');end;
end;[~,clst_idx]=max(r,[],2);
figure(1); subplot(2,3,6); hold on;
Nclst = 0;
for ii=1:K,idx=(clst_idx==ii);if sum(idx)/N>0.01,Nclst = Nclst+1;plot(x(idx,1),x(idx,2),'o');center = mean(x(idx,:));plot(center(1),center(2),'k*');end;
end;
fprintf('\n$$ Using VBEM, totally %d clusters found.\n\n', Nclst);
axis([-10,10,-4,8]);
title('6.Clustering: VBEM (final)', 'fontsize', 20);
xlabel('x1');
ylabel('x2');

其中,使用了自己编写的高斯随机变量pdf计算函数GaussPDF,主要原因是matlab自带的mvnpdf函数有时会报sigma非正定的错误,但实际sigma是正定的。

GaussPDF代码如下:

function p = GaussPDF(x, mu, sigma)[N, D] = size(x);x_u = x-repmat(mu, N, 1);
p = zeros(N,1);
for ii=1:N,p(ii) = exp(-0.5*x_u(ii,:)*pinv(sigma)*x_u(ii,:)')/...sqrt(det(sigma)*(2*pi)^D);
end;end

Variational Inference入门:variational bayesian EM相关推荐

  1. 变分推断(variational inference)/variational EM

    诸神缄默不语-个人CSDN博文目录 由于我真的,啥都不会,所以本文基本上就是,从0开始. 我看不懂的博客就是写得不行的博客.所以我只写我看得懂的部分. 持续更新. 文章目录 1. 琴生不等式 2. 香 ...

  2. 机器学习课程 Variational Inference

    [视频课程:徐亦达讲变分推断]<机器学习课程 Variational Inference>by 徐亦达 Variational Inference Basics/Variational I ...

  3. 模型汇总-9 Variational AutoEncoder_VAE基础:LVM、MAP、EM、MCMC、Variational Inference(VI)

    Kingma et al和Rezende et al在2013年提出了变分自动编码器(Variational AutoEncoders,VAEs)模型,仅仅三年的时间,VAEs就成为一种最流行的生成模 ...

  4. Bayesian Convolution Neural Networks with Bernoulli Approximate Variational Inference

    <Bayesian Convolution Neural Networks with Bernoulli Approximate Variational Inference> https: ...

  5. GAUSSIAN MIXTURE VAE: LESSONS IN VARIATIONAL INFERENCE, GENERATIVE MODELS, AND DEEP NETS

    Not too long ago, I came across this paper on unsupervised clustering with Gaussian Mixture VAEs. I ...

  6. 变分推断(Variational Inference)最新进展简述

    动机 变分推断(Variational Inference, VI)是贝叶斯近似推断方法中的一大类方法,将后验推断问题巧妙地转化为优化问题进行求解,相比另一大类方法马尔可夫链蒙特卡洛方法(Markov ...

  7. Collapsed Variational Inference(Collapsed变分推断)算法以LDA推导为例

    本文作者:合肥工业大学 管理学院 钱洋 email:1563178220@qq.com 内容可能有不到之处,欢迎交流. 未经本人允许禁止转载. 文章目录 简介 LDA变分推断 LDA的Collapse ...

  8. 变分贝叶斯、Variational Inference

    不是大功告成了吗?通常情况下,上式是很难计算的,直观上看,需要考虑所有的都已比较困难了,更不用说能不能积分了,尤其是维度较高的情况,是需要多重积分的.当然,我们可以用Monte Carlo 的方法,不 ...

  9. EM and Variational Inference Derivation

    https://chrischoy.github.io/research/Expectation-Maximization-and-Variational-Inference/

最新文章

  1. RocketMQ命令整理
  2. yii2-按需加载并管理CSS样式/JS脚本
  3. 卢卡斯定理及其卢卡斯定理的拓展
  4. 取成本中心-生产订单
  5. [云炬python3玩转机器学习笔记] 3-11Matplotlib数据可视化基础
  6. Android之Intent.ACTION_MEDIA_SCANNER_SCAN_FILE:扫描指定文件
  7. python-flask-Flask-SQLAlchemy与Flask-Migrate联合进行数据化迁移
  8. 中国服务业发展的轨迹、逻辑与战略转变——改革开放40年来的经验分析
  9. 通过案例学习调优之--Oracle ASH
  10. 你使用过哪些数据分析的方法?
  11. 传输信道加密Stunnel配置
  12. 【优化算法】多目标跟踪优化算法(MTOA)【含Matlab源码 1466期】
  13. 西威变频器图纸 SIEI电路图 西威原理图avy-L 原厂图纸PDF格式 主板21页,底座驱动板7页
  14. 基于python 爬虫的数据库设计开题报告_爬虫开题报告
  15. 你的 Mac 用对了吗?推荐一些 Mac 上比较好用的软件
  16. 最小采样频率计算公式_音频文件大小计算公式-好文转载
  17. 经纬能源安全稳定怎样理财收益最大?怎样理财才干收益最大?
  18. 基于语音的疲劳度检测算法研究
  19. rhel8安装docker-ce
  20. github.com/stretchr/testify/suite

热门文章

  1. 微医在港招股书失效:曾多次喊话上市,注册用户达2.2亿
  2. 石油工程课程设计c语言,东北石油大学-石油工程抽油设计C语言编程.doc
  3. esp32与0.96寸屏幕实现信息传输
  4. 关于核磁共振图像的命名原则及含义(总结自用)
  5. 20 个短小精悍的 pandas 骚操作
  6. Mysql中WhereIn和Join的性能比对
  7. Android编译中m、mm、mmm的区别
  8. android学习总结(16.08.29)进度条控件ProgressBar和ProgressDialog
  9. SCI 投稿全过程信件模板一览
  10. fleck 客户端_C#中使用Fleck实现WebSocket通信简例