生成对抗网络 Generative Adversarial Nets(GAN)详解
生成对抗网络 Generative Adversarial Nets(GAN)详解
近几年的很多算法创新,尤其是生成方面的task,很大一部分的文章都是结合GAN来完成的,比如,图像生成、图像修复、风格迁移等等。今天主要聊一聊GAN的原理和推导。
github: http://www.github.com/goodfeli/adversarial
论文: https://arxiv.org/abs/1406.2661
背景介绍
在GAN算法出来之前,关于生成的task表现一直都不太好,因为之前的方法由于在最大似然估计和相关策略中出现的许多棘手的概率计算难以近似。而GAN呢,直接绕过了这些问题,用两个对抗的网络对数据进行生成对抗学习,通过零和游戏的形式来直接用网络来拟合数据的分布,从而对数据进行生成。(统计学里,任何数据都可以看成分布,任何数据都是同不同的分布中进行采样得到)
零和游戏的概念主要出现在博弈论里面,核心思想就是,两个玩家进行对抗游戏,最后会陷入一个那纳什均衡的状态。什么是纳什均衡?就是说最后两个玩家,他们都会选择最优的策略进行游戏,但是当两个玩家只要有一个人的决策从这个“最优决策”里跳出来,那么他就得不到最多的好处。通过这种方式,实现最终的平衡。而这一思想也深深的渗透到这个算法之中。
这两个网络,一个是生成器(generator)一个是判别器(discriminator),生成器G的角色就是要学习到数据的分布,能够实现生成以假乱真的数据。而鉴别器D的目的就是要鉴别数据的来源是真实的样本分布还是通过生成得到的。G和D之间的对抗,主要是:1.G想要骗过D,就是让D分辨不出来这个数据是从真实样本里来的还是从G生成的;2.D的目标就是分辨这个数据到底是G生成的还是真实的样本。
目标函数
生成对抗网络的核心就是极小极大游戏( two-player minimax game),这里G是生成器,生成假数据,D是鉴别器,可以理解为二分类的分类器(当然后面有很多变体,这里不讨论),对应的目标函数如下:
minGmaxDV(D,G)=Ex∼pdata (x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]\min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]minGmaxDV(D,G)=Ex∼pdata (x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
Ex∼pdata (x)[logD(x)]\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]Ex∼pdata (x)[logD(x)] 的意思就是对真实的数据分布进行采样,然后求函数logD(x)\log D(\boldsymbol{x})logD(x)的期望。
Ez∼pz(z)[log(1−D(G(z)))]\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]Ez∼pz(z)[log(1−D(G(z)))]同理,从高斯噪声中进行采样,并且求log(1−D(G(z)))\log (1-D(G(\boldsymbol{z})))log(1−D(G(z)))的期望。
整个目标函数是通过交替迭代优化的形式,来对G和D进行更新,很像就是两个人下棋一样,你一下我一下。这样想就会比较好理解。
这里首先更新D,所以我们将G的参数都看成常数,那么目标函数就变成:
maxDV(D,G)=Ex∼pdata (x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]\max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]maxDV(D,G)=Ex∼pdata (x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))],
目的是优化能够让这个目标函数最大化的D的参数,所以是 maxD\max _{D}maxD。那既然是max这个目标函数,那么我们就可以知道D(x)D(\boldsymbol{x})D(x)越接近1越好,对应的D(G(z))D(G(\boldsymbol{z}))D(G(z))就是越接近0越好。那么咱就可以理解了,这个时候G是不动的,看成常数,对应的就是更新D的参数,D呢就是要分出真实数据和生成的数据,分的越开越好。
然后我们再更新G,那么对应的目标函数,现在变成这样:
minGV(D,G)=Ex∼pdata (x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]\min _{G} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]minGV(D,G)=Ex∼pdata (x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
这里同样的,就把D看成常数,咱只更新G的参数,所以这回G是主角。这个目标函数就是说优化出最小化目标函数下,对应G的参数。由于D在这是常数,所以咱就把前面半个公式撇了,可得:
minGV(D,G)=Ez∼pz(z)[log(1−D(G(z)))]\min _{G} V(D, G)=\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]minGV(D,G)=Ez∼pz(z)[log(1−D(G(z)))]
为了最小化这个目标函数,这个时候,G就要优化自己,让G(z)G(\boldsymbol{z})G(z)的生成结果尽可能的得到较高的分数,及D(G(z))D(G(\boldsymbol{z}))D(G(z))越接近1越好,所以对应的你会发现log(1−D(G(z)))\log (1-D(G(\boldsymbol{z})))log(1−D(G(z)))被最小化,如果D(G(z))D(G(\boldsymbol{z}))D(G(z))趋近于1。
然后两个模型继续循环往复迭代,你一下,我一下…
那么如果是这样训练下来,这个过程中他们的分布是怎样变化的呢?
理想状态下,他们的关系如下:
这里直接放了论文里的图,z\boldsymbol{z}z是高斯噪声的域domain,x\boldsymbol{x}x是真实样本的域。黑色的点表示真实样本的分布pxp_{\boldsymbol{x}}px,蓝色的点表示分类器分类的表现;绿色的线表示生成数据的分布pgp_{\boldsymbol{g}}pg,由噪声z\boldsymbol{z}z通过x=G(z)\boldsymbol{x}=G(\boldsymbol{z})x=G(z)映射而来。(a)一开始分类器的表现,好像能够大概分开两个分布的样子。(b)对D进行优化,那么D达到了最优的分类状态D∗(x)=pdata (x)pdata (x)+pg(x)D^{*}(\boldsymbol{x})=\frac{p_{\text {data }}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}D∗(x)=pdata (x)+pg(x)pdata (x).(c)这个时候G通过D的参数,进一步对自身优化,然后往真实样本进行靠近。(d)最终,生成模型生成的数据和真实样本一致,鉴别器D无法准确划分,D∗(x)=12D^{*}(\boldsymbol{x})=\frac{1}{2}D∗(x)=21.
训练伪代码
这里的伪代码已经在目标函数上进行了解释,主要这里更加详细,实际上就是目标函数部分解释的内容,∇θd\nabla_{\theta_{d}}∇θd就是对应鉴别器参数的梯度,这里用随机梯度下降法对模型参数进行更新,根据目标函数。∇θg\nabla_{\theta_{g}}∇θg同理,是生成器G的参数梯度,结合上述目标函数 进行理解。
理论推导
咱这里直接顺着论文的顺序,进行进一步的梳理。首先是上面提到的,为什么当G固定的时候,最优的D应该是如下形式:
DG∗(x)=pdata (x)pdata (x)+pg(x)D_{G}^{*}(\boldsymbol{x})=\frac{p_{\text {data }}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}DG∗(x)=pdata (x)+pg(x)pdata (x)
论文也给出了对应的证明,这里着重给出对应的解释:
首先我们回看目标函数,求期望实际上也可以写成这种积分的形式:
V(G,D)=∫xpdata (x)log(D(x))dx+∫zpz(z)log(1−D(g(z)))dz=∫xpdata (x)log(D(x))+pg(x)log(1−D(x))dx\begin{aligned} V(G, D) &=\int_{\boldsymbol{x}} p_{\text {data }}(\boldsymbol{x}) \log (D(\boldsymbol{x})) d x+\int_{z} p_{\boldsymbol{z}}(\boldsymbol{z}) \log (1-D(g(\boldsymbol{z}))) d z \\ &=\int_{\boldsymbol{x}} p_{\text {data }}(\boldsymbol{x}) \log (D(\boldsymbol{x}))+p_{g}(\boldsymbol{x}) \log (1-D(\boldsymbol{x})) d x \end{aligned}V(G,D)=∫xpdata (x)log(D(x))dx+∫zpz(z)log(1−D(g(z)))dz=∫xpdata (x)log(D(x))+pg(x)log(1−D(x))dx
这里可以看到后半部分稍稍改动了一下,就是说直接将 x=g(z)x=g(\boldsymbol{z})x=g(z)进行替代,对pzp_{\boldsymbol{z}}pz进行采样等价于对x\boldsymbol{x}x进行采样,所以就可以写成下面的式子,这里需要注意的是我们并不知道pgp_{\boldsymbol{g}}pg和pdatap_{\boldsymbol{data}}pdata的真实分布。不过我们能够知道pgp_{\boldsymbol{g}}pg和pdatap_{\boldsymbol{data}}pdata都是属于[0,1]之间的值,我们要求的就是D(x)D(\boldsymbol{x})D(x)的值,这个公式下,我们将其替换成如下形式:
f(y)=alog(y)+blog(1−y)f(y) = a \log (y)+b \log (1-y)f(y)=alog(y)+blog(1−y)
我们把a=pdata,b=pg,y=D(x)a=p_{data},b=p_{g},y=D(x)a=pdata,b=pg,y=D(x),然后这里就变成关于yyy的一个函数。然后我们要记得的是,在G固定的时候,为了求得最大化下的D的参数,所以就是最大化这个f(y)f(y)f(y)。可以知道这个f(y)f(y)f(y)是一个凸函数:
f′(y)=ay−b1−yf^{\prime}(y)=\frac{a}{y}-\frac{b}{1-y}f′(y)=ya−1−yb
f′′(y)=−ay2−b(1−y)2≤0f^{\prime\prime}(y)=-\frac{a}{y^{2}}-\frac{b}{(1-y)^{2}} \leq 0f′′(y)=−y2a−(1−y)2b≤0
有个定理就是说,如果f′′(y)f^{\prime\prime}(y)f′′(y)是正定矩阵,那么他就一定是凸函数。这里就是把向量简化成了标量。注意因为我们这里是求极大值,加上一个负号就是求极小,这个二阶导加一个负号就是正定矩阵(原本是负定)。说了那么多直接画个图好了:
这里可以看到假设a=0.8,b=1a=0.8,b=1a=0.8,b=1出来的图就是这副样子。
令f′(y)=0f^{\prime}(y)=0f′(y)=0,可得y=aa+by=\frac{a}{a+b}y=a+ba,相当于极值就是这个。这里就是说明了当给定任意的G,在这一步最优的D的最优解就应该是y=aa+by=\frac{a}{a+b}y=a+ba,即D(x)=pdatapdata+pgD(x)=\frac{p_{data}}{p_{data}+p_{g}}D(x)=pdata+pgpdata.
好的,这里已经证明在优化D的时候,D应该是什么样子,接着我们再看看如果是再最小化目标函数是,固定D的时候,G应该是多少呢?先给出公式
C(G)=maxDV(G,D)=Ex∼pdata [logDG∗(x)]+Ez∼pz[log(1−DG∗(G(z)))]=Ex∼pdata [logDG∗(x)]+Ex∼pg[log(1−DG∗(x))]=Ex∼pdata [logpdata (x)Pdata (x)+pg(x)]+Ex∼pg[logpg(x)pdata (x)+pg(x)]\begin{aligned} C(G) &=\max _{D} V(G, D) \\ &=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}}\left[\log D_{G}^{*}(\boldsymbol{x})\right]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}\left[\log \left(1-D_{G}^{*}(G(\boldsymbol{z}))\right)\right] \\ &=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}}\left[\log D_{G}^{*}(\boldsymbol{x})\right]+\mathbb{E}_{\boldsymbol{x} \sim p_{g}}\left[\log \left(1-D_{G}^{*}(\boldsymbol{x})\right)\right] \\ &=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}}\left[\log \frac{p_{\text {data }}(\boldsymbol{x})}{P_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}\right]+\mathbb{E}_{\boldsymbol{x} \sim p_{g}}\left[\log \frac{p_{g}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}\right] \end{aligned}C(G)=DmaxV(G,D)=Ex∼pdata [logDG∗(x)]+Ez∼pz[log(1−DG∗(G(z)))]=Ex∼pdata [logDG∗(x)]+Ex∼pg[log(1−DG∗(x))]=Ex∼pdata [logPdata (x)+pg(x)pdata (x)]+Ex∼pg[logpdata (x)+pg(x)pg(x)]
这里公式的意思就是假定maxDV(G,D)\max _{D} V(G, D)maxDV(G,D)已经算完了,那么接下来要继续最小化C(G)C(G)C(G),G的应该是什么样子的呢?
上面的公式已经将DG∗(x)D_{G}^{*}(\boldsymbol{x})DG∗(x)给带掉了,就是假定我们已经优化过了D,直接将上面的结果带入就可以得到最后一条公式
C(G)=Ex∼pdata [logpdata (x)Pdata (x)+pg(x)]+Ex∼pg[logpg(x)pdata (x)+pg(x)]C(G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}}\left[\log \frac{p_{\text {data }}(\boldsymbol{x})}{P_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}\right]+\mathbb{E}_{\boldsymbol{x} \sim p_{g}}\left[\log \frac{p_{g}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}\right]C(G)=Ex∼pdata [logPdata (x)+pg(x)pdata (x)]+Ex∼pg[logpdata (x)+pg(x)pg(x)]
咱现在就是要最小化这个公式(优化G),这里再把它写成积分的形式:
C(G)=∫xpdata (x)logpdata (x)Pdata (x)+pg(x)+∫xpg(x)logpdata (x)Pdata (x)+pg(x)dxC(G)=\int_{\boldsymbol{x}} p_{\text {data }}(\boldsymbol{x}) \log \frac{p_{\text {data }}(\boldsymbol{x})}{P_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}+\int_{\boldsymbol{x}}p_{g}(\boldsymbol{x}) \log \frac{p_{\text {data }}(\boldsymbol{x})}{P_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})} d xC(G)=∫xpdata (x)logPdata (x)+pg(x)pdata (x)+∫xpg(x)logPdata (x)+pg(x)pdata (x)dx 因为Pdata (x)+pg(x){P_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}Pdata (x)+pg(x)并不是一个分布,求积分=2,所以这里要改成如下形式,才能套到KL散度的公式中:
C(G)=∫xpdata (x)logpdata (x)12(Pdata (x)+pg(x))12+∫xpg(x)logpdata (x)12(Pdata (x)+pg(x))12dxC(G)=\int_{\boldsymbol{x}} p_{\text {data }}(\boldsymbol{x}) \log \frac{p_{\text {data }}(\boldsymbol{x})}{\frac{1}{2}(P_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x}))}\frac{1}{2}+\int_{\boldsymbol{x}}p_{g}(\boldsymbol{x}) \log \frac{p_{\text {data }}(\boldsymbol{x})}{\frac{1}{2}(P_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x}))}\frac{1}{2} d xC(G)=∫xpdata (x)log21(Pdata (x)+pg(x))pdata (x)21+∫xpg(x)log21(Pdata (x)+pg(x))pdata (x)21dx 然后把12\frac{1}{2}21提出来:
C(G)=−log(4)+KL(pdata ∥pdata +pg2)+KL(pg∥pdata +pg2)C(G)=-\log (4)+K L\left(p_{\text {data }} \| \frac{p_{\text {data }}+p_{g}}{2}\right)+K L\left(p_{g} \| \frac{p_{\text {data }}+p_{g}}{2}\right)C(G)=−log(4)+KL(pdata ∥2pdata +pg)+KL(pg∥2pdata +pg)这里就可以看到两个KL散度,KL散度是大于等于0的,只有当两个分布完全一样的时候,KL散度才等于0。这里又可以进一步转化变成对称的JS散度:
C(G)=−log(4)+2⋅JSD(pdata ∥pg)C(G)=-\log (4)+2 \cdot J S D\left(p_{\text {data }} \| p_{g}\right)C(G)=−log(4)+2⋅JSD(pdata ∥pg)
只有当pdata=pgp_{data}=p_{g}pdata=pg的时候,我们才能得到这个式子的最小化结果就是−log(4)-\log (4)−log(4)。
所以可以看到最终这个目标函数,能够让G不断拟合分布,使得生成的数据的分布接近真实样本的分布。
当然GAN现在发展已经好几年了,已经有很多优化的方法被提出,下次再讨论~
如果觉得不错,记得关注哟!一起来学习深度学习,机器学习等前沿算法!!
转载请注明出处,尊重劳动成果,维护美好社区谢谢!
生成对抗网络 Generative Adversarial Nets(GAN)详解相关推荐
- 生成对抗网络 – Generative Adversarial Networks | GAN
生成对抗网络 – Generative Adversarial Networks | GAN 生成对抗网络 – GAN 是最近2年很热门的一种无监督算法,他能生成出非常逼真的照片,图像甚至视频.我们手 ...
- 生成对抗网络Generative Adversarial Nets(译)
仅供参考,如有翻译不到位的地方敬请指出. 论文地址:Generative Adversarial Nets 论文翻译:XlyPb(http://blog.csdn.net/wspba/article ...
- 生成对抗网络(Generative Adversarial Networks)
参考 生成对抗网络(Generative Adversarial Networks) - 云+社区 - 腾讯云 目录 一.生成对抗网络原理 1.模型的起源 2.模型的结构和损失函数 二.对GAN的改 ...
- 生成对抗网络(Generative Adversial Network,GAN)原理简介
生成对抗网络(GAN)是深度学习中一类比较大的家族,主要功能是实现图像.音乐或文本等生成(或者说是创作),生成对抗网络的主要思想是:通过生成器(generator)与判别器(discriminator ...
- 与判别网络对抗的生成网络 (Generative Adversarial Nets)
Generative Adversarial Nets (GAN) 主线为 Ian J. Goodfellow 的论文 ("Generative Adversarial Nets" ...
- 深度学习之生成对抗网络(4)GAN变种
深度学习之生成对抗网络(4)GAN变种 1. DCGAN 2. InfoGAN 3. CycleGAN 4. WGAN 5. Equal GAN 6. Self-Attention GAN 7. Bi ...
- 深度学习之生成对抗网络(6)GAN训练难题
深度学习之生成对抗网络(6)GAN训练难题 1. 超参数敏感 2. 模式崩塌 尽管从理论层面分析了GAN网络能够学习到数据的真实分布,但是在工程实现中,常常出现GAN网络训练困难的问题,主要体现在G ...
- 深度学习之生成对抗网络(2)GAN原理
深度学习之生成对抗网络(2)GAN原理 1. 网络结构 生成网络G(z)\text{G}(\boldsymbol z)G(z) 判别网络D(x)\text{D}(\boldsymbol x)D(x) ...
- 生成式对抗网络Generative Adversarial Networks(GANs)
1. 前言 2. 参考资料 3. 研究背景 4. GAN的框架 4.1 对抗框架(Adeversarial Nets) *关于"无监督学习"的说明 4.2 Minimax two- ...
- Generative Adversarial Nets(GAN)
一:概述 下图是Generative Adversarial Nets这篇文章的摘要部分,对GAN进行一个整体介绍:GAN包括两个模型,捕获数据分布的生成模型G,以及一个判别模型D.这个框架对应于一个 ...
最新文章
- 生物信息学常见数据格式 • fasta • fastq • gff/gtf
- python什么是高阶函数_说说 Python 中的高阶函数
- 共享一个简单的 Javacript Helper library
- 【机器学习】SVM学习(三):线性分类器的求解
- srwebsocket 服务器过段时间会关闭_王者荣耀:伽罗大招遭到暗改?开启后直接冷却时间,无法手动关闭...
- 如何修改TextView链接点击实现(包含链接生成与点击原理分析)
- sqlserver 插入数据时异常,仅当使用了列列表并且 IDENTITY_INSERT 为 ON 时,才能为表'XXXXX.dbo.XXXXXXXXX'中的标识列指定显式值。...
- spring中bean属性scope
- Vue3开发教程(全)
- 泛微OA常用js代码块
- 什么是中药药浴?中药药浴的操作方法和注意事项
- 三角网格剖分工具 Triangle 安装及使用
- 手把手教你用JAVA实现“语音合成”功能(文字转声音)标贝科技
- javascript错误_JavaScript开发人员最常犯的10个错误
- Dockerhub最新的toomanyrequests问题
- C语言中 srand()函数和rand()函数
- 1月重磅福利——Softing在线培训课程上线
- python3idl下载_选择Python还是IDL?
- python学习——tsv文件批量转为csv文件、csv文件列合并
- matlab 伯德图 横坐标步长_【龙腾原创】教您使用matlab画伯德图(看了你就学会了,比mathcad好用多了。谁用谁知道!)...
热门文章
- 关于诡辩--偷换概念
- [HTML+Bootstrap+CSS+jQuery] 时差计算器(计算时差、验证格式、当前时间、历史记录……)
- python pyplot颜色_matplotlib制图——颜色和样式
- mysql alert on delete cascade_mysql-我的“ ON DELETE CASCADE”不起作用
- MFS详解(一)——MFS介绍
- 人工智能发展神速?37年前的尘封档案告诉你并没有
- Linux优化学习之Load Average (平均负载)
- HTML_水平线详解
- 期货与期权的主要区别与联系?
- 用p5.js绘制创意自画像