白板推导系列Pytorch-期望最大(EM)算法
白板推导系列Pytorch-期望最大(EM)算法
EM算法介绍直接看这篇博客-如何通俗理解EM算法,讲的非常好,里面也有算法的公式推导。当然白板推导的视频里面公式推导已经讲的很清楚了,就是缺少应用实例。这篇博客用三个很通俗的例子引入了极大似然估计和EM算法。美中不足的是并没有详细说明极大似然估计并不是一定陷入鸡生蛋蛋生鸡的循环而没有办法处理隐变量问题,而是由于计算复杂从而摒弃了这个方法。
当我们能知道z的分布的时候,其实也是可以用极大似然估计表示的
但是,很多时候,我们很难获得Z的分布,除非我们事先对Z已经很有了解,比如我们如果能够确定Z是一个伯努利分布(比如三硬币模型),那么对Z的分布估计问题就转化成了一个P参数的估计问题
虽然但是,即便知道z的分布,使用极大似然估计也不一定能求。
下面我们一起看看三硬币模型的例子
三硬币模型
有三枚硬币(ABC)正面向上的概率分别为π,p,q\pi,p,qπ,p,q。进行如下试验——先掷A,如果A正面向上则掷B,如果A反面向上则掷C。如此独立地重复做n次试验;记录B和C的结果,正面向上记为1,观测结果为:1,1,0,0,...1,1,0,0,...1,1,0,0,...
若只能观测到结果,不能观测到掷硬币过程,即每一次的观测结果(1或0)由B或C中的哪枚硬币掷出的是未知的。如此情况下估计三枚硬币正面向上的概率π,p,q\pi,p,qπ,p,q。观测数据表示为Y=(Y1,Y2,...,Yn)TY = (Y_1,Y_2,...,Y_n)^TY=(Y1,Y2,...,Yn)T,未观测数据表示为Z=(Z1,Z2,...,Zn)Z=(Z_1,Z_2,...,Z_n)Z=(Z1,Z2,...,Zn)。
一次试验
P(y∣θ)=πpy(1−p)(1−y)+(1−π)qy(1−q)(1−y)P(y \mid \theta)=\pi p^{y}(1-p)^{(1-y)}+(1-\pi) q^{y}(1-q)^{(1-y)} P(y∣θ)=πpy(1−p)(1−y)+(1−π)qy(1−q)(1−y)
n次试验
P(Y∣θ)=∑ZP(Z∣θ)P(Y∣Z,θ)=∏j=1n[πpyj(1−p)(1−yj)+(1−π)qyj(1−q)(1−yj)]\begin{aligned} P(Y \mid \theta) &=\sum_{Z} P(Z \mid \theta) P(Y \mid Z, \theta) \\ &=\prod_{j=1}^{n}\left[\pi p^{y_{j}}(1-p)^{\left(1-y_{j}\right)}+(1-\pi) q^{y_{j}}(1-q)^{\left(1-y_{j}\right)}\right] \end{aligned} P(Y∣θ)=Z∑P(Z∣θ)P(Y∣Z,θ)=j=1∏n[πpyj(1−p)(1−yj)+(1−π)qyj(1−q)(1−yj)]
极大似然估计
从极大似然的角度,我们显然是要找到最合适的π,p,q\pi,p,qπ,p,q使得P(Y∣θ)P(Y \mid \theta)P(Y∣θ)最大。如下
θ=argmaxθlogP(Y∣θ)=argmaxθ∑j=1nlog[πpyj(1−p)(1−yj)+(1−π)qyj(1−q)(1−yj)]\begin{aligned} \theta &= \underset{\theta}{argmax}\ log\ P(Y|\theta) \\ &= \underset{\theta}{argmax}\sum_{j=1}^{n}log\left[\pi p^{y_{j}}(1-p)^{\left(1-y_{j}\right)}+(1-\pi) q^{y_{j}}(1-q)^{\left(1-y_{j}\right)}\right] \end{aligned} θ=θargmax log P(Y∣θ)=θargmaxj=1∑nlog[πpyj(1−p)(1−yj)+(1−π)qyj(1−q)(1−yj)]
然后如果我们对π,p,q\pi,p,qπ,p,q求偏导
∂L∂π=∑j=1npyj(1−p)(1−yj)−qyj(1−q)(1−yj)πpyj(1−p)(1−yj)+(1−π)qyj(1−q)(1−yj)=0\frac{\partial L}{\partial \pi} = \sum_{j=1}^{n}\frac{p^{y_{j}}(1-p)^{\left(1-y_{j}\right)}-q^{y_{j}}(1-q)^{\left(1-y_{j}\right)}}{\pi p^{y_{j}}(1-p)^{\left(1-y_{j}\right)}+(1-\pi) q^{y_{j}}(1-q)^{\left(1-y_{j}\right)}} = 0 ∂π∂L=j=1∑nπpyj(1−p)(1−yj)+(1−π)qyj(1−q)(1−yj)pyj(1−p)(1−yj)−qyj(1−q)(1−yj)=0
其它都不用求了,就看上面这个式子,你说这要怎么求?我是没有办法
EM求解
以下内容参考EM算法公式推导 (三硬币模型)
现在我们再来看看EM算法会怎么处理这个事情
我们先把EM算法的公式列出来
E-step
E=∫zlogP(y,z∣θ)P(z∣y,θt)dzE = \int_z logP(y,z|\theta)\ P(z|y,\theta^t)dz E=∫zlogP(y,z∣θ) P(z∣y,θt)dz
M-step
θt+1=argmaxθE\theta^{t+1} = \underset{\theta}{argmax}\ E θt+1=θargmax E
对于离散的情况,计算E的时候,我们通常会转化成这样的式子(至于怎么转换的,白板视频中有推导,然后我贴的原博客中也有推导)
∑j=1N[∑zjP(zj∣yj,θt)log[P(yj,zj∣θ)]]\sum_{j=1}^{N}\left[\sum_{z_{j}} P\left(z_{j} \mid y_{j}, \theta^{t}\right)\ log[P\left(y_{j}, z_{j} \mid \theta\right)]\right] j=1∑N⎣⎡zj∑P(zj∣yj,θt) log[P(yj,zj∣θ)]⎦⎤
然后利用贝叶斯定理和全概率公式得到
P(zj∣yj,θt)=P(yj∣zj,θt)∗P(zj∣θt)∑k=12P(yj∣zk,θt)∗P(zk∣θt)\begin{aligned} P(z_j \mid y_j,\theta^t) = \frac{P(y_j|z_j,\theta^t)*P(z_j|\theta^t)}{\sum_{k=1}^{2}P(y_j|z_k,\theta^t)*P(z_k|\theta^t)} \end{aligned} P(zj∣yj,θt)=∑k=12P(yj∣zk,θt)∗P(zk∣θt)P(yj∣zj,θt)∗P(zj∣θt)
P(yj,zj∣θ)=P(yj∣zj,θ)∗P(zj∣θ)P(y_j,z_j|\theta) = P(y_j|z_j,\theta)*P(z_j|\theta) P(yj,zj∣θ)=P(yj∣zj,θ)∗P(zj∣θ)
定义
P(zj=1∣yj,θt)=ptyj∗(1−pt1−yj)∗πtptyj∗(1−pt1−yj)∗πt+qtyj∗(1−qt1−yj)∗(1−πt)=μjtP(zj=0∣yj,θt)=qtyj∗(1−qt1−yj)∗(1−πt)ptyj∗(1−pt1−yj)∗πt+qtyj∗(1−qt1−yj)∗(1−πt)=1−μjt\begin{aligned} P(z_j=1|y_j,\theta^t) &= \frac{p^{t^{y_j}}*(1-p^{t^{1-y_j}})*\pi^t}{p^{t^{y_j}}*(1-p^{t^{1-y_j}})*\pi^t+q^{t^{y_j}}*(1-q^{t^{1-y_j}})*(1-\pi^t)} = \mu_j^{t} \\ \\ P(z_j=0|y_j,\theta^t) &= \frac{q^{t^{y_j}}*(1-q^{t^{1-y_j}})*(1-\pi^t)}{p^{t^{y_j}}*(1-p^{t^{1-y_j}})*\pi^t+q^{t^{y_j}}*(1-q^{t^{1-y_j}})*(1-\pi^t)} = 1-\mu_j^{t} \end{aligned} P(zj=1∣yj,θt)P(zj=0∣yj,θt)=ptyj∗(1−pt1−yj)∗πt+qtyj∗(1−qt1−yj)∗(1−πt)ptyj∗(1−pt1−yj)∗πt=μjt=ptyj∗(1−pt1−yj)∗πt+qtyj∗(1−qt1−yj)∗(1−πt)qtyj∗(1−qt1−yj)∗(1−πt)=1−μjt
P(yj,zj=1∣θ)=pyj∗(1−p1−yj)∗πP(yj,zj=0∣θ)=qyj∗(1−q1−yj)∗(1−π)\begin{aligned} P(y_j,z_j=1|\theta) = p^{y_j}*(1-p^{1-y_j})*\pi \\ P(y_j,z_j=0|\theta) = q^{y_j}*(1-q^{1-y_j})*(1-\pi) \end{aligned} P(yj,zj=1∣θ)=pyj∗(1−p1−yj)∗πP(yj,zj=0∣θ)=qyj∗(1−q1−yj)∗(1−π)
得到
Q(θ,θt)=∑j=1N{μjt⋅ln[π⋅pyj(1−p)1−yj]+(1−μjt)⋅ln[(1−π)qyj(1−q)1−yj]}Q\left(\theta, \theta^{t}\right)=\sum_{j=1}^{N}\left\{\mu_{j}^{t} \cdot \ln \left[\pi \cdot p^{y_{j}}(1-p)^{1-y_{j}}\right]+\left(1-\mu_{j}^{t}\right) \cdot \ln \left[(1-\pi) q^{y_{j}}(1-q)^{1-y_{j}}\right]\right\} Q(θ,θt)=j=1∑N{μjt⋅ln[π⋅pyj(1−p)1−yj]+(1−μjt)⋅ln[(1−π)qyj(1−q)1−yj]}
对π\piπ求偏导
∂Q∂π=∑j=1N{μjt⋅1π+(1−μjt)⋅−qyj(1−q)1−yj(1−π)qyj(1−q)1−yj}=∑j=1Nμjt−ππ⋅(1−π)=0\begin{aligned} \frac{\partial Q}{\partial \pi} &= \sum_{j=1}^{N}\{\mu_{j}^{t} \cdot\frac{1}{ \pi}+\left(1-\mu_{j}^{t}\right) \cdot \frac{-q^{y_{j}}(1-q)^{1-y_{j}}}{(1-\pi) q^{y_{j}}(1-q)^{1-y_{j}}} \} \\ &=\sum_{j=1}^{N}\frac{\mu_{j}^{t}-\pi}{\pi \cdot (1-\pi)} = 0 \end{aligned} ∂π∂Q=j=1∑N{μjt⋅π1+(1−μjt)⋅(1−π)qyj(1−q)1−yj−qyj(1−q)1−yj}=j=1∑Nπ⋅(1−π)μjt−π=0
得到
πt+1=1N∑j=1Nμjt\pi^{t+1} = \frac{1}{N}\sum_{j=1}^{N}\mu_{j}^{t} πt+1=N1j=1∑Nμjt
对ppp求偏导
∂Q∂p=∑j=1Nμjt⋅∂∂p[lnπ+ylnp+(1−y)ln(1−p)]=∑j=1Nujt⋅(yp+1−y1−p)=∑j=1Nujt⋅(yj−pp⋅(1−p))=0⇒∑j=1Nujt⋅yj−∑j=1Nujt⋅p=0\begin{aligned} \frac{\partial Q}{\partial p} &= \sum_{j=1}^{N} \mu_j^{t} \cdot \frac{\partial}{\partial p}[\ln \pi+y \ln p+(1-y) \ln (1-p)]\\ &= \sum_{j=1}^{N}u_j^t\cdot(\frac{y}{p}+\frac{1-y}{1-p}) \\ &= \sum_{j=1}^{N}u_j^t\cdot(\frac{y_j-p}{p\cdot(1-p)}) = 0 \\ &\Rightarrow \sum_{j=1}^{N}u_j^t\cdot y_j-\sum_{j=1}^{N}u_j^t\cdot p = 0 \end{aligned} ∂p∂Q=j=1∑Nμjt⋅∂p∂[lnπ+ylnp+(1−y)ln(1−p)]=j=1∑Nujt⋅(py+1−p1−y)=j=1∑Nujt⋅(p⋅(1−p)yj−p)=0⇒j=1∑Nujt⋅yj−j=1∑Nujt⋅p=0
得到
pt+1=∑j=1Nujt⋅yj∑j=1Nujtp^{t+1} = \frac{\sum_{j=1}^{N}u_j^t\cdot y_j}{\sum_{j=1}^{N}u_j^t} pt+1=∑j=1Nujt∑j=1Nujt⋅yj
对q求偏导(略),q和p是类似的,只要换掉权重ujtu_j^tujt为1−ujt1-u_j^t1−ujt即可
得到
qt+1=∑j=1N(1−ujt)⋅yj∑j=1N(1−ujt)q^{t+1} = \frac{\sum_{j=1}^{N}(1-u_j^t)\cdot y_j}{\sum_{j=1}^{N}(1-u_j^t)} qt+1=∑j=1N(1−ujt)∑j=1N(1−ujt)⋅yj
至此,我们已经获得了π,p,q\pi,p,qπ,p,q的递推式
πt+1=1N∑j=1Nμjtpt+1=∑j=1Nujt⋅yj∑j=1Nujtqt+1=∑j=1N(1−ujt)⋅yj∑j=1N(1−ujt)\begin{aligned} &\pi^{t+1} = \frac{1}{N}\sum_{j=1}^{N}\mu_{j}^{t} \\ &p^{t+1} = \frac{\sum_{j=1}^{N}u_j^t\cdot y_j}{\sum_{j=1}^{N}u_j^t} \\ &q^{t+1} = \frac{\sum_{j=1}^{N}(1-u_j^t)\cdot y_j}{\sum_{j=1}^{N}(1-u_j^t)} \end{aligned} πt+1=N1j=1∑Nμjtpt+1=∑j=1Nujt∑j=1Nujt⋅yjqt+1=∑j=1N(1−ujt)∑j=1N(1−ujt)⋅yj
得到递推式后,我们就可以反复迭代这几个式子,直到
πt+1,pt+1,qt+1≈πt,pt,qt\pi^{t+1},p^{t+1},q^{t+1} \approx \pi^{t},p^{t},q^{t} πt+1,pt+1,qt+1≈πt,pt,qt
EM算法实现
下面我们用代码实现一下三硬币模型
首先我们先定义数据集
假设我们投了20次A硬币,得到B、C的投掷结果如下
observations = [1,0,0,1,1,1,0,1,1,1,0,0,0,0,1,1,0,1,0,0]
初始A,B,C正面朝上的概率分别为0.4,0.6,0.5
pi,p,q = 0.4,0.6,0.5
遍历观测序列
def em(pi,p,q,max_iter = 100,toler = 0.001):for epoch in range(max_iter):p_up,p_down,q_up,q_down = .0,.0,.0,.0for observation in observations:if observation==1:ut = p*pi/(p*pi+q*(1-pi))else:ut = (1-p)*pi/((1-p)*pi+(1-q)*(1-pi))p_up += ut*observationp_down += utq_up += (1-ut)*observationq_down += 1-utpi_next = p_down/len(observations)p_next = p_up/p_downq_next = q_up/q_down# if np.abs(pi-pi_next)<toler and np.abs(p-p_next)<toler and np.abs(q-q_next)<toler:if pi==pi_next and p==p_next and q==q_next:pi = pi_nextp = p_nextq = q_nextbreakelse:pi = pi_nextp = p_nextq = q_nextprint("epoch %s:%s %s %s"%(epoch+1,pi,p,q))return pi,p,q
执行em函数
em(pi,p,q)
输出
epoch 1:0.41711229946524064 0.435897435897436 0.5458715596330275
epoch 2:0.4171122994652407 0.43589743589743596 0.5458715596330275
epoch 3:0.4171122994652408 0.43589743589743596 0.5458715596330275
(0.4171122994652408, 0.43589743589743596, 0.5458715596330275)
白板推导系列Pytorch-期望最大(EM)算法相关推荐
- 机器学习-白板推导-系列(十)笔记:EM算法
文章目录 0 笔记说明 1 算法收敛性证明 2 公式导出 2.1 ELBO+KL Divergence 2.2 ELBO+Jensen Inequlity 2.3 最后的工作 3 从狭义EM到广义EM ...
- 机器学习-白板推导系列笔记(十二)-变分推断(VI)
此文章主要是结合哔站shuhuai008大佬的白板推导视频: VI变分推断_126min 全部笔记的汇总贴:机器学习-白板推导系列笔记 一.背景 对于概率模型 从频率派角度来看就会是一个优化问题 从贝 ...
- 机器学习-白板推导系列笔记(十三)-MCMC
此文章主要是结合哔站shuhuai008大佬的白板推导视频:MCMC_218min 全部笔记的汇总贴:机器学习-白板推导系列笔记 一.蒙特卡洛方法 蒙特卡洛方法(Monte Carlo Method) ...
- 机器学习-白板推导系列(三十)-生成模型(Generative Model)
机器学习-白板推导系列(三十)-生成模型(Generative Model) 30.1 生成模型的定义 前面所详细描述的模型以浅层的机器学习为主.本章将承上启下引出后面深度机器学习的部分.本小节,主要 ...
- 机器学习-白板推导-系列(五)笔记:降维(PCA/SVD/PCoA/PPCA)
文章目录 0 笔记说明 1 背景 1.1 样本均值 1.2 样本协方差矩阵 2 主成分分析PCA 2.1 最大投影方差 2.2 最小重构距离 2.3 总结 3 SVD分解HX 4 主坐标分析PCoA ...
- 机器学习-白板推导-系列(九)笔记:概率图模型: 贝叶斯网络/马尔可夫随机场/推断/道德图/因子图
文章目录 0 笔记说明 1 背景介绍 1.1 概率公式 1.2 概率图简介 1.2.1 表示 1.2.2 推断 1.2.3 学习 1.2.4 决策 1.3 图 2 贝叶斯网络 2.1 条件独立性 2. ...
- 机器学习-白板推导系列笔记(三十四)-MDP
此文章主要是结合哔站shuhuai008大佬的白板推导视频:马尔科夫决策过程_107min 全部笔记的汇总贴:机器学习-白板推导系列笔记 一.背景介绍 Random Variable:XYX⊥YX\; ...
- 机器学习-白板推导-系列(八)笔记:指数族分布/充分统计量/对数配分函数/最大熵
文章目录 0 笔记说明 1 背景 1.1 指数族分布的一般形式 1.2 共轭先验 2 高斯分布的指数族形式 3 对数配分函数与充分统计量 4 极大似然估计与充分统计量 5 熵 5.1 最大熵⇔x服从均 ...
- 机器学习-白板推导系列笔记(二十八)-BM
此文章主要是结合哔站shuhuai008大佬的白板推导视频:玻尔兹曼机_147min 全部笔记的汇总贴:机器学习-白板推导系列笔记 参考花书20.1 一.介绍 玻尔兹曼机连接的每个节点都是离散的二值分 ...
- 机器学习-白板推导系列笔记(二十一)-RBM
此文章主要是结合哔站shuhuai008大佬的白板推导视频:受限玻尔兹曼机_155min 全部笔记的汇总贴:机器学习-白板推导系列笔记 玻尔兹曼机介绍:白板推导系列笔记(二十八)-玻尔兹曼机 一.背景 ...
最新文章
- 王朝阳:河北高校邀请赛 -- 二手车交易价格预测决赛答辩
- [首次分析]PHP写框架
- 一个完整的schema验证xml的样例
- Java Enumeration接口与Iterator接口
- Python二级笔记(9)
- 软件测试工程师阶段_软件工程测试阶段
- 2016 ICPC 北京
- 做项目时的几个感慨(持续更新...)
- 无人车飞速狂飙,黑科技如何为其加油打气?
- iOS程序员必须知道的Android要点
- 大学生创新创业 /互联网+ 大赛 商业计划书目录(模板)
- 做完一个网站重构项目的总结以及感想!
- linux yum安装驱动,centos8安装alsa驱动
- 关于报 程序包com.jt.pojo不存在、报 Process terminated、Failed to execute goal on project jt-manage: 的问题,已解决
- 【论文笔记】Character-Aware Neural Language Models
- DEVC++第五人格V2.0
- 有一个棋盘,有64个方格,在第一个方格里面放1粒芝麻重量是0.00001kg,第二个里面放2粒,第三个里面放4,棋盘上放的所有芝麻的重量。
- hdu1151Air Raid poj2594Treasure Exploration题解
- ZOJ 1217 Eight(单向BFS+map)
- npj | 王德华/张学英等揭示荒漠啮齿动物通过“菌群-肠-肾”轴耐受高盐的机制...