本文主要涉及GAN网络的直观理解和其背后的数学原理。
参考课程:
计算机视觉与深度学习 北京邮电大学 鲁鹏

概述


在所有生成模型中,GAN属于 “密度函数未知,直接硬train” 的那一类,和密度函数可定义的PixelRNN/CNN以及变分自编码器VAE有本质区别。

假设现在我们想做人脸的生成任务。我们希望能找到人脸图像的真实分布,这样直接在这个分布上随便取点,得到的都是人脸的图像。但是分布非常复杂,且无法知道

所以,我们考虑用一个简单的分布和一个映射,将这个简单的分布映射到真实的分布。我们使用神经网络来学习这个映射的过程。

GAN的直观理解

目标函数

GAN网络的设计思路类似玩家博弈的过程,其主要优化的目标为:

符号说明:P_data是真实数据的分布,P(z)是噪声分布(可以是均匀分布、高斯分布等), theta(g)是生成器的参数, theta(d)是判别器的参数。
公式中
x^=Gθg(z)\hat{x} = G_\theta{_g}(z) x^=Gθ​g​(z)
表示生成器生成的样本,而

Dθd(x)D_\theta{_d}(x) Dθ​d​(x)
输出一个0-1之间的数,表示判别器对输入的判断,1表示是真实数据,0表示是生成的数据。

我们先看内侧max,调整theta_d(判别器的参数),使得后面式子最大。对于真实样本(Ex~data那一项),希望判别器生成1;对假样本x_hat,希望D_theta(d)把他输出成0,这样1减去之后最大。

【注意!!在讨论max的时候调整d,此时生成器g的参数是固定的!!反之亦然。】

再来看min的时候,学的是生成器g的参数。此时,前面那一项无所谓(与g无关)。此时希望
Dθd(Gθg(z))D_\theta{_d}(G_\theta{_g(z)}) Dθ​d​(Gθ​g​(z))
趋近于1,此时theta_d不变,我们 希望生成的样本被判别器判断成1. 也就是固定D的前提下,让G尽量欺骗D。


theta_d想让表达式越大越好,所以是梯度上升

因为判别器最后输出是(0, 1)的值,所以最后一层是一个sigmoid. 想让正样本越大越好,负样本越小越好,可以用一个二分类交叉熵损失(BCE)监督。【这里体会到:“似然越大越好” 等价于 “交叉熵损失越小越好”,因为那个max里面是一个概率/似然。下文会详细说明。

但是实际这样梯度会出现问题——

这样训练的效果很差。因为刚开始生成的烂,梯度还小,学不动;后来生成的好,不太需要变化了,梯度反而很大。
G+D是一个网络,D在G后面。优化的时候,是冻结一个,训练另一个。而梯度回传会首先经过D,再回传到G。

实际实现时,我们会将min换成max,使得梯度问题得以解决。

用下面这张图总结下GAN网络的学习过程。

【这里(a)表示的意思是:一开始,判别器没有学好,无法区分真实和生成的分布。】

数学推导

JS散度

在开始之前,先给出JS散度的定义。
JS散度度量了两个概率分布的相似度,是基于KL散度的变体,解决了KL散度非对称的问题。一般地,JS散度是对称的,其取值是0到1之间。定义如下:

JS散度是可以理解为“距离”的,因为是对称的,而KL散度不行,只能说是一种“相似程度”。

极大似然估计 VS KL散度

一般的,我们要选取一个theta,使得似然值最大。

先放结论:

最大化似然 = 最小化KL散度。

【这是一个贯穿机器学习过程的关键理解】。

以下是每一步化简的过程:

回到GAN


Z是噪声服从的分布,这里可以取均匀分布或高斯分布。我们使用神经网络建模,学习了一个G,将Z映射到了一个密度分布P_G.
我们希望调整生成器的参数,使得G的密度分布与真实数据的密度分布接近(其中的Div表示散度,不一定是KL散度)。

但是,P_G是神经网络拟合的,Pdata是未知的,表达式我们根本写不出来,怎么优化?

以下是解决方法。
1、虽然我们不知道这两个分布的具体表达式,但是我们可以从中获取样本

2、接着,我们把GAN的目标式子中的z统一换成G(因为样本是从G的分布里取出来的嘛)。

V(G,D)=Ex−Pdata[logD(x)]+Ex−PG[log(1−D(x))]V(G, D) = E_{x - P_{data}}[logD(x)] + E_{x - P_{G}}[log(1-D(x))] V(G,D)=Ex−Pdata​​[logD(x)]+Ex−PG​​[log(1−D(x))]
3、与上面类似,我们先考虑优化判别器(对应max的部分)。
这里先给出结论:

最大化maxV(D, G)等价于度量P_data和P_G之间的JS散度


我们不是没法度量Div(Pg,Pdata)嘛?现在找到度量方式了!

只需要最大化V(D, G),便可以度量Pg和Pdata之间的JS散度。

先忽略结论的证明,我们绕开了Pdata和Pg数学表达式无法获得的问题,解决了度量两个密度分布的方法。因为maxV的时候,只需要把训练样本输入到神经网络中即可训练theta_G!

换言之,训练神经网络,实际就是在度量Pdata和Pg之间的JS散度

直观理解

关于结论的证明,先从直观的角度来进行。

如果生成的和样本很像,判别器判别很困难,V(G,D)小【因为判断困难,真实数据得不到1,生成的假数据也得不到0,V值自然不高】;反之V(G,D)大 ==》 这不就类似在刻画“散度”嘛?

越好分,值越小,证明他们的距离越小;越难分,值越大,证明他们距离越大!

理论推导


这里用了一个结论:如果想要最大化积分,那么如果对于每个x,f(x)都是最大的,那积分出来的结果也最大,这样我们就去掉了积分符号。
在x给定的情况下,我们要找到最大的D’,对D求导即可。

前面求出的D’带入V(G, D), 并人为加入1/2的因子,朝着JS散度的方向化简。

最后我们便会发现,把最优参数带入后,此时的V(G, D)取到max值,也就是在度量Pdata与PG的JS散度。所以,判别器的输出值就代表了Pdata和Pg的差异!判别器输出值越大,表示Pdata和Pg分的越开;输出值越小,表示他们离得越近。

手动推导及每一步化简的过程:

再看目标式


我们已经证明了,最大化V(D, G)就等价于计算了JS散度。所以对于上面的3个G,在固定G的情况下,我们可以得到D’为图中红色竖线的值(这时V最大)。
而生成器的优化目标为:找到一个最优参数G,使得生成的P_G的概率分布和真实数据的概率分布之间的差异越小越好
假设我们现在G的候选参数就这三个,那就是从三个值里选择最小的值,G3就是最后学到的结果(因为他的V最小,而V是JS散度的刻画,生成器希望差异小)。

“判别器,最大化V(G, D)”可以理解为在蓝色的线上找最大值
“生成器,最小化Div”可以理解为从所有红线中找出最小值

关键的桥梁“距离”,就是通过maxV(G, D)实现的

实际操作

实际做的时候,可以用BCE做损失函数监督。【再次体现最大化似然等价最小化交叉熵】

Summary


但是其实GAN还是有很多问题的,这也是为什么后来出现了WGAN等,这个在这里就按下不表了。

一个小问题

在训练的过程中,我们往往对于判别器训练多次,而生成器只训练一次。这是为什么呢?
一个直观的理解是“判别器如果训练不好,那生成器训练多次也没什么用”,但这么理解只是流于表面。
可以从上面数学推导的理论来考虑。
优化D是为了是discriminator对应的目标函数最大,也就是在整个数据分布上,尽力做到正确区分,这个需要多轮过程做到,且优化D不会改变Pdata和Pg;但是对于generator,一次优化后,很可能此时此数据分布上的discriminator所最优的区分能力并不适合你已经改变之后的generator,导致不符合理论上的推导(也就是我们要最小化的JSD)【只有在固定D局部优化G时,才能看成近似优化两个分布的JS散度。 优化G之后,Pg已经变了,此时如果D不动而还训练G,就不符合理论了】。

GAN(对抗生成网络)原理及数学推导相关推荐

  1. 对抗生成网络原理和作用

    我们通过一个demo(gan.py )来讲解对抗生成网络的原理和作用 1.创建真实数据 2.使用GAN训练噪声数据 3.通过1200次的训练使得生成的数据的分布跟真实数据的分布差不多 4.通过debu ...

  2. GAN——对抗生成网络

    GAN的基本思想 作为现在最火的深度学习模型之一,GAN全称对抗生成网络,顾名思义是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的.它使用两个神经网络,将一个神经网络与另一个神经网络进行对抗. ...

  3. GAN 对抗生成网络代码实现

    作报告写了ppt,这里po上 更完整的介绍关注专栏生成对抗网络Generative Adversarial Network 本篇的同名博客[生成对抗网络GAN入门指南](3)GAN的工程实践及基础代码 ...

  4. GAN对抗生成网络学习笔记(三)DCGAN原理

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一.DCGAN简介 1.1 DCGAN的特点 二.几个重要概念 2.1 下采样(SubSampled) 2.2 上采样 ...

  5. GAN对抗生成网络原始论文理解笔记

    文章目录 论文:Generative Adversarial Nets 符号意义 生成器(Generator) 判别器(Discriminator) 生成器和判别器的关系 GAN的训练流程简述 论文中 ...

  6. Tensorflow GAN对抗生成网络实战

    这一节的回顾主要针对使用JS散度得DCGAN和基于GP理论和Wasserstein Distance理论的WGAN首先是DCGAN 我们的训练数据集是一堆这种二次元的动漫头像的图片,那么我们就是要训练 ...

  7. 对抗生成网络(GAN)学习笔记

    生成模型与判别模型 判别模型:由数据直接学习决策函数Y=f(X)或条件概率分布P(Y|X)作为预测模型,即判别模型.判别方法关心的是对于给定的输入X,应该预测什么样的输出Y. 生成模型:由数据学习联合 ...

  8. 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)

    图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow) 文章目录 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网 ...

  9. 以假乱真的对抗生成网络(GAN)

    本期课程到这里,博主就默认大家已经对BP.CNN.RNN等基本的神经网络属性以及训练过程都有相应的认知了,如果还未了解最基本的知识,可以翻看博主制作的深度学习的学习路线,按顺序阅读即可. 深度学习的学 ...

  10. 悉尼大学陶大程:遗传对抗生成网络有效解决GAN两大痛点

    来源:新智元 本文共7372字,建议阅读10分钟. 本文为你整理了9月20日的AI WORLD 2018 世界人工智能峰会上陶大程教授的演讲内容. [ 导读 ]悉尼大学教授.澳大利亚科学院院士.优必选 ...

最新文章

  1. 影像组学视频学习笔记(11)-支持向量机(SVM)(理论)、Li‘s have a solution and plan.
  2. Linux下使用NTFS格式移动硬盘
  3. 逆向了一下hh.exe
  4. java audiorecord_Android 录音实现(AudioRecord)
  5. [译]如何在.NET Core中使用System.Drawing?
  6. ubuntu14.04安装 R16 Tina Linux SDK
  7. java test circle_TestCircle.java
  8. Flink java作为消费者连接虚拟机中的kafka/或本地的kafka,并解决java.net.UnknownHostException报错
  9. 通过抓包工具抓包APP就连不上网的解决方案
  10. 一位全减器逻辑电路图_一种一位全减器电路的制作方法
  11. 【薪酬调研报告】2019TMT标杆企业高管薪酬与激励调研报告—德勤管理咨询
  12. 使用fopen/fwrite/fread/fseek/fclose对文件从头读写整型数
  13. 计算机编程入门先学什么最好?
  14. HTML DOM中的根节点是______,HTML DOM 学习
  15. element-ui组件的下载与安装
  16. android禁止弹出保存此图片,安卓手机相册总是出现陌生图片?教你彻底清除!...
  17. 3dsmax动画渲染速度慢,渲染结果高糊的解决,图片渲染清晰,但变成动画就糊(Quicksilver硬件渲染器)
  18. python - TypeError: combat(sume,sumu) missing 2 required positional arguments: sume,sumu
  19. Logback 学习笔记
  20. VMware如何监测性能问题

热门文章

  1. DITA达尔文信息类型化体系结构相关总结
  2. centos7下使用yum安装ifconfig
  3. Ubuntu安装Qt教程
  4. 趋势信息整合(01) 谷歌google开发者 那些事儿
  5. C# 屏幕控件截屏 屏幕截屏 截屏
  6. windows 解压缩命令
  7. Windows Hello FIDO2 认证让您更接近无密码
  8. areas ajax路由,Areas(区域)
  9. 【安全知识分享】五星级酒店员工入职安全培训.ppt(附下载)
  10. 7个实用网站,总有一个你能用到的!