目录

介绍

从零开始的CycleGAN

加载数据集

构建鉴别器

构建残差块

构建生成器

构建CycleGAN

训练CycleGAN

评估绩效

结论


  • 下载项目文件 - 7.2 MB

介绍

在本系列文章中,我们将展示一个基于循环一致对抗网络 (CycleGAN)的移动图像到图像转换系统。我们将构建一个CycleGAN,它可以执行不成对的图像到图像的转换,并向您展示一些有趣但具有学术深度的例子。我们还将讨论如何将这种使用TensorFlow和Keras构建的训练有素的网络转换为TensorFlow Lite并用作移动设备上的应用程序。

我们假设您熟悉深度学习的概念,以及Jupyter Notebooks和TensorFlow。欢迎您下载项目代码。

在本系列的前一篇文章中,我们训练和评估了一个使用基于U-Net的生成器的CycleGAN。在本文中,我们将使用基于残差的生成器实现CycleGAN。

从零开始的CycleGAN

最初的CycleGan最初是使用基于残差的生成器构建的。让我们从头开始实现这种类型的CycleGAN。我们将构建网络并训练它使用带有和不带有伪影的眼底数据集来减少眼底图像中的伪影。

网络将有伪影的眼底图像转换为没有伪影的眼底图像,反之亦然,如上所示。

CycleGAN 设计将包括以下步骤:

  • 构建鉴别器
  • 构建残差块
  • 构建生成器
  • 构建完整模型

在开始加载数据之前,让我们导入一些必要的库和包。

#the necessary importsfrom random import random
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from matplotlib import pyplot

加载数据集

与我们在上一篇文章中所做的相反,这次我们将使用本地机器(而不是Google Colab)来训练CycleGAN。因此,应首先下载和处理眼底数据集。我们将使用Jupyter Notebook和TensorFlow来构建和训练这个网络。

from os import listdir
from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed# load all images in a directory into memorydef load_images(path, size=(256,256)):data_list = list()# enumerate filenames in directory, assume all are imagesfor filename in listdir(path):# load and resize the imagepixels = load_img(path + filename, target_size=size)# convert to numpy arraypixels = img_to_array(pixels)# storedata_list.append(pixels)return asarray(data_list)# dataset path
path = r'C:/Users/abdul/Desktop/ContentLab/P3/Fundus/'
# load dataset A
dataA1 = load_images(path + 'trainA/')
dataAB = load_images(path + 'testA/')
dataA = vstack((dataA1, dataAB))
print('Loaded dataA: ', dataA.shape)
# load dataset B
dataB1 = load_images(path + 'trainB/')
dataB2 = load_images(path + 'testB/')
dataB = vstack((dataB1, dataB2))
print('Loaded dataB: ', dataB.shape)
# save as compressed numpy array
filename = 'Artifcats.npz'
savez_compressed(filename, dataA, dataB)
print('Saved dataset: ', filename)

加载数据后,就可以创建一个显示一些训练图像的函数了:

# load and plot the prepared datasetfrom numpy import load
from matplotlib import pyplot# load the datasetdata = load('Artifacts.npz')
dataA, dataB = data['arr_0'], data['arr_1']
print('Loaded: ', dataA.shape, dataB.shape)# plot source imagesn_samples = 3
for i in range(n_samples):pyplot.subplot(2, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(dataA[i].astype('uint8'))# plot target imagefor i in range(n_samples):pyplot.subplot(2, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(dataB[i].astype('uint8'))
pyplot.show()

构建鉴别器

正如我们之前讨论过的,鉴别器是一个由许多卷积层以及LeakReLU和实例归一化层组成的CNN 。

def define_discriminator(image_shape):# weight initializationinit = RandomNormal(stddev=0.02)# source image inputin_image = Input(shape=image_shape)# C64d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)d = LeakyReLU(alpha=0.2)(d)# C128d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# C256d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# C512d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)d = InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# second last output layerd = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)d = InstanceNormalization(axis=-1)(d)d = LeakyReLU(alpha=0.2)(d)# patch outputpatch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)# define modelmodel = Model(in_image, patch_out)# compile modelmodel.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])return model

一旦构建了鉴别器,我们就可以创建它的副本,以便我们有两个相同的鉴别器:DiscA和DiscB。

image_shape=(256,256,3)
DiscA=define_discriminator(image_shape)
DiscB=define_discriminator(image_shape)
DiscA.summary()

构建残差块

下一步是为我们的生成器创建残差块。该块是一组2D卷积层,其中每两层后跟一个实例归一化层。

# generator a resnet blockdef resnet_block(n_filters, input_layer):# weight initializationinit = RandomNormal(stddev=0.02)# first layer convolutional layerg = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)g = InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# second convolutional layerg = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)g = InstanceNormalization(axis=-1)(g)# concatenate merge channel-wise with input layerg = Concatenate()([g, input_layer])return g

构建生成器

残差块的输出将通过生成器的最后一部分(解码器),在那里图像将被上采样并调整到其原始大小。由于编码器尚未定义,我们将构建一个函数来定义解码器和编码器部分并将它们连接到残差块。

# define the generator modeldef define_generator(image_shape, n_resnet=9):# weight initializationinit = RandomNormal(stddev=0.02)# image inputin_image = Input(shape=image_shape)g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)g = InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# d128g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# d256g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# R256for _ in range(n_resnet):g = resnet_block(256, g)# u128g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)# u64g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)g = InstanceNormalization(axis=-1)(g)g = Activation('relu')(g)g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)g = InstanceNormalization(axis=-1)(g)out_image = Activation('tanh')(g)# define modelmodel = Model(in_image, out_image)return model

现在,我们定义生成器genA和genB。

genA=define_generator(image_shape, 9)
genB=define_generator(image_shape, 9)

构建CycleGAN

定义了生成器和鉴别器后,我们现在可以构建整个CycleGAN模型并设置其优化器和其他学习参数。

#define a composite modeldef define_composite_model(g_model_1, d_model, g_model_2, image_shape):# ensure the model we're updating is trainableg_model_1.trainable = True# mark discriminator as not trainabled_model.trainable = False# mark other generator model as not trainableg_model_2.trainable = False# discriminator elementinput_gen = Input(shape=image_shape)gen1_out = g_model_1(input_gen)output_d = d_model(gen1_out)# identity elementinput_id = Input(shape=image_shape)output_id = g_model_1(input_id)# forward cycleoutput_f = g_model_2(gen1_out)# backward cyclegen2_out = g_model_2(input_id)output_b = g_model_1(gen2_out)# define model graphmodel = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])# define optimization algorithm configurationopt = Adam(lr=0.0002, beta_1=0.5)# compile model with weighting of least squares loss and L1 lossmodel.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)return model

现在让我们定义两个模型(A和B),其中一个将眼底图像伪影转换为无伪影眼底(AtoB),另一个将无伪影转换为伪影眼底图像(BtoA)。

comb_modelA=define_composite_model(genA,DiscA,genB,image_shape)
comb_modelB=define_composite_model(genB,DiscB,genA,image_shape)

训练CycleGAN

现在我们的模型已经完成,我们将创建一个训练函数,该函数定义训练参数并计算生成器和鉴别器的损失,以及在训练期间更新权重。此功能将按如下方式操作:

  1. 将图像传递给生成器。
  2. 获取生成器生成的图像。
  3. 将生成的图像传回生成器以验证我们可以从生成的图像中预测原始图像。
  4. 使用生成器,执行真实图像的身份映射。
  5. 将步骤1中生成的图像传递给相应的鉴别器。
  6. 找到生成器的总损失(对抗性+循环+身份)。
  7. 找出鉴别器的损失。
  8. 更新生成器权重。
  9. 更新鉴别器权重。
  10. 在字典中返回损失。
# train the cycleGAN modeldef train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset):# define properties of the training runn_epochs, n_batch, = 30, 1# determine the output square shape of the discriminatorn_patch = d_model_A.output_shape[1]# unpack datasettrainA, trainB = dataset# prepare image pool for fakespoolA, poolB = list(), list()# calculate the number of batches per training epochbat_per_epo = int(len(trainA) / n_batch)# calculate the number of training iterationsn_steps = bat_per_epo * n_epochs# manually enumerate epochsfor i in range(n_steps):# select a batch of real samplesX_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch)X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)# generate a batch of fake samplesX_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch)X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch)# update fakes from poolX_fakeA = update_image_pool(poolA, X_fakeA)X_fakeB = update_image_pool(poolB, X_fakeB)# update generator B->A via adversarial and cycle lossg_loss2, _, _, _, _  = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])# update discriminator for A -> [real/fake]dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)# update generator A->B via adversarial and cycle lossg_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])# update discriminator for B -> [real/fake]dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)# summarize performanceprint('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))# evaluate the model performance every so oftenif (i+1) % (bat_per_epo * 1) == 0:# plot A->B translationsummarize_performance(i, g_model_AtoB, trainA, 'AtoB')# plot B->A translationsummarize_performance(i, g_model_BtoA, trainB, 'BtoA')if (i+1) % (bat_per_epo * 5) == 0:# save the modelssave_models(i, g_model_AtoB, g_model_BtoA)

下面是一些在训练过程中会用到的函数。

#load and prepare training imagesdef load_real_samples(filename):# load the datasetdata = load(filename)# unpack arraysX1, X2 = data['arr_0'], data['arr_1']# scale from [0,255] to [-1,1]X1 = (X1 - 127.5) / 127.5X2 = (X2 - 127.5) / 127.5return [X1, X2]# The generate_real_samples() function below implements this# select a batch of random samples, returns images and targetdef generate_real_samples(dataset, n_samples, patch_shape):# choose random instancesix = randint(0, dataset.shape[0], n_samples)# retrieve selected imagesX = dataset[ix]# generate 'real' class labels (1)y = ones((n_samples, patch_shape, patch_shape, 1))return X, y# generate a batch of images, returns images and targetsdef generate_fake_samples(g_model, dataset, patch_shape):# generate fake instanceX = g_model.predict(dataset)# create 'fake' class labels (0)y = zeros((len(X), patch_shape, patch_shape, 1))return X, y
# update image pool for fake imagesdef update_image_pool(pool, images, max_size=50):selected = list()for image in images:if len(pool) < max_size:# stock the poolpool.append(image)selected.append(image)elif random() < 0.5:# use image, but don't add it to the poolselected.append(image)else:# replace an existing image and use replaced imageix = randint(0, len(pool))selected.append(pool[ix])pool[ix] = imagereturn asarray(selected)

我们添加了更多功能来保存最佳模型并可视化眼底图像中伪影减少的性能。

def save_models(step, g_model_AtoB, g_model_BtoA):# save the first generator modelfilename1 = 'g_model_AtoB_%06d.h5' % (step+1)g_model_AtoB.save(filename1)# save the second generator modelfilename2 = 'g_model_BtoA_%06d.h5' % (step+1)g_model_BtoA.save(filename2)print('>Saved: %s and %s' % (filename1, filename2))# generate samples and save as a plot and save the modeldef summarize_performance(step, g_model, trainX, name, n_samples=5):# select a sample of input imagesX_in, _ = generate_real_samples(trainX, n_samples, 0)# generate translated imagesX_out, _ = generate_fake_samples(g_model, X_in, 0)# scale all pixels from [-1,1] to [0,1]X_in = (X_in + 1) / 2.0X_out = (X_out + 1) / 2.0# plot real imagesfor i in range(n_samples):pyplot.subplot(2, n_samples, 1 + i)pyplot.axis('off')pyplot.imshow(X_in[i])# plot translated imagefor i in range(n_samples):pyplot.subplot(2, n_samples, 1 + n_samples + i)pyplot.axis('off')pyplot.imshow(X_out[i])# save plot to filefilename1 = '%s_generated_plot_%06d.png' % (name, (step+1))pyplot.savefig(filename1)pyplot.close()train(DiscA, DiscB, genA, genB, comb_modelA, comb_modelB, dataset)

评估绩效

使用上述函数,我们对网络进行了30个epoch的训练。结果表明,我们的网络能够减少眼底图像中的伪影。

工件到无工件转换( AtoB )的结果如下所示:

还计算无伪影到伪影 (BtoA)眼底图像转换;这里有些例子。

结论

正如AI先驱Yann LeCun谈到GAN时所说,“(这是)过去10年深度学习中最有趣的想法”。我们希望,通过这个系列,我们已经帮助您理解了为什么GAN是一些非常有趣的想法。我们知道您可能会发现系列中提出的概念有点沉重和模棱两可,但这完全没问题。CycleGAN在一次阅读中非常难以掌握,在你理解之前可以多读几遍这个系列。

https://www.codeproject.com/Articles/5304928/Building-a-Style-Transfer-CycleGAN-from-Scratch

(五)从头开始构建风格迁移CycleGAN相关推荐

  1. (三)使用Keras构建移动风格迁移CycleGAN

    目录 介绍 处理数据集 构建生成器和鉴别器 下一步 下载项目代码 - 7.2 MB 介绍 在本系列文章中,我们将展示一个基于循环一致对抗网络 (CycleGAN)的移动图像到图像转换系统.我们将构建一 ...

  2. [风格迁移系列五: WaveCT-AIN] 医学图像的风格迁移和跨域自适应(泛化性)

    不同于自然图像的风格迁移,在临床应用上,医学图像更加注重图像生成的纹理细节,并且需要实时的推理速度.因此提出一个实时且高质量的风格迁移方法非常重要,这篇论文实现了这个方法: Remove Appear ...

  3. 一文教会你风格迁移CycleGAN从入门到高阶再到最终成功魔改(附成功魔改代码)

    专栏导读

  4. 深度学习框架PyTorch入门与实践:第八章 AI艺术家:神经网络风格迁移

    本章我们将介绍一个酷炫的深度学习应用--风格迁移(Style Transfer).近年来,由深度学习引领的人工智能技术浪潮越来越广泛地应用到社会各个领域.这其中,手机应用Prisma,尝试为用户的照片 ...

  5. 学会CycleGAN进行风格迁移,实现自定义滤镜

    学会CycleGAN进行风格迁移,实现自定义滤镜 前言 效果展示 数据集介绍与加载 CycleGAN模型 数据预处理 模型构建 损失函数 训练结果可视化函数 训练步骤 更多效果展示 风格迁移系列链接 ...

  6. GANs系列:用于图像风格迁移的CycleGAN网络原理解读

    CycleGAN论文:https://arxiv.org/pdf/1703.10593.pdf 一.前言 目前关于GAN应用,比较有意思的应用就是GAN用在图像风格迁移,图像降噪修复,图像超分辨率了, ...

  7. 深度学习入门(五十二)计算机视觉——风格迁移

    深度学习入门(五十二)计算机视觉--风格迁移 前言 计算机视觉--风格迁移 课件 样式迁移 易于CNN的样式迁移 教材 1 方法 2 阅读内容和风格图像 3 预处理和后处理 4 抽取图像特征 5 定义 ...

  8. (六)使用ResNet50迁移学习进行COVID-19诊断:从头开始构建深度学习网络

    目录 安装库并加载数据集 预处理数据 构建深度学习网络 训练网络 评估网络 下一步? 下载源 - 300.4 KB 在本系列文章中,我们将应用深度学习网络ResNet50来诊断胸部X射线图像中的Cov ...

  9. cycleGAN网络风格迁移,将黑夜转变成白天,低照度图像复原

    cycleGAN网络风格迁移,将黑夜转变成白天,低照度图像复原 链接: [link](用CycleGAN网络复原低照度图像,风格转换,将黑夜颠倒成白昼,图像质量更优 - 小拍的文章 - 知乎https ...

最新文章

  1. wdcp php5.3 pdo_mysql,WDCP常用组件(memcache、mysqli、PDO_MYSQL、mysql innodb、libmcrypt、php zip)的安装方法...
  2. Android内存分析和调优
  3. 基本概念,BGP协议的特征和消息类型,状态转换?
  4. RuoYi-Vue————权限管理
  5. 配置nginx到后端服务器负载均衡
  6. Retina时代的前端视觉优化
  7. 会议交流—PPT下载|DataFunSummit2022:知识图谱在线峰会PPT合集!
  8. 【转】 PDO使用归纳【PHP】
  9. matlab中for循环的步长
  10. android emoji unicode编码表,unicode编码
  11. 博士毕业要发多少篇文章? 72 所高校大比较,发文最多的是……
  12. 长见识了: 一篇文章带你看懂 硬盘数据恢复软件的原理
  13. 2016 年全国房价会呈什么趋势?
  14. Kmplayer音频设置
  15. 创建React + Ts项目
  16. 基于Java的Android区块链开发之生成助记词(位数可选)
  17. 【扒开】关于赢驴准心劫持浏览器首页的病毒类行径
  18. 阿里云国际站:实名认证上传材料填写样例(域名持有者为个人)
  19. LeetCode 51~55
  20. pop客户机程序流程图_labview问题集锦

热门文章

  1. 淄博神爱计算机官网,【最美教师】张萍:大爱无言 育人无声
  2. python读取xml文件内容显示不全_python读取xml文件时的问题
  3. java字符串连接效率_关于java:字符串连接中的“+”是否会影响效率?
  4. oracle+greatest+max,ORACLE 内置函数之 GREATEST 和 LEAST(求多列的最大值,最小值)
  5. js几个页面生成pdf 然后批量打印_太好用了!这款免费PDF工具能够满足你的各种需求...
  6. 插画素材 | 圣诞节设计离不了!
  7. 果汁飞溅海报还不会玩?先从临摹学习PSD分层模板开始
  8. lambda不是python的保留字_python-nonlocal关键字的使用,lambda表达式(学习到function到变...
  9. java中注释的嵌套,java – 使用mybatis注释获取嵌套对象
  10. HeadFirst设计模式之观察者模式学习