EM算法是一种迭代算法,主要适用于概率模型的参数估计,特别适用于含有隐变量的概率模型参数的极大似然估计,或者极大后验概率的估计。EM算法的每次迭代有两步组成:E步,求期望;M步,极大化。所以这一算法称之为期望极大化算法,简称EM算法。可能大家都听过EM算法,也知道有E步(求期望)和M步(极大化),但是求期望是求谁的期望,极大化又该如何极大化呢,单看理论有时不免一头雾水,通过实例可以让我们更好的理解,同时也先跟大家说一下EM算法是对初值敏感的。

首先,给大家先推荐一下理论学习的资料吧,毕竟内功还是要修修的。

1、李航统计学习方法-第九章 EM算法及其推广

2、视频教程:徐亦达|概率机器学习(头几个视频是介绍EM算法的)

3、 What is the expectation maximization algorithm?. Nature biotechnology, 26(8), 897.

网上关于EM的理论介绍还是非常多的,这里我们主要通过程序来让大家更好的理解EM算法,

我下面从二硬币模型开始转向三硬币模型帮助初学者更好的理解EM算法的过程。

示例一:二硬币模型

假设现在有两个硬币A和B,我们想要知道两枚硬币各自为正面的概率啊即模型的参数。我们先随机从A,B中选一枚硬币,然后扔10次并记录下相应的结果,H代表正面T代表反面。对以上的步骤重复进行5次。如果在记录的过程中我们记录下来每次是哪一枚硬币(即知道每次选的是A还是B),那可以直接根据结果进行估计(见下图a)。不含隐变量的参数求解问题

但是如果数据中没记录每次投掷的硬币是A还是B(隐变量),只观测到5次循环共50次投币的结果,这时就没法直接估计A和B的正面概率。这时就该轮到EM算法大显身手了,EM算法特别适用于这种含有隐变量的参数求解问题(见下图b)。含有隐变量的参数求解

先初始化输入参数,如上图1步给了一个初始值0.6(A硬币正面的概率),0.5(B硬币正面的概率)。接下来先进行E步(对隐变量求期望),如上图2步:以第一条数据为例,5H5T,为A的概率为

,为B的概率

,归一化后得P(A)=0.45,P(B)=0.55,剩下几条数据同理可得。而后通过M-step可计算重新迭代的概率值。如上图第一次迭代后

,循环上面的E、M步骤直至收敛我们就可以得到最终的答案,如上图进过10次迭代后得到了最终的结果。

示例二:三硬币模型

现在我们将上面的二硬币模型扩展为三硬币模型,其实原理基本差不多。假设有三枚硬币A、B、C,这些硬币正面出现的概率分别p,q和

。先抛C硬币,如果C硬币为正面则选择硬币A,反之选择硬币B,然后对选出的硬币进行一组实验,独立的抛十次。共做5次实验,每次实验独立的抛十次,结果如图中a所示,例如某次实验产生了H、T、T、T、H、H、T、H、T、H,H代表正面朝上。5次实验结果

本人最近也刚学EM算法,下面代码主要参考EM算法及其推广,这里面作者实现了一个两硬币模型的EM算法。本文对其稍做了一点修改,变成三硬币模型。

EM算法步骤:

E步:计算在当前迭代的模型参数下,观测数据y来自硬币B的概率:

M步:估算下一个迭代的新的模型估算值

对于这个三硬币模型来说,我们先通过E步(对隐变量求期望)来求得隐变量的参数(即属于哪个硬币),然后再通过M-step来重新估算三个硬币的参数,直至收敛(达到要求)为止。下面是实现三硬币模型的EM算法代码,希望可以更好的帮助理解。

# !usr/bin/env python

# -*- coding:utf-8 -*-

import numpy as np

from scipy import stats

def em_single(priors, observations):

"""EM算法单次迭代Arguments---------priors : [theta_A, theta_B,theta_C]observations : [m X n matrix]Returns--------new_priors: [new_theta_A, new_theta_B,new_theta_C]:param priors::param observations::return:"""

counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}

theta_A = priors[0]

theta_B = priors[1]

theta_c=priors[2]

# E step

weight_As=[]

for observation in observations:

len_observation = len(observation)

num_heads = observation.sum()

num_tails = len_observation - num_heads

contribution_A = theta_c*stats.binom.pmf(num_heads, len_observation, theta_A)

contribution_B = (1-theta_c)*stats.binom.pmf(num_heads, len_observation, theta_B) # 两个二项分布

weight_A = contribution_A / (contribution_A + contribution_B)

weight_B = contribution_B / (contribution_A + contribution_B)

# 更新在当前参数下A、B硬币产生的正反面次数

weight_As.append(weight_A)

counts['A']['H'] += weight_A * num_heads

counts['A']['T'] += weight_A * num_tails

counts['B']['H'] += weight_B * num_heads

counts['B']['T'] += weight_B * num_tails

# M step

new_theta_c = 1.0*sum(weight_As)/len(weight_As)

new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])

new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])

return [new_theta_A, new_theta_B,new_theta_c]

def em(observations, prior, tol=1e-6, iterations=10000):

"""EM算法:param observations: 观测数据:param prior: 模型初值:param tol: 迭代结束阈值:param iterations: 最大迭代次数:return: 局部最优的模型参数"""

import math

iteration = 0

while iteration < iterations:

new_prior = em_single(prior, observations)

delta_change = np.abs(prior[0] - new_prior[0])

if delta_change < tol:

break

else:

prior = new_prior

iteration += 1

return [new_prior, iteration]

# 硬币投掷结果观测序列:1表示正面,0表示反面。

observations = np.array([[1, 0, 0, 0, 1, 1, 0, 1, 0, 1],

[1, 1, 1, 1, 0, 1, 1, 1, 1, 1],

[1, 0, 1, 1, 1, 1, 1, 0, 1, 1],

[1, 0, 1, 0, 0, 0, 1, 1, 0, 0],

[0, 1, 1, 1, 0, 1, 1, 1, 0, 1]])

print em(observations, [0.5, 0.8, 0.6])

运行后结果为:

[[0.51392121603987106, 0.79337052912023864, 0.47726196801164544], 42]

从结果我们可以了解到经过42轮迭代,我们最终得出了结果:硬币A正面的概率为0.51392121603987106,硬币B为正面的概率为0.79337052912023864,C硬币正面概率为0.47726196801164544。

至此EM算法的实现就完成了,另外还有一个EM算法求高斯混合模型參数预计 的python实现,大家有兴趣的可以了解一下。

通过以上的例子希望能过帮助大家更好的理解EM算法。本人也初学EM算法,如果有错误的地方还恳请指正。

参考:

2、Do, C. B., & Batzoglou, S. (2008). What is the expectation maximization algorithm?. Nature biotechnology, 26(8), 897.

em算法python代码_EM算法Python实战相关推荐

  1. em算法python代码_EM 算法求解高斯混合模型python实现

    注:本文是对<统计学习方法>EM算法的一个简单总结. 1. 什么是EM算法? 引用书上的话: 概率模型有时既含有观测变量,又含有隐变量或者潜在变量.如果概率模型的变量都是观测变量,可以直接 ...

  2. em算法python代码_EM算法的python实现的方法步骤

    导读热词 前言:前一篇文章大概说了EM算法的整个理解以及一些相关的公式神马的,那些数学公式啥的看完真的是忘完了,那就来用代码记忆记忆吧!接下来将会对python版本的EM算法进行一些分析. EM的py ...

  3. 手眼标定算法Tsai-Lenz代码实现(Python、C++、Matlab)

    你好,我是小智. 上一节介绍了手眼标定算法Tsai的原理,这一节介绍算法的代码实现,分别有Python.C++.Matlab版本的算法实现方式. 该算法适用于将相机装在手抓上和将相机装在外部两种情况 ...

  4. pythonencoding etf-8_etf iopv python 代码30个Python常用小技巧

    1.原地交换两个数字x, y =10, 20 print(x, y) y, x = x, y print(x, y) 10 20 20 10 2.链状比较操作符n = 10 print(1 print ...

  5. 用Python代码自己写Python代码,竟如此简单

    用Python代码自己写Python代码,竟如此简单 Python作为一门功能强大且使用灵活的编程语言,可以应用于各种领域,具有"无所不能"的特质. Python甚至可以代替人,自 ...

  6. knn算法python代码_KNN算法原理(python代码实现)

    kNN(k-nearest neighbor algorithm)算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性 ...

  7. 模块度计算python代码_LPA算法C++实现及模块度计算

    前言 这学期开始看社团检测的东西,了解了一些经典算法.比如GN算法,BGLL算法(又叫Louvain, 因为该算法是作者在Louvain大学时提出的),LPA算法,等等. 我先看的LPA(毕竟算法思想 ...

  8. knn算法python代码iris_KNN算法原理及代码实现

    如何选择K值 首先让我们理解K值到底如何影响KNN算法.如果我们 有很多蓝色点和红色点数据,使用不同K值,最终的分类效果大概如下图.我们发现随着K值的增大,分界面越来越平滑. 一般在机器学习中我们要将 ...

  9. knn算法python代码_KNN 算法原理及代码实现

    在本文中,我们将讨论一种广泛使用的分类技术,称为K最近邻(KNN).我们的重点主要集中在算法如何工作以及输入参数如何影响预测结果. 内容包括: 何时使用KNN算法? KNN算法原理 如何选择K值 KN ...

  10. knn的python代码_《机器学习实战》之一:knn(python代码)

    数据 标称型和数值型 算法 归一化处理:防止数值较大的特征对距离产生较大影响 计算欧式距离:测试样本与训练集 排序:选取前k个距离,统计频数(出现次数)最多的类别 def classify0(inX, ...

最新文章

  1. IBM RSA(Rational Software Architect)试用版下载地址
  2. 新一代的树莓派3版本——Raspberry Pi 3 发布了
  3. CentOS下双网卡单网关路由配置
  4. MyBatis 缓存详解-一级缓存的不足
  5. 30岁学python全栈_知乎热帖!戳痛100万程序员:我30岁了,我还能学Python吗?
  6. 2021教师资格证中学科目二简答汇总分享
  7. 老李分享:5个衡量软件质量的标准
  8. Spark SQL初始化和创建DataFrame的几种方式
  9. 对于stackoverflow的中文翻译的相关问题
  10. Windows系统、下的MySQL、版本升级、实操
  11. revit 转换ifc_Revit官方教程:Revit模型如何导成IFC格式?
  12. 基于51单片机的指纹考勤系统
  13. 【html5插入透明Webm视频】
  14. el-checkbox-group 的坑
  15. Windows系统如何关闭防火墙保姆式教程,超详细
  16. html css 分页样式,css中分页样式
  17. 用 JustTrustMe 干翻 SSL Pinning: 爬尤美 app 付费视频(app.youmei.com)
  18. 参考 | 升级 Win11 移动热点开不了或者开了连不上
  19. Python之qq自动发消息
  20. SpringBoot08:Shiro

热门文章

  1. sqlserver 建表语句
  2. 微信公众号支付 java_微信公众号支付开发全过程(java版)
  3. 超详细 excel 基础知识
  4. Windows桌面美化——记录我的设置
  5. Java多线程面试知识点汇总(超详细总结)
  6. 产品经理面试必问(附解析)
  7. 古今地名对照总表 (按笔划数排序,强烈推荐的资料)
  8. 视频编码格式、视频码率、视频帧率、分辨率的概念
  9. 中望cad自定义快捷键命令_cad中望_中望cad常用快捷键及命令
  10. 罗技 logic C930c 摄像头 驱动 win7 64位 家庭中文版 无法使用