深度学习之生成对抗网络(7)WGAN原理
深度学习之生成对抗网络(7)WGAN原理
- 1. JS散度的缺陷
- 2. EM距离
- 3. WGAN-GP
WGAN算法从理论层面分析了GAN训练不稳定的原因,并提出了有效的解决方法。那么是什么原因导致了GAN训练如此不稳定呢?WGAN提出是因为JS散度在不重叠的分布 ppp和 qqq上的梯度曲面是恒定为0的。如下图所示。当分布p和q不重叠时,JS散度的梯度值始终为0,从而导致此时GAN的训练出现梯度弥散现象,参数长时间得不到更新,网络无法收敛。
图1. JS散度出现梯度弥散现象
接下来我们将详细阐述JS散度的缺陷以及怎么解决此缺陷。
1. JS散度的缺陷
为了避免过多的理论推导,我们这里通过一个简单的分布实例来解释JS散度的缺陷。
考虑完全不重叠(θ≠0θ≠0θ=0)的两个分布ppp和qqq,其中ppp为:
∀(x,y)∈p,x=0,y∼U(0,1)∀(x,y)∈p,x=0,y\sim\text{U}(0,1)∀(x,y)∈p,x=0,y∼U(0,1)
分布qqq为:
∀(x,y)∈q,x=θ,y∼U(0,1)∀(x,y)∈q,x=θ,y\sim\text{U}(0,1)∀(x,y)∈q,x=θ,y∼U(0,1)
其中θ∈Rθ∈Rθ∈R,当θ=0θ=0θ=0时,分布ppp和qqq重叠,两者相等;当θ≠0θ≠0θ=0时,分布ppp和qqq不重叠。
图2. 分布$p$和$q$示意图
我们来分析上述分布ppp和qqq之间的JS散度随θθθ的变化情况。根据KL散度与JS散度的定义,计算θ=0θ=0θ=0时的JS散度DJS(p∣∣q)D_{JS} (p||q)DJS(p∣∣q):
DKL(p∣∣q)=∑x=0,y∼U(0,1)1⋅log10=+∞D_{KL} (p||q)=∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{0}=+∞DKL(p∣∣q)=x=0,y∼U(0,1)∑1⋅log01=+∞
DKL(q∣∣p)=∑x=θ,y∼U(0,1)1⋅log10=+∞D_{KL} (q||p)=∑_{x=θ,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{0}=+∞DKL(q∣∣p)=x=θ,y∼U(0,1)∑1⋅log01=+∞
DJS(p∣∣q)=12(∑x=0,y∼U(0,1)1⋅log11/2+∑x=0,y∼U(0,1)1⋅log11/2)=log2D_{JS} (p||q)=\frac{1}{2} \bigg(∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}+∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}\bigg)=\text{log}2DJS(p∣∣q)=21(x=0,y∼U(0,1)∑1⋅log1/21+x=0,y∼U(0,1)∑1⋅log1/21)=log2
当θ=0θ=0θ=0时,两个分布完全重叠,此时的JS散度和KL散度都取得最小值,即0:
DKL(p∣∣q)=DKL(q∣∣p)=DJS(p∣∣q)=0D_{KL} (p||q)=D_{KL} (q||p)=D_{JS} (p||q)=0DKL(p∣∣q)=DKL(q∣∣p)=DJS(p∣∣q)=0
从上面的推导,我们可以得到DJS(p∣∣q)D_{JS} (p||q)DJS(p∣∣q)随θθθ的变化趋势:
DJS(p∣∣q)={log2θ≠00θ=0D_{JS} (p||q) = \begin{cases} \text{log}2 &\text{} θ≠0 \\ 0 &\text{} θ=0 \end{cases}DJS(p∣∣q)={log20θ=0θ=0
也就是说,当两个分布完全不重叠时,无论发布之间的距离远近,JS散度为恒定值log2\text{log}2log2,此时JS散度将无法产生有效的梯度信息;当两个分布出现重叠时,JS散度采会平滑变动,产生有效梯度信息;当完全重合后,JS散度取得最小值0.如下图所示,红色的曲线分割两个正态分布,由于两个分布没有重叠,生成样本位置处的梯度值始终为0,无法更新生成网络的参数,从而出现网络训练困难的现象。
图3. JS散度出现梯度弥散现象
因此,JS散度在分布ppp和qqq不重叠时是无法平滑地衡量分布之间的距离,从而导致此位置上无法产生有效梯度信息,出现GAN训练不稳定的情况。要解决此问题,需要使用一种更好的分布距离衡量标准,使得它即使在分布ppp和qqq不重叠时,也能平滑反映分布之间的真实距离变化。
2. EM距离
WGAN论文发现了JS散度导致GAN训练不稳定的问题,并引入了一种新的分布距离度量方法:Wasserstein距离,也叫推土机距离(Earth-Mover Distance,简称EM距离),它表示了从一个分布变换到另一个分布的最小代价,定义为:
W(p,q)=infγ∼∏(p,q)E(x,y)∼γ[∥x−y∥]W(p,q)=\underset{γ\sim∏(p,q)}{\text{inf}}\mathbb E_{(x,y)\simγ} [\|x-y\|]W(p,q)=γ∼∏(p,q)infE(x,y)∼γ[∥x−y∥]
其中∏(p,q)∏(p,q)∏(p,q)是分布ppp和qqq组合起来的所有可能的联合分布的集合,对于每个可能的联合分布γ∼∏(p,q)γ\sim∏(p,q)γ∼∏(p,q),计算距离∥x−y∥\|x-y\|∥x−y∥的期望E(x,y)∼γ[∥x−y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)∼γ[∥x−y∥],其中(x,y)(x,y)(x,y)采样自联合分布γγγ。不同的联合分布γγγ由不同的期望E(x,y)∼γ[∥x−y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)∼γ[∥x−y∥],这些期望中的下确界即定义为分布ppp和qqq的Wasserstein距离。其中inf{⋅}\text{inf}\{\cdot\}inf{⋅}表示集合的下确界,例如{x∣1<x<3,x∈R}\{x|1<x<3,x∈R\}{x∣1<x<3,x∈R}的下确界为1。
继续考虑图2中的例子,我们直接给出分布ppp和qqq之间的EM距离的表达式:
W(p,q)=∣θ∣W(p,q)=|θ|W(p,q)=∣θ∣
绘制出JS散度和EM距离的曲线,如下图所示,可以看到,JS散度在θ=0θ=0θ=0处不连续,其他位置导数均为0,而EM距离总能够产生有效的导数信息,因此EM距离相对于JS散度更适合直到GAN网络的训练。
图4. JS散度和EM距离随$θ$变换曲线
3. WGAN-GP
考虑到几乎不可能遍历所有的联合分布γγγ去计算距离∥x−y∥\|x-y\|∥x−y∥的期望E(x,y)∼γ[∥x−y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)∼γ[∥x−y∥],因此直接计算生成网络分布pgp_gpg与真实数据数据分布prp_rpr的距离W(pr,pg)W(p_r,p_g )W(pr,pg)距离是不现实的,WGAN作者基于Kantorchovich-Rubin对偶性将直接求W(pr,pg)W(p_r,p_g )W(pr,pg)转换为求:
W(pr,pg)=1Ksup∥f∥L≤KEx∼pr[f(x)]−Ex∼pg[f(x)]W(p_r,p_g )=\frac{1}{K} \underset{\|f\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [f(x)]-\mathbb E_{x\sim p_g} [f(x)]W(pr,pg)=K1∥f∥L≤KsupEx∼pr[f(x)]−Ex∼pg[f(x)]
其中sup{⋅}\text{sup}\{\cdot\}sup{⋅}表示集合的上确界,∥f∥L≤K\|f\|_L≤K∥f∥L≤K表示函数f:R→Rf:R→Rf:R→R满足K阶-Lipschitz连续性,即满足
∣f(x1)−f(x2)∣≤K⋅∣x1−x2∣|f(x_1 )-f(x_2)|≤K\cdot|x_1-x_2 |∣f(x1)−f(x2)∣≤K⋅∣x1−x2∣
于是,我们使用判别网络Dθ(x)D_θ (\boldsymbol x)Dθ(x)参数化f(x)f(\boldsymbol x)f(x)函数,在DθD_θDθ满足1阶-Lipschitz约束条件下,即K=1K=1K=1,此时:
W(pr,pg)=1Ksup∥Dθ∥L≤KEx∼pr[Dθ(x)]−Ex∼pg[Dθ(x)]W(p_r,p_g )=\frac{1}{K} \underset{\|D_θ\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]W(pr,pg)=K1∥Dθ∥L≤KsupEx∼pr[Dθ(x)]−Ex∼pg[Dθ(x)]
因此求解W(pr,pg)W(p_r,p_g )W(pr,pg)的问题可以转化为:
maxθEx∼pr[Dθ(x)]−Ex∼pg[Dθ(x)]\underset{θ}{\text{max}}\ \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]θmax Ex∼pr[Dθ(x)]−Ex∼pg[Dθ(x)]
这就是判别器D的优化目标。判别网络函数D_θ (x)需要满足1阶-Lipschitz约束:
∇x^D(x^)≤1∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})≤1∇x^D(x^)≤1
在WGAN-GP论文中,作者提出采用增加梯度惩罚项(Gradient Penalty)方法来迫使判别网络满足1阶-Lipschitz函数约束,同时作者发现将梯度值约束在1周围时工程效果更好,因此梯度惩罚项定义为:
GP≜Ex^∼Px^[(∥∇x^D(x^)∥2−1)2]GP≜\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]GP≜Ex^∼Px^[(∥∇x^D(x^)∥2−1)2]
因此WGAN的判别器D的训练目标为:
maxθL(G,D)=Exr∼pr[D(xr)]−Exf∼pg[D(xf)]⏟EM距离−λEx^∼Px^[(∥∇x^D(x^)∥2−1)2]⏟GP惩罚项\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距离}-\underbrace{λ\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]}_{GP惩罚项}θmaxL(G,D)=EM距离Exr∼pr[D(xr)]−Exf∼pg[D(xf)]−GP惩罚项λEx^∼Px^[(∥∇x^D(x^)∥2−1)2]
其中x^\hat{\boldsymbol x}x^来自于xr\boldsymbol x_rxr与xf\boldsymbol x_fxf的线性差值:
x^=txr+(1−t)xf,t∈[0,1]\hat{\boldsymbol x}=t\boldsymbol x_r+(1-t) \boldsymbol x_f,t∈[0,1]x^=txr+(1−t)xf,t∈[0,1]
判别器D的优化目标是最小化上述的误差L(G,D)\mathcal L(G,D)L(G,D),即迫使生成器G的分布pgp_gpg与真实分布prp_rpr之间的EM距离Exr∼pr[D(xr)]−Exf∼pg[D(xf)]\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]Exr∼pr[D(xr)]−Exf∼pg[D(xf)]项尽可能大,∥∇x^D(x^)∥2\|∇_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2∥∇x^D(x^)∥2逼近于1。
WGAN的生成器G的训练目标为:
maxθL(G,D)=Exr∼pr[D(xr)]−Exf∼pg[D(xf)]⏟EM距离\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距离}θmaxL(G,D)=EM距离Exr∼pr[D(xr)]−Exf∼pg[D(xf)]
即使得生成器的分布pgp_gpg与真实分布prp_rpr之间的EM距离越小越好。考虑到Exr∼pr[D(xr)]\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]Exr∼pr[D(xr)]一项与生成器无关,因此生成器的训练目标简写为:
maxθL(G,D)=−Exf∼pg[D(xf)]=−Ez∼pz(⋅)[D(G(z))]\begin{aligned}\underset{θ}{\text{max}} \mathcal L(G,D)&=-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]\\ &=-E_{\boldsymbol z\sim p_\boldsymbol z (\cdot)} [D(G(\boldsymbol z))]\end{aligned}θmaxL(G,D)=−Exf∼pg[D(xf)]=−Ez∼pz(⋅)[D(G(z))]
从现实来看,判别网络D的输出不需要添加Sigmoid激活函数,这是因为原始版本的判别器的功能是作为二分类网络,添加Sigmoid函数获得类别的概率;而WGAN中判别器作为EM距离的度量网络,其目标是衡量生成网络的分布pgp_gpg和真实分布prp_rpr之间的EM距离,属于实数空间,因此不需要添加Sigmoid激活函数。在误差函数计算时,WGAN也没有log\text{log}log函数存在。在训练WGAN时,WGAN作者推荐使用RMSProp或SGD等不带动量的优化器。
WGAN从理论层面发现了原始GAN容易出现训练不稳定的原因,并给出了一种新的距离度量标准和工程实现解决方案,取得了较好的效果。WGAN还在一定程度上缓解了模式崩塌的问题,使用WGAN的模型不容易出现模式崩塌的现象。需要注意的是,WGAN一般并不能提升模型的生成效果,仅仅是保证了模型训练的稳定性。当然,保证模型能够稳定地训练也是取得良好效果的前提。如图5所示,原始版本的DCGAN在不使用BN层等设定时出现了训练不稳定的现象,在同样设定下,使用WGAN来训练判别器可以避免此现象,如图6所示。
图5. 不带BN层的DCGAN生成器效果
图6. 不带BN层的WGAN生成效果
深度学习之生成对抗网络(7)WGAN原理相关推荐
- 深度学习之生成对抗网络(8)WGAN-GP实战
深度学习之生成对抗网络(8)WGAN-GP实战 代码修改 完整代码 WGAN WGAN_train 代码修改 WGAN-GP模型可以在原来GAN代码实现的基础上仅做少量修改.WGAN-GP模型的判别 ...
- 深度学习之生成对抗网络(6)GAN训练难题
深度学习之生成对抗网络(6)GAN训练难题 1. 超参数敏感 2. 模式崩塌 尽管从理论层面分析了GAN网络能够学习到数据的真实分布,但是在工程实现中,常常出现GAN网络训练困难的问题,主要体现在G ...
- 深度学习之生成对抗网络(4)GAN变种
深度学习之生成对抗网络(4)GAN变种 1. DCGAN 2. InfoGAN 3. CycleGAN 4. WGAN 5. Equal GAN 6. Self-Attention GAN 7. Bi ...
- 深度学习之生成对抗网络(2)GAN原理
深度学习之生成对抗网络(2)GAN原理 1. 网络结构 生成网络G(z)\text{G}(\boldsymbol z)G(z) 判别网络D(x)\text{D}(\boldsymbol x)D(x) ...
- 深度学习之生成对抗网络(1)博弈学习实例
深度学习之生成对抗网络(1)博弈学习实例 博弈学习实例 在 生成对抗网络(Generative Adversarial Network,简称GAN)发明之前,变分自编码器被认为是理论完备,实现简单, ...
- 【深度学习】生成对抗网络(GAN)的tensorflow实现
[深度学习]生成对抗网络(GAN)的tensorflow实现 一.GAN原理 二.GAN的应用 三.GAN的tensorflow实现 参考资料 GAN( Generative Adversarial ...
- 深度学习之生成对抗网络(5)纳什均衡
深度学习之生成对抗网络(5)纳什均衡 1. 判别器状态 2. 生成器状态 3. 纳什均衡点 现在我们从理论层面进行分析,通过博弈学习的训练方式,生成器G和判别器D分别会达到什么平衡状态.具体地,我们 ...
- 人工智障学习笔记——深度学习(4)生成对抗网络
概念 生成对抗网络(GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discrimi ...
- 【深度学习】生成对抗网络
下文以图片作为数据举例介绍. 生成网络(生成器)–>以假乱真 生成网络的职责是把随机点模仿成与真实数据集相似的图片,这些随机点是从一个潜在空间中随机抽取的.它可以看作一个实现"点对点变 ...
最新文章
- C4D运动图形基本训练学习教程
- mathematica练习程序(图像取反)
- 2020中国教育行业生存实录
- redis主从和集群搭建
- JavaScript : 基本的处理事件
- 号外!德国惊现大罢工--要求每周上班28小时
- 火山PC模拟键盘操作
- 多因子量化投资模型策略深度研究
- 免费网课python_Python网课推荐——免费学习Python编程
- FPGA基础知识(FPGA芯片结构)
- BZOJ3876支线剧情
- EFR32--如何在EFR32BG22透传中添加AT指令控制
- 分布式数据库发展历程SequoiaDB 简介
- 一种可以复制钟表零件的石膏模具
- SVN+Gitee配置版本控制库
- Swift - 设置预编译宏
- 如何将xls批量转换成xlsx
- tvoc气体传感器哪家做的比较好
- C++实现职工管理系统
- 导入SpringBoot项目时突然遇到无法找到入口类的情况
热门文章
- python span.string函数_【转】python f-string
- python自动答题软件_广东开放大学(广开)线上作业自动答题python-selenium
- const int是什么类型_C++的const语义
- com.google.gson.JsonSyntaxException: java.lang.IllegalStateException: Expected a string but was BEGI
- alpine linux 源码安装,关于docker:如何安装Go in alpine linux
- matlab计算截断误差,Matlab相位截断误差仿真综述.doc
- Google、Facebook、GitHub、Babel核心成员齐聚,第13届D2前端技术论坛正式启动
- 竞赛准备篇---(一)抽签问题
- node异步非阻塞的杂谈
- [转]各种字符集和编码详解