机器学习笔记之变分推断——随机梯度变分推断

  • 引言
    • 回顾:基于平均场假设的变分推断
      • 经典变分推断的问题
    • 随机梯度变分推断的求解过程

引言

上一节介绍了基于平均场假设的变分推断与广义EM算法的关系,本节将介绍通过随机梯度的思想实现变分推断

回顾:基于平均场假设的变分推断

基于平均场假设的变分推断通常称为经典变分推断(Classical Variational Inference)。其核心自然是 平均场假设:将隐变量Z\mathcal ZZ的概率分布Q(Z)\mathcal Q(\mathcal Z)Q(Z)看做M\mathcal MM个独立的子概率分布
Q(Z)=∏i=1MQi(Z(i))\mathcal Q(\mathcal Z) = \prod_{i=1}^{\mathcal M} \mathcal Q_i(\mathcal Z^{(i)})Q(Z)=i=1∏M​Qi​(Z(i))
其迭代过程的思想是坐标上升法(Coordinate Ascent):

  • 求解Qj(Z(j))\mathcal Q_j(\mathcal Z^{(j)})Qj​(Z(j)),固定除Qj(Z(j))\mathcal Q_j(\mathcal Z^{(j)})Qj​(Z(j))外的所有分布,并将求解出的Q^i(Z(i))\hat {\mathcal Q}_i(\mathcal Z^{(i)})Q^​i​(Z(i))替换原始的Qj(Z(j))\mathcal Q_j(\mathcal Z^{(j)})Qj​(Z(j))
    Q^j(Z(j))=arg⁡max⁡Qj(Z(j)){−KL[ϕ^(X,Z(j))∣∣Qj(Z(j))]}Q(Z)=Q1(Z(1))×⋯×Q^j(Z(j))×⋯×QM(Z(M))\hat {\mathcal Q}_j (\mathcal Z^{(j)}) = \mathop{\arg\max}\limits_{\mathcal Q_j(\mathcal Z^{(j)})} \left\{-\mathcal K\mathcal L \left[\hat \phi (\mathcal X,\mathcal Z^{(j)}) || \mathcal Q_j(\mathcal Z^{(j)})\right]\right\} \\ \mathcal Q(\mathcal Z) = \mathcal Q_1(\mathcal Z^{(1)}) \times \cdots \times \hat {\mathcal Q}_j(\mathcal Z^{(j)}) \times \cdots\times \mathcal Q_{\mathcal M}(\mathcal Z^{(\mathcal M)})Q^​j​(Z(j))=Qj​(Z(j))argmax​{−KL[ϕ^​(X,Z(j))∣∣Qj​(Z(j))]}Q(Z)=Q1​(Z(1))×⋯×Q^​j​(Z(j))×⋯×QM​(Z(M))
  • 重复上述步骤,最终第一次迭代结果得到如下形式:
    Q(Z)=Q^1(Z(1))×⋯×Q^M(Z(M))\mathcal Q(\mathcal Z) = \hat {\mathcal Q}_1(\mathcal Z^{(1)}) \times \cdots \times \hat {\mathcal Q}_{\mathcal M}(\mathcal Z^{(\mathcal M)})Q(Z)=Q^​1​(Z(1))×⋯×Q^​M​(Z(M))
  • 继续迭代,直到Q(Z)\mathcal Q(\mathcal Z)Q(Z)结果稳定且收敛。

经典变分推断的问题

虽然通过坐标上升法能够近似求解隐变量Z\mathcal ZZ的最优后验概率分布P(Z∣X)P(\mathcal Z \mid \mathcal X)P(Z∣X),但 经典变分推断 的问题也是显而易见的:平均场假设这个假设本身过于苛刻

平均场假设要保证隐变量各分组之间相互独立。而隐变量本身就是基于真实情况人为定义的变量
实际情况中,定义的隐变量满足平均场假设是极为困难的,因此,经典变分推断基本无法使用于真实任务

至此,我们在近似求解后验概率分布P(Z∣X)P(\mathcal Z \mid \mathcal X)P(Z∣X),就需要对 P(Z∣X)P(\mathcal Z \mid \mathcal X)P(Z∣X)整体进行求解。
本节将从梯度角度对P(Z∣X)P(\mathcal Z \mid \mathcal X)P(Z∣X)进行求解。

随机梯度变分推断的求解过程

回顾变分推断的推导过程,基于隐变量Z\mathcal ZZ的最优近似分布Q^(Z)\hat {\mathcal Q}(\mathcal Z)Q^​(Z) 可进行如下表示:
Q^(Z)=arg⁡max⁡Q(Z)L[Q(Z)]⇒Q^(Z)≈P(Z∣X)L[Q(Z)]=∫ZQ(Z)⋅log⁡[P(X,Z)Q(Z)]dZ\hat {\mathcal Q}(\mathcal Z) = \mathop{\arg\max}\limits_{\mathcal Q(\mathcal Z)} \mathcal L[\mathcal Q(\mathcal Z)] \Rightarrow \hat {\mathcal Q}(\mathcal Z) \approx P(\mathcal Z \mid \mathcal X) \\ \mathcal L[\mathcal Q(\mathcal Z)] = \int_{\mathcal Z} \mathcal Q(\mathcal Z) \cdot \log \left[\frac{P(\mathcal X,\mathcal Z)}{\mathcal Q(\mathcal Z)}\right] d \mathcal ZQ^​(Z)=Q(Z)argmax​L[Q(Z)]⇒Q^​(Z)≈P(Z∣X)L[Q(Z)]=∫Z​Q(Z)⋅log[Q(Z)P(X,Z)​]dZ
既然是 通过调整Q(Z)\mathcal Q(\mathcal Z)Q(Z)的最值,使得L[Q(Z)]\mathcal L[\mathcal Q(\mathcal Z)]L[Q(Z)]达到最大,因此可以尝试使用 梯度上升法(Gradient Ascent) 进行求解。

这里需要进行一些假设
既然要求解最优的Q(Z)\mathcal Q(\mathcal Z)Q(Z),根据梯度上升法,自然要求解Q(Z)\mathcal Q(\mathcal Z)Q(Z)的梯度。

而Q(Z)\mathcal Q(\mathcal Z)Q(Z)本身是一个分布,也可以看作成一个概率模型。而概率模型本身可以看作是关于该模型参数的一个函数。因此:定义概率模型Q(Z)\mathcal Q(\mathcal Z)Q(Z)的模型参数为ϕ\phiϕ,最终将求解Q(Z)\mathcal Q(\mathcal Z)Q(Z)的梯度转化为求解模型参数ϕ\phiϕ的梯度
Q(Z∣ϕ)\mathcal Q(\mathcal Z \mid \phi)Q(Z∣ϕ)写法是保留之前对概率模型的表达。例如P(X∣θ)P(\mathcal X \mid \theta)P(X∣θ),对应的L[Q(Z)]\mathcal L[\mathcal Q(\mathcal Z)]L[Q(Z)]公式也需要进行修改。
Q(Z)→Q(Z∣ϕ)L[Q(Z)]=∫Z∣ϕQ(Z∣ϕ)⋅log⁡[P(X,Z)Q(Z∣ϕ)]dZ=EQ(Z∣ϕ)[log⁡P(X,Z)−log⁡Q(Z∣ϕ)]=L(ϕ)\mathcal Q(\mathcal Z) \to \mathcal Q(\mathcal Z \mid \phi) \\ \begin{aligned} \mathcal L[\mathcal Q(\mathcal Z)] & = \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \log \left[\frac{P(\mathcal X,\mathcal Z)}{\mathcal Q(\mathcal Z \mid \phi)}\right] d\mathcal Z \\ & = \mathbb E_{\mathcal Q(\mathcal Z \mid \phi)} \left[\log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] \\ & = \mathcal L(\phi) \end{aligned} Q(Z)→Q(Z∣ϕ)L[Q(Z)]​=∫Z∣ϕ​Q(Z∣ϕ)⋅log[Q(Z∣ϕ)P(X,Z)​]dZ=EQ(Z∣ϕ)​[logP(X,Z)−logQ(Z∣ϕ)]=L(ϕ)​
与此同时,L[Q(Z)]\mathcal L[\mathcal Q(\mathcal Z)]L[Q(Z)]中的变量由Q(Z)\mathcal Q(\mathcal Z)Q(Z)变为ϕ\phiϕ,即L(ϕ)\mathcal L(\phi)L(ϕ)。从而将求解最优Q^(Z)\hat {\mathcal Q}(\mathcal Z)Q^​(Z)转化为求解最优参数ϕ^\hat \phiϕ^​
ϕ^=arg⁡max⁡ϕL(ϕ)\hat \phi = \mathop{\arg\max}\limits_{\phi} \mathcal L(\phi)ϕ^​=ϕargmax​L(ϕ)
梯度∇ϕL(ϕ)\nabla_{\phi}\mathcal L(\phi)∇ϕ​L(ϕ)进行表示:
∇ϕL(ϕ)=∇ϕ∫Z∣ϕQ(Z∣ϕ)⋅log⁡[P(X,Z)Q(Z∣ϕ)]dZ=∇ϕ∫Z∣ϕQ(Z∣ϕ)⋅[log⁡P(X,Z)−log⁡Q(Z∣ϕ)]dZ\begin{aligned} \nabla_{\phi}\mathcal L(\phi) & = \nabla_{\phi} \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \log \left[\frac{P(\mathcal X,\mathcal Z)}{\mathcal Q(\mathcal Z \mid \phi)}\right] d\mathcal Z \\ & = \nabla_{\phi} \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z \end{aligned}∇ϕ​L(ϕ)​=∇ϕ​∫Z∣ϕ​Q(Z∣ϕ)⋅log[Q(Z∣ϕ)P(X,Z)​]dZ=∇ϕ​∫Z∣ϕ​Q(Z∣ϕ)⋅[logP(X,Z)−logQ(Z∣ϕ)]dZ​
根据牛顿-莱布尼兹公式,将积分号∫\int∫与梯度∇\nabla∇进行交换
乘法求导~
∫Z∣ϕ∇ϕQ(Z∣ϕ)⋅[log⁡P(X,Z)−log⁡Q(Z∣ϕ)]dZ+∫Z∣ϕQ(Z∣ϕ)⋅∇ϕ[log⁡P(X,Z)−log⁡Q(Z∣ϕ)]dZ\int_{\mathcal Z \mid \phi} \nabla_{\phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right]d\mathcal Z + \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi}\left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z∫Z∣ϕ​∇ϕ​Q(Z∣ϕ)⋅[logP(X,Z)−logQ(Z∣ϕ)]dZ+∫Z∣ϕ​Q(Z∣ϕ)⋅∇ϕ​[logP(X,Z)−logQ(Z∣ϕ)]dZ

观察第二项:∫Z∣ϕQ(Z∣ϕ)⋅∇ϕ[log⁡P(X,Z)−log⁡Q(Z∣ϕ)]dZ\int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi}\left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z∫Z∣ϕ​Q(Z∣ϕ)⋅∇ϕ​[logP(X,Z)−logQ(Z∣ϕ)]dZ:

  • 由于ϕ\phiϕ是概率模型Q(Z∣ϕ)\mathcal Q(\mathcal Z \mid \phi)Q(Z∣ϕ)的模型参数,而P(X,Z)P(\mathcal X,\mathcal Z)P(X,Z)是X,Z\mathcal X,\mathcal ZX,Z的联合概率分布,因此与ϕ\phiϕ无关。因此第二项可变化为
    −∫Z∣ϕQ(Z∣ϕ)⋅∇ϕlog⁡Q(Z∣ϕ)dZ=−∫Z∣ϕ1Q(Z∣ϕ)⋅Q(Z∣ϕ)⋅∇ϕQ(Z∣ϕ)dZ=−∫Z∣ϕ∇ϕQ(Z∣ϕ)dZ\begin{aligned} & - \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi) d\mathcal Z \\ & = -\int_{\mathcal Z \mid \phi} \frac{1}{\mathcal Q(\mathcal Z \mid \phi)} \cdot \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi} \mathcal Q(\mathcal Z \mid \phi)d\mathcal Z \\ & = - \int_{\mathcal Z \mid \phi} \nabla_{\phi} \mathcal Q(\mathcal Z \mid \phi)d\mathcal Z \end{aligned}​−∫Z∣ϕ​Q(Z∣ϕ)⋅∇ϕ​logQ(Z∣ϕ)dZ=−∫Z∣ϕ​Q(Z∣ϕ)1​⋅Q(Z∣ϕ)⋅∇ϕ​Q(Z∣ϕ)dZ=−∫Z∣ϕ​∇ϕ​Q(Z∣ϕ)dZ​
  • 再次使用牛顿-莱布尼兹公式,将梯度符号∇\nabla∇还原位置:
    −∇ϕ∫Z∣ϕQ(Z∣ϕ)dZ- \nabla_{\phi} \int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) d\mathcal Z−∇ϕ​∫Z∣ϕ​Q(Z∣ϕ)dZ
  • 根据概率密度积分,∫Z∣ϕQ(Z∣ϕ)dZ=1\int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) d\mathcal Z = 1∫Z∣ϕ​Q(Z∣ϕ)dZ=1,第二项相当于对常数1求偏导,最后结果为0。即:
    第二项被完整地消掉了~
    ∫Z∣ϕQ(Z∣ϕ)⋅∇ϕ[log⁡P(X,Z)−log⁡Q(Z∣ϕ)]dZ=−∇ϕ1=0\int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi}\left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z = -\nabla_{\phi} 1 = 0∫Z∣ϕ​Q(Z∣ϕ)⋅∇ϕ​[logP(X,Z)−logQ(Z∣ϕ)]dZ=−∇ϕ​1=0

至此,∇ϕL(ϕ)\nabla_{\phi} \mathcal L(\phi)∇ϕ​L(ϕ)可表示为:
只剩下了第一项~
∇ϕL(ϕ)=∫Z∣ϕ∇ϕQ(Z∣ϕ)⋅[log⁡P(X,Z)−log⁡Q(Z∣ϕ)]dZ\nabla_{\phi} \mathcal L(\phi) = \int_{\mathcal Z \mid \phi} \nabla_{\phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right]d\mathcal Z∇ϕ​L(ϕ)=∫Z∣ϕ​∇ϕ​Q(Z∣ϕ)⋅[logP(X,Z)−logQ(Z∣ϕ)]dZ
观察:∇ϕQ(Z∣ϕ)\nabla_{\phi}\mathcal Q(\mathcal Z \mid \phi)∇ϕ​Q(Z∣ϕ)它并不是概率分布,而是概率分布的梯度。因此没有办法将上式写成期望形式
但是这里通过技巧 将Q(Z∣ϕ)\mathcal Q(\mathcal Z \mid \phi)Q(Z∣ϕ)还原出来
可以自己反过来推一下~
∇ϕQ(Z∣ϕ)=Q(Z∣ϕ)⋅∇ϕlog⁡Q(Z∣ϕ)\nabla_{\phi}\mathcal Q(\mathcal Z \mid \phi) = \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi)∇ϕ​Q(Z∣ϕ)=Q(Z∣ϕ)⋅∇ϕ​logQ(Z∣ϕ)
将上式带入,∇ϕL(ϕ)\nabla_{\phi} \mathcal L(\phi)∇ϕ​L(ϕ)可以表示为:
∫Z∣ϕQ(Z∣ϕ)⋅∇ϕlog⁡Q(Z∣ϕ)⋅[log⁡P(X,Z)−log⁡Q(Z∣ϕ)]dZ\int_{\mathcal Z \mid \phi} \mathcal Q(\mathcal Z \mid \phi) \cdot \nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi) \cdot \left[ \log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)\right] d\mathcal Z∫Z∣ϕ​Q(Z∣ϕ)⋅∇ϕ​logQ(Z∣ϕ)⋅[logP(X,Z)−logQ(Z∣ϕ)]dZ
可以将上述积分看作 Q(Z∣ϕ)\mathcal Q(\mathcal Z \mid \phi)Q(Z∣ϕ)分布的期望形式
∇ϕL(ϕ)=EQ(Z∣ϕ){∇ϕlog⁡Q(Z∣ϕ)⋅[log⁡P(X,Z)−log⁡Q(Z∣ϕ)]}\nabla_{\phi} \mathcal L(\phi) =\mathbb E_{\mathcal Q(\mathcal Z \mid \phi)}\left\{\nabla_{\phi} \log \mathcal Q(\mathcal Z \mid \phi) \cdot [\log P(\mathcal X,\mathcal Z) - \log \mathcal Q(\mathcal Z \mid \phi)]\right\}∇ϕ​L(ϕ)=EQ(Z∣ϕ)​{∇ϕ​logQ(Z∣ϕ)⋅[logP(X,Z)−logQ(Z∣ϕ)]}
至此,将梯度∇ϕL(ϕ)\nabla_{\phi}\mathcal L(\phi)∇ϕ​L(ϕ)使用期望形式表示出来。后续可以使用蒙特卡洛采样方法对该期望进行近似求解

至此,每求解一个∇ϕL(ϕ)\nabla_{\phi} \mathcal L(\phi)∇ϕ​L(ϕ),都可以对Q(Z∣ϕ)\mathcal Q(\mathcal Z \mid \phi)Q(Z∣ϕ)概率分布的模型参数ϕ\phiϕ 更新一次,以此类推。
最终可以近似得到概率模型Q(Z∣ϕ)\mathcal Q(\mathcal Z \mid \phi)Q(Z∣ϕ)的最优模型参数ϕ^\hat \phiϕ^​,从而求解概率模型Q(Z∣ϕ^)\mathcal Q(\mathcal Z \mid \hat \phi)Q(Z∣ϕ^​)。

下一节将介绍 随机梯度变分推断的问题及其他衍生方法

相关参考:
机器学习-变分推断4(随机梯度变分推断-SGVI-1)

机器学习笔记之变分推断(四)随机梯度变分推断(SGVI)相关推荐

  1. 李弘毅机器学习笔记:第十四章—Why deep?

    李弘毅机器学习笔记:第十四章-Why deep? 问题1:越深越好? 问题2:矮胖结构 v.s. 高瘦结构 引入模块化 深度学习 使用语音识别举例 语音辨识: 传统的实现方法:HMM-GMM 深度学习 ...

  2. 机器学习笔记之受限玻尔兹曼机(四)推断任务——边缘概率

    机器学习笔记之受限玻尔兹曼机--推断任务[边缘概率] 引言 回顾:场景构建 推断任务--边缘概率求解 边缘概率与Softplus函数 引言 上一节介绍了受限玻尔兹曼机中随机变量节点的后验概率,本节将介 ...

  3. 机器学习笔记之配分函数(一)对数似然梯度

    机器学习笔记之配分函数--对数似然梯度 引言 回顾:过去介绍配分函数的相关结点 配分函数介绍 配分函数在哪些情况下会"直面"到? 场景构建 包含配分函数的极大似然估计 引言 从本节 ...

  4. 机器学习笔记之狄利克雷过程(四)从概率图角度认识狄利克雷过程

    机器学习笔记之狄利克雷过程--从概率图角度认识狄利克雷过程 引言 关于迪利克雷混合模型 关于后验概率的求解过程 引言 上一节从随机测度 G ( i ) \mathcal G^{(i)} G(i)生成过 ...

  5. 机器学习笔记之集成学习(四)Gradient Boosting

    机器学习笔记之集成学习--Gradient Boosting 引言 回顾: Boosting \text{Boosting} Boosting算法思想与 AdaBoost \text{AdaBoost ...

  6. 《机器学习实战》第5章 随机梯度上升算法

    #!/usr/bin/env python # _*_coding:utf-8 _*_ #@Time :2018/4/9 7:56 #@Author :niutianzhuang #@FileName ...

  7. 机器学习:从感知机模型体会随机梯度下降

    文章目录 感知机模型: 感知机模型的随机梯度下降: 感知机模型的算法描述: 感知机的代码实现: 感知机模型: 寻找一个超平面使数据集线性可分,寻找超平面的过程可以转化为最小化一个损失函数的过程: 如何 ...

  8. 机器学习笔记之概率图模型(四)基于贝叶斯网络的模型概述

    机器学习笔记之概率图模型--基于贝叶斯网络的模型概述 引言 基于贝叶斯网络的模型 场景构建 朴素贝叶斯分类器 混合模型 基于时间变化的模型 特征是连续型随机变量的贝叶斯网络 动态概率图模型 总结 引言 ...

  9. python 决策树和随机森林_【python机器学习笔记】使用决策树和随机森林预测糖尿病...

    决策树:一种有监督的机器学习分类算法,可以训练已知数据,做出对未知数据的预测. 机器学习中的分类任务殊途同归,都是要根据已知的数据特征获得最佳的分类方法.对于一个有多个特征的数据,我们需要知道根据哪些 ...

最新文章

  1. oracle用户权限的基本查询
  2. [转] ios学习--openURL的使用方法
  3. Python编程:Tkinter图形界面设计(2)
  4. 聊聊spring security的permitAll以及webIgnore
  5. Exchange2007 申请安装证书
  6. matlab引擎函数,Matlab引擎库函数
  7. 无线传感器网络与数据交换解析
  8. 四、瞰景Smart3D创建工程
  9. 系泊系统悬链线matlab,基于悬链线方程的系泊系统状态分析
  10. Android开发笔记(一百四十一)读取PPT和PDF文件
  11. 2007年牛人牛语录
  12. Professional Microsoft Office SharePoint Designer 2007
  13. 快速接入高德地图SDK(地图+定位+标记+路线规划+搜索)
  14. Android 解决帧动画卡顿问题
  15. unicode,UTF-8,UTF-16,UTF-32是什么,各有什么关系
  16. SAP MB51选择界面配置
  17. 不能错过2016中国IoT大会的十个理由
  18. Arduino Software (IDE) 开发环境配置
  19. Box-Muller 变换
  20. python画樱桃小丸子_学python画图最快的方式——turtle小海龟画图

热门文章

  1. 想知道如何将PDF合并成一个文件?一分钟教会你
  2. C语言 八进制数转换为四进制
  3. [全国计算机二级]基础知识汇总(一)
  4. 第5章 LinearR/PLR/SVR/KNN/DTR/RFR(测算房价)
  5. DNS 区域传送漏洞(dns-zone-tranfer)学习
  6. ASP.NET2.0数据操作之母板页和站点导航
  7. 新版本微信分享sdk(1.8.3)踩坑实录
  8. ardupilot相机拍照控制
  9. 安全警告——“Windows已经阻止此软件因为无法验证发行者”解决办法
  10. 数据中心机房监控室效果图