谷歌开源的 GAN 库--TFGAN
本文大约 8000 字,阅读大约需要 12 分钟
第一次翻译,限于英语水平,可能不少地方翻译不准确,请见谅!
最近谷歌开源了一个基于 TensorFlow 的库–TFGAN,方便开发者快速上手 GAN 的训练,其 Github 地址如下:
https://github.com/tensorflow/models/tree/master/research/gan
原文网址:Generative Adversarial Networks: Google open sources TensorFlow-GAN (TFGAN)
如果你玩过波斯王子,那你应该知道你需要保护自己不被”影子“所杀掉,但这也是一个矛盾:如果你杀死“影子”,那游戏就结束了;但你不做任何事情,那么游戏也会输掉。
尽管生成对抗网络(GAN)有不少优点,但它也面临着相似的区分问题。大部分支持 GAN 的深度学习专业也是非常谨慎的支持它,并指出它确实存在稳定性的问题。
GAN 的这个问题也可以称做整体收敛性问题。尽管判别器 D 和 生成器 D 相互竞争博弈,但同时也相互依赖对方来达到有效的训练。如果其中一方训练得很差,那整个系统也会很差(这也是之前提到的梯度消失或者模式奔溃问题)。并且你也需要确保他们不会训练太过度,造成另一方无法训练了。因此,波斯王子是一个很有趣的概念。
首先,神经网络的提出就是为了模仿人类的大脑(尽管是人为的)。它们也已经在物体识别和自然语言处理方面取得成功。但是,想要在思考和行为上与人类一致,这还有非常大的差距。
那么是什么让 GANs 成为机器学习领域一个热门话题呢?因为它不仅只是一个相对新的结构,它更加是一个比之前其他模型都能更加准确的对真实数据建模,可以说是深度学习的一个革命性的变化。
最后,它是一个同时训练两个独立的网络的新模型,这两个网络分别是判别器和生成器。这样一个非监督神经网络却能比其他传统网络得到更好性能的结果。
但目前事实是我们对 GANs 的研究还只是非常浅层,仍然有着很多挑战需要解决。GANs 目前也存在不少问题,比如无法区分在某个位置应该有多少特定的物体,不能应用到 3D 物体,以及也不能理解真实世界的整体结构。当然现在有大量研究正在研究如何解决上述问题,新的模型也取得更好的性能。
而最近谷歌为了让 GANs 更容易实现,设计开发并开源了一个基于 TensorFlow 的轻量级库–TFGAN。
根据谷歌的介绍,TFGAN 提供了一个基础结构来减少训练一个 GAN 模型的难度,同时提供非常好测试的损失函数和评估标准,以及给出容易上手的例子,这些例子强调了 TFGAN 的灵活性和易于表现的优点。
此外,还提供了一个教程,包含一个高级的 API 可以快速使用自己的数据集训练一个模型。
上图是展示了对抗损失在图像压缩方面的效果。最上方第一行图片是来自 ImageNet 数据集的图片,也是原始输入图片,中间第二行展示了采用传统损失函数训练得到的图像压缩神经网络的压缩和解压缩效果,最底下一行则是结合传统损失函数和对抗损失函数训练的网络的结果,可以看到尽管基于对抗损失的图片并不像原始图片,但是它比第二行的网络得到更加清晰和细节更好的图片。
TFGAN 既提供了几行代码就可以实现的简答函数来调用大部分 GAN 的使用例子,也是建立在包含复杂 GAN 设计的模式化方式。这就是说,我们可以采用自己需要的模块,比如损失函数、评估策略、特征以及训练等等,这些都是独立的模块。TFGAN 这样的设计方式其实就满足了不同使用者的需求,对于入门新手可以快速训练一个模型来看看效果,对于需要修改其中任何一个模块的使用者也能修改对应模块,而不会牵一发而动全身。
最重要的是,谷歌也保证了这个代码是经过测试的,不需要担心一般的 GAN 库造成的数字或者统计失误。
开始使用
首先添加以下代码来导入 tensorflow 和 声明一个 TFGAN 的实例:
import tensorflow as tf
tfgan = tf.contrib.gan
为何使用 TFGAN
- 采用良好测试并且很灵活的调用接口实现快速训练生成器和判别器网络,此外,还可以混合 TFGAN、原生 TensorFlow以及其他自定义框架代码;
- 使用实现好的GAN 的损失函数和惩罚策略 (比如 Wasserstein loss、梯度惩罚等)
- 训练阶段对 GAN 进行监控和可视化操作,以及评估生成结果
- 使用实现好的技巧来稳定和提高性能
- 基于常规的 GAN 训练例子来开发
- 采用GANEstimator接口里快速训练一个 GAN 模型
- TFGAN 的结构改进也会自动提升你的 TFGAN 项目的性能
- TFGAN 会不断添加最新研究的算法成果
TFGAN 的部件有哪些呢?
TFGAN 是由多个设计为独立的部件组成的,分别是:
- core:提供了一个主要的训练 GAN 模型的结构。训练过程分为四个阶段,每个阶段都可以采用自定义代码或者 调用 TFGAN 库接口来完成;
- features:包含许多常见的 GAN 运算和正则化技术,比如实例正则化(instance normalization)
- losses:包含常见的 GAN 的损失函数和惩罚机制,比如 Wasserstein loss、梯度惩罚、相互信息惩罚等
- evaulation:使用一个预训练好的 Inception 网络来利用
Inception Score
或者Frechet Distance
评估标准来评估非条件生成模型。当然也支持利用自己训练的分类器或者其他方法对有条件生成模型的评估 - examples and tutorial:使用 TFGAN 训练 GAN 模型的例子和教程。包含了使用非条件和条件式的 GANs 模型,比如 InfoGANs 等。
训练一个 GAN 模型
典型的 GAN 模型训练步骤如下:
- 为你的网络指定输入,比如随机噪声,或者是输入图片(一般是应用在图片转换的应用,比如 pix2pixGAN 模型)
- 采用
GANModel
接口定义生成器和判别器网络 - 采用
GANLoss
指定使用的损失函数 - 采用
GANTrainOps
设置训练运算操作,即优化器 - 开始训练
当然,GAN 的设置有多种形式。比如,你可以在非条件下训练生成器生成图片,或者可以给定一些条件,比如类别标签等输入到生成器中来训练。无论是哪种设置,TFGAN 都有相应的实现。下面将结合代码例子来进一步介绍。
实例
非条件 MNIST 图片生成
第一个例子是训练一个生成器来生成手写数字图片,即 MNIST 数据集。生成器的输入是从多变量均匀分布采样得到的随机噪声,目标输出是 MNIST 的数字图片。具体查看论文“Generative Adversarial Networks”。代码如下:
# 配置输入
# 真实数据来自 MNIST 数据集
images = mnist_data_provider.provide_data(FLAGS.batch_size)
# 生成器的输入,从多变量均匀分布采样得到的随机噪声
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])# 调用 tfgan.gan_model() 函数定义生成器和判别器网络模型
gan_model = tfgan.gan_model(generator_fn=mnist.unconditional_generator, discriminator_fn=mnist.unconditional_discriminator, real_data=images,generator_inputs=noise)# 调用 tfgan.gan_loss() 定义损失函数
gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.wasserstein_generator_loss,discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss)# 调用 tfgan.gan_train_ops() 指定生成器和判别器的优化器
train_ops = tfgan.gan_train_ops(gan_model,gan_loss,generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))# tfgan.gan_train() 开始训练,并指定训练迭代次数 num_steps
tfgan.gan_train(train_ops,hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],logdir=FLAGS.train_log_dir)
条件式 MNIST 图片生成
第二个例子同样还是生成 MNIST 图片,但是这次输入到生成器的不仅仅是随机噪声,还会给类别标签,这种 GAN 模型也被称作条件 GAN,其目的也是为了让 GAN 训练不会太过自由。具体可以看论文“Conditional Generative Adversarial Nets”。
代码方面,仅仅需要修改输入和建立生成器与判别器模型部分,如下所示:
# 配置输入
# 真实数据来自 MNIST 数据集,这里增加了类别标签--one_hot_labels
images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size)
# 生成器的输入,从多变量均匀分布采样得到的随机噪声
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])# 调用 tfgan.gan_model() 函数定义生成器和判别器网络模型
gan_model = tfgan.gan_model(generator_fn=mnist.conditional_generator, discriminator_fn=mnist.conditional_discriminator, real_data=images,generator_inputs=(noise, one_hot_labels)) # 生成器的输入增加了类别标签# 剩余的代码保持一致
...
对抗损失
第三个例子结合了 L1 pixel loss 和对抗损失来学习自动编码图片。瓶颈层可以用来传输图片的压缩表示。如果仅仅使用 pixel-wise loss,网络只回倾向于生成模糊的图片,但 GAN 可以用来让这个图片重建过程更加逼真。具体可以看论文“Full Resolution Image Compression with Recurrent Neural Networks”来了解如何用 GAN 来实现图像压缩,以及论文“Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”了解如何用 GANs 来增强生成的图片质量。
代码如下:
# 配置输入
images = image_provider.provide_data(FLAGS.batch_size)# 配置生成器和判别器网络
gan_model = tfgan.gan_model(generator_fn=nets.autoencoder, # 自定义的 autoencoderdiscriminator_fn=nets.discriminator, # 自定义的 discriminatorreal_data=images,generator_inputs=images)# 建立 GAN loss 和 pixel loss
gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.wasserstein_generator_loss,discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,gradient_penalty=1.0)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)# 结合两个 loss
gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)# 剩下代码保持一致
...
图像转换
第四个例子是图像转换,它是将一个领域的图片转变成另一个领域的同样大小的图片。比如将语义分割图变成街景图,或者是灰度图变成彩色图。具体细节看论文“Image-to-Image Translation with Conditional Adversarial Networks”。
代码如下:
# 配置输入,注意增加了 target_image
input_image, target_image = data_provider.provide_data(FLAGS.batch_size)# 配置生成器和判别器网络
gan_model = tfgan.gan_model(generator_fn=nets.generator, discriminator_fn=nets.discriminator, real_data=target_image,generator_inputs=input_image)# 建立 GAN loss 和 pixel loss
gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.least_squares_generator_loss,discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)# 结合两个 loss
gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)# 剩下代码保持一致
...
InfoGAN
最后一个例子是采用 InfoGAN 模型来生成 MNIST 图片,但是可以不需要任何标签来控制生成的数字类型。具体细节可以看论文“InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets”。
代码如下:
# 配置输入
images = mnist_data_provider.provide_data(FLAGS.batch_size)# 配置生成器和判别器网络
gan_model = tfgan.infogan_model(generator_fn=mnist.infogan_generator, discriminator_fn=mnist.infogran_discriminator, real_data=images,unstructured_generator_inputs=unstructured_inputs, # 自定义输入structured_generator_inputs=structured_inputs) # 自定义# 配置 GAN loss 以及相互信息惩罚
gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.wasserstein_generator_loss,discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,gradient_penalty=1.0,mutual_information_penalty_weight=1.0)# 剩下代码保持一致
...
自定义模型的创建
最后同样是非条件 GAN 生成 MNIST 图片,但利用GANModel
函数来配置更多参数从而更加精确控制模型的创建。
代码如下:
# 配置输入
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])# 手动定义生成器和判别器模型
with tf.variable_scope('Generator') as gen_scope:generated_images = generator_fn(noise)
with tf.variable_scope('Discriminator') as dis_scope:discriminator_gen_outputs = discriminator_fn(generated_images)
with variable_scope.variable_scope(dis_scope, reuse=True):discriminator_real_outputs = discriminator_fn(images)
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)# 依赖于你需要使用的 TFGAN 特征,你并不需要指定 `GANModel`函数的每个参数,不过
# 最少也需要指定判别器的输出和变量
gan_model = tfgan.GANModel(generator_inputs,generated_data,generator_variables,gen_scope,generator_fn,real_data,discriminator_real_outputs,discriminator_gen_outputs,discriminator_variables,dis_scope,discriminator_fn)# 剩下代码和第一个例子一样
...
最后,再次给出 TFGAN 的 Github 地址如下:
https://github.com/tensorflow/models/tree/master/research/gan
如果有翻译不当的地方或者有任何建议和看法,欢迎留言交流;也欢迎关注我的微信公众号–机器学习与计算机视觉或者扫描下方的二维码,和我分享你的建议和看法,指正文章中可能存在的错误,大家一起交流,学习和进步!
谷歌开源的 GAN 库--TFGAN相关推荐
- 谷歌开源张量网络库TensorNetwork,GPU处理提升100倍!
编译 | 琥珀 出品 | AI科技大本营(ID:rgznai100) 世界上许多最严峻的科学挑战,如开发高温超导体和理解时空的本质,都涉及处理量子系统的复杂性.然而,这些系统中量子态的数量程指数级增 ...
- 谷歌开源AutoML算法库,自动写出你想要的AI模型
向AI转型的程序员都关注了这个号
- 谷歌开源新模型EfficientNet,或成计算机视觉任务新基础
作者 | Mingxing Tan,Quoc V. Le,Google AI 译者 | 刘畅 责编 | 夕颜 出品 | AI科技大本营(id:rgznai100) 开发一个卷积神经网络(CNN)的成本 ...
- 谷歌开源 TFGAN,让训练和评估 GAN 变得更加简单
作者:思颖 概要:训练神经网络的时候,通常需要定义一个损失函数来告诉网络它离目标还有多远. 三年前,蒙特利尔大学 Ian Goodfellow 等学者提出「生成式对抗网络」(Generative Ad ...
- 谷歌开源 TensorFlow 的简化库 JAX
谷歌开源了一个 TensorFlow 的简化库 JAX. JAX 结合了 Autograd 和 XLA,专门用于高性能机器学习研究. 凭借 Autograd,JAX 可以求导循环.分支.递归和闭包函数 ...
- 各种NLP操作难实现?谷歌开源序列建模框架Lingvo
各种NLP操作难实现?谷歌开源序列建模框架Lingvo 自然语言处理在过去一年取得了很大进步,但直接关注 NLP 或序列建模的框架还很少.本文介绍了谷歌开源的 Lingvo,它是一种建立在 Tenso ...
- 取代C++?谷歌开源编程语言Carbon,网友评价太真实了
目前,Carbon编程语言正处于实验阶段. 在编程语言的世界中,C++的地位举足轻重.在2022年5月的TIOBE编程语言排行榜中,C++位列第四.同样地,谷歌内部也在广泛使用C++. 图源:htt ...
- 最喜欢随机森林?周志华团队 DF21 后,TensorFlow 开源决策森林库 TF-DF
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 转自 | 机器之心 TensorFlow 决策森林 (TF-DF) ...
- 活久见!谷歌开源“大杀器”,CV、NLP都能用!
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 萧箫 发自 凹非寺 来自 | 量子位 好消息,谷歌将AutoML算法 ...
最新文章
- 成功网页设计师的七大必备技能
- 回归分析中自变量共线性_具有大特征空间的回归分析中的变量选择
- 基于APMSSGA-LSTM的容器云资源预测
- android网页接口实现方法,Android 程序员搞 web 之 webApi (十 四)
- mysql的面试2_mysql数据库面试题(2)
- 关于代码调试de那些事
- 一键部署Zabbix客户端
- 主席树初探--BZOJ1901: Zju2112 Dynamic Rankings
- redis 经典36问
- python中使用frame需要安装_python – 在SFrame中分组而不安装graphlab
- scroll案例:带有动画的返回顶部
- 证监会计算机类笔试上岸经验,公务员考试笔试166分上岸经验(全干货)
- pscc2018安装服务器无响应,一招解决PSCC2018无法安装扩展插件
- 【报告分享】 2021年天猫618商务合作方案-天猫x阿里妈妈(附下载)
- 23 SpringBoot @Qualifier注解
- 开源协议之Code Project Open License (CPOL)
- 前端AI语音方面的实现
- 会计学原理学习笔记——第一章——总论(1.5会计目标)
- 层压卷轴标签的全球与中国市场2022-2028年:技术、参与者、趋势、市场规模及占有率研究报告
- 实践与共享:可一键自动搜寻添加有效ID并可成功激活 ESET Nod32 的小工具(绝对好用)...
热门文章
- html怎样将单元格的字竖式,数学竖式计算的标准格式是怎样的?需要注意哪些问题?...
- 怎样用c语言定义高幂整数,位操作运算的奇技淫巧!(附源码)
- Python -- post方式上传文件
- Linux内存卡(SD卡、TF卡)作为Swap交换空间
- './mysql-bin.index' not found (Errcode: 13) 的解决方法
- Java 基础——类和对象
- [Redux/Mobx] 你有了解Rxjs是什么吗?它是做什么的?
- 前端学习(3333):ant design介绍按钮类型
- [html] html的属性值有规定要使用单引号还是双引号吗?
- [css] 描述下你所了解的图片格式及使用场景