深度学习之生成对抗网络(7)WGAN原理

  • 1. JS散度的缺陷
  • 2. EM距离
  • 3. WGAN-GP

 WGAN算法从理论层面分析了GAN训练不稳定的原因,并提出了有效的解决方法。那么是什么原因导致了GAN训练如此不稳定呢?WGAN提出是因为JS散度在不重叠的分布 ppp和 qqq上的梯度曲面是恒定为0的。如下图所示。当分布p和q不重叠时,JS散度的梯度值始终为0,从而导致此时GAN的训练出现梯度弥散现象,参数长时间得不到更新,网络无法收敛。
图1. JS散度出现梯度弥散现象

 接下来我们将详细阐述JS散度的缺陷以及怎么解决此缺陷。

1. JS散度的缺陷

为了避免过多的理论推导,我们这里通过一个简单的分布实例来解释JS散度的缺陷。
考虑完全不重叠(θ≠0θ≠0θ​=0)的两个分布ppp和qqq,其中ppp为:
∀(x,y)∈p,x=0,y∼U(0,1)∀(x,y)∈p,x=0,y\sim\text{U}(0,1)∀(x,y)∈p,x=0,y∼U(0,1)
分布qqq为:
∀(x,y)∈q,x=θ,y∼U(0,1)∀(x,y)∈q,x=θ,y\sim\text{U}(0,1)∀(x,y)∈q,x=θ,y∼U(0,1)
其中θ∈Rθ∈Rθ∈R,当θ=0θ=0θ=0时,分布ppp和qqq重叠,两者相等;当θ≠0θ≠0θ​=0时,分布ppp和qqq不重叠。

图2. 分布$p$和$q$示意图

 我们来分析上述分布ppp和qqq之间的JS散度随θθθ的变化情况。根据KL散度与JS散度的定义,计算θ=0θ=0θ=0时的JS散度DJS(p∣∣q)D_{JS} (p||q)DJS​(p∣∣q):
DKL(p∣∣q)=∑x=0,y∼U(0,1)1⋅log⁡10=+∞D_{KL} (p||q)=∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}⁡\frac{1}{0}=+∞DKL​(p∣∣q)=x=0,y∼U(0,1)∑​1⋅log⁡01​=+∞
DKL(q∣∣p)=∑x=θ,y∼U(0,1)1⋅log⁡10=+∞D_{KL} (q||p)=∑_{x=θ,y\sim\text{U}(0,1)}1\cdot\text{log}⁡\frac{1}{0}=+∞DKL​(q∣∣p)=x=θ,y∼U(0,1)∑​1⋅log⁡01​=+∞
DJS(p∣∣q)=12(∑x=0,y∼U(0,1)1⋅log11/2+∑x=0,y∼U(0,1)1⋅log11/2)=log⁡2D_{JS} (p||q)=\frac{1}{2} \bigg(∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}+∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}\bigg)=\text{log}⁡2DJS​(p∣∣q)=21​(x=0,y∼U(0,1)∑​1⋅log1/21​+x=0,y∼U(0,1)∑​1⋅log1/21​)=log⁡2
 当θ=0θ=0θ=0时,两个分布完全重叠,此时的JS散度和KL散度都取得最小值,即0:
DKL(p∣∣q)=DKL(q∣∣p)=DJS(p∣∣q)=0D_{KL} (p||q)=D_{KL} (q||p)=D_{JS} (p||q)=0DKL​(p∣∣q)=DKL​(q∣∣p)=DJS​(p∣∣q)=0
从上面的推导,我们可以得到DJS(p∣∣q)D_{JS} (p||q)DJS​(p∣∣q)随θθθ的变化趋势:
DJS(p∣∣q)={log⁡2θ≠00θ=0D_{JS} (p||q) = \begin{cases} \text{log⁡}2 &\text{} θ≠0 \\ 0 &\text{} θ=0 \end{cases}DJS​(p∣∣q)={log⁡20​θ​=0θ=0​
也就是说,当两个分布完全不重叠时,无论发布之间的距离远近,JS散度为恒定值log⁡2\text{log}⁡2log⁡2,此时JS散度将无法产生有效的梯度信息;当两个分布出现重叠时,JS散度采会平滑变动,产生有效梯度信息;当完全重合后,JS散度取得最小值0.如下图所示,红色的曲线分割两个正态分布,由于两个分布没有重叠,生成样本位置处的梯度值始终为0,无法更新生成网络的参数,从而出现网络训练困难的现象。

图3. JS散度出现梯度弥散现象

 因此,JS散度在分布ppp和qqq不重叠时是无法平滑地衡量分布之间的距离,从而导致此位置上无法产生有效梯度信息,出现GAN训练不稳定的情况。要解决此问题,需要使用一种更好的分布距离衡量标准,使得它即使在分布ppp和qqq不重叠时,也能平滑反映分布之间的真实距离变化。

2. EM距离

 WGAN论文发现了JS散度导致GAN训练不稳定的问题,并引入了一种新的分布距离度量方法:Wasserstein距离,也叫推土机距离(Earth-Mover Distance,简称EM距离),它表示了从一个分布变换到另一个分布的最小代价,定义为:
W(p,q)=infγ∼∏(p,q)E(x,y)∼γ[∥x−y∥]W(p,q)=\underset{γ\sim∏(p,q)}{\text{inf}}\mathbb E_{(x,y)\simγ} [\|x-y\|]W(p,q)=γ∼∏(p,q)inf​E(x,y)∼γ​[∥x−y∥]
其中∏(p,q)∏(p,q)∏(p,q)是分布ppp和qqq组合起来的所有可能的联合分布的集合,对于每个可能的联合分布γ∼∏(p,q)γ\sim∏(p,q)γ∼∏(p,q),计算距离∥x−y∥\|x-y\|∥x−y∥的期望E(x,y)∼γ[∥x−y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)∼γ​[∥x−y∥],其中(x,y)(x,y)(x,y)采样自联合分布γγγ。不同的联合分布γγγ由不同的期望E(x,y)∼γ[∥x−y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)∼γ​[∥x−y∥],这些期望中的下确界即定义为分布ppp和qqq的Wasserstein距离。其中inf⁡{⋅}\text{inf}⁡\{\cdot\}inf⁡{⋅}表示集合的下确界,例如{x∣1<x<3,x∈R}\{x|1<x<3,x∈R\}{x∣1<x<3,x∈R}的下确界为1。

 继续考虑图2中的例子,我们直接给出分布ppp和qqq之间的EM距离的表达式:
W(p,q)=∣θ∣W(p,q)=|θ|W(p,q)=∣θ∣
绘制出JS散度和EM距离的曲线,如下图所示,可以看到,JS散度在θ=0θ=0θ=0处不连续,其他位置导数均为0,而EM距离总能够产生有效的导数信息,因此EM距离相对于JS散度更适合直到GAN网络的训练。

图4. JS散度和EM距离随$θ$变换曲线

3. WGAN-GP

 考虑到几乎不可能遍历所有的联合分布γγγ去计算距离∥x−y∥\|x-y\|∥x−y∥的期望E(x,y)∼γ[∥x−y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)∼γ​[∥x−y∥],因此直接计算生成网络分布pgp_gpg​与真实数据数据分布prp_rpr​的距离W(pr,pg)W(p_r,p_g )W(pr​,pg​)距离是不现实的,WGAN作者基于Kantorchovich-Rubin对偶性将直接求W(pr,pg)W(p_r,p_g )W(pr​,pg​)转换为求:
W(pr,pg)=1Ksup∥f∥L≤KEx∼pr[f(x)]−Ex∼pg[f(x)]W(p_r,p_g )=\frac{1}{K} \underset{\|f\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [f(x)]-\mathbb E_{x\sim p_g} [f(x)]W(pr​,pg​)=K1​∥f∥L​≤Ksup​Ex∼pr​​[f(x)]−Ex∼pg​​[f(x)]
其中sup⁡{⋅}\text{sup}⁡\{\cdot\}sup⁡{⋅}表示集合的上确界,∥f∥L≤K\|f\|_L≤K∥f∥L​≤K表示函数f:R→Rf:R→Rf:R→R满足K阶-Lipschitz连续性,即满足
∣f(x1)−f(x2)∣≤K⋅∣x1−x2∣|f(x_1 )-f(x_2)|≤K\cdot|x_1-x_2 |∣f(x1​)−f(x2​)∣≤K⋅∣x1​−x2​∣
 于是,我们使用判别网络Dθ(x)D_θ (\boldsymbol x)Dθ​(x)参数化f(x)f(\boldsymbol x)f(x)函数,在DθD_θDθ​满足1阶-Lipschitz约束条件下,即K=1K=1K=1,此时:
W(pr,pg)=1Ksup∥Dθ∥L≤KEx∼pr[Dθ(x)]−Ex∼pg[Dθ(x)]W(p_r,p_g )=\frac{1}{K} \underset{\|D_θ\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]W(pr​,pg​)=K1​∥Dθ​∥L​≤Ksup​Ex∼pr​​[Dθ​(x)]−Ex∼pg​​[Dθ​(x)]
因此求解W(pr,pg)W(p_r,p_g )W(pr​,pg​)的问题可以转化为:
max⁡θEx∼pr[Dθ(x)]−Ex∼pg[Dθ(x)]\underset{θ}{\text{max}⁡}\ \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]θmax⁡​ Ex∼pr​​[Dθ​(x)]−Ex∼pg​​[Dθ​(x)]
这就是判别器D的优化目标。判别网络函数D_θ (x)需要满足1阶-Lipschitz约束:
∇x^D(x^)≤1∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})≤1∇x^​D(x^)≤1
 在WGAN-GP论文中,作者提出采用增加梯度惩罚项(Gradient Penalty)方法来迫使判别网络满足1阶-Lipschitz函数约束,同时作者发现将梯度值约束在1周围时工程效果更好,因此梯度惩罚项定义为:
GP≜Ex^∼Px^[(∥∇x^D(x^)∥2−1)2]GP≜\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]GP≜Ex^∼Px^​​[(∥∇x^​D(x^)∥2​−1)2]
因此WGAN的判别器D的训练目标为:
maxθL(G,D)=Exr∼pr[D(xr)]−Exf∼pg[D(xf)]⏟EM距离−λEx^∼Px^[(∥∇x^D(x^)∥2−1)2]⏟GP惩罚项\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距离}-\underbrace{λ\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]}_{GP惩罚项}θmax​L(G,D)=EM距离Exr​∼pr​​[D(xr​)]−Exf​∼pg​​[D(xf​)]​​−GP惩罚项λEx^∼Px^​​[(∥∇x^​D(x^)∥2​−1)2]​​
其中x^\hat{\boldsymbol x}x^来自于xr\boldsymbol x_rxr​与xf\boldsymbol x_fxf​的线性差值:
x^=txr+(1−t)xf,t∈[0,1]\hat{\boldsymbol x}=t\boldsymbol x_r+(1-t) \boldsymbol x_f,t∈[0,1]x^=txr​+(1−t)xf​,t∈[0,1]
判别器D的优化目标是最小化上述的误差L(G,D)\mathcal L(G,D)L(G,D),即迫使生成器G的分布pgp_gpg​与真实分布prp_rpr​之间的EM距离Exr∼pr[D(xr)]−Exf∼pg[D(xf)]\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]Exr​∼pr​​[D(xr​)]−Exf​∼pg​​[D(xf​)]项尽可能大,∥∇x^D(x^)∥2\|∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2∥∇x^​D(x^)∥2​逼近于1。

 WGAN的生成器G的训练目标为:
maxθL(G,D)=Exr∼pr[D(xr)]−Exf∼pg[D(xf)]⏟EM距离\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距离}θmax​L(G,D)=EM距离Exr​∼pr​​[D(xr​)]−Exf​∼pg​​[D(xf​)]​​
即使得生成器的分布pgp_gpg​与真实分布prp_rpr​之间的EM距离越小越好。考虑到Exr∼pr[D(xr)]\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]Exr​∼pr​​[D(xr​)]一项与生成器无关,因此生成器的训练目标简写为:
maxθL(G,D)=−Exf∼pg[D(xf)]=−Ez∼pz(⋅)[D(G(z))]\begin{aligned}\underset{θ}{\text{max}} \mathcal L(G,D)&=-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]\\ &=-E_{\boldsymbol z\sim p_\boldsymbol z (\cdot)} [D(G(\boldsymbol z))]\end{aligned}θmax​L(G,D)​=−Exf​∼pg​​[D(xf​)]=−Ez∼pz​(⋅)​[D(G(z))]​
 从现实来看,判别网络D的输出不需要添加Sigmoid激活函数,这是因为原始版本的判别器的功能是作为二分类网络,添加Sigmoid函数获得类别的概率;而WGAN中判别器作为EM距离的度量网络,其目标是衡量生成网络的分布pgp_gpg​和真实分布prp_rpr​之间的EM距离,属于实数空间,因此不需要添加Sigmoid激活函数。在误差函数计算时,WGAN也没有log\text{log}log函数存在。在训练WGAN时,WGAN作者推荐使用RMSProp或SGD等不带动量的优化器。

 WGAN从理论层面发现了原始GAN容易出现训练不稳定的原因,并给出了一种新的距离度量标准和工程实现解决方案,取得了较好的效果。WGAN还在一定程度上缓解了模式崩塌的问题,使用WGAN的模型不容易出现模式崩塌的现象。需要注意的是,WGAN一般并不能提升模型的生成效果,仅仅是保证了模型训练的稳定性。当然,保证模型能够稳定地训练也是取得良好效果的前提。如图5所示,原始版本的DCGAN在不使用BN层等设定时出现了训练不稳定的现象,在同样设定下,使用WGAN来训练判别器可以避免此现象,如图6所示。

图5. 不带BN层的DCGAN生成器效果
图6. 不带BN层的WGAN生成效果

深度学习之生成对抗网络(7)WGAN原理相关推荐

  1. 深度学习之生成对抗网络(8)WGAN-GP实战

    深度学习之生成对抗网络(8)WGAN-GP实战 代码修改 完整代码 WGAN WGAN_train 代码修改  WGAN-GP模型可以在原来GAN代码实现的基础上仅做少量修改.WGAN-GP模型的判别 ...

  2. 深度学习之生成对抗网络(6)GAN训练难题

    深度学习之生成对抗网络(6)GAN训练难题 1. 超参数敏感 2. 模式崩塌  尽管从理论层面分析了GAN网络能够学习到数据的真实分布,但是在工程实现中,常常出现GAN网络训练困难的问题,主要体现在G ...

  3. 深度学习之生成对抗网络(4)GAN变种

    深度学习之生成对抗网络(4)GAN变种 1. DCGAN 2. InfoGAN 3. CycleGAN 4. WGAN 5. Equal GAN 6. Self-Attention GAN 7. Bi ...

  4. 深度学习之生成对抗网络(2)GAN原理

    深度学习之生成对抗网络(2)GAN原理 1. 网络结构 生成网络G(z)\text{G}(\boldsymbol z)G(z) 判别网络D(x)\text{D}(\boldsymbol x)D(x) ...

  5. 深度学习之生成对抗网络(1)博弈学习实例

    深度学习之生成对抗网络(1)博弈学习实例 博弈学习实例  在 生成对抗网络(Generative Adversarial Network,简称GAN)发明之前,变分自编码器被认为是理论完备,实现简单, ...

  6. 【深度学习】生成对抗网络(GAN)的tensorflow实现

    [深度学习]生成对抗网络(GAN)的tensorflow实现 一.GAN原理 二.GAN的应用 三.GAN的tensorflow实现 参考资料 GAN( Generative Adversarial ...

  7. 深度学习之生成对抗网络(5)纳什均衡

    深度学习之生成对抗网络(5)纳什均衡 1. 判别器状态 2. 生成器状态 3. 纳什均衡点  现在我们从理论层面进行分析,通过博弈学习的训练方式,生成器G和判别器D分别会达到什么平衡状态.具体地,我们 ...

  8. 人工智障学习笔记——深度学习(4)生成对抗网络

    概念 生成对抗网络(GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discrimi ...

  9. 【深度学习】生成对抗网络

    下文以图片作为数据举例介绍. 生成网络(生成器)–>以假乱真 生成网络的职责是把随机点模仿成与真实数据集相似的图片,这些随机点是从一个潜在空间中随机抽取的.它可以看作一个实现"点对点变 ...

最新文章

  1. C4D运动图形基本训练学习教程
  2. mathematica练习程序(图像取反)
  3. 2020中国教育行业生存实录
  4. redis主从和集群搭建
  5. JavaScript : 基本的处理事件
  6. 号外!德国惊现大罢工--要求每周上班28小时
  7. 火山PC模拟键盘操作
  8. 多因子量化投资模型策略深度研究
  9. 免费网课python_Python网课推荐——免费学习Python编程
  10. FPGA基础知识(FPGA芯片结构)
  11. BZOJ3876支线剧情
  12. EFR32--如何在EFR32BG22透传中添加AT指令控制
  13. 分布式数据库发展历程SequoiaDB 简介
  14. 一种可以复制钟表零件的石膏模具
  15. SVN+Gitee配置版本控制库
  16. Swift - 设置预编译宏
  17. 如何将xls批量转换成xlsx
  18. tvoc气体传感器哪家做的比较好
  19. C++实现职工管理系统
  20. 导入SpringBoot项目时突然遇到无法找到入口类的情况

热门文章

  1. python span.string函数_【转】python f-string
  2. python自动答题软件_广东开放大学(广开)线上作业自动答题python-selenium
  3. const int是什么类型_C++的const语义
  4. com.google.gson.JsonSyntaxException: java.lang.IllegalStateException: Expected a string but was BEGI
  5. alpine linux 源码安装,关于docker:如何安装Go in alpine linux
  6. matlab计算截断误差,Matlab相位截断误差仿真综述.doc
  7. Google、Facebook、GitHub、Babel核心成员齐聚,第13届D2前端技术论坛正式启动
  8. 竞赛准备篇---(一)抽签问题
  9. node异步非阻塞的杂谈
  10. [转]各种字符集和编码详解