InfoGAN详解与实现(采用tensorflow2.x实现)

  • InfoGAN原理
  • InfoGAN实现
    • 导入必要库
    • 生成器
    • 鉴别器
    • 模型构建
    • 模型训练
    • 效果展示

InfoGAN原理

最初的GAN能够产生有意义的输出,但是缺点是它的属性无法控制。例如,无法明确向生成器提出生成女性名人的脸,该女性名人是黑发,白皙的肤色,棕色的眼睛,微笑着。这样做的根本原因是因为使用的100-dim噪声矢量合并了生成器输出的所有显着属性。
如果能够修改原始GAN,从而将表示形式分为合并和分离可解释的潜在编码向量,则可以告诉生成器要合成什么。
合并和分离编码可以表示如下:
具有分离表示的GAN也可以以与普通GAN相同的方式进行优化。生成器的输出可以表示为:
G(z,c)=G(z)G(z,c)=G(z)G(z,c)=G(z)
编码z=(z,c)z = (z,c)z=(z,c)包含两个元素,zzz表示合并表示,c=c1,c2,...,cLc=c_1,c_2,...,c_Lc=c1​,c2​,...,cL​表示分离的编码表示。
为了强制编码的解耦,InfoGAN提出了一种针对原始损失函数的正则化函数,该函数将潜在编码ccc和G(z,c)G(z,c)G(z,c)之间的互信息最大化:
I(c;G(z,c))=IG(c;z)I(c;G(z,c))=IG(c;z)I(c;G(z,c))=IG(c;z)
正则化器强制生成器考虑潜在编码。在信息论领域,潜在编码ccc和G(z,c)G(z,c)G(z,c)之间的互信息定义为:
I(G(c;z)=H(c)−H(c∣G(z,c))I(G(c;z)=H(c)-H(c|G(z,c))I(G(c;z)=H(c)−H(c∣G(z,c))
其中H(c)H(c)H(c)是潜在编码ccc的熵,而H(c∣G(z,c))H(c|G(z,c))H(c∣G(z,c))是得到生成器的输出G(z,c)G(z,c)G(z,c)后c的条件熵。
最大化互信息意味着在生成得到生成的输出时将H(c∣G(z,c))H(c|G(z,c))H(c∣G(z,c))最小化或减小潜在编码中的不确定性。
但是由于估计H(c∣G(z,c))H(c|G(z,c))H(c∣G(z,c))需要后验分布p(c∣G(z,c))=p(c∣x)p(c|G(z,c))=p(c|x)p(c∣G(z,c))=p(c∣x),因此难以估算H(c∣G(z,c))H(c|G(z,c))H(c∣G(z,c))。
解决方法是通过使用辅助分布Q(c∣x)Q(c|x)Q(c∣x)估计后验概率来估计互信息的下限,估计相互信息的下限为:
I(c;G(z,c))≥LI(G,Q)=Ec∼p(c),x∼G(z,c)[logQ(c∣x)]+H(c)I(c;G(z,c)) \ge L_I(G,Q)=E_{c \sim p(c),x \sim G(z,c)}[logQ(c|x)]+H(c)I(c;G(z,c))≥LI​(G,Q)=Ec∼p(c),x∼G(z,c)​[logQ(c∣x)]+H(c)
在InfoGAN中,假设H(c)H(c)H(c)为常数。因此,使互信息最大化是使期望最大化的问题。生成器必须确信已生成具有特定属性的输出。此期望的最大值为零。因此,互信息的下限的最大值为H(c)H(c)H(c)。在InfoGAN中,离散潜在编码Q(c∣x)Q(c|x)Q(c∣x)的可以用softmax表示。期望是tf.keras中的负categorical_crossentropy损失。
对于一维连续编码,期望是ccc和xxx上的二重积分,这是由于期望样本同时来自分离编码分布和生成器分布。估计期望值的一种方法是通过假设样本是连续数据的良好度量。因此,损失估计为clogQ(c∣x)clogQ(c|x)clogQ(c∣x)。
为了完成InfoGAN的网络,应该有一个logQ(c∣x)logQ(c|x)logQ(c∣x)的实现。为简单起见,网络Q是附加到鉴别器的辅助网络。
鉴别器损失函数
L(D)=−Ex∼pdatalogD(x)−Ez,clog[1−D(G(z,c))]−λI(c;G(z,c))\mathcal L^{(D)} = -\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_{z,c}log[1 − D(G(z,c))]-\lambda I(c;G(z,c))L(D)=−Ex∼pdata​​logD(x)−Ez,c​log[1−D(G(z,c))]−λI(c;G(z,c))
生成器损失函数:
L(G)=−Ez,clogD(G(z,c))−λI(c;G(z,c))\mathcal L^{(G)} = -\mathbb E_{z,c}logD(G(z,c))-\lambda I(c;G(z,c))L(G)=−Ez,c​logD(G(z,c))−λI(c;G(z,c))
其中λ\lambdaλ是正的常数

InfoGAN实现

如果将其应用于MNIST数据集,InfoGAN可以学习分离的离散编码和连续编码,以修改生成器输出属性。 例如,像CGAN和ACGAN一样,将使用10维独热标签形式的离散编码来指定要生成的数字。但是,可以添加两个连续的编码,一个用于控制书写样式的角度,另一个用于调整笔划宽度。保留较小尺寸的编码以表示所有其他属性:

导入必要库

import tensorflow as tf
import numpy as np
from tensorflow import keras
import os
from matplotlib import pyplot as plt
import math
from PIL import Image
from tensorflow.keras import backend as K

生成器

def generator(inputs,image_size,activation='sigmoid',labels=None,codes=None):"""generator modelArguments:inputs (layer): input layer of generatorimage_size (int): Target size of one sideactivation (string): name of output activation layerlabels (tensor): input labelscodes (list): 2-dim disentangled codes for infoGANreturns:model: generator model"""image_resize = image_size // 4kernel_size = 5layer_filters = [128,64,32,1]inputs = [inputs,labels] + codesx = keras.layers.concatenate(inputs,axis=1)x = keras.layers.Dense(image_resize*image_resize*layer_filters[0])(x)x = keras.layers.Reshape((image_resize,image_resize,layer_filters[0]))(x)for filters in layer_filters:if filters > layer_filters[-2]:strides = 2else:strides = 1x = keras.layers.BatchNormalization()(x)x = keras.layers.Activation('relu')(x)x = keras.layers.Conv2DTranspose(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)if activation is not None:x = keras.layers.Activation(activation)(x)return keras.Model(inputs,x,name='generator')

鉴别器

def discriminator(inputs,activation='sigmoid',num_labels=None,num_codes=None):"""discriminator modelArguments:inputs (Layer): input layer of the discriminatoractivation (string): name of output activation layernum_labels (int): dimension of one-hot labels for ACGAN & InfoGANnum_codes (int): num_codes-dim 2 Q network if InfoGANReturns:Model: Discriminator model"""kernel_size = 5layer_filters = [32,64,128,256]x = inputsfor filters in layer_filters:if filters == layer_filters[-1]:strides = 1else:strides = 2x = keras.layers.LeakyReLU(0.2)(x)x = keras.layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)x = keras.layers.Flatten()(x)outputs = keras.layers.Dense(1)(x)if activation is not None:print(activation)outputs = keras.layers.Activation(activation)(outputs)if num_labels:layer = keras.layers.Dense(layer_filters[-2])(x)labels = keras.layers.Dense(num_labels)(layer)labels = keras.layers.Activation('softmax',name='label')(labels)# 1-dim continous Q of 1st c given xcode1 = keras.layers.Dense(1)(layer)code1 = keras.layers.Activation('sigmoid',name='code1')(code1)# 1-dim continous Q of 2nd c given xcode2 = keras.layers.Dense(1)(layer)code2 = keras.layers.Activation('sigmoid',name='code2')(code2)outputs = [outputs,labels,code1,code2]return keras.Model(inputs,outputs,name='discriminator')

模型构建

#mi_loss
def mi_loss(c,q_of_c_give_x):"""mi_loss = -c * log(Q(c|x))"""return K.mean(-K.sum(K.log(q_of_c_give_x + K.epsilon()) * c,axis=1))def build_and_train_models(latent_size=100):"""Load the dataset, build InfoGAN models,Call the InfoGAN train routine."""(x_train,y_train),_ = keras.datasets.mnist.load_data()image_size = x_train.shape[1]x_train = np.reshape(x_train,[-1,image_size,image_size,1])x_train = x_train.astype('float32') / 255.num_labels = len(np.unique(y_train))y_train = keras.utils.to_categorical(y_train)#超参数model_name = 'infogan_mnist'batch_size = 64train_steps = 40000lr = 2e-4decay = 6e-8input_shape = (image_size,image_size,1)label_shape = (num_labels,)code_shape = (1,)#discriminator modelinputs = keras.layers.Input(shape=input_shape,name='discriminator_input')#discriminator with 4 outputsdiscriminator_model = discriminator(inputs,num_labels=num_labels,num_codes=2)optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)loss = ['binary_crossentropy','categorical_crossentropy',mi_loss,mi_loss]loss_weights = [1.0,1.0,0.5,0.5]discriminator_model.compile(loss=loss,loss_weights=loss_weights,optimizer=optimizer,metrics=['acc'])discriminator_model.summary()input_shape = (latent_size,)inputs = keras.layers.Input(shape=input_shape,name='z_input')labels = keras.layers.Input(shape=label_shape,name='labels')code1 = keras.layers.Input(shape=code_shape,name='code1')code2 = keras.layers.Input(shape=code_shape,name='code2')generator_model = generator(inputs,image_size,labels=labels,codes=[code1,code2])generator_model.summary()optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)discriminator_model.trainable = Falseinputs = [inputs,labels,code1,code2]adversarial_model = keras.Model(inputs,discriminator_model(generator_model(inputs)),name=model_name)adversarial_model.compile(loss=loss,loss_weights=loss_weights,optimizer=optimizer,metrics=['acc'])adversarial_model.summary()models = (generator_model,discriminator_model,adversarial_model)data = (x_train,y_train)params = (batch_size,latent_size,train_steps,num_labels,model_name)train(models,data,params)

模型训练

def train(models,data,params):"""Train the network#Argumentsmodels (Models): generator,discriminator,adversarial modeldata (tuple): x_train,y_train dataparams (tuple): Network params"""generator,discriminator,adversarial = modelsx_train,y_train = databatch_size,latent_size,train_steps,num_labels,model_name = paramssave_interval = 500code_std = 0.5noise_input = np.random.uniform(-1.0,1.,size=[16,latent_size])noise_label = np.eye(num_labels)[np.arange(0,16) % num_labels]noise_code1 = np.random.normal(scale=code_std,size=[16,1])noise_code2 = np.random.normal(scale=code_std,size=[16,1])train_size = x_train.shape[0]print(model_name,"Labels for generated images: ",np.argmax(noise_label, axis=1))for i in range(train_steps):rand_indexes = np.random.randint(0,train_size,size=batch_size)real_images = x_train[rand_indexes]real_labels = y_train[rand_indexes]#random codes for real imagesreal_code1 = np.random.normal(scale=code_std,size=[batch_size,1])real_code2 = np.random.normal(scale=code_std,size=[batch_size,1])#生成假图片,标签和编码noise = np.random.uniform(-1.,1.,size=[batch_size,latent_size])fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]fake_code1 = np.random.normal(scale=code_std,size=[batch_size,1])fake_code2 = np.random.normal(scale=code_std,size=[batch_size,1])inputs = [noise,fake_labels,fake_code1,fake_code2]fake_images = generator.predict(inputs)x = np.concatenate((real_images,fake_images))labels = np.concatenate((real_labels,fake_labels))codes1 = np.concatenate((real_code1,fake_code1))codes2 = np.concatenate((real_code2,fake_code2))y = np.ones([2 * batch_size,1])y[batch_size:,:] = 0#train discriminator networkoutputs = [y,labels,codes1,codes2]# metrics = ['loss', 'activation_1_loss', 'label_loss',# 'code1_loss', 'code2_loss', 'activation_1_acc',# 'label_acc', 'code1_acc', 'code2_acc']metrics = discriminator.train_on_batch(x, outputs)fmt = "%d: [dis: %f, bce: %f, ce: %f, mi: %f, mi:%f, acc: %f]"log = fmt % (i, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4], metrics[6])#train the adversarial networknoise = np.random.uniform(-1.,1.,size=[batch_size,latent_size])fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]fake_code1 = np.random.normal(scale=code_std,size=[batch_size,1])fake_code2 = np.random.normal(scale=code_std,size=[batch_size,1])y = np.ones([batch_size,1])inputs = [noise,fake_labels,fake_code1,fake_code2]outputs = [y,fake_labels,fake_code1,fake_code2]metrics = adversarial.train_on_batch(inputs,outputs)fmt = "%s [adv: %f, bce: %f, ce: %f, mi: %f, mi:%f, acc: %f]"log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4], metrics[6])print(log)if (i + 1) % save_interval == 0:# plot generator images on a periodic basisplot_images(generator,noise_input=noise_input,noise_label=noise_label,noise_codes=[noise_code1, noise_code2],show=False,step=(i + 1),model_name=model_name)# save the modelif (i + 1) % (2 * save_interval) == 0:generator.save(model_name + ".h5")

效果展示

#绘制生成图片
def plot_images(generator,noise_input,noise_label=None,noise_codes=None,show=False,step=0,model_name="gan"):"""Generate fake images and plot themFor visualization purposes, generate fake imagesthen plot them in a square grid# Argumentsgenerator (Model): The Generator Model for fake images generationnoise_input (ndarray): Array of z-vectorsshow (bool): Whether to show plot or notstep (int): Appended to filename of the save imagesmodel_name (string): Model name"""os.makedirs(model_name, exist_ok=True)filename = os.path.join(model_name, "%05d.png" % step)rows = int(math.sqrt(noise_input.shape[0]))if noise_label is not None:noise_input = [noise_input, noise_label]if noise_codes is not None:noise_input += noise_codesimages = generator.predict(noise_input)plt.figure(figsize=(2.2, 2.2))num_images = images.shape[0]image_size = images.shape[1]for i in range(num_images):plt.subplot(rows, rows, i + 1)image = np.reshape(images[i], [image_size, image_size])plt.imshow(image, cmap='gray')plt.axis('off')plt.savefig(filename)if show:plt.show()else:plt.close('all')
#模型训练
build_and_train_models(latent_size=62)
steps = 500

steps = 16000

修改书写角度的分离编码

InfoGAN详解与实现(采用tensorflow2.x实现)相关推荐

  1. 变分自编码器(VAE)详解与实现(tensorflow2.x)

    变分自编码器(VAE)详解与实现(tensorflow2.x) VAE介绍 VAE原理 变分推理 VAE核心方程 优化方式 重参数化技巧(Reparameterization trick) VAE实现 ...

  2. ACGAN(Auxiliary Classifier GAN)详解与实现(tensorflow2.x实现)

    ACGAN(Auxiliary Classifier GAN)详解与实现(tensorflow2.x实现) ACGAN原理 ACGAN实现 模块导入 生成器 鉴别器 模型构建 模型训练 虚假图像生成及 ...

  3. 深度残差网络(ResNet)详解与实现(tensorflow2.x)

    深度残差网络(ResNet)详解与实现(tensorflow2.x) ResNet原理 ResNet实现 模型创建 数据加载 模型编译 模型训练 测试模型 训练过程 ResNet原理 深层网络在学习任 ...

  4. YOLOv5算法详解

    目录 1.需求解读 2.YOLOv5算法简介 3.YOLOv5算法详解 3.1 YOLOv5网络架构 3.2 YOLOv5实现细节详解 3.2.1 YOLOv5基础组件 3.2.2 输入端细节详解 3 ...

  5. **组播PIM-SM详解****

    组播PIM-SM详解** PIM_SM:采用PULL的模式. 特点:需要建立SPT(源树)和RPT(共享树)两种树. 涉及设备:BSR.RP.DR.源端.组成员(接收者). 设计pim的路由表项:(S ...

  6. 天梯赛基础题型详解(2019 - 08 - 12)

    A.枚举 (1) 详解:用枚举法,从最开始的只有一层沙漏开始枚举,直至找到一个沙漏所用符号的总和小于等于输入的数(将每一次不同层数的沙漏的符号和都用数组储存起来),然后标记那个最大的和.要注意的是每增 ...

  7. 【蓝桥杯Python组】2022年第十三届蓝桥杯省赛B组Python解题思路详解

    第十三届蓝桥杯省赛B组Python解题思路详解 因为今年采用线上的举办方式进行比赛,所以组委会对题目做了一定的调整,将原来的5道填空+5道编程题变成了2道填空+8道编程题,据说是为了防止抄袭.其实题目 ...

  8. StackedGAN详解与实现(采用tensorflow2.x实现)

    StackedGAN详解与实现(采用tensorflow2.3实现) StackedGAN原理 StackedGAN实现 编码器 对抗网络 鉴别器 生成器 模型构建 模型训练 效果展示 Stacked ...

  9. 自编码器模型详解与实现(采用tensorflow2.x实现)

    自编码器模型详解与实现(采用tensorflow2.x实现) 使用自编码器学习潜变量 编码器 解码器 构建自编码器 从潜变量生成图像 完整代码 使用自编码器学习潜变量 由于高维输入空间中有很多冗余,可 ...

  10. CycleGAN详解与实现(采用tensorflow2.x实现)

    CycleGAN详解与实现(采用tensorflow2.x实现) CycleGAN原理 CycleGAN概述 CycleGAN原理 前向循环 反向循环 训练过程 CycleGAN实现 加载库 生成器 ...

最新文章

  1. Ubuntu 系统打不开图片提示Fatal error reading PNG image File: Not a PNG file
  2. tcp udp区别优缺点_CCNA必懂篇,传输层协议TCP/UDP的区别和作用
  3. (ql)30W单片精密开关电源 电路图加分析
  4. 使用GPG校验sign签名
  5. Citrix VDI实战攻略之八:测试验收
  6. Python3 hex() 函数
  7. C#计算程序的运行时间
  8. mipi协议_Cadence发布业界首款面向多协议PHY的验证IP产品
  9. [算法] 已知前序和后序遍历,建立二叉树
  10. 设计模式学习一:strategyPattern
  11. pytorch-使用GPU加速模型训练
  12. java基础总结06-常用api类-时间日期类
  13. bulk insert java_java oracle bulk insert
  14. 计算机音乐广东爱情故事,改编自网易云音乐——广东十年爱情故事热评
  15. 摄像头能用计算机里不显示,摄像头没有显示
  16. 收美之鸿蒙灵戒,顺网神戒之鸿蒙
  17. java实现手机尾号评分
  18. 3D 池化(MaxPool3D) 和 3D(Conv3d) 卷积详解
  19. 华为A1路由器虚拟服务器,华为a1路由器怎么用手机设置DMZ主机
  20. 四天搞懂生成对抗网络(三)——用CGAN做图像转换的鼻祖pix2pix

热门文章

  1. 大学高数常微分方程思维导图_思维导图_2016考研数学:高数中六种常见题型归纳_沪江英语...
  2. for linux pdf转mobi_在Linux上,如何为Amazon Kindle转换各种电子书格式
  3. 多个app用同一个签名文件_运动设备和运动APP的合理搭配
  4. 计算机高效课堂建设,基于信息技术的小学音乐高效课堂的构建
  5. 计算机软件系统管理说课,计算机软件系统 说课稿
  6. 如何解决stata数据管理器中变量变红的问题
  7. dell服务器无线网卡,dell笔记本内置无线网卡找不到怎么处理
  8. Postman中的Pre-request Scrip详解
  9. Java线程同步和锁定
  10. navicat和mysql有必要都装吗_MySQL基本介绍及Navicat安装