[生成对抗网络GAN入门指南](10)InfoGAN: Interpretable Representation Learning by Information Maximizing GAN
本篇blog的内容基于原始论文InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets(NPIs2016)和《生成对抗网络入门指南》第六章。完整代码及简析见文章末尾
一、为什么要使用InfoGAN
InfoGAN采用无监督的方式学习,并尝试实现可解释特征。使用了信息论的原理,通过最大化输入噪声和观察值之间的互信息(Mutual Information,MI)来对网络模型进行优化。InfoGAN能适用于各种复杂的数据集,可以同时实现离散特征和连续特征。
二、输入端数据
InfoGAN在输入端把随机输入分为两个部分:
第一部分为z,代表噪声;
第二部分为c,代表隐含编码;
目标是希望在每个维度上都具备可解释型特征。
在同时输入噪声z和隐含编码c后,生成概率 ,为了应对这个问题,在InfoGAN中需要对隐含编码c和生成分布G(z,c)求互信息 ,并使其最大化
三、InfoGAN结构
InfoGAN和前面介绍过的GAN区别在于,真实训练数据不有标签数据,二输入数据为隐含编码和随机噪声的组合,最后通过判别器一端和最大化互信息的方式还原隐含编码的信息。也就是说,判别器D最终需要同时具备还原隐含编码和辨别真伪的能力。前者为了生成图像能够很好具备编码中的特性,也就是说隐含编码可以对生网络产生相对显著地成果;后者是要求生成模型在还原信息的同时保证生成的数据与真实数据非常逼近。
1. 互信息
互信息表示两个随机变量之间的依赖程度的度量。对于随机变量X和Y,互信息为I(X;Y),H(X)和H(Y)为边缘熵,H(X|Y)和H(Y|X)为条件熵。
2. 结构
3. 目标函数
当X和Y相互独立时候,互信息为0.给定任意的输入,希望生成器的 有一个相对较小的熵,即希望隐含编码c的信息在生成过程中不会流失。对此我们修改目标函数:
由于概率能以得到,导致互信息难以最大化,实际计算可以定义一个近似概率的辅助分布来获取互信息的下界,推导如下:
由此可以得到互信息的下界值:
4. InfoGAN的推导
我们可以重新改写之前不等式,并重新使蒙特卡洛方法逼近
得到我们最终的目标函数
四、实验效果
1.MNIST数据
我们发现通过控制隐含编码中的可以调节生成数字是几,其他参数可以调节生成字符的倾斜程度、字体宽度等。
2. 3D人脸数据
3. 椅子数据集
4. 门牌号数据集
五、实验代码
1. 导入相关包及超参数
from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenate
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
import keras.backend as Kimport matplotlib.pyplot as pltimport numpy as npclass INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# Build and the discriminator and recognition networkself.discriminator, self.auxilliary = self.build_disk_and_q_net()self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# Build and compile the recognition network Qself.auxilliary.compile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise and the target label as input# and generates the corresponding digit of that labelgen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated image as input and determines validityvalid = self.discriminator(img)# The recognition network produces the labeltarget_label = self.auxilliary(img)# The combined model (stacked generator and discriminator)self.combined = Model(gen_input, [valid, target_label])self.combined.compile(loss=losses,optimizer=optimizer)
2. 构造生成器和判别器
def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation("tanh"))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# Shared layers between discriminator and recognition networkmodel = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# Discriminatorvalidity = Dense(1, activation='sigmoid')(img_embedding)# Recognitionq_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)# Return discriminator and recognition networkreturn Model(img, validity), Model(img, label)
3. 构造互信息
def mutual_info_loss(self, c, c_given_x):"""The mutual information metric we aim to minimize"""eps = 1e-8conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropydef sample_generator_input(self, batch_size):# Generator inputssampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labels
4. 训练
def train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, y_train), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------# Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and categorical labelssampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# Generate a half batch of new imagesgen_imgs = self.generator.predict(gen_input)# Train on real and generated datad_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# Avg. lossd_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------# Train Generator and Q-network# ---------------------g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])# Plot the progressprint ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)
5. 可视化
def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(c):sampled_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)gen_input = np.concatenate((sampled_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')axs[j,i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()def save_model(self):def save(model, model_name):model_path = "saved_model/%s.json" % model_nameweights_path = "saved_model/%s_weights.hdf5" % model_nameoptions = {"file_arch": model_path,"file_weight": weights_path}json_string = model.to_json()open(options['file_arch'], 'w').write(json_string)model.save_weights(options['file_weight'])save(self.generator, "generator")save(self.discriminator, "discriminator")if __name__ == '__main__':infogan = INFOGAN()infogan.train(epochs=50000, batch_size=128, sample_interval=50)
实验结果
完整代码
from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenate
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
import keras.backend as Kimport matplotlib.pyplot as pltimport numpy as npclass INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# Build and the discriminator and recognition networkself.discriminator, self.auxilliary = self.build_disk_and_q_net()self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# Build and compile the recognition network Qself.auxilliary.compile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise and the target label as input# and generates the corresponding digit of that labelgen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated image as input and determines validityvalid = self.discriminator(img)# The recognition network produces the labeltarget_label = self.auxilliary(img)# The combined model (stacked generator and discriminator)self.combined = Model(gen_input, [valid, target_label])self.combined.compile(loss=losses,optimizer=optimizer)def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation("tanh"))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# Shared layers between discriminator and recognition networkmodel = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# Discriminatorvalidity = Dense(1, activation='sigmoid')(img_embedding)# Recognitionq_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)# Return discriminator and recognition networkreturn Model(img, validity), Model(img, label)def mutual_info_loss(self, c, c_given_x):"""The mutual information metric we aim to minimize"""eps = 1e-8conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropydef sample_generator_input(self, batch_size):# Generator inputssampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labelsdef train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, y_train), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------# Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and categorical labelssampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# Generate a half batch of new imagesgen_imgs = self.generator.predict(gen_input)# Train on real and generated datad_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# Avg. lossd_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------# Train Generator and Q-network# ---------------------g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])# Plot the progressprint ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(c):sampled_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)gen_input = np.concatenate((sampled_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')axs[j,i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()def save_model(self):def save(model, model_name):model_path = "saved_model/%s.json" % model_nameweights_path = "saved_model/%s_weights.hdf5" % model_nameoptions = {"file_arch": model_path,"file_weight": weights_path}json_string = model.to_json()open(options['file_arch'], 'w').write(json_string)model.save_weights(options['file_weight'])save(self.generator, "generator")save(self.discriminator, "discriminator")if __name__ == '__main__':infogan = INFOGAN()infogan.train(epochs=50000, batch_size=128, sample_interval=50)
[生成对抗网络GAN入门指南](10)InfoGAN: Interpretable Representation Learning by Information Maximizing GAN相关推荐
- 【论文阅读】InfoGAN: Interpretable Representation Learning by Information Maximizing GAN
论文下载 bib: @inproceedings{chenduan2016infogan,author = {Xi Chen and Yan Duan and Rein Houthooft and J ...
- InfoGAN:Interpretable Representation Learning by Information Maximizing GANs论文解读
概述: InfoGAN是国际神经信息处理系统大会NIPS 2016上的论文,作者来自加州大学伯克利分校和OpenAI团队的研究人员,被OpenAI称为当年的五大突破之一.针对传统生成对抗网络以高度混杂 ...
- InfoGAN:Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets
InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets ...
- 必读论文 | 生成对抗网络经典论文推荐10篇
生成式对抗网络(Generative adversarial networks, GAN)是当前人工智能学界最为重要的研究热点之一.其突出的生成能力不仅可用于生成各类图像和自然语言数据,还启发和推动了 ...
- GAN(生成对抗网络)入门
- GAN生成对抗网络入门篇
笔记整理:王小草 时间:2019年1月 一.GAN简介 1 背景 全称:generative adversarial network 生成式对抗网络(不一定是深度学习) 论文:https://arxi ...
- 生成对抗网络(GAN)资料打包
进入正文 全文 摘要 生成式对抗网络,即所谓的GAN是近些年来最火的无监督学习方法之一,模型由Goodfellow等人在2014年首次提出,将博弈论中非零和博弈思想与生成模型结合在一起,巧妙避开了传统 ...
- 深度学习中的生成对抗网络GAN
转载:一文看尽深度学习中的生成对抗网络 | CVHub带你看一看GANs架构发展的8年 (qq.com) 导读 生成对抗网络 (Generative Adversarial Networks, GAN ...
- 生成对抗网络(GAN)的前沿进展(论文、报告、框架和Github资源)汇总
生成模型(GenerativeModel)是一种可以通过学习训练样本来产生更多类似样本的模型.在所有生成模型当中,最具潜力的是生成对抗网络(Generative Adversarial Network ...
最新文章
- 电脑桌面便签小工具_可以直接在桌面上显示内容的便签软件电脑版
- Flume Sinks官网剖析(博主推荐)
- 一系列视频教程 收藏
- Subsonic使用中
- 2019秋季PAT甲级_C++题解
- 原理图连线有错误提醒_拔罐方法不对=缩短生命,中医提醒,拔火罐警惕三个禁忌...
- Android API Level对应Android版本一览表
- mysqls压力测试怎么用_用 Swagger 测试接口,怎么在请求头中携带 Token?
- 我终于知道,中国互联网是怎么弯道超车,干翻美国了
- SQL:postgresql中合并多个查询结果UNION (ALL)
- IDL处理葵花8Himawari-8标准HSD数据——制作大气校正数据集(卫星角度数据)
- 服务器系统计划任务不执行,Windows 2008 r2任务计划程序执行批处理失败问题解决方法...
- Idea中文件图标发生变化,导致文件显示出现异常
- 1425:【例题4】加工生产调度
- How to play PRA CandyBox game——GoDapp
- 深度学习笔记(四) cost function来源和证明
- RHCSA8考试练习题
- 历经一个月,终于搞定了SVM(支持向量机)-附源代码解析
- 怎么画好人体结构?男人、女人、孩子的身体比例画法
- 每日英语:The Deeply Odd Lives of Chinese Bureaucrats