GAN最不好理解的就是Loss函数的定义和训练过程,这里用一段代码来辅助理解,就能明白到底是怎么回事。其实GAN的损失函数并没有特殊之处,就是常用的binary_crossentropy,关键在于训练过程中存在两个神经网络和两个损失函数。

np.random.seed(42)
tf.random.set_seed(42)codings_size = 30generator = keras.models.Sequential([keras.layers.Dense(100, activation="selu", input_shape=[codings_size]),keras.layers.Dense(150, activation="selu"),keras.layers.Dense(28 * 28, activation="sigmoid"),keras.layers.Reshape([28, 28])
])
discriminator = keras.models.Sequential([keras.layers.Flatten(input_shape=[28, 28]),keras.layers.Dense(150, activation="selu"),keras.layers.Dense(100, activation="selu"),keras.layers.Dense(1, activation="sigmoid")
])
gan = keras.models.Sequential([generator, discriminator])discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

这里generator并不用compile,因为gan网络已经compile了。具体原因见下文。

训练过程的代码如下

def train_gan(gan, dataset, batch_size, codings_size, n_epochs=50):generator, discriminator = gan.layersfor epoch in range(n_epochs):print("Epoch {}/{}".format(epoch + 1, n_epochs))              # not shown in the bookfor X_batch in dataset:# phase 1 - training the discriminatornoise = tf.random.normal(shape=[batch_size, codings_size])generated_images = generator(noise)X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)discriminator.trainable = Truediscriminator.train_on_batch(X_fake_and_real, y1)# phase 2 - training the generatornoise = tf.random.normal(shape=[batch_size, codings_size])y2 = tf.constant([[1.]] * batch_size)discriminator.trainable = Falsegan.train_on_batch(noise, y2)plot_multiple_images(generated_images, 8)                     # not shownplt.show()                                                    # not shown

第一阶段(discriminator训练)

# phase 1 - training the discriminator
noise = tf.random.normal(shape=[batch_size, codings_size])
generated_images = generator(noise)
X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
discriminator.trainable = True
discriminator.train_on_batch(X_fake_and_real, y1)

这个阶段首先生成数量相同的真实图片和假图片,concat在一起,即X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)。然后是label,真图片的label是1,假图片的label是0。

然后是迅速阶段,首先将discrinimator设置为可训练,discriminator.trainable = True,然后开始阶段。第一个阶段的训练过程只训练discriminator,discriminator.train_on_batch(X_fake_and_real, y1),而不是整个GAN网络gan

第二阶段(generator训练)

# phase 2 - training the generator
noise = tf.random.normal(shape=[batch_size, codings_size])
y2 = tf.constant([[1.]] * batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y2)

在第二阶段首先生成假图片,但是不再生成真图片。把假图片的label全部设置为1,并把discriminator的权重冻结,即discriminator.trainable = False。这一步很关键,应该这么理解:

前面第一阶段的是discriminator的训练,使真图片的预测值尽量接近1,假图片的预测值尽量接近0,以此来达到优化损失函数的目的。现在将discrinimator的权重冻结,网络中输入假图片,并故意把label设置为1。

注意,在整个gan网络中,从上向下的顺序是先通过geneartor,再通过discriminator,即gan = keras.models.Sequential([generator, discriminator])。第二个阶段将discrinimator冻结,并训练网络gan.train_on_batch(noise, y2)。如果generator生成的图片足够真实,经过discrinimator后label会尽可能接近1。由于故意把y2的label设置为1,所以如果genrator生成的图片足够真实,此时generator训练已经达到最优状态,不会大幅度更新权重;如果genrator生成的图片不够真实,经过discriminator之后,预测值会接近0,由于y2的label是1,相当于预测值不准确,这时候gan网络的损失函数较大,generator会通过更新generator的权重来降低损失函数。

之后,重新回到第一阶段训练discriminator,然后第二阶段训练generator。假设整个GAN网络达到理想状态,这时候generator产生的假图片,经过discriminator之后,预测值应该是0.5。假如这个值小于0.5,证明generator不是特别准确,在第二阶段训练过程中,generator的权重会被继续更新。假如这个值大于0.5,证明discriminator不是特别准确,在第一阶段训练中,discriminator的权征会被继续更新。

简单说,对于一张generator生成的假图片,discriminator会尽量把预测值拉下拉,generator会尽量把预测值往上扯,类似一个拔河的过程,最后达到均衡状态,例如0.6, 0.4, 0.55, 0.45, 0.51, 0.49, 0.50。

理解GAN对抗神经网络的损失函数和训练过程相关推荐

  1. Keras深度学习实战(1)——神经网络基础与模型训练过程详解

    Keras深度学习实战(1)--神经网络基础与模型训练过程详解 0. 前言 1. 神经网络基础 1.1 简单神经网络的架构 1.2 神经网络的训练 1.3 神经网络的应用 2. 从零开始构建前向传播 ...

  2. 条件生成对抗神经网络,生成对抗网络gan原理

    关于GAN生成式对抗网络中判别器的输出的问题 . ...摘要生成式对抗网络GAN(Generativeadversarialnetworks)目前已经成为人工智能学界一个热门的研究方向.GAN的基本思 ...

  3. 对抗神经网络算法 应用,对抗神经网络算法 英文

    深度学习什么是对抗式神经网络? 对抗式神经网络GAN让机器学会"左右互搏"GAN网络的原理本质上就是这两篇小说中主人公练功的人工智能或机器学习版本. 一个网络中有两个角色,修炼的过 ...

  4. 练习推导一个最简单的BP神经网络训练过程【个人作业/数学推导】

    写在前面: 各式资料中关于BP神经网络的讲解已经足够全面详尽,故不在此过多赘述.本文重点在于由一个"最简单"的神经网络练习推导其训练过程,和大家一起在练习中一起更好理解神经网络训练 ...

  5. 深入理解生成对抗网络(GAN 基本原理,训练崩溃,训练技巧,DCGAN,CGAN,pix2pix,CycleGAN)

    文章目录 GAN 基本模型 模型 GAN 的训练 模式崩溃 训练崩溃 图像生成中的应用 DCGAN:CNN 与 GAN 的结合 转置卷积 DCGAN CGAN:生成指定类型的图像 图像翻译中的应用 p ...

  6. 【深度学习理论】通俗理解生成对抗网络GAN

    作者 | 陈诚 来源 | 机器学习算法与自然语言处理 ▌1. 引言 自2014年Ian Goodfellow提出了GAN(Generative Adversarial Network)以来,对GAN的 ...

  7. 通俗理解生成对抗网络GAN

    0. 引言 自2014年Ian Goodfellow提出了GAN(Generative Adversarial Network)以来,对GAN的研究可谓如火如荼.各种GAN的变体不断涌现,下图是GAN ...

  8. 【Pytorch神经网络理论篇】 23 对抗神经网络:概述流程 + WGAN模型 + WGAN-gp模型 + 条件GAN + WGAN-div + W散度

    1 对抗神经简介 1.1 对抗神经网络的基本组成 1.1.1 基本构成 对抗神经网络(即生成式对抗网络,GAN)一般由两个模型组成: 生成器模型(generator):用于合成与真实样本相差无几的模拟 ...

  9. 赠书 | 读懂生成对抗神经网络 GAN,看这文就够了

    生成对抗神经网络(Generative Adversarial Nets,GAN)是一种深度学习的框架,它是通过一个相互对抗的过程来完成模型训练的.典型的GAN包含两个部分,一个是生成模型(Gener ...

最新文章

  1. Java中迭代列表中数据时几种循环写法的效率比较
  2. 警告:隐式声明与内建函数‘exit‘不兼容解决方案
  3. 进阶学习(3.2)Factory Method Pattern 工厂方法模式
  4. [备忘]silverlight中关于“复制到输出目录”和“生成操作”
  5. c语言小饭店等位就餐程序,C语言程序设计 C语言程序设计 3.C语言程序设计教案全部.doc...
  6. VUE安装依赖命令总结
  7. pythonrequests说明_解决Python requests 报错方法集锦
  8. call()函数、apply()函数区别与意义
  9. C Tricks(十五)—— 算符优先级的表示
  10. Processing-基础小坑-
  11. 模块ntdll中出现异常eaccessviolation_SAP ERP软件中的物料凭证 MIGO
  12. 高校机房建设 云服务器 终端,学校云机房建设使用NComputing微型终端机解决方案...
  13. linux 多线程编程-互斥锁问题之tpp.c:63: __pthread_tpp_change_priority failed 问题解决
  14. 【知识图谱】Neo4j 导入数据构建知识图谱的三种方法
  15. #186-[栈]法力水晶
  16. 全球智慧城市IOT市场规模报告
  17. 个人笔记——PointNet初学
  18. 计算机英文收集(二)
  19. SQLPro Studio for Mac(可视化数据库管理工具)
  20. 【区块链 | 前端】前端开发人员入门区块链的最佳实践

热门文章

  1. Redis数据类型-Hash-基本使用
  2. 【正点原子MP157连载】第四十一章 RGB转HDMI实验-摘自【正点原子】STM32MP1嵌入式Linux驱动开发指南V1.7
  3. MATLAB通信工具箱仿真16QAM系统
  4. 中国最初开始发展计算机是在哪一年,中国从哪一年开始有手机了,手机出现最早的城市在哪里...
  5. 四位行波进位加法器_【HDL系列】超前进位加法器原理与设计
  6. 【板栗糖GIS】——如何下载哔哩哔哩的视频CC字幕为不带时间节点的纯文字
  7. hp服务器重装系统按什么键,惠普重装系统按什么键|惠普u盘装系统按哪个键
  8. -- str --() 方法
  9. 上极限与上确界有什么区别
  10. Mysql第三方备份工具Xtrabackup使用说明