此文章主要是结合哔站shuhuai008大佬的白板推导视频: VI变分推断_126min

全部笔记的汇总贴:机器学习-白板推导系列笔记

一、背景

对于概率模型

  • 从频率派角度来看就会是一个优化问题
  • 从贝叶斯角度来看就会是一个积分问题

从贝叶斯来看

p(x^∣x)=∫θp(x^,θ∣x)dθ=∫θp(x^∣θ,x)p(θ∣x)dθ=∫θp(x^∣θ)p(θ∣x)dθ=Eθ∣x[p(x^∣θ)]p(\hat{x}|x)=\int _{\theta }p(\hat{x},\theta |x)\mathrm{d}_\theta \\=\int _{\theta }p(\hat{x}|\theta ,x)p(\theta |x)\mathrm{d}_\theta \\ \overset{}{=}\int _{\theta }p(\hat{x}|\theta)p(\theta |x)\mathrm{d}_\theta \\=E_{\theta |x}[p(\hat{x}|\theta )]p(x^∣x)=∫θ​p(x^,θ∣x)dθ​=∫θ​p(x^∣θ,x)p(θ∣x)dθ​=∫θ​p(x^∣θ)p(θ∣x)dθ​=Eθ∣x​[p(x^∣θ)]

Inference分为:

  • 精确推断
  • 近似推断(确定性近似—VI;随机近似—MCMC、MH、Gibbs)

优化问题分为:

  • 回归 model:f(w)=wTxf(w)=w^Txf(w)=wTx

loss-function:无约束
L(w)=∑i=1N∣∣wTxi−yi∣∣2L(w)=\sum^{N}_{i=1}||w^Tx_i-y_i||^2L(w)=∑i=1N​∣∣wTxi​−yi​∣∣2
w^=arg⁡min⁡L(w)\hat{w}=\arg\min L(w)w^=argminL(w)
解法:
1.解析解:求导令为000,得w∗=(XTX)−1XTYw^*=(X^TX)^{-1}X^TYw∗=(XTX)−1XTY
2.数值解:GD、SGD

  • SVM(分类)

f(w)=sign(wTx+b)f(w)=sign(w^Tx+b)f(w)=sign(wTx+b)
loss-function:有约束
min⁡12wTw\min\frac{1}{2}w^Twmin21​wTw
s.t.yi(wTxi+b)≥1,i=1,2,⋯,Ns.t. \ y_i(w^Tx_i+b)\geq 1,i=1,2,\cdots,Ns.t. yi​(wTxi​+b)≥1,i=1,2,⋯,N
connex优化 对偶

  • EM

θ^=arg⁡max⁡log⁡p(x∣θ)\hat{\theta}=\arg\max\log p(x|\theta)θ^=argmaxlogp(x∣θ)
θ(t+1)=arg max⁡∫p(x,z∣θ)⋅p(z∣x,θ(t))dzθ\theta^{(t+1)}=\underset{\theta}{\argmax\int p(x,z|\theta)\cdot p(z|x,\theta^{(t)}){d}z}θ(t+1)=θargmax∫p(x,z∣θ)⋅p(z∣x,θ(t))dz​

二、公式

Data:

xxx:observed variable→X:{xi}i=1N\rightarrow X:\left \{x_{i}\right \}_{i=1}^{N}→X:{xi​}i=1N​
zzz:latent variable + parameter→Z:{zi}i=1N\rightarrow Z:\left \{z_{i}\right \}_{i=1}^{N}→Z:{zi​}i=1N​
(X,Z)(X,Z)(X,Z):complete data

引入分布q(z)q(z)q(z):

logp(x)=logp(x,z)−logp(z∣x)=logp(x,z)q(z)−logp(z∣x)q(z)log\; p(x)=log\; p(x,z)-log\; p(z|x)=log\; \frac{p(x,z)}{q(z)}-log\; \frac{p(z|x)}{q(z)}logp(x)=logp(x,z)−logp(z∣x)=logq(z)p(x,z)​−logq(z)p(z∣x)​

式子两边同时对q(z)q(z)q(z)求积分:

左边=∫zq(z)⋅logp(x∣θ)dz=logp(x∣θ)∫zq(z)dz=logp(x∣θ)=\int _{z}q(z)\cdot log\; p(x |\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z )\mathrm{d}z=log\; p(x|\theta )=∫z​q(z)⋅logp(x∣θ)dz=logp(x∣θ)∫z​q(z)dz=logp(x∣θ)
右边=∫zq(z)logp(x,z∣θ)q(z)dz⏟ELBO(EvidenceLowerBound)−∫zq(z)logp(z∣x,θ)q(z)dz⏟KL(q(z)∣∣p(z∣x,θ))=L(q)⏟变分+KL(q∣∣p)⏟≥0=\underset{ELBO(Evidence\; Lower\; Bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}\\ =\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}}=ELBO(EvidenceLowerBound)∫z​q(z)logq(z)p(x,z∣θ)​dz​​KL(q(z)∣∣p(z∣x,θ))−∫z​q(z)logq(z)p(z∣x,θ)​dz​​=变分L(q)​​+≥0KL(q∣∣p)​​

当qqq与ppp相等时,KL(q∣∣p)KL(q||p)KL(q∣∣p)等于000,此时KL(q∣∣p)KL(q||p)KL(q∣∣p)取值最小,所以这时就是要使L(q)L(q)L(q)越大越好:

q~(z)=argmaxq(z)L(q)⇒q~(z)≈p(z∣x)\tilde{q}(z)=\underset{q(z)}{argmax}\; L(q)\Rightarrow \tilde{q}(z)\approx p(z|x)q~​(z)=q(z)argmax​L(q)⇒q~​(z)≈p(z∣x)

我们对q(zq(zq(z)做以下假设,将多维变量的不同维度分为MMM组,组与组之间而且是相互独立的,所以:

q(z)=∏i=1Mqi(zi)q(z)=\prod_{i=1}^{M}q_{i}(z_{i})q(z)=i=1∏M​qi​(zi​)

此时我们固定qi(zi),i≠jq_{i}(z_{i}),i\neq jqi​(zi​),i​=j来求qj(zj)q_{j}(z_{j})qj​(zj​),所以:

L(q)=∫zq(z)logp(x,z)dz⏟①−∫zq(z)logq(z)dz⏟②L(q)=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}}L(q)=①∫z​q(z)logp(x,z)dz​​−②∫z​q(z)logq(z)dz​​

对于①①①:

①=∫z∏i=1Mqi(zi)logp(x,z)dz1dz2⋯dzM=∫zjqj(zj)(∫z−zj∏i≠jMqi(zi)logp(x,z)dz1dz2⋯dzM(i≠j))⏟∫z−zjlogp(x,z)∏i≠jMqi(zi)dzidzj=∫zjqj(zj)⋅E∏i≠jMqi(zi)[logp(x,z)]⋅dzj=∫zjqj(zj)⋅logp^(x,zj)⋅dzj①=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; p(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{j}}q_{j}(z_{j})\underset{\int _{z-z_{j}}log\; p(x,z)\prod_{i\neq j}^{M}q_{i}(z_{i})\mathrm{d}z_{i}}{\underbrace{\left (\int _{z-z_{j}}\prod_{i\neq j}^{M}q_{i}(z_{i})log\; p(x,z)\underset{(i\neq j)}{\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}}\right )}}\mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]\cdot \mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot log\; \hat{p}(x,z_{j})\cdot \mathrm{d}z_{j} ①=∫z​i=1∏M​qi​(zi​)logp(x,z)dz1​dz2​⋯dzM​=∫zj​​qj​(zj​)∫z−zj​​logp(x,z)∏i​=jM​qi​(zi​)dzi​⎝⎛​∫z−zj​​i​=j∏M​qi​(zi​)logp(x,z)(i​=j)dz1​dz2​⋯dzM​​⎠⎞​​​dzj​=∫zj​​qj​(zj​)⋅E∏i​=jM​qi​(zi​)​[logp(x,z)]⋅dzj​=∫zj​​qj​(zj​)⋅logp^​(x,zj​)⋅dzj​

对于②②②:

②=∫zq(z)logq(z)dz=∫z∏i=1Mqi(zi)∑i=1Mlogqi(zi)dz=∫z∏i=1Mqi(zi)[logq1(z1)+logq2(z2)+⋯+logqM(zM)]dz②=\int _{z}q(z)log\; q(z)\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})\sum_{i=1}^{M}log\; q_{i}(z_{i})\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})[log\; q_{1}(z_{1})+log\; q_{2}(z_{2})+\cdots +log\; q_{M}(z_{M})]\mathrm{d}z②=∫z​q(z)logq(z)dz=∫z​i=1∏M​qi​(zi​)i=1∑M​logqi​(zi​)dz=∫z​i=1∏M​qi​(zi​)[logq1​(z1​)+logq2​(z2​)+⋯+logqM​(zM​)]dz

其中

∫z∏i=1Mqi(zi)logq1(z1)dz=∫z1z2⋯zMq1(z1)q2(z2)⋯qM(zM)⋅logq1(z1)dz1dz2⋯dzM=∫z1q1(z1)logq1(z1)dz1⋅∫z2q2(z2)dz2⏟=1⋅∫z3q3(z3)dz3⏟=1⋯∫zMqM(zM)dzM⏟=1=∫z1q1(z1)logq1(z1)dz1\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{1}(z_{1})\mathrm{d}z\\ =\int _{z_{1}z_{2}\cdots z_{M}}q_{1}(z_{1})q_{2}(z_{2})\cdots q_{M}(z_{M})\cdot log\; q_{1}(z_{1})\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\cdot \underset{=1}{\underbrace{\int _{z_{2}}q_{2}(z_{2})\mathrm{d}z_{2}}}\cdot \underset{=1}{\underbrace{\int _{z_{3}}q_{3}(z_{3})\mathrm{d}z_{3}}}\cdots \underset{=1}{\underbrace{\int _{z_{M}}q_{M}(z_{M})\mathrm{d}z_{M}}}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}∫z​i=1∏M​qi​(zi​)logq1​(z1​)dz=∫z1​z2​⋯zM​​q1​(z1​)q2​(z2​)⋯qM​(zM​)⋅logq1​(z1​)dz1​dz2​⋯dzM​=∫z1​​q1​(z1​)logq1​(z1​)dz1​⋅=1∫z2​​q2​(z2​)dz2​​​⋅=1∫z3​​q3​(z3​)dz3​​​⋯=1∫zM​​qM​(zM​)dzM​​​=∫z1​​q1​(z1​)logq1​(z1​)dz1​

也就是说

∫z∏i=1Mqi(zi)logqk(zk)dz=∫zkqk(zk)logqk(zk)dzk\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{k}(z_{k})\mathrm{d}z=\int _{z_{k}}q_{k}(z_{k})log\; q_{k}(z_{k})\mathrm{d}z_{k}∫z​i=1∏M​qi​(zi​)logqk​(zk​)dz=∫zk​​qk​(zk​)logqk​(zk​)dzk​

②=∑i=1M∫ziqi(zi)logqi(zi)dzi=∫zjqj(zj)logqj(zj)dzj+C②=\sum_{i=1}^{M}\int _{z_{i}}q_{i}(z_{i})log\; q_{i}(z_{i})\mathrm{d}z_{i}\\ =\int _{z_{j}}q_{j}(z_{j})log\; q_{j}(z_{j})\mathrm{d}z_{j}+C②=i=1∑M​∫zi​​qi​(zi​)logqi​(zi​)dzi​=∫zj​​qj​(zj​)logqj​(zj​)dzj​+C

①−②①-②\;①−②:

①−②=∫zjqj(zj)⋅logp^(x,zj)qj(zj)dzj+C∫zjqj(zj)⋅logp^(x,zj)qj(zj)dzj=−KL(qj(zj)∣∣p^(x,zj))≤0①-②=\int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}+C\\ \int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}=-KL(q_{j}(z_{j})||\hat{p}(x,z_{j}))\leq 0①−②=∫zj​​qj​(zj​)⋅logqj​(zj​)p^​(x,zj​)​dzj​+C∫zj​​qj​(zj​)⋅logqj​(zj​)p^​(x,zj​)​dzj​=−KL(qj​(zj​)∣∣p^​(x,zj​))≤0

当qj(zj)=p^(x,zj)q_{j}(z_{j})=\hat{p}(x,z_{j})qj​(zj​)=p^​(x,zj​)才能得到最⼤值。

三、联系EM算法

在广义EM算法中,我们需要首先固定θ\thetaθ,然后求与ppp最接近的qqq,这里就可以使用变分推断的方法:

logpθ(x)=ELBO⏟L(q)+KL(q∣∣p)⏟≥0≥L(q)log\; p_{\theta }(x)=\underset{L(q)}{\underbrace{ELBO}}+\underset{\geq 0}{\underbrace{KL(q||p)}}\geq L(q)logpθ​(x)=L(q)ELBO​​+≥0KL(q∣∣p)​​≥L(q)

目标函数:

q^=argminqKL(q∣∣p)=argmaxqL(q)\hat{q}=\underset{q}{argmin}\; KL(q||p)=\underset{q}{argmax}\; L(q)q^​=qargmin​KL(q∣∣p)=qargmax​L(q)

logqj(zj)=E∏i≠jmqi(zi)[logpθ(x,z)]=∫z1∫z2⋯∫zj−1∫zj+1⋯∫zmq1q2⋯qj−1qj+1⋯qm⋅logpθ(x,z)dz1dz2⋯dzj−1dzj+1⋯dzmlog\; q_{j}(z_{j})=E_{\prod_{i\neq j}^{m}q_{i}(z_{i})}[log\; p_{\theta }(x,z)]\\ =\int _{z_{1}}\int _{z_{2}}\cdots \int _{z_{j-1}}\int _{z_{j+1}}\cdots \int _{z_{m}}q_{1}q_{2}\cdots q_{j-1}q_{j+1}\cdots q_{m}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{j-1}\mathrm{d}z_{j+1}\cdots \mathrm{d}z_{m}logqj​(zj​)=E∏i​=jm​qi​(zi​)​[logpθ​(x,z)]=∫z1​​∫z2​​⋯∫zj−1​​∫zj+1​​⋯∫zm​​q1​q2​⋯qj−1​qj+1​⋯qm​⋅logpθ​(x,z)dz1​dz2​⋯dzj−1​dzj+1​⋯dzm​

logq^1(z1)=∫z2⋯∫zmq2⋯qm⋅logpθ(x,z)dz2⋯dzmlogq^2(z2)=∫z1∫z3⋯∫zmq^1q3⋯qm⋅logpθ(x,z)dz1dz3⋯dzm⋮logq^m(zm)=∫z1⋯∫zm−1q^1⋯q^m−1⋅logpθ(x,z)dz1⋯dzm−1log\; \hat{q}_{1}(z_{1})=\int _{z_{2}}\cdots \int _{z_{m}}q_{2}\cdots q_{m}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{2}\cdots \mathrm{d}z_{m}\\ log\; \hat{q}_{2}(z_{2})=\int _{z_{1}}\int _{z_{3}}\cdots \int _{z_{m}}\hat{q}_{1}q_{3}\cdots q_{m}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{3}\cdots \mathrm{d}z_{m}\\ \vdots \\ log\; \hat{q}_{m}(z_{m})=\int _{z_{1}}\cdots \int _{z_{m-1}}\hat{q}_{1}\cdots \hat{q}_{m-1}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\cdots \mathrm{d}z_{m-1}logq^​1​(z1​)=∫z2​​⋯∫zm​​q2​⋯qm​⋅logpθ​(x,z)dz2​⋯dzm​logq^​2​(z2​)=∫z1​​∫z3​​⋯∫zm​​q^​1​q3​⋯qm​⋅logpθ​(x,z)dz1​dz3​⋯dzm​⋮logq^​m​(zm​)=∫z1​​⋯∫zm−1​​q^​1​⋯q^​m−1​⋅logpθ​(x,z)dz1​⋯dzm−1​

方法:坐标上升

ELBO=Eq(z)[log⁡pθ(x(i),z)q(z)]=Eq(z)[log⁡pθ(x(i),z)]+H[q(z)]KL(q∣∣p)=∫q(z)⋅log⁡q(z)pθ(z∣x(i))dzELBO=E_{q_{(z)}}[\log\frac{p_\theta(x^{(i)},z)}{q_{(z)}}]\\=E_{q_{(z)}}[\log{p_\theta(x^{(i)},z)}]+H[{q_{(z)}}]\\ KL(q||p)=\int q(z)\cdot \log\frac{q(z)}{p_\theta(z|x^{(i)})}{d}z ELBO=Eq(z)​​[logq(z)​pθ​(x(i),z)​]=Eq(z)​​[logpθ​(x(i),z)]+H[q(z)​]KL(q∣∣p)=∫q(z)⋅logpθ​(z∣x(i))q(z)​dz

四、随机梯度变分推断(SGVI)

(一)直接求导

优化⽅法除了坐标上升,还有梯度上升的⽅式。
假定q(Z)=qϕ(Z)q(Z)=q_{\phi }(Z)q(Z)=qϕ​(Z),是和ϕ\phiϕ这个参数相连的概率分布。于是

argmaxq(Z)L(q)=argmaxϕL(ϕ)\underset{q(Z)}{argmax}\; L(q)=\underset{\phi }{argmax}\; L(\phi )q(Z)argmax​L(q)=ϕargmax​L(ϕ)

其中

L(ϕ)=Eqϕ[logpθ(x,z)−logqϕ(z)]L(\phi )=E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]L(ϕ)=Eqϕ​​[logpθ​(x,z)−logqϕ​(z)]

这里的xxx表示的是样本

∇ϕL(ϕ)=∇ϕEqϕ[logpθ(x,z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logpθ(x,z)−logqϕ(z)]dz=∫∇ϕqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz⏟①+∫qϕ(z)∇ϕ[logpθ(x,z)−logqϕ(z)]dz⏟②\nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\underset{①}{\underbrace{\int \nabla_{\phi }q_{\phi }(z)\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}+\underset{②}{\underbrace{\int q_{\phi }(z)\nabla_{\phi }[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}∇ϕ​L(ϕ)=∇ϕ​Eqϕ​​[logpθ​(x,z)−logqϕ​(z)]=∇ϕ​∫qϕ​(z)[logpθ​(x,z)−logqϕ​(z)]dz=①∫∇ϕ​qϕ​(z)⋅[logpθ​(x,z)−logqϕ​(z)]dz​​+②∫qϕ​(z)∇ϕ​[logpθ​(x,z)−logqϕ​(z)]dz​​

其中

②=∫qϕ(z)∇ϕ[logpθ(x,z)⏟与ϕ无关−logqϕ(z)]dz=−∫qϕ(z)∇ϕlogqϕ(z)dz=−∫qϕ(z)1qϕ(z)∇ϕqϕ(z)dz=−∫∇ϕqϕ(z)dz=−∇ϕ∫qϕ(z)dz=−∇ϕ1=0②=\int q_{\phi }(z)\nabla_{\phi }[\underset{与\phi 无关}{\underbrace{log\; p_{\theta }(x,z)}}-log\; q_{\phi }(z)]\mathrm{d}z\\ =-\int q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)\mathrm{d}z\\ =-\int q_{\phi }(z)\frac{1}{q_{\phi }(z)}\nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\int \nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }\int q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }1\\ =0②=∫qϕ​(z)∇ϕ​[与ϕ无关logpθ​(x,z)​​−logqϕ​(z)]dz=−∫qϕ​(z)∇ϕ​logqϕ​(z)dz=−∫qϕ​(z)qϕ​(z)1​∇ϕ​qϕ​(z)dz=−∫∇ϕ​qϕ​(z)dz=−∇ϕ​∫qϕ​(z)dz=−∇ϕ​1=0

因此

∇ϕL(ϕ)=①=∫∇ϕqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz=∫qϕ(z)∇ϕlogqϕ(z)⋅[logpθ(x,z)−logqϕ(z)]dz=Eqϕ[(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))]\nabla_{\phi }L(\phi )=①\\ =\int {\color{Red}{\nabla_{\phi }q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\int {\color{Red}{q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]∇ϕ​L(ϕ)=①=∫∇ϕ​qϕ​(z)⋅[logpθ​(x,z)−logqϕ​(z)]dz=∫qϕ​(z)∇ϕ​logqϕ​(z)⋅[logpθ​(x,z)−logqϕ​(z)]dz=Eqϕ​​[(∇ϕ​logqϕ​(z))(logpθ​(x,z)−logqϕ​(z))]

这个期望可以通过蒙特卡洛采样来近似,从⽽得到梯度,然后利⽤梯度上升的⽅法来得到参数:

zl∼qϕ(z)Eqϕ[(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))]∼1L∑i=1L(∇ϕlogqϕ(z))(logpθ(x,z)−logqϕ(z))z^{l}\sim q_{\phi }(z)\\ E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\sim \frac{1}{L}\sum_{i=1}^{L}(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))zl∼qϕ​(z)Eqϕ​​[(∇ϕ​logqϕ​(z))(logpθ​(x,z)−logqϕ​(z))]∼L1​i=1∑L​(∇ϕ​logqϕ​(z))(logpθ​(x,z)−logqϕ​(z))

由于存在对数项,当趋近于000时,微小的改变会导致很大的误差,所以采用重参数化技巧(Reparameterization Tick)。

(二)重参数化技巧

取z=gϕ(ε,x),ε∼p(ε)z=g_{\phi }(\varepsilon ,x),\varepsilon \sim p(\varepsilon )z=gϕ​(ε,x),ε∼p(ε),对于z∼qϕ(z∣x)z\sim q_{\phi }(z|x)z∼qϕ​(z∣x),可以得到∣qϕ(z∣x)dz∣=∣p(ε)dε∣\left | q_{\phi }(z|x)\mathrm{d}z \right |=\left | p(\varepsilon )\mathrm{d}\varepsilon \right |∣qϕ​(z∣x)dz∣=∣p(ε)dε∣。代入上式:

∇ϕL(ϕ)=∇ϕEqϕ[logpθ(x,z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logpθ(x,z)−logqϕ(z)]dz=∇ϕ∫[logpθ(x,z)−logqϕ(z)]qϕ(z)dz=∇ϕ∫[logpθ(x,z)−logqϕ(z)]p(ε)dε=∇ϕEp(ε)(logpθ(x,z)−logqϕ(z)]=Ep(ε)[∇ϕ(logpθ(x,z)−logqϕ(z))]=Ep(ε)[∇z(logpθ(x,z)−logqϕ(z))∇ϕz]=Ep(ε)[∇z(logpθ(x,z)−logqϕ(z))∇ϕgϕ(ε,x)]\nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]q_{\phi }(z)\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]p(\varepsilon )\mathrm{d}\varepsilon \\ =\nabla_{\phi }E_{p(\varepsilon )}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =E_{p(\varepsilon )}[\nabla_{\phi }(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }z]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }g_{\phi }(\varepsilon ,x)]∇ϕ​L(ϕ)=∇ϕ​Eqϕ​​[logpθ​(x,z)−logqϕ​(z)]=∇ϕ​∫qϕ​(z)[logpθ​(x,z)−logqϕ​(z)]dz=∇ϕ​∫[logpθ​(x,z)−logqϕ​(z)]qϕ​(z)dz=∇ϕ​∫[logpθ​(x,z)−logqϕ​(z)]p(ε)dε=∇ϕ​Ep(ε)​(logpθ​(x,z)−logqϕ​(z)]=Ep(ε)​[∇ϕ​(logpθ​(x,z)−logqϕ​(z))]=Ep(ε)​[∇z​(logpθ​(x,z)−logqϕ​(z))∇ϕ​z]=Ep(ε)​[∇z​(logpθ​(x,z)−logqϕ​(z))∇ϕ​gϕ​(ε,x)]

进⾏蒙特卡洛采样,然后计算期望,得到梯度。
SGVI的迭代过程为:

ϕt+1←ϕt+λt⋅∇ϕL(ϕ)\phi ^{t+1}\leftarrow \phi ^{t}+\lambda ^{t}\cdot \nabla_{\phi }L(\phi )ϕt+1←ϕt+λt⋅∇ϕ​L(ϕ)

下一章传送门:白板推导系列笔记(十三)-马尔科夫链蒙特卡洛方法

参考文章
变分推断|机器学习推导系列

机器学习-白板推导系列笔记(十二)-变分推断(VI)相关推荐

  1. 机器学习-白板推导系列笔记(二十八)-BM

    此文章主要是结合哔站shuhuai008大佬的白板推导视频:玻尔兹曼机_147min 全部笔记的汇总贴:机器学习-白板推导系列笔记 参考花书20.1 一.介绍 玻尔兹曼机连接的每个节点都是离散的二值分 ...

  2. 机器学习-白板推导系列笔记(二十一)-RBM

    此文章主要是结合哔站shuhuai008大佬的白板推导视频:受限玻尔兹曼机_155min 全部笔记的汇总贴:机器学习-白板推导系列笔记 玻尔兹曼机介绍:白板推导系列笔记(二十八)-玻尔兹曼机 一.背景 ...

  3. 机器学习-白板推导系列笔记(三十四)-MDP

    此文章主要是结合哔站shuhuai008大佬的白板推导视频:马尔科夫决策过程_107min 全部笔记的汇总贴:机器学习-白板推导系列笔记 一.背景介绍 Random Variable:XYX⊥YX\; ...

  4. 机器学习-白板推导系列笔记(十三)-MCMC

    此文章主要是结合哔站shuhuai008大佬的白板推导视频:MCMC_218min 全部笔记的汇总贴:机器学习-白板推导系列笔记 一.蒙特卡洛方法 蒙特卡洛方法(Monte Carlo Method) ...

  5. 机器学习-白板推导系列(三十)-生成模型(Generative Model)

    机器学习-白板推导系列(三十)-生成模型(Generative Model) 30.1 生成模型的定义 前面所详细描述的模型以浅层的机器学习为主.本章将承上启下引出后面深度机器学习的部分.本小节,主要 ...

  6. 机器学习-白板推导-系列(十)笔记:EM算法

    文章目录 0 笔记说明 1 算法收敛性证明 2 公式导出 2.1 ELBO+KL Divergence 2.2 ELBO+Jensen Inequlity 2.3 最后的工作 3 从狭义EM到广义EM ...

  7. 机器学习-白板推导-系列(五)笔记:降维(PCA/SVD/PCoA/PPCA)

    文章目录 0 笔记说明 1 背景 1.1 样本均值 1.2 样本协方差矩阵 2 主成分分析PCA 2.1 最大投影方差 2.2 最小重构距离 2.3 总结 3 SVD分解HX 4 主坐标分析PCoA ...

  8. 机器学习-白板推导-系列(九)笔记:概率图模型: 贝叶斯网络/马尔可夫随机场/推断/道德图/因子图

    文章目录 0 笔记说明 1 背景介绍 1.1 概率公式 1.2 概率图简介 1.2.1 表示 1.2.2 推断 1.2.3 学习 1.2.4 决策 1.3 图 2 贝叶斯网络 2.1 条件独立性 2. ...

  9. 机器学习-白板推导-系列(八)笔记:指数族分布/充分统计量/对数配分函数/最大熵

    文章目录 0 笔记说明 1 背景 1.1 指数族分布的一般形式 1.2 共轭先验 2 高斯分布的指数族形式 3 对数配分函数与充分统计量 4 极大似然估计与充分统计量 5 熵 5.1 最大熵⇔x服从均 ...

最新文章

  1. java orm 工具_GitHub - donnie4w/jdao: jdao是一个java的轻量级orm工具包
  2. python multi_python – 堆叠MultiIndex的所有级别
  3. Oracle性能监控脚本
  4. AI研发新药登上Nature子刊:46天合成潜在新药候选分子,比传统方法快15倍 | 开源...
  5. Python进阶之路:namedtuple
  6. python get rect 函数_python笔记之函数
  7. CCNA知识总结(一)
  8. Eclipse下Maven工程多模块继承和聚合的创建
  9. 06-netty之http之文件服务器
  10. bootstrap表格遍历_BootStrap实现带有增删改查功能的表格(DEMO详解)
  11. Oracle转义符处理
  12. dll domodal运行时异常_解决装备疑难,计算机丢失***.dll文件方法「设计画圈」
  13. 【java笔记】方法引用介绍和使用
  14. mysql数据库登录历史_mysql数据库查看历史记录
  15. 小程序如何避免多次点击,重复触发事件
  16. Android心电数据分析,Android系统下的心电数据分析软件设计
  17. 从Q_Learning看强化学习
  18. 前中后序遍历对比记忆
  19. 【转】J2SDK1.5+TOMCAT5.5 最详细有效安装及配置
  20. windows下iexplore的命令行参数

热门文章

  1. 高考415分能上计算机网络的学校吗,2021年高考415分文科能上什么学校 文科415分左右的大学有哪些...
  2. 云媒易:抖音短视频推广小技巧汇总
  3. 你所不知道的21个云服务器的应用实践———云计算综合入门指南
  4. 拉伯证券|A股涨势趋缓,个股分化,北向资金继续“买买买”
  5. Iphone实现播放语音
  6. 迷你辦公室 迷你寫字樓 商務寫字樓 商務中心
  7. STM32F10x程序移植到GD32E10x
  8. inshot怎么转gif_inshot怎么用 教你照片视频制作技巧
  9. 16_linux笔记-用户与组
  10. 劳动纠纷管辖是怎么规定的