系列文章目录

深度学习GAN(一)之简单介绍
深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子
深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子
深度学习GAN(四)之cGAN (Conditional GAN)的例子
深度学习GAN(五)之PIX2PIX GAN的例子
深度学习GAN(六)之CycleGAN的例子


CycleGAN的入门例子-Tensorflow2.1-keras

  • 系列文章目录
  • 1. 什么是CycleGAN
  • 2. 数据集准备
  • 3. 环境准备
  • 4. 怎么创建CycleGAN转换马到斑马
    • 4.1. 定义判别器
    • 4.2. 定义生成器
    • 4.3. 定义复合模型
    • 4.4 加载真实图片以及生成假的图片
    • 4.4 保存模型
    • 4.5 利用生成器生成图片
    • 4.6 更新图像池
    • 4.7 训练模型
  • 5. 训练结果
  • 6. 完整代码
  • 7. 如何使用CycleGAN生成器执行图像转换
  • 6. 个人总结

1. 什么是CycleGAN

CycleGAN模型是在2017的这篇论文中提出的-Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.。

CycleGAN模型的好处是无需配对示例即可对其进行训练。 也就是说,为了训练模型,例如在转换之前和之后,不需要照片的示例。 白天和晚上都拍摄同一城市景观的照片。 取而代之的是,该模型能够使用来自每个域的照片集合,并提取和利用集合中图像的基础样式来执行翻译。

该模型体系结构由两个生成器模型组成:一个生成器(Generator-A)用于生成第一域(Domain-A)的图像,第二生成器(Generator-B)用于生成第二域(Domain-B)的图像 。

  • Generator-A -> Domain-A
  • Generator-B -> Domain-B

生成器模型执行图像转换,这意味着图像生成过程取决于输入图像,特别是来自其他domain的图像。 生成器A从Domain-B获取图像作为输入,生成器B从Domain-A获取图像作为输入。

  • Domain-B -> Generator-A -> Domain-A
  • Domain-A -> Generator-B -> Domain-B

每个生成器都有一个对应的判别器模型。 第一个判别器模型(Discriminator-A)从Domain-A获取真实图像,并从Generator-A生成图像,并预测它们是真实的还是假的。 第二个判别器模型(Discriminator-B)从Domain-B获取真实图像,并从Generator-B生成图像,并预测它们是真实的还是伪造的。

  • Domain-A -> Discriminator-A -> [Real/Fake]
  • Domain-B -> Generator-A -> Discriminator-A -> [Real/Fake]
  • Domain-B -> Discriminator-B -> [Real/Fake]
  • Domain-A -> Generator-B -> Discriminator-B -> [Real/Fake]

像常规GAN模型一样,判别器和生成器模型是在对抗性的零和过程中训练的。 生成器学会更好地欺骗判别器,判别器学会更好地检测伪造图像。 在一起,模型在训练过程中找到了平衡。

此外,生成器模型经过了规范化处理,不仅可以在目标域中创建新图像,还可以转换源域中输入图像的更多重构版本。 这是通过将生成的图像用作相应生成器模型的输入并将输出图像与原始图像进行比较来实现的。 通过两个生成器传递图像称为循环。 一起训练每一对生成器模型,以更好地重现原始源图像,称为循环一致性。

  • Domain-B -> Generator-A -> Domain-A -> Generator-B -> Domain-B
  • Domain-A -> Generator-B -> Domain-B -> Generator-A -> Domain-A

该体系结构还有另一个元素,称为身份映射。 在这里,为生成器提供了来自目标域的图像作为输入,并且可以生成相同的图像而无需更改。 尽管可以使输入图像的颜色配置文件更好地匹配,但是对体系结构的这种添加是可选的。

  • Domain-A -> Generator-A -> Domain-A
  • Domain-B -> Generator-B -> Domain-B

2. 数据集准备

我们用的数据集为“ horses2zebra”。 该数据集的zip文件约为111M,可以从CycleGAN网页下载:
Download Horses to Zebras Dataset (111 megabytes)

减压后你会看到这样的目录结构

horse2zebra
├── testA
├── testB
├── trainA
└── trainB

打开testA文件夹,里面都是马的图片

打开testB文件夹,里面都是斑马的图片

每张图片的大小都是256x256

下面的代码从train和test文件夹中加载所有照片,并为A类创建一个图像数组,为B类创建另一个图像。

然后将两个数组都以压缩的NumPy数组格式保存到新文件中 horse2zebra_256.npz。

from os import listdir
import numpy as np
from numpy import asarray
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img# load all images in a directory into memory
def load_images(path, size=(256, 256)):data_list = list()# enumerate filenames in directory, assume all are imagesfor filename in listdir(path):# load and resize the imagepixels = load_img(path + filename, target_size=size)# convert to numpy arraypixels = img_to_array(pixels)# storedata_list.append(pixels)return asarray(data_list)# dataset path
path = 'D:/ML/datasets/horse2zebra/'
# load dataset A
dataA1 = load_images(path + 'trainA/')
dataAB = load_images(path + 'testA/')
dataA = np.vstack((dataA1, dataAB))
print('Loaded dataA: ', dataA.shape)
# load dataset B
dataB1 = load_images(path + 'trainB/')
dataB2 = load_images(path + 'testB/')
dataB = np.vstack((dataB1, dataB2))
print('Loaded dataB: ', dataB.shape)
# save as compressed numpy array
filename = 'horse2zebra_256.npz'
np.savez_compressed(filename, dataA, dataB)
print('Saved dataset: ', filename)
Loaded dataA:  (1187, 256, 256, 3)
Loaded dataB:  (1474, 256, 256, 3)
Saved dataset:  horse2zebra_256.npz

应用下面代码加载horse2zebra_256.npz数据集,然后用matplotlib显示图片。

# load and plot the prepared dataset
from numpy import load
from matplotlib import pyplot
# load the dataset
data = load('horse2zebra_256.npz')
dataA, dataB = data['arr_0'], data['arr_1']
print('Loaded: ', dataA.shape, dataB.shape)
# plot source images
n_samples = 3
for i in range(n_samples):pyplot.subplot(2, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(dataA[i].astype('uint8'))
# plot target image
for i in range(n_samples):pyplot.subplot(2, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(dataB[i].astype('uint8'))
pyplot.show()

3. 环境准备

我用的是tensorflow 2.1 与tensorflow_addons 0.9

pip install tensorflow-gpu==2.1.0
pip install tensorflow_addons==0.9.1

4. 怎么创建CycleGAN转换马到斑马

整个架构由四个模型组成,两个判别器模型和两个生成器模型

判别器是执行图像分类的深层卷积神经网络。它以源图像作为输入,并预测目标图像是真实图像还是伪图像的可能性。使用两种判别器模型,一种用于Domain A(马),一种用于Domain B(斑马)。

判别器设计基于模型的有效接收场,该有效接收场定义了模型的一个输出与输入图像中像素数之间的关系。这被称为PatchGAN模型,并经过精心设计,以使模型的每个输出预测都映射到输入图像的70×70正方形或小块。这种方法的好处是可以将相同的模型应用于不同大小的输入图像,例如大于或小于256×256像素。

模型的输出取决于输入图像的大小,但可以是一个值或值的平方激活图。每个值是输入图像中的色块是真实的可能性的概率。如果需要,可以将这些值取平均值以给出总体可能性或分类分数。

在模型中使用了卷积批处理范式LeakyReLU层的模式,这是深度卷积判别器模型所共有的。与其他模型不同,CycleGAN判别器使用InstanceNormalization而不是BatchNormalization。这是一种非常简单的归一化类型,涉及标准化(例如缩放到标准高斯)每个输出要素图上的值,而不是跨批处理中的要素。

4.1. 定义判别器

import tensorflow as tffrom tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTransposefrom tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropoutfrom matplotlib import pyplot
from tensorflow.keras.layers import LeakyReLU
import tensorflow_addons as tfa
import numpy as np
from random import randomdef define_discriminator(image_shape):# weight initializationinit = RandomNormal(stddev=0.02)# source image inputin_image = Input(shape=image_shape)# C64d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)d = LeakyReLU(alpha=0.2)(d)# C128d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = tfa.layers.InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# C256d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = tfa.layers.InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# C512d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = tfa.layers.InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# second last output layerd = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)d = tfa.layers.InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# patch outputpatch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)# define modelmodel = Model(in_image, patch_out)# compile modelmodel.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])return modelif __name__ == '__main__':module = define_discriminator((256,256,3))print(module.summary())
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 256, 256, 3)]     0
_________________________________________________________________
conv2d (Conv2D)              (None, 128, 128, 64)      3136
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 128, 128, 64)      0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 64, 64, 128)       131200
_________________________________________________________________
instance_normalization (Inst (None, 64, 64, 128)       256
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64, 64, 128)       0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 256)       524544
_________________________________________________________________
instance_normalization_1 (In (None, 32, 32, 256)       512
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 32, 32, 256)       0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 512)       2097664
_________________________________________________________________
instance_normalization_2 (In (None, 16, 16, 512)       1024
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 16, 16, 512)       0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 16, 512)       4194816
_________________________________________________________________
instance_normalization_3 (In (None, 16, 16, 512)       1024
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 16, 16, 512)       0
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 16, 16, 1)         8193
=================================================================
Total params: 6,962,369
Trainable params: 6,962,369
Non-trainable params: 0

4.2. 定义生成器

生成器模型比判别器模型更复杂。

生成器是encoder-decoder模型架构。该模型获取源图像(例如,马的照片)并生成目标图像(例如,斑马的照片)。它首先通过对输入图像进行下采样或编码到bottleneck层,然后使用多个ResNet层来解释编码,然后通过一系列对输出图像进行上采样或解码以达到输出大小的层来完成此操作图片。

首先,我们需要一个函数来定义ResNet块。这些是由两个3×3 CNN层组成的块,其中,该块的输入在通道方向上串联到该块的输出。

这是在resnet_block()函数中实现的,该函数创建两个具有3×3过滤器和1×1跨度的Convolution-InstanceNorm块,并且在第二个块之后没有ReLU激活,与build_conv_block()函数中的正式Torch实现匹配。为简单起见,使用相同的填充代替了纸张中建议的反射填充。

# generator a resnet block
def resnet_block(n_filters, input_layer):# weight initializationinit = RandomNormal(stddev=0.02)# first layer convolutional layerg = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# second convolutional layerg = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)# concatenate merge channel-wise with input layerg = Concatenate()([g, input_layer])return g

接下来,我们可以定义一个函数,该函数将为256×256输入图像创建9分辨率的块版本。 通过将image_shape设置为(128x128x3),将n_resnet函数参数设置为6,可以轻松将其更改为6分辨率块版本。

重要的是,该模型输出形状为输入的像素值,并且像素值在GAN生成器模型典型的[-1,1]范围内。

# define the standalone generator model
def define_generator(image_shape, n_resnet=9):# weight initializationinit = RandomNormal(stddev=0.02)# image inputin_image = Input(shape=image_shape)# c7s1-64g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# d128g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# d256g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# R256for _ in range(n_resnet):g = resnet_block(256, g)# u128g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# u64g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# c7s1-3g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)out_image = Activation('tanh')(g)# define modelmodel = Model(in_image, out_image)return model
if __name__ == '__main__':module = define_generator((256,256,3))print(module.summary())
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 64) 9472        input_1[0][0]
__________________________________________________________________________________________________
instance_normalization (Instanc (None, 256, 256, 64) 128         conv2d[0][0]
__________________________________________________________________________________________________
activation (Activation)         (None, 256, 256, 64) 0           instance_normalization[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 128 73856       activation[0][0]
__________________________________________________________________________________________________
instance_normalization_1 (Insta (None, 128, 128, 128 256         conv2d_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 128, 128, 128 0           instance_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 64, 64, 256)  295168      activation_1[0][0]
__________________________________________________________________________________________________
instance_normalization_2 (Insta (None, 64, 64, 256)  512         conv2d_2[0][0]
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 64, 64, 256)  0           instance_normalization_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 64, 64, 256)  590080      activation_2[0][0]
__________________________________________________________________________________________________
instance_normalization_3 (Insta (None, 64, 64, 256)  512         conv2d_3[0][0]
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 64, 64, 256)  0           instance_normalization_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 64, 64, 256)  590080      activation_3[0][0]
__________________________________________________________________________________________________
instance_normalization_4 (Insta (None, 64, 64, 256)  512         conv2d_4[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 64, 64, 512)  0           instance_normalization_4[0][0]   activation_2[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 256)  1179904     concatenate[0][0]
__________________________________________________________________________________________________
instance_normalization_5 (Insta (None, 64, 64, 256)  512         conv2d_5[0][0]
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 64, 64, 256)  0           instance_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 256)  590080      activation_4[0][0]
__________________________________________________________________________________________________
instance_normalization_6 (Insta (None, 64, 64, 256)  512         conv2d_6[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 64, 64, 768)  0           instance_normalization_6[0][0]   concatenate[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 64, 64, 256)  1769728     concatenate_1[0][0]
__________________________________________________________________________________________________
instance_normalization_7 (Insta (None, 64, 64, 256)  512         conv2d_7[0][0]
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 64, 64, 256)  0           instance_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 64, 64, 256)  590080      activation_5[0][0]
__________________________________________________________________________________________________
instance_normalization_8 (Insta (None, 64, 64, 256)  512         conv2d_8[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 64, 64, 1024) 0           instance_normalization_8[0][0]   concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 64, 64, 256)  2359552     concatenate_2[0][0]
__________________________________________________________________________________________________
instance_normalization_9 (Insta (None, 64, 64, 256)  512         conv2d_9[0][0]
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 64, 64, 256)  0           instance_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 64, 64, 256)  590080      activation_6[0][0]
__________________________________________________________________________________________________
instance_normalization_10 (Inst (None, 64, 64, 256)  512         conv2d_10[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 1280) 0           instance_normalization_10[0][0]  concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 64, 64, 256)  2949376     concatenate_3[0][0]
__________________________________________________________________________________________________
instance_normalization_11 (Inst (None, 64, 64, 256)  512         conv2d_11[0][0]
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 64, 64, 256)  0           instance_normalization_11[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 64, 64, 256)  590080      activation_7[0][0]
__________________________________________________________________________________________________
instance_normalization_12 (Inst (None, 64, 64, 256)  512         conv2d_12[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 64, 64, 1536) 0           instance_normalization_12[0][0]  concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 64, 64, 256)  3539200     concatenate_4[0][0]
__________________________________________________________________________________________________
instance_normalization_13 (Inst (None, 64, 64, 256)  512         conv2d_13[0][0]
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 64, 64, 256)  0           instance_normalization_13[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 256)  590080      activation_8[0][0]
__________________________________________________________________________________________________
instance_normalization_14 (Inst (None, 64, 64, 256)  512         conv2d_14[0][0]
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 64, 64, 1792) 0           instance_normalization_14[0][0]  concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 64, 64, 256)  4129024     concatenate_5[0][0]
__________________________________________________________________________________________________
instance_normalization_15 (Inst (None, 64, 64, 256)  512         conv2d_15[0][0]
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 64, 64, 256)  0           instance_normalization_15[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 64, 64, 256)  590080      activation_9[0][0]
__________________________________________________________________________________________________
instance_normalization_16 (Inst (None, 64, 64, 256)  512         conv2d_16[0][0]
__________________________________________________________________________________________________
concatenate_6 (Concatenate)     (None, 64, 64, 2048) 0           instance_normalization_16[0][0]  concatenate_5[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 64, 64, 256)  4718848     concatenate_6[0][0]
__________________________________________________________________________________________________
instance_normalization_17 (Inst (None, 64, 64, 256)  512         conv2d_17[0][0]
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 64, 64, 256)  0           instance_normalization_17[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 64, 64, 256)  590080      activation_10[0][0]
__________________________________________________________________________________________________
instance_normalization_18 (Inst (None, 64, 64, 256)  512         conv2d_18[0][0]
__________________________________________________________________________________________________
concatenate_7 (Concatenate)     (None, 64, 64, 2304) 0           instance_normalization_18[0][0]  concatenate_6[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 64, 64, 256)  5308672     concatenate_7[0][0]
__________________________________________________________________________________________________
instance_normalization_19 (Inst (None, 64, 64, 256)  512         conv2d_19[0][0]
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 64, 64, 256)  0           instance_normalization_19[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 64, 64, 256)  590080      activation_11[0][0]
__________________________________________________________________________________________________
instance_normalization_20 (Inst (None, 64, 64, 256)  512         conv2d_20[0][0]
__________________________________________________________________________________________________
concatenate_8 (Concatenate)     (None, 64, 64, 2560) 0           instance_normalization_20[0][0]  concatenate_7[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 128, 128, 128 2949248     concatenate_8[0][0]
__________________________________________________________________________________________________
instance_normalization_21 (Inst (None, 128, 128, 128 256         conv2d_transpose[0][0]
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 128, 128, 128 0           instance_normalization_21[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 256, 256, 64) 73792       activation_12[0][0]
__________________________________________________________________________________________________
instance_normalization_22 (Inst (None, 256, 256, 64) 128         conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 256, 256, 64) 0           instance_normalization_22[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 256, 256, 3)  9411        activation_13[0][0]
__________________________________________________________________________________________________
instance_normalization_23 (Inst (None, 256, 256, 3)  6           conv2d_21[0][0]
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 256, 256, 3)  0           instance_normalization_23[0][0]
==================================================================================================
Total params: 35,276,553
Trainable params: 35,276,553
Non-trainable params: 0
__________________________________________________________________________________________________

判别器模型直接在真实和生成的图像上训练,而生成器模型则没有。

取而代之的是,生成器模型通过其相关的判别器模型进行训练。具体来说,对它们进行更新以最小化判别器预测的生成图像标记为“真实”的损失,称为对抗损失。因此,鼓励他们生成更适合目标域的图像。

生成器模型还基于与其他生成器模型(称为循环损失(cycle loss))一起使用时在源图像再生方面的有效性而更新。最终,当从目标域提供一个称为身份丢失的示例时,生成器模型有望不经翻译就输出图像。

总而言之,每种生成器模型都是通过下面四个损耗函数的一起进行优化的:

  • Adversarial loss (L2 or mean squared error).
  • Identity loss (L1 or mean absolute error).
  • Forward cycle loss (L1 or mean absolute error).
  • Backward cycle loss (L1 or mean absolute error).

均方误差(mean-square error, MSE)
平均绝对误差(Mean Absolute Error, MAE)。

这可以通过定义用于训练每个生成器模型的复合模型来实现,尽管该模型需要负责与相关的判别器模型和其他生成器模型共享权重,但是该复合模型仅负责更新该生成器模型的权重

这在下面的define_composite_model()函数中实现,该函数采用已定义的生成器模型(g_model_1)以及已生成的生成器模型输出(d_model)和其他生成器模型(g_model_2)的已定义判别器模型。其他模型的权重被标记为不可训练,因为我们只对更新第一个生成器模型感兴趣,即此复合模型的重点。

判别器连接到生成器的输出,以便将生成的图像分类为真实图像或伪图像。组合模型的第二个输入定义为来自目标域(而不是源域)的图像,生成器应在不进行身份映射转换的情况下输出该图像。接下来,正向循环损耗包括将发生器的输出连接到另一个发生器,这将重建源图像。最后,后向循环损耗涉及来自目标域的用于身份映射的图像,该图像也通过另一个生成器,该生成器的输出连接到我们的主生成器作为输入,并从目标域输出该图像的重建版本。

总而言之,一个复合模型具有两个输入,分别用于来自Domain-A和Domain-B的真实照片,以及四个输出,用于判别器输出,身份生成图像,正向循环生成图像和反向循环生成图像。

4.3. 定义复合模型

对于复合模型,仅更新第一模型或主模型的权重,这是通过所有损失函数的加权总和来完成的。如本文所述,对循环损失(cycle loss)的权重(10倍)比对抗损失(adversarial loss)更大,并且始终使用身份损失(identity loss),权重为循环丢失的一半(5倍),与官方实现源代码相匹配。

# define a composite model for updating generators by adversarial and cycle loss
def define_composite_model(g_model_1, d_model, g_model_2, image_shape):# ensure the model we're updating is trainableg_model_1.trainable = True# mark discriminator as not trainabled_model.trainable = False# mark other generator model as not trainableg_model_2.trainable = False# discriminator elementinput_gen = Input(shape=image_shape)gen1_out = g_model_1(input_gen)output_d = d_model(gen1_out)# identity elementinput_id = Input(shape=image_shape)output_id = g_model_1(input_id)# forward cycleoutput_f = g_model_2(gen1_out)# backward cyclegen2_out = g_model_2(input_id)output_b = g_model_1(gen2_out)# define model graphmodel = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])# define optimization algorithm configurationopt = Adam(lr=0.0002, beta_1=0.5)# compile model with weighting of least squares loss and L1 lossmodel.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)return model

我们需要为每个生成器模型创建一个复合模型,例如 生成器A(B->A)用于斑马到马的转换,生成器B(A->B)用于马到斑马的转换。

损失函数的权重比例如下
Adversarial loss:Identity loss:Forward cycle loss Backward cycle loss = 1:5:10:10

跨两个域的所有这些前进和后退变得令人困惑。 以下是每个复合模型的所有输入和输出的完整列表。 同一性和循环损失计算为每个翻译序列在输入图像和输出图像之间的L1距离。 对抗损失计算为模型输出与目标值(真实值1.0和假值0.0)之间的L2距离。

1. Generator-A Composite Model (B ->A or Zebra to Horse)

下面是模型的输入 转换, 与输出:

  • Adversarial Loss: Domain-B -> Generator-A -> Domain-A -> Discriminator-A -> [real/fake]
  • Identity Loss: Domain-A -> Generator-A -> Domain-A
  • Forward Cycle Loss: Domain-B -> Generator-A -> Domain-A -> Generator-B -> Domain-B
  • Backward Cycle Loss: Domain-A -> Generator-B -> Domain-B ->Generator-A -> Domain-A

输入与输出如下:

Inputs: Domain-B, Domain-A
Outputs: Real, Domain-A, Domain-B, Domain-A

2. Generator-B Composite Model (A -> B or Horse -> Zebra)

下面是模型的输入 转换, 与输出:

  • Adversarial Loss: Domain-A -> Generator-B -> Domain-B -> Discriminator-B -> [real/fake]
  • Identity Loss: Domain-B -> Generator-B -> Domain-B
  • Forward Cycle Loss: Domain-A -> Generator-B -> Domain-B -> Generator-A -> Domain-A
  • Backward Cycle Loss: Domain-B -> Generator-A -> Domain-A -> Generator-B -> Domain-B

输入与输出如下:
Inputs: Domain-A, Domain-B
Outputs: Real, Domain-B, Domain-A, Domain-B

定义CycleGAN的模型是难的一部分。下面就是标准的GAN的训练。

接下来,我们可以以压缩的NumPy数组格式加载配对的图像数据集。 这将返回两个NumPy数组的列表:第一个用于源图像,第二个用于对应的目标图像。

4.4 加载真实图片以及生成假的图片

load_real_samples方法是加载真实图片。
generate_real_samples 方法是生成真实图片。每个数组标签都是1, shape是(16,16,1)
generate_fake_samples方法是利用生成器生成图片。每个数组标签都是0,shape是(16,16,1)
标签这里不一样,一般是数字,但是这里是shape为(16,16,1)三维数组。

# load and prepare training images
def load_real_samples(filename):# load the datasetdata = np.load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]
# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):# choose random instancesix = np.random.randint(0, dataset.shape[0], n_samples)# retrieve selected imagesX = dataset[ix]# generate 'real' class labels (1)y = np.ones((n_samples, patch_shape, patch_shape, 1))return X, y# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, dataset, patch_shape):# generate fake instanceX = g_model.predict(dataset)# create 'fake' class labels (0)y = np.zeros((len(X), patch_shape, patch_shape, 1))return X, y

4.4 保存模型

通常,GAN模型不会收敛。 相反,在生成器模型和判别器模型之间找到了平衡。 因此,我们不能轻易判断培训是否应该停止。 因此,我们可以保存模型并在训练期间(例如每一个或五个训练时期)使用它定期生成示例图像到图像的转换。

然后,我们可以在训练结束时查看生成的图像,并使用图像质量选择最终模型。

下面的save_models()函数会将每个生成器模型以H5格式保存到当前目录,包括文件名中的训练迭代编号。 这将需要安装h5py库。

# save the generator models to file
def save_models(step, g_model_AtoB, g_model_BtoA):# save the first generator modelfilename1 = 'g_model_AtoB_%06d.h5' % (step+1)g_model_AtoB.save(filename1)# save the second generator modelfilename2 = 'g_model_BtoA_%06d.h5' % (step+1)g_model_BtoA.save(filename2)print('>Saved: %s and %s' % (filename1, filename2))

4.5 利用生成器生成图片

下面的summary_performance()函数使用给定的生成器模型生成一些随机选择的源照片的翻译版本,并将图保存到文件中。

源图像绘制在第一行上,生成的图像绘制在第二行上。 同样,图文件名包含训练迭代编号。

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, trainX, name, n_samples=5):# select a sample of input imagesX_in, _ = generate_real_samples(trainX, n_samples, 0)# generate translated imagesX_out, _ = generate_fake_samples(g_model, X_in, 0)# scale all pixels from [-1,1] to [0,1]X_in = (X_in + 1) / 2.0X_out = (X_out + 1) / 2.0# plot real imagesfor i in range(n_samples):pyplot.subplot(2, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(X_in[i])# plot translated imagefor i in range(n_samples):pyplot.subplot(2, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(X_out[i])# save plot to filefilename1 = '%s_generated_plot_%06d.png' % (name, (step+1))pyplot.savefig(filename1)pyplot.close()

4.6 更新图像池

判别器模型直接在真实和生成的图像上更新,尽管为了进一步管理判别器模型学习的速度,维护了一组虚假图像。

本文为每个判别器模型定义了一个由50个生成的图像组成的图像池,该模型首先被填充,并有可能通过替换现有图像将新图像添加到池中,或者直接使用生成的图像。 我们可以将其实现为每个判别符的Python图像列表,并使用下面的update_image_pool()函数维护每个池列表。

# update image pool for fake images
def update_image_pool(pool, images, max_size=50):selected = list()for image in images:if len(pool) < max_size:# stock the poolpool.append(image)selected.append(image)elif random() < 0.5:# use image, but don't add it to the poolselected.append(image)else:# replace an existing image and use replaced imageix = np.random.randint(0, len(pool))selected.append(pool[ix])pool[ix] = imagereturn np.asarray(selected)

4.7 训练模型

下面的train()函数将所有六个模型(两个判别器,两个生成器和两个复合模型)与数据集一起作为参数进行训练。

批次大小固定为一幅图像,以匹配论文中的描述,并且模型适合100个纪元。假设马数据集具有1,187张图像,则一个时期定义为1,187批,并且训练迭代次数相同。使用两个生成器在每个时期生成图像,并且每五个时期或(1187 * 5)5,935个训练迭代会保存模型。

模型更新的顺序与官方的Torch实施相匹配。首先,从每个域中选择一批真实图像,然后为每个域生成一批伪图像。然后使用伪造的图像来更新每个鉴别者的伪造的图像池。

接下来,通过组合模型更新Generator-A模型(斑马到马),然后通过Discriminator-A模型(马)更新。然后更新Generator-B(马到斑马)复合模型和Discriminator-B(斑马)模型。

然后,在训练迭代结束时报告每个更新模型的损失。重要的是,仅报告用于更新每个发电机的加权平均损失。

# train cyclegan models
def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset):# define properties of the training runn_epochs, n_batch, = 100, 1# determine the output square shape of the discriminatorn_patch = d_model_A.output_shape[1] # 16# unpack datasettrainA, trainB = dataset# prepare image pool for fakespoolA, poolB = list(), list()# calculate the number of batches per training epochbat_per_epo = int(len(trainA) / n_batch) # 1187# calculate the number of training iterationsn_steps = bat_per_epo * n_epochs # 1187 * 100# manually enumerate epochsfor i in range(n_steps):# select a batch of real samplesX_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch) # 1187, 1, 16X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)# generate a batch of fake samplesX_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch) # B>AX_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch) # A>B# update fakes from poolX_fakeA = update_image_pool(poolA, X_fakeA)X_fakeB = update_image_pool(poolB, X_fakeB)# update generator B->A via adversarial and cycle lossg_loss2, _, _, _, _ = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])# update discriminator for A -> [real/fake]dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)# update generator A->B via adversarial and cycle lossg_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])# update discriminator for B -> [real/fake]dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)# summarize performanceprint('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))# evaluate the model performance every so oftenif (i+1) % (bat_per_epo * 1) == 0:# plot A->B translationsummarize_performance(i, g_model_AtoB, trainA, 'AtoB')# plot B->A translationsummarize_performance(i, g_model_BtoA, trainB, 'BtoA')if (i+1) % (bat_per_epo * 5) == 0:# save the modelssave_models(i, g_model_AtoB, g_model_BtoA)

5. 训练结果

每次训练迭代都会打印该损失,包括真实示例和假示例的Discriminator-A损失(dA),真实示例和假示例的Discriminator-B损失(dB)以及Generator-AtoB和Generator-BtoA损失,每个都是 对抗性,同一性,前进和后退周期损失的加权平均值(g)。

如果判别器的损失变为零并停留了很长时间,请考虑重新开始训练,因为这是训练失败的一个例子。

>1, dA[2.284,0.678] dB[1.422,0.918] g[18.747,18.452]
>2, dA[2.129,1.226] dB[1.039,1.331] g[19.469,22.831]
>3, dA[1.644,3.909] dB[1.097,1.680] g[19.192,23.757]
>4, dA[1.427,1.757] dB[1.236,3.493] g[20.240,18.390]
>5, dA[1.737,0.808] dB[1.662,2.312] g[16.941,14.915]
...
>118696, dA[0.004,0.016] dB[0.001,0.001] g[2.623,2.359]
>118697, dA[0.001,0.028] dB[0.003,0.002] g[3.045,3.194]
>118698, dA[0.002,0.008] dB[0.001,0.002] g[2.685,2.071]
>118699, dA[0.010,0.010] dB[0.001,0.001] g[2.430,2.345]
>118700, dA[0.002,0.008] dB[0.000,0.004] g[2.487,2.169]
>Saved: g_model_AtoB_118700.h5 and g_model_BtoA_118700.h5

大概9个Epoch后的结果如下:
可以看出稍微有点变化,变化不是很明显。

大概50个Epoch后结果如下:

从斑马到马的翻译对于该模型的学习似乎更具挑战性,尽管在50至60个纪元后也开始产生一些可能的翻译。

如本文中所使用的,另外100个带有权重衰减的训练时期可以实现更好的质量结果,也许还有一个数据生成器可以系统地处理每个数据集而不是随机采样。

6. 完整代码

import tensorflow as tffrom tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTransposefrom tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropoutfrom matplotlib import pyplot
from tensorflow.keras.layers import LeakyReLU
import tensorflow_addons as tfa
import numpy as np
from random import randomdef define_discriminator(image_shape):# weight initializationinit = RandomNormal(stddev=0.02)# source image inputin_image = Input(shape=image_shape)# C64d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)d = LeakyReLU(alpha=0.2)(d)# C128d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = tfa.layers.InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# C256d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = tfa.layers.InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# C512d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = tfa.layers.InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# second last output layerd = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)d = tfa.layers.InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# patch outputpatch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)# define modelmodel = Model(in_image, patch_out)# compile modelmodel.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])return model# generator a resnet block
def resnet_block(n_filters, input_layer):# weight initializationinit = RandomNormal(stddev=0.02)# first layer convolutional layerg = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# second convolutional layerg = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)# concatenate merge channel-wise with input layerg = Concatenate()([g, input_layer])return g# define the standalone generator model
def define_generator(image_shape, n_resnet=9):# weight initializationinit = RandomNormal(stddev=0.02)# image inputin_image = Input(shape=image_shape)# c7s1-64g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# d128g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# d256g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# R256for _ in range(n_resnet):g = resnet_block(256, g)# u128g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# u64g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# c7s1-3g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)g = tfa.layers.InstanceNormalization(axis=-1)(g)out_image = Activation('tanh')(g)# define modelmodel = Model(in_image, out_image)return model# define a composite model for updating generators by adversarial and cycle loss
def define_composite_model(g_model_1, d_model, g_model_2, image_shape):# ensure the model we're updating is trainableg_model_1.trainable = True# mark discriminator as not trainabled_model.trainable = False# mark other generator model as not trainableg_model_2.trainable = False# discriminator elementinput_gen = Input(shape=image_shape)gen1_out = g_model_1(input_gen) # A >Boutput_d = d_model(gen1_out)  # 识别B# identity elementinput_id = Input(shape=image_shape)output_id = g_model_1(input_id) # B >A# forward cycleoutput_f = g_model_2(gen1_out) # A>B >A# backward cyclegen2_out = g_model_2(input_id) # B >Aoutput_b = g_model_1(gen2_out) # B> A >B# define model graphmodel = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])# define optimization algorithm configurationopt = Adam(lr=0.0002, beta_1=0.5)# compile model with weighting of least squares loss and L1 lossmodel.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)return model# load and prepare training images
def load_real_samples(filename):# load the datasetdata = np.load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):# choose random instancesix = np.random.randint(0, dataset.shape[0], n_samples)# retrieve selected imagesX = dataset[ix]# generate 'real' class labels (1)y = np.ones((n_samples, patch_shape, patch_shape, 1)) #(1,16,16,1)return X, y# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, dataset, patch_shape):# generate fake instanceX = g_model.predict(dataset)# create 'fake' class labels (0)y = np.zeros((len(X), patch_shape, patch_shape, 1)) # (1, 16,16,1)return X, y# save the generator models to file
def save_models(step, g_model_AtoB, g_model_BtoA):# save the first generator modelfilename1 = 'g_model_AtoB_%06d.h5' % (step+1)g_model_AtoB.save(filename1)# save the second generator modelfilename2 = 'g_model_BtoA_%06d.h5' % (step+1)g_model_BtoA.save(filename2)print('>Saved: %s and %s' % (filename1, filename2))# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, trainX, name, n_samples=5):# select a sample of input imagesX_in, _ = generate_real_samples(trainX, n_samples, 0)# generate translated imagesX_out, _ = generate_fake_samples(g_model, X_in, 0)# scale all pixels from [-1,1] to [0,1]X_in = (X_in + 1) / 2.0X_out = (X_out + 1) / 2.0# plot real imagesfor i in range(n_samples):pyplot.subplot(2, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(X_in[i])# plot translated imagefor i in range(n_samples):pyplot.subplot(2, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(X_out[i])# save plot to filefilename1 = '%s_generated_plot_%06d.png' % (name, (step+1))pyplot.savefig(filename1)pyplot.close()# update image pool for fake images
def update_image_pool(pool, images, max_size=50):selected = list()for image in images:if len(pool) < max_size:# stock the poolpool.append(image)selected.append(image)elif random() < 0.5:# use image, but don't add it to the poolselected.append(image)else:# replace an existing image and use replaced imageix = np.random.randint(0, len(pool))selected.append(pool[ix])pool[ix] = imagereturn np.asarray(selected)# train cyclegan models
def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset):# define properties of the training runn_epochs, n_batch, = 100, 1# determine the output square shape of the discriminatorn_patch = d_model_A.output_shape[1] # 16# unpack datasettrainA, trainB = dataset# prepare image pool for fakespoolA, poolB = list(), list()# calculate the number of batches per training epochbat_per_epo = int(len(trainA) / n_batch) # 1187# calculate the number of training iterationsn_steps = bat_per_epo * n_epochs # 1187 * 100# manually enumerate epochsfor i in range(n_steps):# select a batch of real samplesX_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch) # 1187, 1, 16X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)# generate a batch of fake samplesX_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch) # B>AX_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch) # A>B# update fakes from poolX_fakeA = update_image_pool(poolA, X_fakeA)X_fakeB = update_image_pool(poolB, X_fakeB)# update generator B->A via adversarial and cycle lossg_loss2, _, _, _, _ = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])# update discriminator for A -> [real/fake]dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)# update generator A->B via adversarial and cycle lossg_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])# update discriminator for B -> [real/fake]dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)# summarize performanceprint('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))# evaluate the model performance every so oftenif (i+1) % (bat_per_epo * 1) == 0:# plot A->B translationsummarize_performance(i, g_model_AtoB, trainA, 'AtoB')# plot B->A translationsummarize_performance(i, g_model_BtoA, trainB, 'BtoA')if (i+1) % (bat_per_epo * 5) == 0:# save the modelssave_models(i, g_model_AtoB, g_model_BtoA)if __name__ == '__main__':# load image datadataset = load_real_samples('horse2zebra_256.npz')print('Loaded', dataset[0].shape, dataset[1].shape)# define input shape based on the loaded datasetimage_shape = dataset[0].shape[1:]# generator: A -> Bg_model_AtoB = define_generator(image_shape)# generator: B -> Ag_model_BtoA = define_generator(image_shape)# discriminator: A -> [real/fake]d_model_A = define_discriminator(image_shape)# discriminator: B -> [real/fake]d_model_B = define_discriminator(image_shape)# composite: A -> B -> [real/fake, A]c_model_AtoB = define_composite_model(g_model_AtoB, d_model_B, g_model_BtoA, image_shape)# composite: B -> A -> [real/fake, B]c_model_BtoA = define_composite_model(g_model_BtoA, d_model_A, g_model_AtoB, image_shape)# train modelstrain(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset)

7. 如何使用CycleGAN生成器执行图像转换

# example of using saved cyclegan models for image translation
from keras.models import load_model
from numpy import load
from numpy import vstack
from matplotlib import pyplot
from numpy.random import randint
#import tensorflow_addons as tfa
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
# load and prepare training images
def load_real_samples(filename):# load the datasetdata = load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]# select a random sample of images from the dataset
def select_sample(dataset, n_samples):# choose random instancesix = randint(0, dataset.shape[0], n_samples)# retrieve selected imagesX = dataset[ix]return X# plot the image, the translation, and the reconstruction
def show_plot(imagesX, imagesY1, imagesY2):images = vstack((imagesX, imagesY1, imagesY2))titles = ['Real', 'Generated', 'Reconstructed']# scale from [-1,1] to [0,1]images = (images + 1) / 2.0# plot images row by rowfor i in range(len(images)):# define subplotpyplot.subplot(1, len(images), 1 + i)# turn off axispyplot.axis('off')# plot raw pixel datapyplot.imshow(images[i])# titlepyplot.title(titles[i])pyplot.show()# load dataset
A_data, B_data = load_real_samples('horse2zebra_256.npz')
print('Loaded', A_data.shape, B_data.shape)
# load the models
cust = {'InstanceNormalization': InstanceNormalization}
model_AtoB = load_model('g_model_AtoB_089025.h5', cust)
model_BtoA = load_model('g_model_BtoA_089025.h5', cust)
# plot A->B->A
A_real = select_sample(A_data, 1)
B_generated = model_AtoB.predict(A_real)
A_reconstructed = model_BtoA.predict(B_generated)
show_plot(A_real, B_generated, A_reconstructed)
# plot B->A->B
B_real = select_sample(B_data, 1)
A_generated = model_BtoA.predict(B_real)
B_reconstructed = model_AtoB.predict(A_generated)
show_plot(B_real, A_generated, B_reconstructed)

6. 个人总结

两个生成器保存下来的参数大小大概都是137M左右,这个权重其实很大的,所以这也说明了为什么训练的很慢的原因,我在GPU上跑,大概花1个半小时跑9个Epoch.

[深度学习-实践]CycleGAN的入门例子-Tensorflow2.1-keras相关推荐

  1. [深度学习-实践]人脸识别的例子-Tensorflow2.x Keras

    系列文章目录 人脸检测的例子-Tensorflow2.x keras 人脸识别的例子-Tensorflow2.x Keras 实时人脸识别例子-Tensorflow2.x Keras 人脸识别的例子- ...

  2. 好书分享——《深度学习框架PyTorch:入门与实践》

    内容简介 : <深度学习框架PyTorch:入门与实践>从多维数组Tensor开始,循序渐进地带领读者了解PyTorch各方面的基础知识.结合基础知识和前沿研究,带领读者从零开始完成几个经 ...

  3. CVer入门必备:计算机视觉的深度学习实践

    <计算机视觉的深度学习实践>  原价     ¥ 899.00 现超 300人报名 已至底价     ¥ 399.00  >>  点击文末阅读原文参团  << 参团 ...

  4. 五本必读的深度学习圣经书籍,入门 AI 从 深度学习 开始

    原标题:`五本必读的深度学习圣经书籍,入门 AI 从「深度学习」开始` (以下以 Daniel Jeffries 第一人称撰写) 多年来,由于实验室研究和现实应用效果之间的鸿沟,少有人持续研究人工智能 ...

  5. 适合大一大二学生的深度学习实践项目汇总:涵盖图像处理、语音识别、自然语言处理等领域

    摘要:深度学习已经成为人工智能领域的核心技术,无论是图像处理.语音识别还是自然语言处理等多个领域,都有深度学习技术的身影.本文汇总了涉及多个领域的深度学习实践项目,希望能为您提供一个全面的参考,让您在 ...

  6. PyTorch深度学习实践

    根据学习情况随时更新. 2020.08.14更新完成. 参考课程-刘二大人<PyTorch深度学习实践> 文章目录 (一)课程概述 (二)线性模型 (三)梯度下降算法 (四)反向传播 (五 ...

  7. 实用卷积神经网络 运用python pdf_解析卷积神经网络—深度学习实践手册 中文pdf高清版...

    解析卷积神经网络-深度学习实践手册从实用角度着重解析了深度学习中的一类神经网络模型--卷积神经网络,向读者剖析了卷积神经网络的基本部件与工作机理,更重要的是系统性的介绍了深度卷积神经网络在实践应用方面 ...

  8. 7月7日云栖精选夜读:专访 | 杨强教授谈CCAI、深度学习泡沫与人工智能入门

    摘要: 中国人工智能学会.阿里巴巴集团 & 蚂蚁金服主办,CSDN.中国科学院自动化研究所承办,云栖社区作为独家直播合作伙伴的第三届中国人工智能大会(CCAI 2017)将在杭州国际会议中心盛 ...

  9. 视频教程-深度学习与TensorFlow 2入门实战-深度学习

    深度学习与TensorFlow 2入门实战 新加坡国立大学研究员 龙良曲 ¥399.00 立即订阅 扫码下载「CSDN程序员学院APP」,1000+技术好课免费看 APP订阅课程,领取优惠,最少立减5 ...

最新文章

  1. java语言编写简易表达式_将简单的表达语言放入Java
  2. linux nohup 英文全称 no hang up(
  3. mac安装git客户端
  4. Euro Truck Simulator 2欧洲卡车模拟2用VR玩
  5. 2019ACM浪潮杯山东省赛参赛总结
  6. LPC2000 UART串口使用心得
  7. 遇冷的斗鱼直播,还“斗”得动吗?| 畅言
  8. DCGAN-深度卷积生成对抗网络-转置卷积
  9. Hadoop原理——HDFS原理
  10. 得力人脸识别考勤机密码设置_得力人脸识别考勤机使用与上传数据方法哪位清楚?...
  11. vc++6.0 下实现的 立体四子棋 程序 (原型来源于北京科技馆)
  12. 弦截法及Python实现
  13. Hive 2.3.4 Name node is in safe mode. The reported blocks xxx has reached the threshold 0.9990 of to
  14. U8glib学习使用(详细版)
  15. 针对传感网的数据管理系统结构有_2010年自考管理信息系统模拟试题及答案(三)...
  16. 中国零售科技创新企业榜TOP50:有TalkingData也有拼多多...
  17. 盘点2020年北京市小升初考试关于信息学竞赛的那些事儿!
  18. SIGGRAPH 2018 见闻录
  19. Rimworld Mod教程 第一章:这可能是你见过的最细的MOD教程了
  20. java抢微博福卡,敬业福终极攻略:一天拿20多张福卡,轻松集齐支付宝五福

热门文章

  1. sequelize常见操作使用方法
  2. web.py开发web 第一章 Hello World
  3. 一个c++程序员一年前的生活笔记
  4. 远离ARP*** ARP防火墙新版发布
  5. How to Install Snapd and Snap applications on CentOS 7
  6. GitHub+Vue自动化构建部署
  7. Spring 注解教程
  8. springboot logback 日志配置
  9. maccmsv10 苹果cms10 站群扩展 自用版
  10. 【网络编程】中文字符、时间等编码转换