1.介绍

在原始gan(GAN 简介与代码实战)中,生成数据的来源一般是一个固定分布噪声z,z可以生成不同的图片,z代表着很多意思,我们无法知道z的那个维度代表什么(比如在生成数字手写图片的时候,0维度是否代表笔画的风格,我们不得而知),z是不可解释的。为了解决这个问题,InfoGAN就横空出世了,更加详细的内容参见论文:Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

2.模型结构

网络是基于DC-GAN(Deep Convolutional GAN)的,G和D都由CNN构成。在此基础上,Q和D共享卷积网络,然后分别通过各自的全连接层输出不同的内容:Q输出对应于生成图片的c'(与之对应的c是可以控制图片的生成,比如生成什么数字),D则仍然判别真伪。

3.模型特点

相对于原始gan,作者将Z分成z(固定分布噪声)和c(一些隐变量信息,比如笔画风格,字体大小等),损失函数里面用到互信息,使得隐变量c与生成的变量G(z,c)拥有尽可能多的共同信息。

4.代码实现keras

class INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# Build and the discriminator and recognition networkself.discriminator, self.auxilliary = self.build_disk_and_q_net()self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# Build and compile the recognition network Qself.auxilliary.compile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise and the target label as input# and generates the corresponding digit of that labelgen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated image as input and determines validityvalid = self.discriminator(img)# The recognition network produces the labeltarget_label = self.auxilliary(img)# The combined model  (stacked generator and discriminator)self.combined = Model(gen_input, [valid, target_label])self.combined.compile(loss=losses,optimizer=optimizer)def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation("tanh"))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# Shared layers between discriminator and recognition networkmodel = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# Discriminatorvalidity = Dense(1, activation='sigmoid')(img_embedding)# Recognitionq_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)# Return discriminator and recognition networkreturn Model(img, validity), Model(img, label)def mutual_info_loss(self, c, c_given_x):"""The mutual information metric we aim to minimize"""eps = 1e-8conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropydef sample_generator_input(self, batch_size):# Generator inputssampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labelsdef train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, y_train), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------#  Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and categorical labelssampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# Generate a half batch of new imagesgen_imgs = self.generator.predict(gen_input)# Train on real and generated datad_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# Avg. lossd_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train Generator and Q-network# ---------------------g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])# Plot the progressprint ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(c):sampled_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)gen_input = np.concatenate((sampled_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')axs[j,i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()

InfoGAN 简介与代码实战相关推荐

  1. BART原理简介与代码实战

    写在前面 最近huggingface的transformer库,增加了BART模型,Bart是该库中最早的Seq2Seq模型之一,在文本生成任务,例如摘要抽取方面达到了SOTA的结果. 本次放出了三组 ...

  2. Flume NG 简介及配置实战

    2019独角兽企业重金招聘Python工程师标准>>> Flume NG 简介及配置实战 博客分类: 分布式计算 1.Flume 的一些核心概念: 1.1 数据流模型 1.2 高可靠 ...

  3. 【强化学习】Sarsa算法求解悬崖行走问题 + Python代码实战

    文章目录 一.Sarsa算法简介 1.1 更新公式 1.2 预测策略 1.3 详细资料 二.Python代码实战 2.1 运行前配置 2.2 主要代码 2.3 运行结果展示 2.4 关于可视化寻路过程 ...

  4. 【强化学习】优势演员-评论员算法(Advantage Actor-Critic , A2C)求解倒立摆问题 + Pytorch代码实战

    文章目录 一.倒立摆问题介绍 二.优势演员-评论员算法简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.2.1 网络参数不共享版本 4.2.2 网络参数共享版本 ...

  5. 【强化学习】Q-Learning算法求解悬崖行走问题 + Python代码实战

    文章目录 一.Q-Learning算法简介 1.1 更新公式 1.2 预测策略 1.3 详细资料 二.Python代码实战 2.1 运行前配置 2.2 主要代码 2.3 运行结果展示 2.4 关于可视 ...

  6. 【强化学习】双深度Q网络(DDQN)求解倒立摆问题 + Pytorch代码实战

    文章目录 一.倒立摆问题介绍 二.双深度Q网络简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.3 运行结果展示 4.4 关于可视化的设置 一.倒立摆问题介绍 A ...

  7. 【强化学习】竞争深度Q网络(Dueling DQN)求解倒立摆问题 + Pytorch代码实战

    文章目录 一.倒立摆问题介绍 二.竞争深度Q网络简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.3 运行结果展示 4.4 关于可视化的设置 一.倒立摆问题介绍 ...

  8. 【强化学习】PPO算法求解倒立摆问题 + Pytorch代码实战

    文章目录 一.倒立摆问题介绍 二.PPO算法简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.3 运行结果展示 4.4 关于可视化的设置 一.倒立摆问题介绍 Ag ...

  9. Neural Collaborative Filtering(NCF) 代码实战(Keras)

    博客主要分为两部分.第一部分为论文简介,第二部分为代码实战. 论文简介: 1. 通用框架 下图是作者提出的用神经网络解决推荐系统问题的通用框架.论文先将用户与物品分别进行one-hot编码,然后通过一 ...

最新文章

  1. libcurl多线程下载开发过程中需要注意的一个问题
  2. 201C Fragile Bridges
  3. libQtCore.so.4相关错误
  4. android sdl,Android下SDL2实现五子棋游戏
  5. 在Ubuntu为Android硬件抽象层(HAL)模块编写JNI方法提供Java访问硬件服务接口 6...
  6. Qt中的测试 枚举与 QFlags详解
  7. Linux上安装ZooKeeper并设置开机启动(CentOS7+ZooKeeper3.4.10)
  8. Windows 10 下一版本更新代号为“Manganese”
  9. rf连oracle版本一致,Navicat premium连不上Oracle的问题解决
  10. innodb存储引擎之内存
  11. Pycharm-SSH连接服务器
  12. 提示“8080端口号被占用
  13. 2021-07-06
  14. Type of the default value for 'songs' prop must be a function
  15. 社保交满15年就不用交了吗?常见重点问答请查收,千万别误解了~
  16. 如何为新的微信公众号做引流矩阵的8个渠道
  17. Numpy IO:npy、npz
  18. EtherCAT IGH 命令行介绍
  19. 机器学习模型评估与改进:网格化调参(grid search)
  20. 频谱分析系列:1dB增益压缩点概述及测试

热门文章

  1. MySQL数据库查询的最全面总结
  2. java爬取网易云歌单_[原创]基于Java网易云音乐评论抓取~【悠着点玩啊~】
  3. UEditor编辑器存储型XSS漏洞
  4. “穷光蛋出身,但志存高远” 联想控股董事长柳传志
  5. 牛皮席差异不同的地方!
  6. ​Mac触控板常用的手势操作
  7. 【Proteus仿真】【51单片机】路灯控制系统
  8. jquery 动态分页,本地分页
  9. 基于AIOT的智能家居系统
  10. 设计模式(21):创建型-单例模式(Singleton)