GAN 的调参技巧总结

  • 生成器的最后一层不使用 sigmoid,使用 tanh 代替

  • 使用噪声作为生成器的输入时,生成噪声的步骤使用 正态分布 的采样来产生,而不使用均匀分布

  • 训练 discriminator 的时候,将 fake img 的标签设为 1real img 的标签设成 0,这样更有利于其训练

  • 在 generator 和 discriminator 设计的时候都使用 dropout 层来增加随机性;或者在 discriminator 的标签中 添加噪声 来提高随机性;因为随机性对于 GAN 的训练有帮助

  • 在 discriminator 和 generator 中都使用 LeakyRelu 来作为激活函数而不用传统的 Relu

  • Conv2DTranspose,stride=2 来代替上采样操作;用 Conv2D,stride=2 来代替下采样操作(maxpooling)

  • 在生成的图像中,经常会见到棋盘状伪影,这是由于生成器中像素空间的不均匀覆盖造成的(如下图),为了解决这个问题,每当生成器和判别器中都使用步进的Conv2DTranspose或Conv2D时,使用的 内核大小要能够被步幅大小整除。例如 stride=2,kernel=(4,4)

  • discriminator 的容量和能力一定要小于 generator,因为判别远比生成容易,如果 discriminator 太强了,反而不利于 generator 的学习,就像一个太过于严厉的老师是不利于学生大踏步地进行创新和进步的,老师一定要温和。比例大约控制在 generator 的容量是 discriminator 的 10 倍左右。

  • 在训练的时候,如果不想调整网络的参数,那么可以尝试在训练的时候让一个网络训练好几次,然后另外一个网络训练一次,例如,如果 generator 很强,那么久让 discriminator 训练 3 次,generator 更新一次参数,加个 for 循环即可,亲测有效。

  • generator 在训练的过程中,前期大概率会被 discriminator 压制;因为刚开始生成的东西还很简单,因此要保证在训练的时候 generator 在前几个 epochs 不能 loss 很快地上升到 1,这样的话不利于后面的训练 ,正确的引导方式应该是设计 generator 的 loss 先上升到 0.7,0.8 左右,然后再慢慢降下来,这样就很有利于训练 GAN 网络

  • 一个被良好训练的 GAN 网络应该具有下面的 loss 走向:鉴别器和生成器都在波动,而不是一方的 loss 很快上升到 1,而另一方很快降到 0;

  • 定义网络的顺序是:

    • 定义 discriminator 网络,
    • 创建 discriminator,
    • compile discriminator,
    • 定义 generator,
    • 创建 generator;
    • 然后锁定 discriminator 的参数(dis.trainable=False)这时候千万不要再次 compile discriminator!!
    • 定义 GAN 的联合网络;
    • 创建 GAN 网络,
    • 然后 compile GAN 网络;因为 GAN 网络创建的时候 鉴别器被锁住,因此 GAN 网络 compile 之后能够记住这种状态,因此在 GAN 网络调用的时候 discriminator 永远都是锁住的,但是训练的其他时候,discriminator 都是没有被锁住的。这就是上面的 compile 顺序的巧妙!!!
  • 如果这里你没看懂,一定要在下面的代码中仔细留意关于 discriminator 的compile 部分,因为如果这里出问题,你最终即使网络训练的时候通过调整参数而达到了上图中演示的那种良好的交替波动情况,最后的输出也大概率属于下图中的情况,如果你遇到了下图的情况,请仔细考虑你的 compile 步骤:

  • 在训练 GAN 之前,尽量对数据进行标准化,图片数据除以 255. 即可,下面代码中也有演示,对数据规范化绝对是一个好习惯

代码

1. 导包

import keras,os
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.losses import *
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from keras.preprocessing import imagefrom keras.datasets import fashion_mnist,cifar10,cifar100,mnist
from keras.utils import to_categoricalos.environ["CUDA_VISIBLE_DEVICES"] = " 2"

2. 定义 generator

def generator(input_shape):inputs = Input(input_shape)x = Dense(128 * 16 * 16)(inputs)x = LeakyReLU()(x)x = Reshape((16, 16, 128))(x)x = Conv2D(256, 5, padding = 'same')(x)x = LeakyReLU()(x)x = Conv2DTranspose(256, 4, strides = 2, padding = 'same')(x)x = LeakyReLU()(x)x = Conv2D(256, 5, padding = 'same')(x)x = LeakyReLU()(x)x = Conv2D(256, 5, padding = 'same')(x)x = LeakyReLU()(x)x = Conv2D(3, 7, activation='tanh', padding = 'same')(x)return Model(inputs,x)

3. 创建 generator

gen = generator((100,))

4. 定义 discriminator

def discriminator(input_shape):inputs = Input(input_shape)x = Conv2D(128, 3)(inputs)x = LeakyReLU()(x)x = Conv2D(128, 4, strides=2)(x)x = LeakyReLU()(x)x = Conv2D(128, 4, strides=2)(x)x = LeakyReLU()(x)x = Conv2D(128, 4, strides=2)(x)x = LeakyReLU()(x)x = Flatten()(x)x = Dropout(0.4)(x)x = Dense(1, activation='sigmoid')(x) #分类层return Model(inputs,x)

5. 创建 discrminator 并 compile

dis = discriminator((32,32,3))
dis.compile(loss=keras.losses.binary_crossentropy,optimizer= keras.optimizers.RMSprop(lr = 0.0008,clipvalue = 1.0,decay=1e-8))

6. 定义 GAN 网络(先锁定 discrminator 但不 compile)


def GAN():dis.trainable=Falsegan_input = Input((100,))fake_image = gen(gan_input)score = dis(fake_image)return Model(gan_input,score)

7. 创建并 compile GAN 网络

gan = GAN()
gan.compile(loss=keras.losses.binary_crossentropy,optimizer=keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8))

8. 导入数据并规范化

(x_train,y_train),(x_test,y_test)= cifar10.load_data()y_train_label = y_train
y_test_label = y_testx_train = x_train[y_train.flatten() == 7]  #选择马类数据即可
x_train = x_train.reshape(5000,32,32,3).astype('float32')/255.

9. 训练过程

epochs = 4000
batch_size = 64
valid = np.ones((batch_size,1))
fake = np.zeros((batch_size,1))generated_img = []
discriminator_loss = []
generator_loss = []
save_dir = './A-GAN-PHOTO'for epoch in range(epochs):noise = np.random.normal(0,1,size=(batch_size,100))img_index = np.random.randint(0,5000,batch_size)fake_img = gen.predict(noise)real_img = x_train[img_index]data = np.concatenate([fake_img, real_img])label = np.concatenate([fake,valid])label += 0.05 * np.random.random(label.shape)d_loss = dis.train_on_batch(data,label)# ---------------------#  训练生成模型# ---------------------noise_ = np.random.normal(0,1,size=(batch_size,100))g_loss = gan.train_on_batch(noise_, valid)if epoch%100 == 0:im = fake_img[0]generated_img.append(im)img = image.array_to_img(fake_img[0] * 255, scale=False)img.save(os.path.join(save_dir, 'generated_horse' + str(epoch) + '.png'))    #保存一张生成图像img = image.array_to_img(real_img[0] * 255, scale=False)img.save(os.path.join(save_dir, 'real_horse' + str(epoch) +'.png'))   #保存一张真实图像用于对比print('discriminator_loss:',d_loss)print('adversal_loss:',g_loss)discriminator_loss.append(d_loss)generator_loss.append(g_loss)# discriminator_loss.append(d_loss[-1])# generator_loss.append(g_loss[-1])# print("d_loss:%f"%d_loss[-1])# print("g_loss:%f"%g_loss[-1])print("epoch:%d" % epoch + "========")

10. 展示 generated_img 里的所有图片

fig, axes = plt.subplots(nrows=2, ncols=20, sharex=True, sharey=True, figsize=(80,12))
imgs = generated_imgfor image, row in zip([imgs[:20], imgs[20:40]], axes):for img, ax in zip(image, row):ax.imshow(img)ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)fig.tight_layout(pad=0.1)

  • 可以看到生成效果还是不错的。

11. 画出训练过程的 GAN 的 loss 曲线

plt.plot(discriminator_loss,label='discriminator_loss')
plt.plot(generator_loss,label='generator_loss')
plt.legend()

大坑汇总

  • 虽然很多人提倡在训练 GAN 的时候使用 Batchnormalization 层,但是我在实操中发现使用这个层很多时候会导致一方的 loss 很快地偏向 1 或者 0,另一方也是,这种情况导致的大波动非常不利于 GAN 总体的训练
  • 还是上面提到的,注意 compile,签完不要在定义 GAN 网络的时候再 compile discriminator,那样的话 discriminator 就永远地锁定了。如果不在 GAN 中 compile discriminator,那么 discriminator 具有两种状态,那就是在别的位置都是不被锁定的,唯独在 GAN 中是被锁定的。

写在后面

最后,祝大家的 GAN 网络都能训练有素,耗子尾汁,别那么不讲武德。

不服就GAN:GAN网络生成 cifar10 的图片实例(keras 详细实现步骤),GAN 的训练的各种技巧总结,GAN的注意事项和大坑汇总相关推荐

  1. GAN网络生成手写体数字图片

    Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的. 目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接 ...

  2. php 创建透明png,php生成透明背景图片实例

    例子,php生成背景图片的代码. 复制代码 代码示例: //透明背景图片 header('content-type:text/html;charset=gbk'); $safess = $_get[s ...

  3. python批量生成图片并保存_Python批量生成幻影坦克图片实例代码

    前言 说到幻影坦克,我就想起红色警戒里的-- 幻影坦克(Mirage Tank),<红色警戒2>以及<尤里的复仇>中盟军的一款伪装坦克,盟军王牌坦克之一.是爱因斯坦在德国黑森林 ...

  4. python语言编写一个生成九宫格图片的代码_python简单实现9宫格图片实例

    在日常生活中我们经常在朋友圈看到有人发九宫格图片,其实质就是将一张图片切成九份,然后在微信中一起发这九张图. 那么我们如何自己动手实现呢? 说到切图Python 就可以实现,主要用到的 Python ...

  5. 一文看尽深度学习中的生成对抗(GAN)网络

    参考:<CVHub带你看一看GANs架构发展的8年> 导读 生成对抗网络 (Generative Adversarial Networks, GANs) 在过去几年中被广泛地研究,其在图像 ...

  6. DL之GAN:生成对抗网络GAN的简介、应用、经典案例之详细攻略

    DL之GAN:生成对抗网络GAN的简介.应用.经典案例之详细攻略 目录 生成对抗网络GAN的简介 1.生成对抗网络的重要进展 1.1.1986年的RBM→2006年的DBN

  7. 【视频课】生成对抗网络经典任务,详解基于GAN的图像生成算法!

    前言 欢迎大家关注有三AI的视频课程系列,我们的视频课程系列共分为5层境界,内容和学习路线图如下: 第1层:掌握学习算法必要的预备知识,包括Python编程,深度学习基础,数据使用,框架使用. 第2层 ...

  8. 花书+吴恩达深度学习(二八)深度生成模型之有向生成网络(VAE, GAN, 自回归网络)

    文章目录 0. 前言 1. sigmoid 信念网络 2. 生成器网络 3. 变分自编码器 VAE 4. 生成式对抗网络 GAN 5. 生成矩匹配网络 6. 自回归网络 6.1 线性自回归网络 6.2 ...

  9. 强化学习(二)--让你轻松玩转生成对抗网络(GAN)与生成对抗模仿学习(GAIL)

    GAN的基本结构 GAN的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator) GAN 充分利用"对抗过程"训练两个神经网络,这两个网络会互相 ...

最新文章

  1. vue中利用scss实现整体换肤和字体大小设置
  2. 经常使用的android弹出对话框
  3. .NET 数据访问架构指南(转)
  4. html css web笔记,Web/HTML/CSS/的笔记
  5. 建标库标准怎么导出pdf_保存和导出PDF文档,这款OCR文字识别软件能做到
  6. android 音视频 教程,Android移动端音视频的快速开发教程(九)
  7. asp和php数据库怎么区分,asp与php的数据库有哪些区别
  8. 不要滥用UNLOGGED table 和 hash index
  9. LeetCode 2063. 所有子字符串中的元音(数学)
  10. [深度学习]-基于tensorflow的CNN和RNN-LSTM文本情感分析对比
  11. maven + sonar, gradle + sonar
  12. webpack的五个核心概念---webpack工作笔记002
  13. SharePoint2010 空白站点集无法找到术语管理库
  14. hltm连接css的link,CSS 链接(link)
  15. Unity3D插件 Puppet3D的使用
  16. 基于JAVA医疗器械销售电子商城计算机毕业设计源码+系统+mysql数据库+lw文档+部署
  17. Android 11适配指南之系统相机拍照、打开相册
  18. java 无理数_《数学分析原理》笔记之——无理数的引入
  19. 台式计算机硬盘英寸,2.5英寸的机械硬盘,能安装到台式机里面用么?
  20. 视频编码:H.264编码

热门文章

  1. 品牌对比 | 蜜雪冰城 VS 喜茶
  2. 二阶求导 算法 实现 寻峰问题(转)
  3. Kong API Gateway 管理API详解
  4. 怎么识别图片中的文字?这三款识别软件还不错
  5. 8080 端口被占用的解决方法 netstat -ano;taskkill (命令行)
  6. netstat和netstat -ano
  7. 适合长时间佩戴的耳机有哪些、六款适合久戴的运动耳机推荐
  8. Windows10 安装spyder
  9. C++三角定位法求两圆交点坐标
  10. FAST-LIO公式推导