wgan 不理解 损失函数_WGAN
GAN-QP 写到一半发现关于 WGAN 以及它相关约束部分之前没有完全读懂,需要重读,那顺手也把笔记给谢了吧
WGAN 在之前阅读的背景是 GAN 在许多条件下有比较严重的不稳定性,在寻找解决方案的过程中发现了 WGAN 的文章,当时对这篇文章的理解为它提出了一种新的散度衡量函数,使用 W 距离,简洁明了,解决了无重叠概率分布情况下概率的衡量问题,对于相关约束没有深刻的理解,这次带着对于约束的疑问重新读这篇文章。
1)为什么约束是对于判别器的
2)约束的原因
介绍
本文解决的是非监督学习的问题。首先,何为学习概率分布呢?最经典的解释是学习到一种概率密度函数,通常通过一系列的参数构造概率密度函数 P,然后通过调整参数使得函数最大化的与真实数据的相似性。对于一个真实分布 x 而言,我们需要解决的问题是
当我们确定真实的分布具有概率密度函数 Pr,参数拟合的分布有密度函数 Pθ,那么拉近两个分布的方法自然是拉近他们的 KL 散度。
为了让这个散度有意义,首先需要做的是构建 Pθ,在我们生活中处理的分布函数的支撑集(定义:在概率论中,一个概率分布的支撑集是随机变量的所有可能值组成的集合的闭包)是低维度的,也就是重叠部分测度为 0,可忽略不计,那么 KL 散度的作用就不够了。
下面这个理解会更佳的精准
因为真实样本的概率分布 Pr 与生成器生成的样本概率分布 Pg 的支撑集不同,又由于两者的流型(Manifold)的维度皆小于样本空间的维度,因而两者的流型基本上是不可能完全对齐的,因而即便有少量相交的点,它们在两个概率流型上的测度为0,可忽略,因而可以将两个概率的流型看成是可分离的,因而若是一个最优的判别器去判断则一定可以百分百将这两个流型分开,即无论我们的生成器如何努力皆获得不了分类误差的信息,这便是GAN训练困难的重要原因。
最典型的解决方案是对于模型分布加上一个噪声模块,这也是为什么所有的生成模型都包含一个噪声组件。最简单的做法是加上一个维度相当高的高斯噪声,能覆盖所有样本。这样做会导致生成的图片模糊不清,这样的做法并不够好。
与其去顾及一个可能不存在的真实数据的分布密度,我们选择定义一个符合 p(z) 分布的随机变量 z,将其传给参数构造的函数 gθ,z —> 产生了一种确定的分布 Pθ,通过构建参数 θ 我们可以改变 Pθ 的分布使其靠近真实样本的分布。这种做法的好处是
1)不像概率密度,这种表示可以直接在一个低维度流型内。
2)这种方法比起知道概率密度的具体数值能更好的生成样本。
GAN 和 VAE 都是这种方法的很好例子,VAE 更专注于逼近样本似然,生成样本与真实样本需要受到同样的模型限制,且需要增加额外的噪声;GAN 在定义目标函数的时候更佳的灵活(JS 散度,f-散度等),但是在训练层面,我们也知道 GAN 是不稳定的。
本文的目的在与讨论目标函数的定义,如何衡量生成样本与真实样本的距离,这些距离函数之间最基础的差异在于对序列分布概率的收敛上,判断一个概率分布收敛的方式是 ρ(Pt, P∞) —> 0,除了概率函数的收敛以外,很重要的一点是参数的收敛。
在参数收敛的过程中,依赖于收敛的路径,收敛的路径取决于散度函数的定义,这个距离的定义越 weak(?),θ 到 Pθ 之间的映射就越容易定义,从而就越容易收敛。
我们关心 θ 到 Pθ 映射连续的主要原因总结为:
我们需要一个衡量距离的损失函数为连续的,
那么本文的主要内容来了,主要总结如下:
1)从理论的角度来分析为什么 EM 距离可行,并且与其他主流的距离定义进行比较。
2)根据 EM 距离提出了 WGAN结构,并且阐述训练优化相关的问题。
3)展示 WGAN 解决的主要问题,尤其是无需在判别器和生成器之间维护一个平衡,也无需精心设计网络结构,在 GAN 中常见的模式丢失现象也得到了很好的改善。在 WGAN 中最引人注目的亮点是通过训练最佳判别器来实现的连续 EM 距离估计。
不同的距离函数
(偷个懒,背景定义部分懒得翻译)
Total Variation
说烂了的 KL 散度
以及它兄弟 JS 散度
重点来了,EM (Earth-Mover)距离,也称为 W 距离(Wasserstein-1)
Π(Pr,Pg) 表示所有的联合分布 γ(x,y) ,其边界为 Pr 和 Pg,直觉上,我们认为这是衡量从 x 到 y 所需的“质量”,需要花费多少力气才可以将分布 Pg 拉近 Pr。
下面将举出一些在 EM 距离下收敛,其他距离下不收敛的简单例子
令 Z ~ U[0,1],P0 为 (0,Z),x 轴与 y 轴上的分布,gθ(z) = (θ, z),θ 是一个参数,我们可以得到:
W(P0,Pθ) = |θ|
JS(P0,Pθ) = log2 (θ != 0) or 0 (θ == 0)
KL(Pθ,P0) = +∞ (θ != 0) or 0 (θ == 0)
δ(P0,Pθ) = 1 (θ != 0) or 0 (θ == 0)
明显的说明,只有在 W 距离下,参数会收敛到 0,g 收敛到 P0
通过以上例子我们可以在低维度流型空间内使用梯度下降来优化模型,在其他散度损失函数下,这是没有梯度甚至是不连续的。
在原文第 5 页开始有两个定理文末附有证明,简洁明了,不照搬了。
定理1
1)如果 θ—> g 连续,那么 W 距离连续
2)如果满足局部 Lipschitz 条件,那么 W 距离处处连续且可微
3)以上两点在 JS、KL 系列的散度定义中不存在
定理2
在对比下,W 距离比 JS、KL 等散度要敏感。
WGAN
通过理论的分析,可以知道 W 距离在性质上的优越性,然而,W 距离中有一个十分棘手的问题就是最大下界(?),同时,Kantorovich-Rubinstein 二重性也告诉我们
Inf 下界,inf 上界
其梯度表示
weight clipping 不是一个好的选择,但是他能强制的使得满足 Lipschitz 约束,如果约束参数过大,那么收敛会慢,如果约束过小,那么会出现梯度弥散(如果未使用 BN 或者网络巨大)。
算法流程
与传统相比较,
1)网络输出不加 sigmoid
2)优化时使用不基于动量的优化方法,如 RMS
3)使用 EM 距离取代之前的损失函数
最终解决的问题:
1)训练不稳定,不收敛
2)模型坍塌问题,(G倾向于生成一些有把握但相似的图片,而不敢轻易地尝试去生成没把握的新图片,即所谓的mode collapse问题)
3)生成多样性问题。
后面是一些文章中的实验结果展示,就过了。。。
存有疑问
* 文中提出关于距离函数的 strength(weak,strong)是什么
解决的问题
* 为什么是判别器具需要 weight clipping?
因为 EM 距离对于生成器而言是拉近距离,对于判别器而言是拉远距离,存在上界无限的情况,所以为了保障其正常训练,必须要有一个稳定的上界,故需要满足 Lipschitz 限制,所以需要对判别器具加以限制。
PS:
看别人笔记的过程中有一段关于改进原版 G 的问题部分写的比较好,作为保留
(但是最终自己没推出这个公式 ,不知从何而来)
(终于想通了)
核心损失函数部分代码
g = generator(noise)
d_real = discriminator(X)
d_fake = discriminator(g, reuse=True)
loss_d_real = -tf.reduce_mean(d_real)
loss_d_fake = tf.reduce_mean(d_fake)
loss_g = -tf.reduce_mean(d_fake)
loss_d = loss_d_real + loss_d_fake
alpha = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.)
interpolates = alpha * X + (1 - alpha) * g
grad = tf.gradients(discriminator(interpolates, reuse=True), [interpolates])[0]
slop = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1]))
gp = tf.reduce_mean((slop - 1.) ** 2)
loss_d += LAMBDA * gp
vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
wgan 不理解 损失函数_WGAN相关推荐
- wgan 不理解 损失函数_WGAN学习笔记
GAN自从被提出之后就受到了广泛的关注,GAN也被逐渐用于各种有趣的应用之中.虽然GAN的idea对研究者们有着巨大的吸引力,但是GAN的训练却不像普通DNN那样简单,generator和discri ...
- wgan 不理解 损失函数_WGAN源码解读
作者的代码包括两部分:models包下包含dcgan.py和mlp.py, 这两个py文件是两种不同的网络结构,在dcgan.py中判别器和生成器都含有卷积网络,而mlp.py中判别器和生成器都只是全 ...
- wgan 不理解 损失函数_AI初识:深度学习中常用的损失函数有哪些?
加入极市专业CV交流群,与6000+来自腾讯,华为,百度,北大,清华,中科院等名企名校视觉开发者互动交流!更有机会与李开复老师等大牛群内互动! 同时提供每月大咖直播分享.真实项目需求对接.干货资讯汇总 ...
- wgan 不理解 损失函数_[图像盲去噪与GAN]GCBD翻译理解
图像去噪是low-level视觉问题中的一个经典的话题.其退化模型为 y=x+v,图像去噪的目标就是通过减去噪声 v,从含噪声的图像 y 中得到干净图像 x .在很多情况下,因为各种因素的影响,噪声的 ...
- wgan 不理解 损失函数_GAN:「太难的部分我就不生成了,在下告退」
选自 arXiv 作者:David Bau, Jun-Yan Zhu等 机器之心编译 参与:Panda W 生成对抗网络(GAN)现在已经能合成极具真实感的图像了,但 MIT.IBM 和香港中文大学的 ...
- YOLOV5代码理解——损失函数的计算
YOLOV5代码理解--损失函数的计算 摘要: 神经网络的训练的主要流程包括图像输入神经网络, 得到模型的输出结果,计算模型的输出与真实值的损失, 计算损失值的梯度,最后用梯度下降算法更新模型参数.损 ...
- 可视化深入理解损失函数与梯度下降 | 技术头条
作者 | Hugegene 译者 | 刘畅 责编 | Rachel 出品 | AI科技大本营(id:rgznai100) [导语]本文对梯度函数和损失函数间的关系进行了介绍,并通过可视化方式进行了详细 ...
- “损失函数”是如何设计出来的?直观理解“最小二乘法”和“极大似然估计法”
[本文内容是自对视频:"损失函数"是如何设计出来的?的整理.补充和修正] 在大多数课程,尤其是帮助大家快速掌握深度学习的课程,损失函数似乎并不是一个需要额外关心的问题.因为它往往都 ...
- 【白话理解神经网络中的“损失函数”——最小二乘法和极大似然估计法】
目录 写在前面的话 理解损失函数 最小二乘法 最大似然估计法(统计方法) 写在前面的话 "损失函数"是如何设计出来的?直观理解"最小二乘法"和"极大似 ...
最新文章
- 批归一化和Dropout不能共存?这篇研究说可以
- “sockaddr_in”:“struct”类型重定义
- mysql 存储过程项目小结
- 6 频率_六级连续6年出现频率最高的200个词组【pdf版本】
- python list tuple 打包 解包_python的打包与解包
- [转] PyTorch 0.4新版本 升级指南 no_grad
- HTTP缓存策略 304
- 我靠ppt做兼职副业,1月还清2W贷款成功上岸!
- spring-boot-route(十九)spring-boot-admin监控服务
- 单片机c语言取反符号怎么打,arduino取反怎么写
- 吉他录音混音教程入门|连这些录音知识都不懂,以后还怎么“混”?| MZD Studios
- 手持式频谱分析仪怎么选择
- 解决sqliteman创建失败的一种方法
- Chrome断点调试
- 网络拓扑结构的优缺点分析
- 2022.04.14【读书笔记】|WGCNA分析原理和数据挖掘技巧
- 省市区三级级联JSON解析打印各级key及value
- html 可脱机浏览,如何脱机浏览Web页面
- VUE 引用腾讯地图
- C++题目:新的篮球队(题集)
热门文章
- listview的动态加载数据问题
- 常见的服务器响应状态码
- centos7-aliyun
- Win10 家庭版 升级到 专业版 的流程
- java long 对应mybati类型_MyBatis常用的jdbcType数据类型
- Wireshark 合并数据包
- 基于深度学习的海洋动物检测系统(Python+YOLOv5+清新界面)
- php 点击表头排序,点击表头切换升降序排序方式
- python request大批量发送请求调用接口时,报错:[WinError 10048] 通常每个套接字地址(协议/网络地址/端口)只允许使用一次。
- 常熟理工php实验三_西普学院(实验吧)Web题解