InfoGAN详细介绍及特征解耦图像生成

  • 一.InfoGAN框架理解
    • 特征耦合
    • InfoGAN
    • InfoGAN论文实验结果
  • 二.VAE-GAN框架理解
    • VAE-GAN算法步骤
  • 三.BiGAN框架理解
  • 四.InfoGAN论文复现
    • 使用MNIST数据集复现InfoGAN
    • 代码编写
      • 初始化判别器
      • 初始化生成器
      • 初始化分类器
      • 训练InfoGAN网络
  • 总结
  • 参考文献及博客

一.InfoGAN框架理解

特征耦合

我们知道最基本的GAN就是输入一个随机的向量,输出一个图片。以手写数字为例,我们希望修改随机向量的某一维,能改变数字的特想,比如角度,粗细,数字等
特征解耦:

如上图:实际情况中的特征是非常杂乱无章的,然后我们希望的特征关系是比较整齐明了的,具体哪一列表示什么很清晰,从而便于控制它。而infogan的目的就是将这些杂乱无章的特征清晰化规律化。
特征解耦举例:
我们可以找到某一个控制某个特征对应的神经元,然后去改变它的值进而就可以改天具体某个特征。如下图举例所示:

InfoGAN

InfoGAN的框架如下:

Z为输入,输入分为两部分(标签和噪声),这里的C是在初始化的时候随机给定的,这里的生成器类似于编码器,其次还有一个分类器中与生成器组成一个类似于自动编码器。将真实数据或者生成数据放入分类器中,它可以去学习C,因为生成器生成的数据中就隐藏着C的一些关系,分类器就将数据重新提取重新得到C的这样一个过程。
更进一步说明:

InfoGAN论文实验结果


a图所示的向量C1可控制生成具体哪一类数字,b图几乎没有差异,C3是控制宽度的一个向量,通过控制C3可生成不同宽度的数字。

二.VAE-GAN框架理解

自动编码器是将输入数据压缩到高度抽象的特征空间然后进行重构的过程。一开始对于输入数据X进行压缩重构成一个向量z,在进行解码器进行解码,使输出和输入尽可能接近。这一步的操作的目的是为了更好的欺骗判别器,其余部分跟传统的GAN 网络是一致的,如下图所示:

VAE-GAN算法步骤


编码器要做的就是让P(z|x)逼近分布P(z),比如标准正太分布,同时最小化生成器(解码器)和输入x的差距。生成器(解码器)要做的就是最小化输出和输入x的差距,同时又要骗过判别器。判别器要做的就是给真实的高分,跟P(z)采样生成的和重建的低分。
具体算法:

三.BiGAN框架理解

BiGAN就是双向GAN的意思,这里的判别器与上面介绍的判别器不一样,这里的判别器接收的是图像和编码,判别图像和编码是来自编码器还是解码器。
算法思想:将编码器和解码器分开,但是加一个判别器,将他们的输入和输出同时作为判别器的输入,然后区分是来自编码器还是解码器,如果无法分别来自哪个,就说明编码器的输入图片和解码器生成的图片很接近,编码器输出的z和解码器输入的z很接近,目的就达到了。
简单的原理就是将编码器看成一个P(x,z)分布,将解码器看成Q(x,z)分布,通过判别器,让他们的差异越来越小。理想情况下就会:


具体算法:

四.InfoGAN论文复现

使用MNIST数据集复现InfoGAN

无标记信息情况下学习生成可控的图片

通过输入向量的控制进行控制自己想要生成的图像。

代码编写

初始化判别器

初始化判别器代码:
// An highlighted blockdef _init_dicriminator(self, input, isTrain=True, reuse=False):"""初始化判别器网络模型:param input: 输入数据op:param isTrain: 是否训练状态:param reuse: 是否复用内部参数:return: 判断结果op"""with tf.variable_scope('discriminator', reuse=reuse):# input [none,32,32,1]conv1 = tf.layers.conv2d(input, 32, [4, 4], strides=(2, 2), padding='same')  # [none,16,16,32]bn1 = tf.layers.batch_normalization(conv1, training=isTrain)active1 = tf.nn.leaky_relu(bn1)# layer 2conv2 = tf.layers.conv2d(active1, 64, [4, 4], strides=(2, 2), padding="same")  # [none,8,8,64]bn2 = tf.layers.batch_normalization(conv2, training=isTrain)active2 = tf.nn.leaky_relu(bn2)# layer 3conv3 = tf.layers.conv2d(active2, 128, [4, 4], strides=(2, 2), padding='same')  # [none,4,4,128]bn3 = tf.layers.batch_normalization(conv3, training=isTrain)active3 = tf.nn.leaky_relu(bn3)# layer 4active4 = tf.reshape(active3, shape=[-1, 4 * 4 * 128])out = tf.layers.dense(inputs=active4, units=self.d_dim)return out;

初始化生成器

 初始化生成器代码
// An highlighted blockdef _init_generator(self, input_c, input_z, isTrain=True, resue=False):"""初始化生成器网络模型:param input_c: 输入条件[none,c_dim]:param input_z: 输入随机噪声[None,z_dim]:param isTrain: 是否训练状态:param resue: 是否复用内部参数:return: 生成数据op"""with tf.variable_scope("generator", reuse=resue):# layer1input = tf.concat([input_c, input_z], axis=1)  # [none,c_dim+z_dim]input = tf.reshape(input, shape=[-1, 1, 1, self.c_dim + self.z_dim])de1 = tf.layers.conv2d_transpose(input, 256, [4, 4], strides=(1, 1), padding='valid')  # [none,4,4,256]de_bn1 = tf.layers.batch_normalization(de1, training=isTrain)de_active1 = tf.nn.leaky_relu(de_bn1)# layer 2de2 = tf.layers.conv2d_transpose(de_active1, 128, [4, 4], strides=(2, 2), padding="same")  # [none,8,8,128]de_bn2 = tf.layers.batch_normalization(de2, training=isTrain)de_active2 = tf.nn.leaky_relu(de_bn2)# layer 3de3 = tf.layers.conv2d_transpose(de_active2, 64, [4, 4], strides=(2, 2), padding="same")  # [none,16,16,64]de_bn3 = tf.layers.batch_normalization(de3, training=isTrain)de_active3 = tf.nn.leaky_relu(de_bn3)# layer 4de4 = tf.layers.conv2d_transpose(de_active3, 1, [4, 4], strides=(2, 2), padding="same")  # [none,32,32,1]out = tf.nn.sigmoid(de4)  #0,1return out

初始化分类器

初始化分类器代码
// An highlighted blockdef _init_classifier(self, input, isTrain=True, reuse=False):"""初始化分类器网络模型:param input: 输入数据(图像)[none,img_h,img_w,img_c]:param isTrain: 是否训练状态:param reuse: 是否复用内部参数:return: 分类条件结果op"""with tf.variable_scope("classifier", reuse=reuse):# input [none,32,32,1]conv1 = tf.layers.conv2d(input, 32, [4, 4], strides=(2, 2), padding='same')  # [none,16,16,32]bn1 = tf.layers.batch_normalization(conv1, training=isTrain)active1 = tf.nn.leaky_relu(bn1)# layer 2conv2 = tf.layers.conv2d(active1, 64, [4, 4], strides=(2, 2), padding="same")  # [none,8,8,64]bn2 = tf.layers.batch_normalization(conv2, training=isTrain)active2 = tf.nn.leaky_relu(bn2)# layer 3conv3 = tf.layers.conv2d(active2, 128, [4, 4], strides=(2, 2), padding='same')  # [none,4,4,128]bn3 = tf.layers.batch_normalization(conv3, training=isTrain)active3 = tf.nn.leaky_relu(bn3)# layer 4active4 = tf.reshape(active3, shape=[-1, 4 * 4 * 128])out_c = tf.layers.dense(inputs=active4, units=self.c_dim, activation=tf.nn.softmax)return out_c

训练InfoGAN网络

训练InfoGAN网络代码
// An highlighted blockdef train(self, batch_size=64, itrs=100000, save_time=1000):"""训练InfoGAN网络:param batch_size: 采样数据量:param itrs: 迭代训练次数:param save_time: 保存,测试模型周期:return: None"""start_time = time.time()data = dh.load_mnist_resize(path="data/MNIST_data",img_w=32,img_h=32)for i in range(itrs):mask = np.random.choice(data['data'].shape[0],batch_size,replace=True)batch_x = data['data'][mask]batch_noise_c = np.random.multinomial(1,self.c_dim*[0.1],size=batch_size)batch_noise_z = np.random.normal(0,1,(batch_size,self.z_dim))#训练判别器_,D_loss_curr= self.sess.run([self.D_trainer,self.D_loss],feed_dict={self.x:batch_x,self.gen_z:batch_noise_z,self.gen_c:batch_noise_c,self.isTrain:True})# 训练生成器batch_noise_c = np.random.multinomial(1, self.c_dim * [0.1], size=batch_size)batch_noise_z = np.random.normal(0, 1, (batch_size, self.z_dim))_,G_loss_curr = self.sess.run([self.G_trainer,self.G_loss],feed_dict={self.gen_c:batch_noise_c,self.gen_z:batch_noise_z,self.isTrain:True})# 训练分类器idx = np.random.randint(0,self.c_dim)batch_noise_z =np.random.normal(0,1,(batch_size,self.z_dim))batch_noise_c = np.zeros([batch_size,self.c_dim])batch_noise_c[:,idx] =1_,C_loss_curr = self.sess.run([self.C_trainer,self.C_loss],feed_dict={self.gen_z:batch_noise_z,self.gen_c:batch_noise_c,self.isTrain:True})# 保存模型if i%save_time==0:idx = np.random.randint(0, self.c_dim)batch_noise_z = np.random.normal(0, 1, (25, self.z_dim))batch_noise_c = np.zeros([25, self.c_dim])batch_noise_c[:, idx] = 1self.gen_data(c=batch_noise_c,z=batch_noise_z,save_path="out/InfoGAN_MNIST/"+str(i).zfill(6)+".png")self.test_model()print("i:",i," D_loss",D_loss_curr," G_loss",G_loss_curr," C_loss",C_loss_curr)self.save()end_time = time.time()time_loss = end_time-start_timeprint("时间消耗",int(time_loss),"秒")start_time = time.time()self.sess.close()

总结

无论是InfoGAN还是VAE—GAN,BiGAN都是自动编码器+GAN的框架,核心就是利用自动编码器压缩后的特征与GAN网络建立联系

参考文献及博客

[1] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley,S. Ozair, A. Courville, and Y. Bengio, “Generative adversarial nets,” in Advances in Neural Information Processing Systems (NIPS), pp. 2672–2680,2014.
[2] T. Salimans, I. Goodfellow, W. Zaremba, V. Cheung, A. Radford, X. Chen,and X. Chen, “Improved techniques for training gans,” in Advances inNeural Information Processing Systems (NIPS), pp. 2226–2234, 2016.
[3]M. Arjovsky, S. Chintala, and L. Bottou, “Wasserstein gan,”arXiv:1701.07875, 2017.
博客:
https://so.csdn.net/so/search?q=BiGAN&t=blog&u=wangwei19871103

InfoGAN详细介绍及特征解耦图像生成相关推荐

  1. 一文总结图像生成必备经典模型(一)

    本文将分 2 期进行连载,共介绍 16 个在图像生成任务上曾取得 SOTA 的经典模型. 第 1 期:ProGAN.StyleGAN.StyleGAN2.StyleGAN3.VDVAE.NCP-VAE ...

  2. ChatGPT使用案例之图像生成

    ChatGPT使用案例之图像生成 这里一节我们介绍一下ChatGPT的图像生成,这里我们使用代码来完成,也就是通过API 来完成,因为ChatGPT 本身是不能生成图片的,言外之意我们图片生成是Cha ...

  3. “用于无监督图像生成解耦的正交雅可比正则化”论文解读

    Tikhonov regularization terms https://blog.csdn.net/jiejinquanil/article/details/50411617 本文是对博客http ...

  4. StyleGAN2发展介绍 花卉图像生成 模型修改

    第一部分-----StyleGAN2发展介绍 stylegan发展:gan要做的事情是通过生成器由简单分布生成复杂图像分布,生成的图片分布和真实图片分布的差距越小越好,这一步通过判别器实现.生成大的. ...

  5. 生成对抗网络(GAN)详细介绍及数字手写体生成应用仿真(附代码)

    生成对抗网络(GAN)详细介绍及生成数字手写体仿真(附代码) 生成对抗网络简介 深度学习基础介绍 损失函数与梯度下降 反向传播算法推导 批量标准化介绍 Dropout介绍 GAN原始论文理解 生成对抗 ...

  6. AI图片生成Stable Diffusion参数及使用方式详细介绍

    Stable Diffusion环境搭建与运行请参考上一篇博文<AI图片生成Stable Diffusion环境搭建与运行>,地址为"https://blog.csdn.net/ ...

  7. CVPR 2021 | 天津大学提出PISE:形状与纹理解耦的人体图像生成与编辑方法

    ©PaperWeekly 原创 · 作者|张劲松 学校|天津大学硕士生 研究方向|计算机视觉 导读:由单张人体图像来生成任意视角任意姿态下的图像,是近几年视觉领域研究的热点问题.现有方法无法实现灵活的 ...

  8. Opencv Surf特征实现图像无缝拼接生成全景图像(三)

    转自:https://guo-pu.blog.csdn.net/article/details/90657830 图像拼接在实际的应用场景很广,比如无人机航拍,遥感图像等等,图像拼接是进一步做图像理解 ...

  9. 图像生成对抗生成网络gan_生成对抗网络(GAN)的直观介绍

    图像生成对抗生成网络gan by Thalles Silva 由Thalles Silva 暖身 (Warm up) Let's say there's a very cool party going ...

  10. CV之IG:图像生成(Image Generation)的简介、使用方法、案例应用之详细攻略

    CV之IG:图像生成(Image Generation)的简介.使用方法.案例应用之详细攻略 目录 图像生成(Image Generation)的简介 图像生成(Image Generation)的使 ...

最新文章

  1. iOS夯实:RunLoop
  2. [YTU]_2474( C++习题 输入输出--保护继承)
  3. Django 框架篇: 一. Django介绍; 二. 安装; 三. 创建项目;
  4. product category no need to optimize
  5. 揭秘神仙高校的课堂!网友跪了:这就是差距啊!
  6. SLS机器学习介绍(02):时序聚类建模
  7. 使用Powershell远程管理Windows Server(WinRM)
  8. 服务器路径和本地路径在使用cd时候的区别
  9. oracle多条数据合并成一条_建议将北京地铁13-B线和28号线合并成一条线
  10. Windows下Redis的使用
  11. JS 打印 data数据_用D3.js 十分钟实现字符跳动效果
  12. nxlog 中文乱码解决
  13. 【GAN论文解读系列】NeurIPS 2016 InfoGAN 使用InfoGAN解耦出可解释的特征
  14. matlab帧差法测速,matlab帧差法物体检测
  15. PDF文件拆分为图片
  16. 三流领导管下级,二流领导管同级,一流领导管......
  17. SQP 序列二次规划法
  18. 第一位程序员原来是一个女性!
  19. 【AWS云从业者基础知识笔记】——模块11:AWS认证的云从业者基础
  20. ZGF建筑事务所公布波特兰国际机场新航站楼设计方案,木构屋顶展现自然景观

热门文章

  1. sata接口 图解 定义_SATA数据和电源接口定义详解
  2. 数据挖掘第三版课后习题
  3. 运用SPSS进行PCA主成分分析(因子分析)
  4. 西门子S7-300 PLC视频教程(百度网盘)收集于网络-供参考学习
  5. win10 卸载mysql5.7
  6. 计算机网络:自顶向下(Top-Down)学习笔记_1.1
  7. Java常见异常和解决办法
  8. DbgView不能显示OutputDebugString的输出内容
  9. log4cxx OutputDebugString DebugView dbgview
  10. 解决 Oracle 密码过期 the password has expired