来源:机器之心

本文长度为3071字,建议阅读6分钟

本文在 MNIST 上对VAE和GAN这两类生成模型的性能进行了对比测试。

项目链接:https://github.com/kvmanohar22/ Generative-Models

变分自编码器(VAE)与生成对抗网络(GAN)是复杂分布上无监督学习最具前景的两类方法。

本项目总结了使用变分自编码器(Variational Autoencode,VAE)和生成对抗网络(GAN)对给定数据分布进行建模,并且对比了这些模型的性能。你可能会问:我们已经有了数百万张图像,为什么还要从给定数据分布中生成图像呢?正如 Ian Goodfellow 在 NIPS 2016 教程中指出的那样,实际上有很多应用。我觉得比较有趣的一种是使用 GAN 模拟可能的未来,就像强化学习中使用策略梯度的智能体那样。

本文组织架构:

  • 变分自编码器(VAE)

  • 生成对抗网络(GAN)

  • 训练普通 GAN 的难点

  • 训练细节

  • 在 MNIST 上进行 VAE 和 GAN 对比实验

    • 在无标签的情况下训练 GAN 判别器

    • 在有标签的情况下训练 GAN 判别器

  • 在 CIFAR 上进行 VAE 和 GAN 实验

  • 延伸阅读

VAE

变分自编码器可用于对先验数据分布进行建模。从名字上就可以看出,它包括两部分:编码器和解码器。编码器将数据分布的高级表征映射到数据的低级表征,低级表征叫作本征向量(latent vector)。解码器吸收数据的低级表征,然后输出同样数据的高级表征。

从数学上来讲,让 X 作为编码器的输入,z 作为本征向量,X′作为解码器的输出。

图 1 是 VAE 的可视化图。

图 1:VAE 的架构

这与标准自编码器有何不同?关键区别在于我们对本征向量的约束。如果是标准自编码器,那么我们主要关注重建损失(reconstruction loss),即:

而在变分自编码器的情况中,我们希望本征向量遵循特定的分布,通常是单位高斯分布(unit Gaussian distribution),使下列损失得到优化:

p(z′)∼N(0,I) 中 I 指单位矩阵(identity matrx),q(z∣X) 是本征向量的分布,其中由神经网络来计算。KL(A,B) 是分布 B 到 A 的 KL 散度。

由于损失函数中还有其他项,因此存在模型生成图像的精度,同本征向量的分布与单位高斯分布的接近程度之间存在权衡(trade-off)。这两部分由两个超参数λ_1 和λ_2 来控制。

GAN

GAN 是根据给定的先验分布生成数据的另一种方式,包括同时进行的两部分:判别器和生成器。

判别器用于对「真」图像和「伪」图像进行分类,生成器从随机噪声中生成图像(随机噪声通常叫作本征向量或代码,该噪声通常从均匀分布(uniform distribution)或高斯分布中获取)。生成器的任务是生成可以以假乱真的图像,令判别器也无法区分出来。也就是说,生成器和判别器是互相对抗的。判别器非常努力地尝试区分真伪图像,同时生成器尽力生成更加逼真的图像,目的是使判别器将这些图像也分类为「真」图像。

图 2 是 GAN 的典型结构。

图 2:GAN

生成器包括利用代码输出图像的解卷积层。图 3 是生成器的架构图。

图 3:典型 GAN 的生成器图示(图像来源:OpenAI)

训练 GAN 的难点

训练 GAN 时我们会遇到一些挑战,我认为其中最大的挑战在于本征向量/代码的采样。代码只是从先验分布中对本征变量的噪声采样。有很多种方法可以克服该挑战,包括:使用 VAE 对本征变量进行编码,学习数据的先验分布。这听起来要好一些,因为编码器能够学习数据分布,现在我们可以从分布中进行采样,而不是生成随机噪声。

训练细节

我们知道两个分布 p(真实分布)和 q(估计分布)之间的交叉熵通过以下公式计算:

  • 对于二元分类


  • 对于 GAN,我们假设分布的一半来自真实数据分布,一半来自估计分布,因此:

训练 GAN 需要同时优化两个损失函数。

按照极小极大值算法:

这里,判别器需要区分图像的真伪,不管图像是否包含真实物体,都没有注意力。当我们在 CIFAR 上检查 GAN 生成的图像时会明显看到这一点。

我们可以重新定义判别器损失目标,使之包含标签。这被证明可以提高主观样本的质量。如:在 MNIST 或 CIFAR-10(两个数据集都有 10 个类别)。

上述 Python 损失函数在 TensorFlow 中的实现:

def VAE_loss(true_images, logits, mean, std):"""Args:true_images : batch of input imageslogits      : linear output of the decoder network (the constructed images)mean        : mean of the latent codestd         : standard deviation of the latent code"""imgs_flat    = tf.reshape(true_images, [-1, img_h*img_w*img_d])encoder_loss = 0.5 * tf.reduce_sum(tf.square(mean)+tf.square(std)-tf.log(tf.square(std))-1, 1)decoder_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=img_flat), 1)return tf.reduce_mean(encoder_loss + decoder_loss)
  def GAN_loss_without_labels(true_logit, fake_logit):"""Args:true_logit : Given data from true distribution,`true_logit` is the output of Discriminator (a column vector)fake_logit : Given data generated from Generator,`fake_logit` is the output of Discriminator (a column vector)"""true_prob = tf.nn.sigmoid(true_logit)fake_prob = tf.nn.sigmoid(fake_logit)d_loss = tf.reduce_mean(-tf.log(true_prob)-tf.log(1-fake_prob))g_loss = tf.reduce_mean(-tf.log(fake_prob))return d_loss, g_loss  
  def GAN_loss_with_labels(true_logit, fake_logit):"""Args:true_logit : Given data from true distribution,`true_logit` is the output of Discriminator (a matrix now)fake_logit : Given data generated from Generator,`fake_logit` is the output of Discriminator (a matrix now)"""d_true_loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.labels, logits=self.true_logit, dim=1)d_fake_loss = tf.nn.softmax_cross_entropy_with_logits(labels=1-self.labels, logits=self.fake_logit, dim=1)g_loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.labels, logits=self.fake_logit, dim=1)d_loss = d_true_loss + d_fake_loss      return tf.reduce_mean(d_loss), tf.reduce_mean(g_loss)

在 MNIST 上进行 VAE 与 GAN 对比实验

1. 不使用标签训练判别器

我在 MNIST 上训练了一个 VAE。代码地址:https://github.com/kvmanohar22/Generative-Models

实验使用了 MNIST 的 28×28 图像,下图中:

  • 左侧:数据分布的 64 张原始图像

  • 中间:VAE 生成的 64 张图像

  • 右侧:GAN 生成的 64 张图像

第 1 次迭代:

第 2 次迭代:

第 3 次迭代:

第 4 次迭代:

第 100 次迭代:

VAE(125)和 GAN(368)训练的最终结果:

根据GAN迭代次数生成的gif图:

显然,VAE 生成的图像与 GAN 生成的图像相比,前者更加模糊。这个结果在预料之中,因为 VAE 模型生成的所有输出都是分布平均。为了减少图像的模糊度,我们可以使用 L1 损失来代替 L2 损失。

在第一个实验后,作者还将在近期研究使用标签训练判别器,并在 CIFAR 数据集上测试 VAE 与 GAN 的性能。

使用

  • 下载 MNIST 和 CIFAR 数据集

使用 MNIST 训练 VAE 请运行:

python main.py --train --model vae --dataset mnist

使用 MNIST 训练 GAN 请运行:

python main.py --train --model gan --dataset mnist

想要获取完整的命令行选项,请运行:

python main.py --help

该模型由 generate_frq 决定生成图片的频率,默认值为 1。

GAN 在 MNIST 上的训练结果

MNIST 数据集中的样本图像:

上方是 VAE 生成的图像,下方的图展示了 GAN 生成图像的过程:

延伸阅读


  • Tutorial on Variational Autoencoders by Carl Doersch

  • NIPS 2016 Tutorial: Generative Adversarial Networks (pdf) by Ian Goodfellow

  • NIPS 2016 - Generative Adversarial Networks (video) by Ian Goodfellow

  • NIPS 2016 Workshop on Adversarial Training - How to train a GAN by Soumith Chintala

原文链接:https://kvmanohar22.github.io/Generative-Models/

校对:朱江华峰

为保证发文质量、树立口碑,数据派现设立“错别字基金”,鼓励读者积极纠错

若您在阅读文章过程中发现任何错误,请在文末留言,或到后台反馈,经小编确认后,数据派将向检举读者发8.8元红包

同一位读者指出同一篇文章多处错误,奖金不变。不同读者指出同一处错误,奖励第一位读者。

感谢一直以来您的关注和支持,希望您能够监督数据派产出更加高质的内容。

在TensorFlow中对比两大生成模型:VAE与GAN(附测试代码)相关推荐

  1. 数据仓库中的两大经典模型

    在数据分析相关内容中,包括两大重要内容:一是底层数据系统建设内容,二是业务报表相关内容梳理.一是系统基础,二是基础之上的业务逻辑衍生. 在番茄风控之前的数据分析课程中,主要集中在以上的第二点即业务相关 ...

  2. 生成模型VAE、GAN和基于流的模型详细对比

    在Ian Goodfellow和其他研究人员在一篇论文中介绍生成对抗网络两年后,Yann LeCun称对抗训练是"过去十年里ML最有趣的想法".尽管GANs很有趣,也很有前途,但它 ...

  3. 过拟合和欠拟合_现代深度学习解决方案中的两大挑战:拟合和欠拟合

    全文共2306字,预计学习时长5分钟 对机器学习模型而言,最糟糕的两种情况无非是构建无用的知识体系,或是从训练数据集中一无所获.在机器学习理论中,这两种现象分别被称为过拟合和欠拟合,是现代深度学习解决 ...

  4. 怎么在excel中对比两列数据并查找重复项

    怎么在excel中对比两列数据并查找重复项 方法一: 方法二: 方法三: Excel查找2列相同的数据,并且返回对应列的另1列数据: IF函数语法格式: 方法一: =MATCH(A1,D$1:D95, ...

  5. 【GAN优化外篇】详解生成模型VAE的数学原理

    最近在学习生成模型的相关知识,这篇文章将介绍一下变分自编码器(Variational Auto-encoder),本文只介绍一些粗浅内容,不会涉及比较深刻的问题. 作者&编辑 | 小米粥 1. ...

  6. Python教学 | Python 中的分支结构(判断语句)【附本文代码和数据】

    查看原文:[数据seminar]Python教学 | Python 中的分支结构(判断语句)[附本文代码和数据] Part1引言 上期文章我们学习了组合数据类型字典以及元组,这标志着 Python 基 ...

  7. 三种方式获取大疆照片的EXIF/XMP信息(附测试代码)

    目录 软件方式 在线方式 Python方式 第一种:pyexiv2 第二种:pyexif 测试代码:三种方式获取大疆照片的EXIF/XMP信息(附测试代码) - 小锋学长生活大爆炸 (xfxuezha ...

  8. excel中对比两个sheet,找出匹配不上的

    问题描述:数据的特点是,在同一个excel文件中存在两个sheet,他们的数据结构是一样的,其中一个中的数据是另一个的子集,目的是要找出他们的不同,即找出在那张大些的sheet中存在,但在那张小些的s ...

  9. Excel中对比两列数据的不同并做特殊标记

    最近在处理一批自然保护区的数据,数据的来源不同,要对比两种数据的区别,当然要用Excel进行处理了.如下图: 有A和C两列数据,A列数据是中国林业科学院的网站上爬取到的:C列数据是国家环保部公布的PD ...

最新文章

  1. C语言函数集(十三)
  2. 关于学习新技术的方法
  3. 微服务怎么部署到服务器的_浅谈微服务部署方案
  4. 安装完最小化 RHEL/CentOS 7 后需要做的 30 件事情(三)码农网
  5. python爬虫实战教程分享 或许你可以看一下这篇文章
  6. android textview字体贴底部,在android中底部设置textview
  7. 2个比较经典的PHP加密解密函数分享
  8. addEventListener监听
  9. 计算机新建没有文本文档,我的电脑新建文本文档没有显示TXT,为什么?
  10. TeamViewer远程唤醒主机实战教程(多图)
  11. Clang 10.0 手写静态分析器Checker
  12. 为什么你写了一万小时的代码,却没能成为架构师?| 程序员有话说
  13. 爬虫项目实操五、用Scrapy爬取当当图书榜单
  14. 批量删除微博的js代码
  15. unity中导入的角色没有Avatar
  16. 信用评分卡DAY8-9
  17. Mybatis generator mapper文件重新生成不会覆盖原文件
  18. 骑士CMS模版注入+文件包含getshell复现
  19. GameofMir引擎架设传奇服务器【1:架设服务端】
  20. python内容推荐理由_好书推荐~第5期 | Python 数据可视化

热门文章

  1. 理解和解决Java并发修改异常ConcurrentModificationException(转载)
  2. 使用Identity Server 4建立Authorization Server (3)
  3. 电信应在短时间内放弃CDMA网络
  4. 【Unity3D】 KeyCode 键码
  5. Ext JS 6正式版的GPL版本下载地址
  6. OJ在线编程----常见输入输出练习场
  7. 医院选址问题--数据结构课程设计
  8. python中True 为1 ,False为0
  9. Spring集成MyBatis完整示例
  10. java 中的进制转换