循环生成对抗网络CycleGan实现风格迁移

dataset

  • https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/vangogh2photo.zip

GitHub地址:https://github.com/yunlong-G/tensorflow_learn/blob/master/GAN/CycleGan.ipynb
environment

  • python=3.6
  • tensorflow=1.13.1
  • scipy=1.2.1
  • keras=2.2.4
  • keras-contrib=2.0.8

结构简介

如果要用普通GAN将照片转换为绘画(或着反过来),需要使用成对的图像进行训练。而CycleGAN是一种特殊的GAN,无须使用成对图像进行训练,便可以将图像从一个领域 变换到另一个领域 。CycleGAN训练学习两种映射的生成网络。绝大多数GAN训练只一个生成网络,而CycleGAN会训练两个生成网络和两个判别网络。CycleGAN包含如下两个生成网络。

  • 生成网络A:学习映射G:X→YG:X\rightarrow YG:X→Y,其中XXX是源领域, YYY是目标领域。该映射接收源领域AAA的图像,将其转换成和目标领域BBB中的图像相似的图像。简单说来,该网络旨在学习能使G(X)G(X)G(X)和YYY相似的映射。
  • 生成网络B:学习映射F:Y→XF:Y\rightarrow XF:Y→X,接收目标领域BBB的图像,将其转换成和源领域AAA中图像相似的图像。类似地,该网络旨在学习能使F(G(X))F(G(X))F(G(X))和XXX相似的映射。

两个网络的架构相同,但都单独训练。

CycleGAN包含如下两个判别网络。

  • 判别网络A:判别网络A负责区分生成网络B生成的图像(用F(Y)F(Y)F(Y)表示)和源领域A中的真实图像(表示为XXX)。
  • 判别网络B:判别网络B负责区分生成网络A生成的图像(用$G(X) 表示)和目标领域B中的真实图像(表示为表示)和目标领域B中的真实图像(表示为表示)和目标领域B中的真实图像(表示为Y$)。

两个网络的架构相同,也需要单独训练。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-s465uGrQ-1620893448088)(attachment:268fb040-17fc-47fd-aca7-33fd02e347b5.png)]

损失函数

总损失函数

L(G,F,Dx,Dy)=LGAN(G,DY,X,Y)+LGAN(F,DX,Y,X)+λLcyc(G,F)L(G,F,D_x,D_y) = L_{GAN}(G,D_Y,X,Y)+L_{GAN}(F,D_X,Y,X)+\lambda L_{cyc}(G,F)L(G,F,Dx​,Dy​)=LGAN​(G,DY​,X,Y)+LGAN​(F,DX​,Y,X)+λLcyc​(G,F)

对抗损失

对抗损失是来自概率分布A或概率分布B的图像和生成网络生成的图像之间的损失。该网络涉及两个映射函数,都需要应用对抗损失。
LGAN(G,DY,X,Y)=Eypdata(y)[logDY(y)]+Expdata(x)[log(1−DY(G(x))]L_{GAN}(G,D_Y,X,Y)=E_{y~p_{data}(y)}[logD_Y(y)]+E_{x~p_{data}(x)}[log(1-D_Y(G(x))]LGAN​(G,DY​,X,Y)=Ey pdata​(y)​[logDY​(y)]+Ex pdata​(x)​[log(1−DY​(G(x))]
CycleGAN包含两个生成器G和F,对应两个判别器DX和DY,下面以生成器G和判别器DY进行分析(F和DX的原理与之相同):

  • GGG输入XXX绘画,输出G(X)G(X)G(X)图片,使DYD_YDY​判断G(X)G(X)G(X)与YYY越来越相似。
  • FFF输入YYY绘画,输出F(Y)F(Y)F(Y)图片,使DXD_XDX​判断F(Y)F(Y)F(Y)与XXX越来越相似。

循环一致损失

对抗性损失能够让生成器GGG和生成器FFF学习到域YYY和域XXX的分布,但是却没有保证从XXX得到G(X)G(X)G(X)时图像的内容不变,因为G(X)G(X)G(X)只需要符合域YYY分布即可,并没有对其施加约束,所以XXX到G(X)G(X)G(X)包含很多种可能的映射。

利用循环一致性损失来进行约束,即XXX通过GGG生成G(X)G(X)G(X)后,再通过FFF,生成F(G(X))F(G(X))F(G(X))并使其接近于X。
Lcyc(G,F)=Expdata(x)[∣∣F(G(x))−x∣∣1]+Eypdata(y)[∣∣G(F(y))−y∣∣1]L_{cyc}(G,F)=E_{x~p_{data}(x)}[||F(G(x))-x||_1]+E_{y~p_{data}(y)}[||G(F(y))-y||_1]Lcyc​(G,F)=Ex pdata​(x)​[∣∣F(G(x))−x∣∣1​]+Ey pdata​(y)​[∣∣G(F(y))−y∣∣1​]

构建CycleGAN

# 导入需要的包
import time
from glob import globimport matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from keras import Input, Model
from keras.callbacks import TensorBoard
from keras.layers import Conv2D, BatchNormalization, Activation, Add, Conv2DTranspose, \ZeroPadding2D, LeakyReLU
from keras.optimizers import Adam
from keras_contrib.layers import InstanceNormalization
from scipy.misc import imread, imresize
Using TensorFlow backend.

残差块

def residual_block(x):"""Residual block"""res = Conv2D(filters=128, kernel_size=3, strides=1, padding="same")(x)res = BatchNormalization(axis=3, momentum=0.9, epsilon=1e-5)(res)res = Activation('relu')(res)res = Conv2D(filters=128, kernel_size=3, strides=1, padding="same")(res)res = BatchNormalization(axis=3, momentum=0.9, epsilon=1e-5)(res)return Add()([res, x])

生成器

def build_generator():"""Create a generator network using the hyperparameter values defined below"""input_shape = (128, 128, 3)residual_blocks = 6input_layer = Input(shape=input_shape)# First Convolution blockx = Conv2D(filters=32, kernel_size=7, strides=1, padding="same")(input_layer)x = InstanceNormalization(axis=1)(x)x = Activation("relu")(x)# 2nd Convolution blockx = Conv2D(filters=64, kernel_size=3, strides=2, padding="same")(x)x = InstanceNormalization(axis=1)(x)x = Activation("relu")(x)# 3rd Convolution blockx = Conv2D(filters=128, kernel_size=3, strides=2, padding="same")(x)x = InstanceNormalization(axis=1)(x)x = Activation("relu")(x)# Residual blocksfor _ in range(residual_blocks):x = residual_block(x)# Upsampling blocks# 1st Upsampling blockx = Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', use_bias=False)(x)x = InstanceNormalization(axis=1)(x)x = Activation("relu")(x)# 2nd Upsampling blockx = Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding='same', use_bias=False)(x)x = InstanceNormalization(axis=1)(x)x = Activation("relu")(x)# Last Convolution layerx = Conv2D(filters=3, kernel_size=7, strides=1, padding="same")(x)output = Activation('tanh')(x)model = Model(inputs=[input_layer], outputs=[output])return model

判别器

def build_discriminator():"""Create a discriminator network using the hyperparameter values defined below"""input_shape = (128, 128, 3)hidden_layers = 3input_layer = Input(shape=input_shape)x = ZeroPadding2D(padding=(1, 1))(input_layer)# 1st Convolutional blockx = Conv2D(filters=64, kernel_size=4, strides=2, padding="valid")(x)x = LeakyReLU(alpha=0.2)(x)x = ZeroPadding2D(padding=(1, 1))(x)# 3 Hidden Convolution blocksfor i in range(1, hidden_layers + 1):x = Conv2D(filters=2 ** i * 64, kernel_size=4, strides=2, padding="valid")(x)x = InstanceNormalization(axis=1)(x)x = LeakyReLU(alpha=0.2)(x)x = ZeroPadding2D(padding=(1, 1))(x)# Last Convolution layeroutput = Conv2D(filters=1, kernel_size=4, strides=1, activation="sigmoid")(x)model = Model(inputs=[input_layer], outputs=[output])return model

载入图片

def load_images(data_dir):imagesA = glob(data_dir + '/testA/*.*')imagesB = glob(data_dir + '/testB/*.*')allImagesA = []allImagesB = []for index, filename in enumerate(imagesA):imgA = imread(filename, mode='RGB')imgB = imread(imagesB[index], mode='RGB')imgA = imresize(imgA, (128, 128))imgB = imresize(imgB, (128, 128))if np.random.random() > 0.5:imgA = np.fliplr(imgA)imgB = np.fliplr(imgB)allImagesA.append(imgA)allImagesB.append(imgB)# Normalize imagesallImagesA = np.array(allImagesA) / 127.5 - 1.allImagesB = np.array(allImagesB) / 127.5 - 1.return allImagesA, allImagesBdef load_test_batch(data_dir, batch_size):imagesA = glob(data_dir + '/testA/*.*')imagesB = glob(data_dir + '/testB/*.*')imagesA = np.random.choice(imagesA, batch_size)imagesB = np.random.choice(imagesB, batch_size)allA = []allB = []for i in range(len(imagesA)):# Load images and resize imagesimgA = imresize(imread(imagesA[i], mode='RGB').astype(np.float32), (128, 128))imgB = imresize(imread(imagesB[i], mode='RGB').astype(np.float32), (128, 128))allA.append(imgA)allB.append(imgB)return np.array(allA) / 127.5 - 1.0, np.array(allB) / 127.5 - 1.0

保存训练结果

def save_images(originalA, generatedB, recosntructedA, originalB, generatedA, reconstructedB, path):"""Save images"""fig = plt.figure()ax = fig.add_subplot(2, 3, 1)ax.imshow(originalA)ax.axis("off")ax.set_title("Original")ax = fig.add_subplot(2, 3, 2)ax.imshow(generatedB)ax.axis("off")ax.set_title("Generated")ax = fig.add_subplot(2, 3, 3)ax.imshow(recosntructedA)ax.axis("off")ax.set_title("Reconstructed")ax = fig.add_subplot(2, 3, 4)ax.imshow(originalB)ax.axis("off")ax.set_title("Original")ax = fig.add_subplot(2, 3, 5)ax.imshow(generatedA)ax.axis("off")ax.set_title("Generated")ax = fig.add_subplot(2, 3, 6)ax.imshow(reconstructedB)ax.axis("off")ax.set_title("Reconstructed")plt.savefig(path)def write_log(callback, name, loss, batch_no):"""Write training summary to TensorBoard"""summary = tf.Summary()summary_value = summary.value.add()summary_value.simple_value = losssummary_value.tag = namecallback.writer.add_summary(summary, batch_no)callback.writer.flush()

集成训练

if __name__ == '__main__':data_dir = "../dataset/vangogh2photo/"batch_size = 4epochs = 1000mode = 'train'if mode == 'train':# 载入数据imagesA, imagesB = load_images(data_dir=data_dir)# 迭代器设置common_optimizer = Adam(0.0002, 0.5)#定义网络,训练判别器discriminatorA = build_discriminator()discriminatorB = build_discriminator()discriminatorA.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])discriminatorB.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])# 构建生成网络generatorAToB = build_generator()generatorBToA = build_generator()"""创建对抗网络"""inputA = Input(shape=(128, 128, 3))inputB = Input(shape=(128, 128, 3))# 构建两个生成器generatedB = generatorAToB(inputA)generatedA = generatorBToA(inputB)     # 构建重构网络reconstructedA = generatorBToA(generatedB)reconstructedB = generatorAToB(generatedA)generatedAId = generatorBToA(inputA)generatedBId = generatorAToB(inputB)# 使判别器不被训练discriminatorA.trainable = FalsediscriminatorB.trainable = FalseprobsA = discriminatorA(generatedA)probsB = discriminatorB(generatedB)# 整体网络构建adversarial_model = Model(inputs=[inputA, inputB],outputs=[probsA, probsB, reconstructedA, reconstructedB,generatedAId, generatedBId])adversarial_model.compile(loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],loss_weights=[1, 1, 10.0, 10.0, 1.0, 1.0],optimizer=common_optimizer)# 利用tensorboard记录训练数据tensorboard = TensorBoard(log_dir="logs/{}".format(time.time()), write_images=True, write_grads=True,write_graph=True)tensorboard.set_model(generatorAToB)tensorboard.set_model(generatorBToA)tensorboard.set_model(discriminatorA)tensorboard.set_model(discriminatorB)real_labels = np.ones((batch_size, 7, 7, 1))fake_labels = np.zeros((batch_size, 7, 7, 1))for epoch in range(epochs):print("Epoch:{}".format(epoch))dis_losses = []gen_losses = []num_batches = int(min(imagesA.shape[0], imagesB.shape[0]) / batch_size)print("Number of batches:{}".format(num_batches))for index in range(num_batches):print("Batch:{}".format(index))# 获得样例图片batchA = imagesA[index * batch_size:(index + 1) * batch_size]batchB = imagesB[index * batch_size:(index + 1) * batch_size]# 利用生成器生成图片generatedB = generatorAToB.predict(batchA)generatedA = generatorBToA.predict(batchB)# 对判别器A训练区分真假图片dALoss1 = discriminatorA.train_on_batch(batchA, real_labels)dALoss2 = discriminatorA.train_on_batch(generatedA, fake_labels)# 对判别器B训练区分真假图片dBLoss1 = discriminatorB.train_on_batch(batchB, real_labels)dbLoss2 = discriminatorB.train_on_batch(generatedB, fake_labels)# 计算总的loss值d_loss = 0.5 * np.add(0.5 * np.add(dALoss1, dALoss2), 0.5 * np.add(dBLoss1, dbLoss2))print("d_loss:{}".format(d_loss))"""训练生成网络"""g_loss = adversarial_model.train_on_batch([batchA, batchB],[real_labels, real_labels, batchA, batchB, batchA, batchB])print("g_loss:{}".format(g_loss))dis_losses.append(d_loss)gen_losses.append(g_loss)"""每一次都保留loss值"""write_log(tensorboard, 'discriminator_loss', np.mean(dis_losses), epoch)write_log(tensorboard, 'generator_loss', np.mean(gen_losses), epoch)# 每10个世代进行一次效果展示if epoch % 10 == 0:# 得到测试机图片A,BbatchA, batchB = load_test_batch(data_dir=data_dir, batch_size=2)# 利用生成器生成图片generatedB = generatorAToB.predict(batchA)generatedA = generatorBToA.predict(batchB)# 的到重构图片reconsA = generatorBToA.predict(generatedB)reconsB = generatorAToB.predict(generatedA)# 保存原生图片,生成图片,重构图片for i in range(len(generatedA)):save_images(originalA=batchA[i], generatedB=generatedB[i], recosntructedA=reconsA[i],originalB=batchB[i], generatedA=generatedA[i], reconstructedB=reconsB[i],path="results/gen_{}_{}".format(epoch, i))# 保存模型generatorAToB.save_weights("./weight/generatorAToB.h5")generatorBToA.save_weights("./weight/generatorBToA.h5")discriminatorA.save_weights("./weight/discriminatorA.h5")discriminatorB.save_weights("./weight/discriminatorB.h5")elif mode == 'predict':# 构建生成网络并载入权重generatorAToB = build_generator()generatorBToA = build_generator()generatorAToB.load_weights("./weight/generatorAToB.h5")generatorBToA.load_weights("./weight/generatorBToA.h5")# 获得预测的图片batchA, batchB = load_test_batch(data_dir=data_dir, batch_size=2)# 保存预测图片generatedB = generatorAToB.predict(batchA)generatedA = generatorBToA.predict(batchB)reconsA = generatorBToA.predict(generatedB)reconsB = generatorAToB.predict(generatedA)for i in range(len(generatedA)):save_images(originalA=batchA[i], generatedB=generatedB[i], recosntructedA=reconsA[i],originalB=batchB[i], generatedA=generatedA[i], reconstructedB=reconsB[i],path="results/test_{}".format(i))

训练30个世代

训练50个世代

训练100世代

预测结果

只需要将main中mode = 'train’改成mode = ‘predict’,再提一下就是,main中的epoch = 1000 大家更具自己需要重新设置就好,还有现在的代码是都训练完才会保存模型,可以修改成10个世代保存一次。

训练过程记录

loss值

我们在训练时利用tensorboard记录的判别器和生成器的loss值,之后可以在cmd窗口到logs目录前,使用tensorboard --logdir = ./logs命令可视化训练过程loss值变化

同时,通过上部的选项按键可以看到网络的结构图

小结

本次实验利用CycleGan实现油画风格到照片风格的转变,因为利用自己电脑运行代码,只运行了100个世代已经花费了两个多小时,达到的按效果还不是很明显,可以看到的是图片转成油画比较接近了,但是优化到照片还缺少了一定的清晰度。在之后的学习再尝试别的的网络学习了。

GAN学习记录(五)——循环生成对抗网络CycleGan相关推荐

  1. 深度学习(五) 生成对抗网络入门与实践

    一.生成对抗网络基本概念 1.发展背景 自然界中人类的特性可以概括两大特殊能力,分别是认识和创造.那么在深度学习-神经网络中,我们之前所学习的全连接神经网络.卷积神经网络等,它们都有一个共同的特点就是 ...

  2. 深度学习(2)——生成对抗网络

    深度学习(2)--生成对抗网络 译文,如有错误请与笔者联系 摘要 本文提出一个通过对抗过程来预测生成模型的新框架,其中我们同时训练两个模型:一个用来捕捉数据分布的生成模型G和预测样本来自训练数据而不是 ...

  3. Java学习记录五(多线程、网络编程、Lambda表达式和接口组成更新)

    Java学习记录五(多线程.网络编程.Lambda表达式和接口组成更新) Java 25.多线程 25.1实现多线程 25.1.1进程 25.1.2线程 25.1.3多线程的实现 25.1.4设置和获 ...

  4. GAN(Generative Adversarial Nets (生成对抗网络))

    一.GAN 1.应用 GAN的应用十分广泛,如图像生成.图像转换.风格迁移.图像修复等等. 2.简介 生成式对抗网络是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块:生成 ...

  5. GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10

    CycleGAN的原理可以概述为: 将一类图片转换成另一类图片 .也就是说,现在有两个样 本空间,X和Y,我们希望把X空间中的样本转换成Y空间中 的样本.(获取一个数据集的特征,并转化成另一个数据 集 ...

  6. CycleGAN(循环生成对抗网络)论文解读

    图像到图像的转换的目标是使用配准的图像对训练集来学习输入图像和输出图像之间的映射,而CycleGAN中使用的方法是缺少配对训练集的情况下进行图像转换 传统的图像转换如上图左,训练集是配对的x,y图像{ ...

  7. 深度学习(五十三)对抗网络

  8. 深度学习(九) GAN 生成对抗网络 理论部分

    GAN 生成对抗网络 理论部分 前言 一.Pixel RNN 1.图片的生成模型 2.Pixel RNN 3.Pixel CNN 二.VAE(Variational Autoencoder) 1.VA ...

  9. 四天搞懂生成对抗网络(一)——通俗理解经典GAN

    点击左上方蓝字关注我们 [飞桨开发者说]吕坤,唐山广播电视台,算法工程师,喜欢研究GAN等深度学习技术在媒体.教育上的应用. 序言 做图像分类.检测任务时,为了提高模型精度,在数据处理方面,我尝试了很 ...

最新文章

  1. 操作系统学习2:操作系统的发展和概览
  2. 异步获取邮件推送结果
  3. 项目的命名规范,为以后的程序开发中养成良好的行为习惯
  4. 部分和问题 (dfs搜索 尺取)
  5. Linux进程间通信五 Posix 信号量简介与示例
  6. android碎片化的解决方法,解决 Android 设备碎片化--屏幕适配
  7. Hbase Rowkey设计原则
  8. PAM+4+matlab,基于PAM4调制的400G光模块
  9. python发展路线_Python进阶路径-从学徒到大师
  10. vector与list的区别
  11. Animation中的scale、rotate、translate、alpha
  12. jQuery基础之核心函数,jQuery对象及伪数组 静态方法和实例方法的定义,各种静态方法(each,map,holdRedady,trim,isWindow,isArray,isFunction)
  13. 记录并分析一些软件,以便以后换电脑重新安装(不定时更新)
  14. linux 802.11无线网卡驱动,Linux无线网络配置——无线网卡驱动安装与WLAN802.11配置...
  15. 软考(22)-网络存储、网络安全、网络规划与设计
  16. 【UE4】WebUI插件实现HTML透明区域事件穿透响应
  17. Android Studio 一个工程打包多个不同包名的APK
  18. 华科计算机系教学大纲,《批判性思维》课程教学大纲
  19. Linux 测试IP和端口是否能访问
  20. ubuntu系统下运行可执行文件 (application/x-executable)

热门文章

  1. python人机猜拳游戏代码_实用宝典|如何用Python实现人机猜拳小游戏
  2. 九龙证券|阿里+鸿蒙+人工智能+元宇宙概念热度爆棚,“会说话的猫”亮了!
  3. 用Scrapy对豆瓣top250进行电影详细信息爬取
  4. python:Django
  5. PHP框架页面作业,高校邦《ThinkPHP框架技术》作业题库
  6. 如何在 ESRI ArcMap 中打开谷歌卫星地图
  7. 内存管理(四)——虚拟内存
  8. JavaScript:京东放大镜效果
  9. 亚马逊云科技Marketplace(中国区)正式支持付费AMI产品
  10. 医疗NLP实践与思考