生成对抗网络主要分为生成器网络判别器网络。

  • 生成器网络:他以一个随机向量(潜在空间的一个随机点)作为输入,并将其解码成一张合成图像。
  • 判别器网络:以一张图像(真实的或合成的均可)作为输入,并预测该图像是来自训练集还是生成器网络创建。

本节将会介绍如何用 Keras 来实现形式最简单的 GAN。GAN 属于高级应用,所以本书不会深入介绍其技术细节。我们具体实现的是一个深度卷积生成式对抗网络(DCGAN,deep convolutional GAN),即生成器和判别器都是深度卷积神经网络的 GAN。特别地,它在生成器中使用 Conv2DTranspose 层进行图像上采样。 我们将在 CIFAR10 数据集的图像上训练 GAN,这个数据集包含 50 000 张 32×32 的 RGB图像,这些图像属于 10 个类别(每个类别 5000 张图像)。为了简化,我们只使用属于“frog”(青蛙)类别的图像。

GAN 的简要实现流程如下所示。

  • (1) generator 网络将形状为 (latent_dim,) 的向量映射到形状为 (32, 32, 3) 的图像。
  • (2) discriminator 网络将形状为 (32, 32, 3) 的图像映射到一个二进制分数,用于评估图像为真的概率。
  • (3) gan 网络将 generator 网络和 discriminator 网络连接在一起:gan(x) = discriminator (generator(x))。生成器将潜在空间向量解码为图像,判别器对这些图像的真实性进 行评估,因此这个 gan 网络是将这些潜在向量映射到判别器的评估结果。
  • (4) 我们使用带有“真”/“假”标签的真假图像样本来训练判别器,就和训练普通的图像 分类模型一样。
  • (5) 为了训练生成器,我们要使用 gan 模型的损失相对于生成器权重的梯度。这意味着,在每一步都要移动生成器的权重,其移动方向是让判别器更有可能将生成器解码的图像划分为“真”。换句话说,我们训练生成器来欺骗判别器。

训练gan的技巧

训练 GAN 和调节 GAN 实现的过程非常困难。你应该记住一些公认的技巧。与深度学习中的大部分内容一样,这些技巧更像是炼金术而不是科学,它们是启发式的指南,并没有理论 的支持。这些技巧得到了一定程度的来自对现象的直观理解的支持,经验告诉我们,它们的效果都很好,但不一定适用于所有情况。

下面是本节实现 GAN 生成器和判别器时用到的一些技巧。这里并没有列出与 GAN 相关的 全部技巧,更多技巧可查阅关于 GAN 的文献。

  • 我们使用 tanh 作为生成器最后一层的激活,而不用 sigmoid,后者在其他类型的模型中更加常见。
  • 我们使用正态分布(高斯分布)对潜在空间中的点进行采样,而不用均匀分布。
  • 随机性能够提高稳健性。训练 GAN 得到的是一个动态平衡,所以 GAN 可能以各种方式“卡住”。在训练过程中引入随机性有助于防止出现这种情况。我们通过两种方式引入随机性: 一种是在判别器中使用 dropout,另一种是向判别器的标签添加随机噪声。
  • 稀疏的梯度会妨碍 GAN 的训练。在深度学习中,稀疏性通常是我们需要的属性,但在GAN 中并非如此。有两件事情可能导致梯度稀疏:最大池化运算和 ReLU 激活。我们推荐使用步进卷积代替最大池化来进行下采样,还推荐使用 LeakyReLU 层来代替 ReLU 激活。LeakyReLU 和 ReLU 类似,但它允许较小的负数激活值,从而放宽了稀疏性限制。
  • 在生成的图像中,经常会见到棋盘状伪影,这是由生成器中像素空间的不均匀覆盖导致的。为了解决这个问题,每当在生成器和判别器中都使用步进的 Conv2DTranpose 或 Conv2D 时,使用的内核大小要能够被步幅大小整除。

生成器

首先,我们来开发 generator 模型,它将一个向量(来自潜在空间,训练过程中对其随机采样)转换为一张候选图像。GAN 常见的诸多问题之一,就是生成器“卡在”看似噪声的生成图像上。一种可行的解决方案是在判别器和生成器中都使用 dropout。

import keras
from keras import layers
import numpy as nplatent_dim = 32
height = 32
width = 32
channels = 3generator_input = keras.Input(shape=(latent_dim,))# First, transform the input into a 16x16 128-channels feature map
# 将输入转换为大小为 16×16 的128 个通道的特征图x = layers.Dense(128 * 16 * 16)(generator_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((16, 16, 128))(x)# Then, add a convolution layer(添加一个卷积层)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)# Upsample to 32x32(上采样为 32×32)
x = layers.Conv2DTranspose(256, 4, strides=2, padding='same')(x)
x = layers.LeakyReLU()(x)# Few more conv layers
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)# Produce a 32x32 1-channel feature map(生成一个大小为 32×32 的单通道特征图(即 CIFAR10 图像的形状))
x = layers.Conv2D(channels, 7, activation='tanh', padding='same')(x)
generator = keras.models.Model(generator_input, x)
#将生成器模型实例化,它将形状为 (latent_dim,)的输入映射到形状为 (32, 32, 3) 的图像
generator.summary()

判别器

接下来,我们来开发 discriminator 模型,它接收一张候选图像(真实的或合成的)作为输入,并将其划分到这两个类别之一:“生成图像”或“来自训练集的真实图像”

discriminator_input = layers.Input(shape=(height, width, channels))
x = layers.Conv2D(128, 3)(discriminator_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)# One dropout layer - important trick!
# 一个 dropout 层:这是很重要的技巧
x = layers.Dropout(0.4)(x)# Classification layer(判别层)
x = layers.Dense(1, activation='sigmoid')(x)discriminator = keras.models.Model(discriminator_input, x)
discriminator.summary()# To stabilize training, we use learning rate decay
# and gradient clipping (by value) in the optimizer.
# 将判别器模型实例化,它将形状为 (32, 32, 3)的输入转换为一个二进制分类决策(真 / 假)discriminator_optimizer = keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=1e-8)# clipvalue=1.0:在优化器中使用梯度裁剪(限制梯度值的范围)
# decay=1e-8:为了稳定训练过程,使用学习率衰减
discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy')

对抗网络

最后,我们要设置 GAN,将生成器和判别器连接在一起。训练时,这个模型将让生成器向某个方向移动,从而提高它欺骗判别器的能力。这个模型将潜在空间的点转换为一个分类决策(即 “真”或“假”),它训练的标签都是“真实图像”。因此,训练 gan 将会更新 generator 的权重, 使得 discriminator 在观察假图像时更有可能预测为“真”。请注意,有一点很重要,就是在训练过程中需要将判别器设置为冻结(即不可训练),这样在训练 gan 时它的权重才不会更新。 如果在此过程中可以对判别器的权重进行更新,那么我们就是在训练判别器始终预测“真”,但这并不是我们想要的!

# 将判别器权重设置为不可训练(仅应用于 gan 模型)
discriminator.trainable = Falsegan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)gan_optimizer = keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')

如何训练DCGAN

现在开始训练。再次强调一下,训练循环的大致流程如下所示。每轮都进行以下操作。

  • (1) 从潜在空间中抽取随机的点(随机噪声)。
  • (2) 利用这个随机噪声用 generator 生成图像。
  • (3) 将生成图像与真实图像混合。
  • (4) 使用这些混合后的图像以及相应的标签(真实图像为“真”,生成图像为“假”)来训练discriminator。
  • (5) 在潜在空间中随机抽取新的点。
  • (6) 使用这些随机向量以及全部是“真实图像”的标签来训练 gan。这会更新生成器的权重(只更新生成器的权重,因为判别器在 gan 中被冻结),其更新方向是使得判别器能够将生成图像预测为“真实图像”。这个过程是训练生成器去欺骗判别器。
import os
from keras.preprocessing import image# Load CIFAR10 data(加载 CIFAR10数据)
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()# Select frog images (class 6)(选择青蛙图像(类别编号为 6))
x_train = x_train[y_train.flatten() == 6]# Normalize data(数据标准化)
x_train = x_train.reshape((x_train.shape[0],) + (height, width, channels)).astype('float32') / 255.iterations = 10000
batch_size = 20
save_dir = 'data/gan_images/'
# 指定保存生成 图像的目录# Start training loop(开始循环)
start = 0
for step in range(iterations):# Sample random points in the latent space(在潜在空间中采样随机点)random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))# Decode them to fake images(将这些点解码为虚假图像)generated_images = generator.predict(random_latent_vectors)# Combine them with real images(将这些虚假图像与真实图像合在一起)stop = start + batch_sizereal_images = x_train[start: stop]combined_images = np.concatenate([generated_images, real_images])# Assemble labels discriminating real from fake images#合并标签,区分真实和虚假的图像labels = np.concatenate([np.ones((batch_size, 1)),np.zeros((batch_size, 1))])# Add random noise to the labels - important trick!#向标签中添加随机噪声,这是一个很重要的技巧labels += 0.05 * np.random.random(labels.shape)# Train the discriminator(训练判别器)d_loss = discriminator.train_on_batch(combined_images, labels)# sample random points in the latent space(在潜在空间中采样随机点)random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))# Assemble labels that say "all real images"(合并标签,全部是“真实图像”(这是在撒谎,通过 gan 模型)misleading_targets = np.zeros((batch_size, 1))# Train the generator (via the gan model,# where the discriminator weights are frozen)#来训练生成器(此时冻结判别器权重)a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)start += batch_sizeif start > len(x_train) - batch_size:start = 0# Occasionally save / plotif step % 100 == 0:# Save model weights(保存模型权重)gan.save_weights('gan.h5')# Print metrics(将指标打印出来)print('discriminator loss at step %s: %s' % (step, d_loss))print('adversarial loss at step %s: %s' % (step, a_loss))# Save one generated image(保存一张生成图像)img = image.array_to_img(generated_images[0] * 255., scale=False)img.save(os.path.join(save_dir, 'generated_frog' + str(step) + '.png'))# Save one real image, for comparison(保存一张真实图像,用于对比)img = image.array_to_img(real_images[0] * 255., scale=False)img.save(os.path.join(save_dir, 'real_frog' + str(step) + '.png'))

下面展示生成的图像

import matplotlib.pyplot as plt# sample random points in the latent space(在潜在空间中采样随机点)
random_latent_vectors = np.random.normal(size=(10, latent_dim))# Decode them to fake images(将这些点解码为虚假图像)
generated_images = generator.predict(random_latent_vectors)for i in range(generated_images.shape[0]):img = image.array_to_img(generated_images[i] * 255., scale=False)plt.figure()plt.imshow(img)plt.show()

从零开始学keras之生成对抗网络GAN相关推荐

  1. 『一起学AI』生成对抗网络(GAN)原理学习及实战开发

     参考并翻译教程:https://d2l.ai/chapter_generative-adversarial-networks/gan.html,加入笔者的理解和心得 1.生成对抗网络原理 在Col ...

  2. Keras实现生成对抗网络(GAN)(生成二维平面上服从某一分布的点)

    GAN原理 相关数学推导可参考 李宏毅https://www.bilibili.com/video/av36779967/?p=4 通俗的比喻:制造假钞(G)和警察(D)对抗的过程.假钞制造者制造假钞 ...

  3. 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)

     不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN) 生成对抗网络(Generative Adversarial Networks,GAN)最早由 Ian Goodfello ...

  4. 基于Keras的生成对抗网络(1)——利用Keras搭建简单GAN生成手写体数字

    目录 0.前言 一.GAN结构 二.函数代码 2.1 生成器Generator 2.2 判别器Discriminator 2.3 train函数 三.结果演示 四.完整代码 五.常见问题汇总 0.前言 ...

  5. [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  6. 生成对抗网络——GAN(一)

    Generative adversarial network 据有关媒体统计:CVPR2018的论文里,有三分之一的论文与GAN有关 由此可见,GAN在视觉领域的未来多年内,将是一片沃土(CVer们是 ...

  7. 生成对抗网络(GAN)资料打包

    进入正文 全文 摘要 生成式对抗网络,即所谓的GAN是近些年来最火的无监督学习方法之一,模型由Goodfellow等人在2014年首次提出,将博弈论中非零和博弈思想与生成模型结合在一起,巧妙避开了传统 ...

  8. 生成对抗网络gan原理_中国首个“芯片大学”即将落地;生成对抗网络(GAN)的数学原理全解...

    开发者社区技术周刊又和大家见面了,萌妹子主播为您带来第三期"开发者技术联播".让我们一起听听,过去一周有哪些值得我们开发者关注的重要新闻吧. 中国首个芯片大学,南京集成电路大学即将 ...

  9. 【GAN优化】长文综述解读如何定量评价生成对抗网络(GAN)

    欢迎大家来到<GAN优化>专栏,这里将讨论GAN优化相关的内容,本次将和大家一起讨论GAN的评价指标. 作者&编辑 | 小米粥 编辑 | 言有三 在判别模型中,训练完成的模型要在测 ...

最新文章

  1. Linux Ubuntu从零开始部署web环境及项目 -----tomcat+jdk+mysql (二)
  2. 地理位置经纬度在Mysql中用什么字段类型
  3. 《明解C语言 入门篇》第4章 程序的循环控制 练习题解答
  4. Python根据字幕文件自动给视频添加字幕
  5. 崔荣容,英语如法入门1-50讲
  6. win 64 安装 sql server 2000、出现挂起 解决
  7. Python【一点号】短视频的自动上传与发布实例演示,同时支持抖音、快手、哔哩哔哩、小红书、微视、西瓜视频、微信视频号等平台的视频自动化同步发布
  8. Troubleshooting: WAITED TOO LONG FOR A ROW CACHE ENQUEUE LOCK!
  9. IE10以上input自带的叉号和眼睛
  10. AES算法在Wi-Fi加密中的应用
  11. unity中3D数学相关类、属性、方法、用途总结+超级综合的案例
  12. 设计模式二:创建型-工厂模式
  13. 2013驾考科目一理论知识重点归纳
  14. PHPstorm 函数或者方法的注释的时间和用户名,PHPstorm里函数方法的注释是没有动态时间设置的,但是看了PHP file里面有时间日期的注释,而PHP Function Doc Commen
  15. html5 css背景图片满,css background-size与背景图片填满div
  16. CyberLink PowerDVD Ultra v19.0.2005.62极致中文破解版
  17. 什么是最好的UML在线免费软件
  18. iPhone 13 支持卫星上网?没那么简单
  19. 0240 计算机维修技术,0240.2016《计算机维修技术》西南大学网上作业题和答案.doc...
  20. 常见linux命令介绍-ps

热门文章

  1. 金蝶云星空使用WebAPI来新增单据
  2. 解决 pycharm can not save setting
  3. Linux服务器系统备份还原
  4. requests获取响应时间(elapsed)与超时(timeout)
  5. python sort dict 总结
  6. 乌版图 read-only file system
  7. 只能由中文、字母、数字、下划线组成的字符串
  8. IIS下PHP的ISAPI和FastCGI比较
  9. c语言字符串提取第二个字符,c语言如何复制字符串(取前n个字符)strncpy()函数的应用实例...
  10. java运用ascii实现动画效果_安卓开发20:动画之Animation 详细使用-主要通过java代码实现动画效果...