条件生成对抗网络CGAN

CGAN是最早使目标数据生成成为可能的GAN创新之一,可以说是最具影响力的一种。接下来,介绍CGAN的工作方式以及如何用MNIST数据集实现它的小规模版本。

CGAN原理

生成器学习为训练数据集中的每个标签生成逼真的样本,而鉴别器则学习区分真的样本-标签对与假的样本-标签对。半监督GAN的鉴别器除了区分真实样本与伪样本,还为每个真实样本分配正确的标签;而CGAN中的鉴别器不会学习识别哪个样本是哪个类。它只学习接受真实的且样本-标签匹配正确的对,拒绝不匹配的对和样本为假的对。

例如:无论样本1是真是假,CGAN的判别器都拒绝该(样本1与标签2)对,为了欺骗鉴别器,CGAN生成器仅生成逼真的数据是不够的,生成的样本还需要与标签相匹配。在对生成器进行充分训练之后,就可以通过传递所需的标签来指定希望CGAN合成的样本。

CGAN的生成器

利用噪声z和标签y合成一个为样本x*|y

CGAN的判别器

接受带标签的真实样本(x,y),以及带有标签的伪样本(x*|y,y),在真实样本-标签对上,鉴别器学习如何识别真实数据以及如何识别匹配对。在生成器生成的样本中,鉴别器学习识别伪样本-标签对,以将它们与真实样本-标签对区分开来。
判别器输出表明输入是真实的匹配对的概率,它的目标是学会接受所有的真实样本-标签对,并拒绝所有伪样本和所有与标签不匹配的样本。

架构图与汇总表

对于每个伪样本,相同的标签y同时被传递给生成器和鉴别器。另外,通过在带有不匹配标签的真实样本上训练鉴别器来拒绝不匹配的对;它识别不匹配对的能力是被训练成只接收真实匹配对时的副产品。

CGAN的实现

# 导入包
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras import backend as K
from tensorflow.keras.datasets import mnist
from keras.layers import Embedding, Multiply, Dropout, Lambda, Concatenate, Input, Dense, Flatten, Reshape, Activation, BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
Using TensorFlow backend.
# 模型输入维度
img_rows = 28
img_cols = 28
channels = 1
# 图像大小
img_shape = (img_rows, img_cols, channels)
# 噪声向量大小
z_dim = 100num_classes = 10

构造生成器

(1)使用Keras的Embedding层将标签y(0到9的整数)转换为大小为z_dim(随机噪声向量的长度)的稠密向量。

(2)使用Keras的Multiply层将标签与噪声向量z嵌入联合表示中。顾名思义,该层将两个等长向量的对应项相乘,并输出作为结果乘积的单个向量。

(3)将得到的向量作为输入,保留CGAN生成器网络的其余部分以合成图像。

def build_generator(z_dim):model = Sequential()model.add(Dense(256 * 7 * 7, input_dim=z_dim))model.add(Reshape((7, 7, 256)))model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))model.add(BatchNormalization())model.add(LeakyReLU(alpha=0.01))model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same'))model.add(BatchNormalization())model.add(LeakyReLU(alpha=0.01))model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same'))model.add(Activation('tanh'))return modeldef build_cgan_genertator(z_dim):z = Input(shape=(z_dim, ))label = Input(shape=(1,), dtype='int32')label_embedding = Embedding(num_classes, z_dim, input_length=1)(label)label_embedding = Flatten()(label_embedding)joined_representation = Multiply()([z, label_embedding])generator = build_generator(z_dim)conditioned_img = generator(joined_representation)return Model([z, label], conditioned_img)

构造CGAN的判别器

步骤:

(1)取一个标签(0到9的整数),使用Keras的Embedding层将标签变成大小为28 × 28 × 1 = 784(扁平化图像的长度)的稠密向量。

(2)将嵌入标签调整为图像尺寸(28 × 28 × 1)。

(3)将重塑后的嵌入标签连接到对应图像上,生成形状(28 × 28× 2)的联合表示。可以将其视为在顶部“贴有”嵌入标签的图像。

(4)将图像-标签的联合表示输入CGAN的鉴别器网络中。注意,为了训练正常进行,必须将模型输入尺寸调整为(28 × 28 × 2)来对应新的输入形状。

def build_discriminator(img_shape):model = Sequential()model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(img_shape[0], img_shape[1], img_shape[2]+1),padding='same'))model.add(LeakyReLU(alpha=0.01))model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape,padding='same'))model.add(BatchNormalization())model.add(LeakyReLU(alpha=0.01))model.add(Conv2D(128, kernel_size=3, strides=2, input_shape=img_shape,padding='same'))model.add(BatchNormalization())model.add(LeakyReLU(alpha=0.01))model.add(Dropout(0.5))model.add(Flatten())model.add(Dense(1, activation='sigmoid'))return model
def build_cgan_discriminator(img_shape):img = Input(shape=img_shape)label = Input(shape=(1, ), dtype='int32')label_embedding = Embedding(num_classes, np.prod(img_shape), input_length=1)(label)label_embedding = Flatten()(label_embedding)label_embedding = Reshape(img_shape)(label_embedding) # 将标签调整和输入图像一样的维度concatenated = Concatenate(axis= -1)([img, label_embedding])# 将图像与其嵌入标签链接discriminator = build_discriminator(img_shape)classification = discriminator(concatenated)return Model([img, label], classification)

搭建整个模型

def build_cgan(generator, discriminator):z = Input(shape=(z_dim, ))label = Input(shape=(1, ))img = generator([z, label])classification = discriminator([img, label])model = Model([z, label], classification)return modeldiscriminator = build_cgan_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
generator = build_cgan_genertator(z_dim)
discriminator.trainable = Falsecgan = build_cgan(generator, discriminator)
cgan.compile(loss='binary_crossentropy', optimizer=Adam())

训练

losses = []
accuracies = []def train(iterations, batch_size, sample_interval):(X_train, y_train), (_, _) = mnist.load_data('./MNIST')X_train = X_train / 127.5 - 1.0X_train = np.expand_dims(X_train, axis=3)real = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for iteration in range(iterations):idx = np.random.randint(0, X_train.shape[0], batch_size)
#         print(X_train.shape[0])imgs, labels = X_train[idx], y_train[idx]z = np.random.normal(0, 1, (batch_size, z_dim))gen_imgs = generator.predict([z, labels])
#         print(imgs.shape)d_loss_real = discriminator.train_on_batch([imgs, labels], real)d_loss_fake = discriminator.train_on_batch([gen_imgs, labels], fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)z = np.random.normal(0, 1, (batch_size, z_dim))labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)g_loss = cgan.train_on_batch([z, labels], real)if (iteration + 1) % sample_interval == 0:losses.append((d_loss[0], g_loss))accuracies.append(100.0 * d_loss[1])print("%d [D loss: %f, acc.: %.2f%%] [G loss:%f]"%(iteration + 1, d_loss[0], 100.0 * d_loss[1], g_loss))sample_images()
def sample_images (image_grid_rows=2, image_grid_columns=5):z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))labels = np.arange(0, 10).reshape(-1, 1)gen_imgs = generator.predict([z, labels])gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(image_grid_rows,image_grid_columns,figsize=(10,4),sharey=True,sharex=True)cnt = 0for i in range(image_grid_rows):for j in range(image_grid_columns):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')axs[i,j].set_title("Digit: %d" % labels[cnt])cnt +=1
iterations  = 12000
batch_size = 32
sample_interval = 1000
train(iterations, batch_size, sample_interval)
1000 [D loss: 0.000204, acc.: 100.00%] [G loss:9.885448]
2000 [D loss: 0.000059, acc.: 100.00%] [G loss:9.908726]
3000 [D loss: 0.230777, acc.: 90.62%] [G loss:4.183795]
4000 [D loss: 0.040735, acc.: 98.44%] [G loss:3.380749]
5000 [D loss: 0.192189, acc.: 90.62%] [G loss:3.410103]
6000 [D loss: 0.134279, acc.: 98.44%] [G loss:3.005539]
7000 [D loss: 0.412724, acc.: 82.81%] [G loss:1.312850]
8000 [D loss: 0.211682, acc.: 90.62%] [G loss:3.666016]
9000 [D loss: 0.080928, acc.: 98.44%] [G loss:7.182220]
10000 [D loss: 0.107635, acc.: 98.44%] [G loss:2.332113]
11000 [D loss: 0.194184, acc.: 93.75%] [G loss:3.737709]
12000 [D loss: 0.191671, acc.: 89.06%] [G loss:4.127837]

训练1000次

训练6000次

训练1200次

小结

CGAN实现了不光是生成类似于真的样本而且还要生成一个符合条件的真实样本。通过增加生成器判别器的输入进一步提高了GAN的功能。

Github地址:https://github.com/yunlong-G/tensorflow_learn/blob/master/GAN/CGAN.ipynb

GAN学习记录(四)——条件生成对抗网络CGAN相关推荐

  1. [深度学习-实践]条件生成对抗网络cGAN的例子-Tensorflow2.x Keras

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子 深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子 深度学习GAN(四)之c ...

  2. 深度学习故障诊断之-使用条件生成对抗网络CGAN生成泵流量信号

    开始填坑 MATLAB统计机器学习,深度学习,计算机视觉 - 哥廷根数学学派的文章 - 知乎 MATLAB统计机器学习,深度学习,计算机视觉 - 知乎 之前写过在使用深度学习对机械系统或电气系统进行故 ...

  3. GAN(生成对抗网络) and CGAN(条件生成对抗网络)

    前言 GAN(生成对抗网络)是2014年由Goodfellow大佬提出的一种深度生成模型,适用于无监督学习.监督学习.但是GAN进行生成时是不可控的,所以后来又有人提出可控的CGAN(条件生成对抗网络 ...

  4. 深度学习(2)——生成对抗网络

    深度学习(2)--生成对抗网络 译文,如有错误请与笔者联系 摘要 本文提出一个通过对抗过程来预测生成模型的新框架,其中我们同时训练两个模型:一个用来捕捉数据分布的生成模型G和预测样本来自训练数据而不是 ...

  5. CGAN(条件生成-对抗网络)简述教程

    一直以来提醒自己说,你要坚持,要坚持,要坚持更新,坚持读paper,要是现在都坚持不下去,还有将近五年该怎么办呐. 所幸,我还是有底线的,在今天抽了一个下午,看了看论文,跑了跑代码,过程谈不上舒服,也 ...

  6. 深度学习(五) 生成对抗网络入门与实践

    一.生成对抗网络基本概念 1.发展背景 自然界中人类的特性可以概括两大特殊能力,分别是认识和创造.那么在深度学习-神经网络中,我们之前所学习的全连接神经网络.卷积神经网络等,它们都有一个共同的特点就是 ...

  7. 人工智能--条件生成对抗网络

    学习目标: 理解条件生成对抗网络的基本原理. 掌握利用条件生成对抗网络生成新样本的方法. 学习内容: fashion_mnist数据库(from keras.datasets import fashi ...

  8. CGAN之条件生成对抗网络(Matlab)

    代码来源 代码全文: clear all; close all; clc; %% Conditional Generative Adversarial Network %% Load Data loa ...

  9. GAN(Generative Adversarial Nets (生成对抗网络))

    一.GAN 1.应用 GAN的应用十分广泛,如图像生成.图像转换.风格迁移.图像修复等等. 2.简介 生成式对抗网络是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块:生成 ...

  10. 条件生成对抗网络(CGAN)

    记录一下 上资源:(github), 基于Pytorch的条件对抗生成网络 条件对抗生成网络和生成对抗网络的区别在于,条件对抗网络生成器和鉴别器额外输入了条件信息(以minist为例,就是额外输入了标 ...

最新文章

  1. oracle修改数据高性能,oracle数据库的性能调整
  2. java/android 做题中整理的碎片小贴士(12)
  3. 软件外包项目中的进度管理
  4. 51nod1229-序列求和V2【数学,拉格朗日插值】
  5. JavaScript学习随记——数组一
  6. Django模块学习- django-pagination
  7. 1.创建一个 Slim 应用
  8. linux查看rabbitmq的插件,docker安装rabbitmq延时队列插件
  9. 2021华为软件精英挑战赛的baseLine,Java版,仅供参考,无核心算法
  10. C编程实例-“约瑟夫问题” 解法
  11. uos服务器系统rpm安装oracle 19c
  12. 优酷视频如何登录优酷账号?
  13. CAD最常用的快捷键大全来啦
  14. 开店选址分析(转自:https://www.sohu.com/a/228415364_167028)
  15. java 关注公众号回调_处理公众号回调消息
  16. pat a1096(因式分解)
  17. codeforces24D
  18. 华为服务器euler系统,华为euler服务器
  19. fagin 启动报错 Fallback/fallbackFactory
  20. Java+OpenCV实现图片中的人脸识别

热门文章

  1. C51单片机串口初始化为何是这样:SCON=0X52;TMOD=0X20;TH1=0XF3;TR1=1;
  2. Flutter-防京东商城项目-修改收货地址 删除收货地址-43
  3. 答案--Java面试笔试题(3年以上)
  4. 基于Unity的极乐净土/others MMD动画制作
  5. 数据分析-信用卡反欺诈模型
  6. treetable php,jQuery树型表格插件jQuery treetable
  7. 有源滤波器: 基于UAF42的50Hz陷波器设计
  8. 【蓝牙串口无线烧写程序】适用于STM32F103和STM32F107的Bootloader
  9. 高通平台开发系列讲解(外设篇)BMI160基本配置
  10. DBV命令行工具检测坏块