ACGAN(Auxiliary Classifier GAN)详解与实现(tensorflow2.x实现)

  • ACGAN原理
  • ACGAN实现
    • 模块导入
    • 生成器
    • 鉴别器
    • 模型构建
    • 模型训练
    • 虚假图像生成及绘制plot_images函数
    • 训练结果

ACGAN原理

ACGAN的原理GAN(CGAN)相似。对于CGAN和ACGAN,生成器输入均为潜在矢量及其标签,输出是属于输入类标签的伪造图像。对于CGAN,判别器的输入是图像(包含假的或真实的图像)及其标签, 输出是图像属于真实图像的概率。对于ACGAN,判别器的输入是一幅图像,而输出是该图像属于真实图像的概率以及其类别概率。
本质上,在CGAN中,向网络提供了标签。在ACGAN中,使用辅助解码器网络重建辅助信息。ACGAN理论认为,强制网络执行其他任务可以提高原始任务的性能。在这种情况下,辅助任务是图像分类。原始任务是生成伪造图像。
判别器目标函数:
L(D)=−Ex∼pdatalogD(x)−Ezlog[1−D(G(z∣y))]−Ex∼pdatap(c∣x)−Ezlogp(c∣g(z∣y))\mathcal L^{(D)} = -\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog[1 − D(G(z|y))]-\mathbb E_{x\sim p_{data}}p(c|x)-\mathbb E_zlogp(c|g(z|y))L(D)=−Ex∼pdata​​logD(x)−Ez​log[1−D(G(z∣y))]−Ex∼pdata​​p(c∣x)−Ez​logp(c∣g(z∣y))
生成器目标函数:
L(G)=−EzlogD(g(z∣y))−Ezlogp(c∣g(z∣y))\mathcal L^{(G)} = -\mathbb E_{z}logD(g(z|y))-\mathbb E_zlogp(c|g(z|y))L(G)=−Ez​logD(g(z∣y))−Ez​logp(c∣g(z∣y))

ACGAN实现

模块导入

import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import os
import math
from PIL import Image

生成器

def generator(inputs,image_size,activation='sigmoid',labels=None):"""生成网络Arguments:inputs (layer): 输入image_size (int): 图片尺寸activation (string): 输出层激活函数labels (tensor): 标签returns:model: 生成网络"""image_resize = image_size // 4kernel_size = 5layer_filters = [128,64,32,1]inputs = [inputs,labels]x = keras.layers.concatenate(inputs,axis=1)x = keras.layers.Dense(image_resize*image_resize*layer_filters[0])(x)x = keras.layers.Reshape((image_resize,image_resize,layer_filters[0]))(x)for filters in layer_filters:if filters > layer_filters[-2]:strides = 2else:strides = 1x = keras.layers.BatchNormalization()(x)x = keras.layers.Activation('relu')(x)x = keras.layers.Conv2DTranspose(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)if activation is not None:x = keras.layers.Activation(activation)(x)return keras.Model(inputs,x,name='generator')

鉴别器

def discriminator(inputs,activation='sigmoid',num_labels=None):"""生成网络Arguments:inputs (Layer): 输入activation (string): 输出层激活函数num_labels (int): 类别数Returns:Model: 鉴别网络"""kernel_size = 5layer_filters = [32,64,128,256]x = inputsfor filters in layer_filters:if filters == layer_filters[-1]:strides = 1else:strides = 2x = keras.layers.LeakyReLU(0.2)(x)x = keras.layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)x = keras.layers.Flatten()(x)outputs = keras.layers.Dense(1)(x)if activation is not None:print(activation)outputs = keras.layers.Activation(activation)(outputs)if num_labels:#ACGAN有第二个输出,用于输出图片的类别layer = keras.layers.Dense(layer_filters[-2])(x)labels = keras.layers.Dense(num_labels)(layer)labels = keras.layers.Activation('softmax',name='label')(labels)outputs = [outputs,labels]return keras.Model(inputs,outputs,name='discriminator')

模型构建

def build_and_train_models():"""The ACGAN training"""#数据加载及预处理(x_train,y_train),_ = keras.datasets.mnist.load_data()image_size = x_train.shape[1]x_train = np.reshape(x_train,[-1,image_size,image_size,1])x_train = x_train.astype('float32') / 255.num_labels = len(np.unique(y_train))y_train = keras.utils.to_categorical(y_train)#超参数model_name = 'acgan-mnist'latent_size = 100batch_size = 64train_steps = 40000lr = 2e-4decay = 6e-8input_shape = (image_size,image_size,1)label_shape = (num_labels,)#discriminatorinputs = keras.layers.Input(shape=input_shape,name='discriminator_input')discriminator = discriminator(inputs,num_labels=num_labels)optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)loss = ['binary_crossentropy','categorical_crossentropy']discriminator.compile(loss=loss,optimizer=optimizer,metrics=['acc'])discriminator.summary()#generatorinput_shape = (latent_size,)inputs = keras.layers.Input(shape=input_shape,name='z_input')labels = keras.layers.Input(shape=label_shape,name='labels')generator = generator(inputs,image_size,labels=labels)generator.summary()optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)discriminator.trainable = Falseadversarial = keras.Model([inputs,labels],discriminator(generator([inputs,labels])),name=model_name)adversarial.compile(loss=loss,optimizer=optimizer,metrics=['acc'])adversarial.summary()models = (generator,discriminator,adversarial)data = (x_train,y_train)params = (batch_size,latent_size,train_steps,num_labels,model_name)train(models,data,params)

模型训练

def train(models,data,params):"""Train the discriminator and adversarial NetworksArguments:models (list): generator,discriminator,adversarialdata (list): x_train,y_trainparams (list): network parameter"""generator,discriminator,adversarial = modelsx_train,y_train = databatch_size,latent_size,train_steps,num_labels,model_name = paramssave_interval = 500noise_input = np.random.uniform(-1.,1.,size=[16,latent_size])noise_label = np.eye(num_labels)[np.arange(0,16) % num_labels]train_size = x_train.shape[0]print(model_name,'Labels for generated images: ',np.argmax(noise_label,axis=1))for i in range(train_steps):#训练鉴别器rand_indexes = np.random.randint(0,train_size,size=batch_size)real_images = x_train[rand_indexes]real_labels = y_train[rand_indexes]#产生伪造图片noise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]fake_images = generator.predict([noise,fake_labels])#构造输入x = np.concatenate((real_images,fake_images))#训练类别标签labels = np.concatenate((real_labels,fake_labels))#标签y = np.ones([2*batch_size,1])y[batch_size:,:] = 0.0#训练模型metrics = discriminator.train_on_batch(x,[y,labels])fmt = '%d: [disc loss: %f, srcloss: %f],'fmt += 'lbloss: %f, srcacc: %f, lblacc: %f'log = fmt % (i,metrics[0],metrics[1],metrics[2],metrics[3],metrics[4])#train adversarial network for 1 batchnoise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]y = np.ones([batch_size,1])metrics = adversarial.train_on_batch([noise,fake_labels],[y,fake_labels])fmt = "%s [advr loss: %f, srcloss: %f,"fmt += "lblloss: %f, srcacc: %f, lblacc: %f]"log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4])print(log)if (i + 1) % save_interval == 0:# 绘制生成图片plot_images(generator,noise_input=noise_input,noise_label=noise_label,show=False,step=(i + 1),model_name=model_name)generator.save(model_name + ".h5")

虚假图像生成及绘制plot_images函数

def plot_images(generator,noise_input,noise_label=None,noise_codes=None,show=False,step=0,model_name="gan"):"""生成虚假图片及绘制# Argumentsgenerator (Model): 生成模型noise_input (ndarray): 潜在模型show (bool): 是否展示step (int): step值model_name (string): 模型名称"""os.makedirs(model_name, exist_ok=True)filename = os.path.join(model_name, "%05d.png" % step)rows = int(math.sqrt(noise_input.shape[0]))if noise_label is not None:noise_input = [noise_input, noise_label]if noise_codes is not None:noise_input += noise_codesimages = generator.predict(noise_input)plt.figure(figsize=(2.2, 2.2))num_images = images.shape[0]image_size = images.shape[1]for i in range(num_images):plt.subplot(rows, rows, i + 1)image = np.reshape(images[i], [image_size, image_size])plt.imshow(image, cmap='gray')plt.axis('off')plt.savefig(filename)if show:plt.show()else:plt.close('all')

训练结果

#运行
if __name__ == '__main__':build_and_train_models()
step=1000:

step=15000:

ACGAN(Auxiliary Classifier GAN)详解与实现(tensorflow2.x实现)相关推荐

  1. Style Transfer for Anime Sketches with Enhanced Residual U-net and Auxiliary Classifier GAN

    网络结构 本文的GAN网络结构为: 生成网络的输入为需要风格转换的图像即input,以及风格特征.采用VGG16/19的fc1层,提取风格图像的特征,风格特征为4096维的向量. 生成网络结构和目标函 ...

  2. 官网实例详解-目录和实例简介-keras学习笔记四

    https://github.com/keras-team/keras/tree/master/examples Keras examples directory Keras实例目录 (点击跳转) 官 ...

  3. [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  4. 万字详解什么是生成对抗网络GAN

    摘要:这篇文章将详细介绍生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN).发展历程.预备知识,并通过Keras搭建最简答的手写数字图片生成案. ...

  5. ICCV2017 | 一文详解GAN之父Ian Goodfellow 演讲《生成对抗网络的原理与应用》(附完整PPT)

    当地时间 10月 22 日到10月29日,两年一度的计算机视觉国际顶级会议 International Conference on Computer Vision(ICCV 2017)在意大利威尼斯开 ...

  6. GAN综述及其在图像生成领域的应用(含原理、代码详解)

    本文将持续更新. 目录 1. 基本GAN 1.1 GAN(2014) 1.2 CGAN(2015) 1.3 DCGAN(2015) 1.4 VAE-GAN(2016) 1.5 ACGAN(2017) ...

  7. 明晚8点公开课 | 用AI给旧时光上色!详解GAN在黑白照片上色中的应用

    在改革开放40周年之际,百度联合新华社推出了一个刷屏级的H5应用--用AI技术为黑白老照片上色,浓浓的怀旧风勾起了心底快被遗忘的时光. 想了解如何给老照片上色?本次公开课中,我们邀请到了百度高级研发工 ...

  8. 【GAN优化】详解GAN中的一致优化问题

    GAN的训练是一个很难解决的问题,上期其实只介绍了一些基本的动力学概念以及与GAN的结合,并没有进行过多的深入.动力学是一门比较成熟的学科,有很多非常有用的结论,我们将尝试将其用在GAN上,来得到一些 ...

  9. 【GAN优化】详解SNGAN(频谱归一化GAN)

    今天将和大家一起学习具有很高知名度的SNGAN.之前提出的WGAN虽然性能优越,但是留下一个难以解决的1-Lipschitz问题,SNGAN便是解决该问题的一个优秀方案.我们将先花大量精力介绍矩阵的最 ...

最新文章

  1. 【iOS】中间透明的引导蒙层
  2. idea中使用docker插件部署项目
  3. 如何实现不安装xpoedinstall激活xposed模块
  4. 笔记本键盘维修[原创]
  5. 使用nfs映射远程服务器磁盘目录
  6. 你所不知道的mybatis居然也有拦截器
  7. 7-2 输出约数 (9 分)
  8. Android自动挂断电话
  9. 【QT学习之路】Charts的简单使用
  10. 机电传动与控制【2】
  11. linkboy+ESP32创意DIY时钟
  12. 分享使用PHP开发留言板
  13. 怎么恢复电脑删除的文件,误删除数据恢复
  14. Ubuntu20.04部署ntp服务
  15. 五、SQL–索引/约束⑥(外键约束)
  16. HTML有2种路径的写法:绝对路径和相对路径
  17. parameterType 用法
  18. C#设计程序购买商品总金额
  19. 网络通信基础知识(三)
  20. 常用的DOS命令(cmd基本操作)

热门文章

  1. InnoDB中锁的算法(1)
  2. 单调栈 BZOJ2364 城市美化
  3. Storm中关于Topology的设计
  4. redis 在windows 下的安装和使用
  5. [转]MFC中ON_COMMAND, ON_MESSAGE, ON_NOTIFY它们的区别
  6. [转载] python中count()、values_counts()、size()函数
  7. [转载] python中bool啥意思_Python中的bool类型
  8. 如何选择使用IEnumerable, ICollection, IList
  9. iOS开源项目周报0302
  10. Git正确的协作方式(很简单)