作者:郑华滨
链接:https://www.zhihu.com/question/52602529/answer/158727900
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

前段时间,Wasserstein GAN以其精巧的理论分析、简单至极的算法实现、出色的实验效果,在GAN研究圈内掀起了一阵热潮(对WGAN不熟悉的读者,可以参考我之前写的介绍文章:令人拍案叫绝的Wasserstein GAN - 知乎专栏)。但是很多人(包括我们实验室的同学)到了上手跑实验的时候,却发现WGAN实际上没那么完美,反而存在着训练困难、收敛速度慢等问题。其实,WGAN的作者Martin Arjovsky不久后就在reddit上表示他也意识到了这个问题,认为关键在于原设计中Lipschitz限制的施加方式不对,并在新论文中提出了相应的改进方案:

  • 论文:[1704.00028] Improved Training of Wasserstein GANs
  • Tensorflow实现:igul222/improved_wgan_training

首先回顾一下WGAN的关键部分——Lipschitz限制是什么。WGAN中,判别器D和生成器G的loss函数分别是:

(公式1)

(公式2)

公式1表示判别器希望尽可能拉高真样本的分数,拉低假样本的分数,公式2表示生成器希望尽可能拉高假样本的分数。

Lipschitz限制则体现为,在整个样本空间 上,要求判别器函数D(x)梯度的Lp-norm不大于一个有限的常数K:

(公式3)

直观上解释,就是当输入的样本稍微变化后,判别器给出的分数不能发生太过剧烈的变化。在原来的论文中,这个限制具体是通过weight clipping的方式实现的:每当更新完一次判别器的参数之后,就检查判别器的所有参数的绝对值有没有超过一个阈值,比如0.01,有的话就把这些参数clip回 [-0.01, 0.01] 范围内。通过在训练过程中保证判别器的所有参数有界,就保证了判别器不能对两个略微不同的样本给出天差地别的分数值,从而间接实现了Lipschitz限制。

然而weight clipping的实现方式存在两个严重问题:

第一,如公式1所言,判别器loss希望尽可能拉大真假样本的分数差,然而weight clipping独立地限制每一个网络参数的取值范围,在这种情况下我们可以想象,最优的策略就是尽可能让所有参数走极端,要么取最大值(如0.01)要么取最小值(如-0.01)!为了验证这一点,作者统计了经过充分训练的判别器中所有网络参数的数值分布,发现真的集中在最大和最小两个极端上:

<img src="https://pic2.zhimg.com/v2-7a3aedf9fa60ce660bff9f03935d8f15_b.jpg" data-rawwidth="636" data-rawheight="541" class="origin_image zh-lightbox-thumb" width="636" data-original="https://pic2.zhimg.com/v2-7a3aedf9fa60ce660bff9f03935d8f15_r.jpg">

这样带来的结果就是,判别器会非常倾向于学习一个简单的映射函数(想想看,几乎所有参数都是正负0.01,都已经可以直接视为一个二值神经网络了,太简单了)。而作为一个深层神经网络来说,这实在是对自身强大拟合能力的巨大浪费!判别器没能充分利用自身的模型能力,经过它回传给生成器的梯度也会跟着变差。

在正式介绍gradient penalty之前,我们可以先看看在它的指导下,同样充分训练判别器之后,参数的数值分布就合理得多了,判别器也能够充分利用自身模型的拟合能力:

<img src="https://pic3.zhimg.com/v2-27afb895eea82f5392b19ca770865b96_b.jpg" data-rawwidth="1303" data-rawheight="543" class="origin_image zh-lightbox-thumb" width="1303" data-original="https://pic3.zhimg.com/v2-27afb895eea82f5392b19ca770865b96_r.jpg">

第二个问题,weight clipping会导致很容易一不小心就梯度消失或者梯度爆炸。原因是判别器是一个多层网络,如果我们把clipping threshold设得稍微小了一点,每经过一层网络,梯度就变小一点点,多层之后就会指数衰减;反之,如果设得稍微大了一点,每经过一层网络,梯度变大一点点,多层之后就会指数爆炸。只有设得不大不小,才能让生成器获得恰到好处的回传梯度,然而在实际应用中这个平衡区域可能很狭窄,就会给调参工作带来麻烦。相比之下,gradient penalty就可以让梯度在后向传播的过程中保持平稳。论文通过下图体现了这一点,其中横轴代表判别器从低到高第几层,纵轴代表梯度回传到这一层之后的尺度大小(注意纵轴是对数刻度),c是clipping threshold:

<img src="https://pic2.zhimg.com/v2-34114a10c56518d606c1b5dd77f64585_b.jpg" data-rawwidth="723" data-rawheight="546" class="origin_image zh-lightbox-thumb" width="723" data-original="https://pic2.zhimg.com/v2-34114a10c56518d606c1b5dd77f64585_r.jpg">

说了这么多,gradient penalty到底是什么?

前面提到,Lipschitz限制是要求判别器的梯度不超过K,那我们何不直接设置一个额外的loss项来体现这一点呢?比如说:

(公式4)

不过,既然判别器希望尽可能拉大真假样本的分数差距,那自然是希望梯度越大越好,变化幅度越大越好,所以判别器在充分训练之后,其梯度norm其实就会是在K附近。知道了这一点,我们可以把上面的loss改成要求梯度norm离K越近越好,效果是类似的:

(公式5)

究竟是公式4好还是公式5好,我看不出来,可能需要实验验证,反正论文作者选的是公式5。接着我们简单地把K定为1,再跟WGAN原来的判别器loss加权合并,就得到新的判别器loss:

(公式6)

这就是所谓的gradient penalty了吗?还没完。公式6有两个问题,首先是loss函数中存在梯度项,那么优化这个loss岂不是要算梯度的梯度?一些读者可能对此存在疑惑,不过这属于实现上的问题,放到后面说。

其次,3个loss项都是期望的形式,落到实现上肯定得变成采样的形式。前面两个期望的采样我们都熟悉,第一个期望是从真样本集里面采,第二个期望是从生成器的噪声输入分布采样后,再由生成器映射到样本空间。可是第三个分布要求我们在整个样本空间 上采样,这完全不科学!由于所谓的维度灾难问题,如果要通过采样的方式在图片或自然语言这样的高维样本空间中估计期望值,所需样本量是指数级的,实际上没法做到。

所以,论文作者就非常机智地提出,我们其实没必要在整个样本空间上施加Lipschitz限制,只要重点抓住生成样本集中区域、真实样本集中区域以及夹在它们中间的区域就行了。具体来说,我们先随机采一对真假样本,还有一个0-1的随机数:

(公式7)

然后在 的连线上随机插值采样:

(公式8)

把按照上述流程采样得到的 所满足的分布记为 ,就得到最终版本的判别器loss:

(公式9)

这就是新论文所采用的gradient penalty方法,相应的新WGAN模型简称为WGAN-GP。我们可以做一个对比:

  • weight clipping是对样本空间全局生效,但因为是间接限制判别器的梯度norm,会导致一不小心就梯度消失或者梯度爆炸;
  • gradient penalty只对真假样本集中区域、及其中间的过渡地带生效,但因为是直接把判别器的梯度norm限制在1附近,所以梯度可控性非常强,容易调整到合适的尺度大小。

论文还讲了一些使用gradient penalty时需要注意的配套事项,这里只提一点:由于我们是对每个样本独立地施加梯度惩罚,所以判别器的模型架构中不能使用Batch Normalization,因为它会引入同个batch中不同样本的相互依赖关系。如果需要的话,可以选择其他normalization方法,如Layer Normalization、Weight Normalization和Instance Normalization,这些方法就不会引入样本之间的依赖。论文推荐的是Layer Normalization。

实验表明,gradient penalty能够显著提高训练速度,解决了原始WGAN收敛缓慢的问题:

<img src="https://pic4.zhimg.com/v2-5b01ef93f60a14e7fa10dbea2b620627_b.jpg" data-rawwidth="1255" data-rawheight="479" class="origin_image zh-lightbox-thumb" width="1255" data-original="https://pic4.zhimg.com/v2-5b01ef93f60a14e7fa10dbea2b620627_r.jpg">

虽然还是比不过DCGAN,但是因为WGAN不存在平衡判别器与生成器的问题,所以会比DCGAN更稳定,还是很有优势的。不过,作者凭什么能这么说?因为下面的实验体现出,在各种不同的网络架构下,其他GAN变种能不能训练好,可以说是一件相当看人品的事情,但是WGAN-GP全都能够训练好,尤其是最下面一行所对应的101层残差神经网络:

<img src="https://pic2.zhimg.com/v2-e0a3d86ccfa101a4d3fee1c0cef96a81_b.jpg" data-rawwidth="835" data-rawheight="1279" class="origin_image zh-lightbox-thumb" width="835" data-original="https://pic2.zhimg.com/v2-e0a3d86ccfa101a4d3fee1c0cef96a81_r.jpg">

剩下的实验结果中,比较厉害的是第一次成功做到了“纯粹的”的文本GAN训练!我们知道在图像上训练GAN是不需要额外的有监督信息的,但是之前就没有人能够像训练图像GAN一样训练好一个文本GAN,要么依赖于预训练一个语言模型,要么就是利用已有的有监督ground truth提供指导信息。而现在WGAN-GP终于在无需任何有监督信息的情况下,生成出下图所示的英文字符序列:

<img src="https://pic1.zhimg.com/v2-33c3af150f9bd52485b800948d3cb700_b.jpg" data-rawwidth="1056" data-rawheight="769" class="origin_image zh-lightbox-thumb" width="1056" data-original="https://pic1.zhimg.com/v2-33c3af150f9bd52485b800948d3cb700_r.jpg">

它是怎么做到的呢?我认为关键之处是对样本形式的更改。以前我们一般会把文本这样的离散序列样本表示为sequence of index,但是它把文本表示成sequence of probability vector。对于生成样本来说,我们可以取网络softmax层输出的词典概率分布向量,作为序列中每一个位置的内容;而对于真实样本来说,每个probability vector实际上就蜕化为我们熟悉的onehot vector。

但是如果按照传统GAN的思路来分析,这不是作死吗?一边是hard onehot vector,另一边是soft probability vector,判别器一下子就能够区分它们,生成器还怎么学习?没关系,对于WGAN来说,真假样本好不好区分并不是问题,WGAN只是拉近两个分布之间的Wasserstein距离,就算是一边是hard onehot另一边是soft probability也可以拉近,在训练过程中,概率向量中的有些项可能会慢慢变成0.8、0.9到接近1,整个向量也会接近onehot,最后我们要真正输出sequence of index形式的样本时,只需要对这些概率向量取argmax得到最大概率的index就行了。

新的样本表示形式+WGAN的分布拉近能力是一个“黄金组合”,但除此之外,还有其他因素帮助论文作者跑出上图的效果,包括:

  • 文本粒度为英文字符,而非英文单词,所以字典大小才二三十,大大减小了搜索空间
  • 文本长度也才32
  • 生成器用的不是常见的LSTM架构,而是多层反卷积网络,输入一个高斯噪声向量,直接一次性转换出所有32个字符

上面第三点非常有趣,因为它让我联想到前段时间挺火的语言学科幻电影《降临》:

<img src="https://pic4.zhimg.com/v2-be3b9c574b9c1b1ac05d0d462a04acd7_b.jpg" data-rawwidth="1280" data-rawheight="1998" class="origin_image zh-lightbox-thumb" width="1280" data-original="https://pic4.zhimg.com/v2-be3b9c574b9c1b1ac05d0d462a04acd7_r.jpg">

里面的外星人“七肢怪”所使用的语言跟人类不同,人类使用的是线性的、串行的语言,而“七肢怪”使用的是非线性的、并行的语言。“七肢怪”在跟主角交流的时候,都是一次性同时给出所有的语义单元的,所以说它们其实是一些多层反卷积网络进化出来的人工智能生命吗?

<img src="https://pic3.zhimg.com/v2-f45cc71a24451e1c8d650dba9001a406_b.png" data-rawwidth="2000" data-rawheight="838" class="origin_image zh-lightbox-thumb" width="2000" data-original="https://pic3.zhimg.com/v2-f45cc71a24451e1c8d650dba9001a406_r.png">
<img src="https://pic1.zhimg.com/v2-1ff948ae23ed8c19bfb4bedf6f466744_b.jpg" data-rawwidth="1622" data-rawheight="690" class="origin_image zh-lightbox-thumb" width="1622" data-original="https://pic1.zhimg.com/v2-1ff948ae23ed8c19bfb4bedf6f466744_r.jpg">

开完脑洞,我们回过头看,不得不承认这个实验的setup实在过于简化了,能否扩展到更加实际的复杂场景,也会是一个问题。但是不管怎样,生成出来的结果仍然是突破性的。

最后说回gradient penalty的实现问题。loss中本身包含梯度,优化loss就需要求梯度的梯度,这个功能并不是现在所有深度学习框架的标配功能,不过好在Tensorflow就有提供这个接口——tf.gradients。开头链接的GitHub源码中就是这么写的:

# interpolates就是随机插值采样得到的图像,gradients就是loss中的梯度惩罚项
gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]

对于我这样的PyTorch党就非常不幸了,高阶梯度的功能还在开发,感兴趣的PyTorch党可以订阅这个GitHub的pull request:Autograd refactor,如果它被merged了话就可以在最新版中使用高阶梯度的功能实现gradient penalty了。但是除了等待我们就没有别的办法了吗?其实可能是有的,我想到了一种近似方法来实现gradient penalty,只需要把微分换成差分:

(公式10)

也就是说,我们仍然是在分布 上随机采样,但是一次采两个,然后要求它们的连线斜率要接近1,这样理论上也可以起到跟公式9一样的效果,我自己在MNIST+MLP上简单验证过有作用,PyTorch党甚至Tensorflow党都可以尝试用一下

原文地址: https://www.zhihu.com/question/52602529/answer/158727900

Wasserstein GAN最新进展:从weight clipping到gradient penalty,更加先进的Lipschitz限制手法相关推荐

  1. GAN最新进展:8大技巧提高稳定性

    生成对抗网络GAN很强大,但也有很多造成GAN难以使用的缺陷.本文介绍了可以克服GAN训练缺点的一些解决方案,有助于提高GAN性能. 生成对抗网络 (GAN) 是一类功能强大的神经网络,具有广泛的应用 ...

  2. W-GAN系 (Wasserstein GAN、 Improved WGAN)

    习总结于国立台湾大学 :李宏毅老师 Wasserstein GAN  .  Improved Training of Wasserstein GANs 本文outline 一句话介绍WGAN: Usi ...

  3. 计算机视觉(CV)-生成模型:WGAN【Lipschitz:Weight Clipping】--> WGAN-GP【Lipschitz:Gradient Penalty】

    深度学习-生成模型:WGAN[Lipschitz:Weight Clipping]--> WGAN-GP[Lipschitz:Gradient Penalty] 一.WGAN相比较GAN的改进点 ...

  4. 经典论文复现 | ICML 2017大热论文:Wasserstein GAN

    过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含"伪代码".这是今年 AAAI ...

  5. 变分推断(Variational Inference)最新进展简述

    动机 变分推断(Variational Inference, VI)是贝叶斯近似推断方法中的一大类方法,将后验推断问题巧妙地转化为优化问题进行求解,相比另一大类方法马尔可夫链蒙特卡洛方法(Markov ...

  6. WGAN(Wasserstein GAN)看这一篇就够啦,WGAN论文解读

    WGAN论文地址:[1701.07875] Wasserstein GAN (arxiv.org) WGAN解决的问题 原始GAN训练过程中经常遇到的问题: 模式崩溃,生成器生成非常窄的分布,仅覆盖数 ...

  7. 深度CTR预估模型的演化之路2019最新进展

    作者 | 锅逗逗 来源 | 深度传送门(ID: deep_deliver) 导读:本文主要介绍深度CTR经典预估模型的演化之路以及在2019工业界的最新进展. 介绍 在计算广告和推荐系统中,点击率(C ...

  8. 还记得Wasserstein GAN吗?

    ICML 2017 仍然在悉尼火热进行中,Facebook 研究院今天也发文介绍了自己的 ICML 论文.Facebook有9篇论文被 ICML 2017接收,这些论文的主题包括语言建模.优化和图像的 ...

  9. 收敛速度更快更稳定的Wasserstein GAN(WGAN)

    生成对抗网络(GANs)是一种很有力的生成模型,它解决生成建模问题的方式就像在两个对抗式网络中进行比赛:给出一些噪声源,生成器网络能够产生合成的数据,鉴别器网络在真实数据和生成器的输出中进行鉴别.GA ...

最新文章

  1. python生成饼图文件_python使用HTMLTestRunner导出饼图分析报告的方法
  2. linux共享比windows好处,开源Linux虚拟化优势比Windows有何特点?
  3. P1339 [USACO09OCT]热浪Heat Wave(SPFA)
  4. Python引起的混乱解决之道——感悟
  5. 科幻画图片大全浇水机器人_从机器人到智能机器人,谭建荣院士为温肯师生揭秘新科技...
  6. 消除ubuntu16.04自带的alt快捷键
  7. vim 配置_模块化你的vim配置文件
  8. 高中教师计算机技能大赛,计算机科学与技术学院第九届教师技能大赛初赛圆满举行...
  9. JavaScript函数节流和函数防抖之间的区别
  10. MySQL的主动优化和被动优化_MySQL“被动”性能优化汇总
  11. oem718d 基准站设置_诺瓦泰NovAtel OEM718D全系统多频单机测向板卡
  12. lnmp无法删除.user.ini文件的解决办法
  13. 数据库驱动加载失败,数据库连接失败
  14. 2022-2028年中国生物质颗粒行业市场行情动态及竞争战略分析报告
  15. android 对焦,Android相机对焦模式
  16. FPGA学习:Verilog基本语法
  17. Andorid 拍照、从相册中选择图片兼容7.0uri
  18. PDF Reader Pro,功能强大的 PDF 阅读编辑器
  19. Unsupervised Question Answering by Cloze Translation
  20. win7 设定固定的ip地址

热门文章

  1. Nginx设置目录浏览并配置验证
  2. 编程方法学13:字符串处理
  3. 6、VTK基本数据结构
  4. ImportError: No module named sklearn.metrics
  5. 【Paddle】解压文件到指定文件夹
  6. 十一、“由专入分易,由分入专难。”(2020.12.18)
  7. 如何将ipynb转换为html,md,pdf等格式
  8. 连表查询使用in_SQL 组合查询
  9. [OS复习]进程互斥与同步1
  10. css中超级链接样式的设置顺序