1.首先说明一下CGAN的意义

GAN的原始模型有很多可以改进的缺点,首当其中就是“模型不可控”。从上面对GAN的介绍能够看出,模型以一个随机噪声为输入。显然,我们很难对输出的结构进行控制。例如,使用纯粹的GAN,我们可以训练出一个生成器:输入随机噪声,产生一张写着0-9某一个数字的图片。然而,在现实应用中,我们往往想要生成“指定”的一张图片。

2.直观解决方案

在GAN上增加一个额外的输入。也就是说,以前我们的生成模型是,现在,我们的生成模型是在一个条件c的控制下产生。而这个c就是我们用来控制模型的额外的输入。

c可以是表示我们意图的一串编码,例如我们想要做0-9的手写数字生成,则c可以是一个10维的one-hot向量。则在训练过程中,我们将这些label加入到训练数据中,从而得到一个按照我们需求产生图片的生成器。

这就是Conditional Generative Adversarial Nets最基本的想法。这里要注意的是,这个c不但附加在了生成器上,同时也附加在了判别器上,相当于给了判别器一个额外的信息:现在这个图片是以条件c生成的?还是以条件c控制下的真正的图片?

3.训练目标

原文中有这样一张图,在其他博客中也常见到

对于GAN来说,我们训练的目标是:

而对于Conditional的GAN来说,训练目标只需要变成:

(原文中的公式有误,后面一项的判别器D中忘了加以y为条件的概率)

其实这个改动形象一些表示就是将原来只接受一个输入z的生成器变成接受两个输入(z和y),将原来只接受一个输入x的判别器变成接受两个输入(x和y)。

CGAN代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os#数据输入
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128#返回随机值
def xavier_init(size):in_dim = size[0]xavier_stddev = 1. / tf.sqrt(in_dim / 2.)return tf.random_normal(shape=size, stddev=xavier_stddev)#X代表输入图片,应该是28*28,但是这里没有使用CNN,y是相应的label
""" Discriminator Net model """
X = tf.placeholder(tf.float32, shape=[None, 784])
y = tf.placeholder(tf.float32, shape=[None, y_dim])
#权重,CGAN的输入是将图片输入与label concat起来,所以权重维度为784+10
D_W1 = tf.Variable(xavier_init([X_dim + y_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
#第二层有h_dim个节点
D_W2 = tf.Variable(xavier_init([h_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))theta_D = [D_W1, D_W2, D_b1, D_b2]#D网络,这里是一个简单的神经网络,x是输入图片向量,y是相应的label
def discriminator(x, y):inputs = tf.concat(axis=1, values=[x, y])D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)D_logit = tf.matmul(D_h1, D_W2) + D_b2D_prob = tf.nn.sigmoid(D_logit)return D_prob, D_logit#G网络参数,输入维度为Z_dim+y_dim,中间层有h_dim个节点,输出X_dim的数据
""" Generator Net model """
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])
#权重
G_W1 = tf.Variable(xavier_init([Z_dim + y_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))theta_G = [G_W1, G_W2, G_b1, G_b2]#G网络
def generator(z, y):inputs = tf.concat(axis=1, values=[z, y])G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)G_log_prob = tf.matmul(G_h1, G_W2) + G_b2G_prob = tf.nn.sigmoid(G_log_prob)return G_prob#噪声产生的函数
def sample_Z(m, n):return np.random.uniform(-1., 1., size=[m, n])def plot(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return fig#生成网络,基本和GAN一致
G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y)
D_fake, D_logit_fake = discriminator(G_sample, y)
#优化式
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
#训练
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)sess = tf.Session()
sess.run(tf.global_variables_initializer())
#输出图片在out文件夹
if not os.path.exists('out/'):os.makedirs('out/')i = 0for it in range(1000000):if it % 1000 == 0:#n_sample 是G网络测试用的Batchsize,为16,所以输出的png图有16张n_sample = 16Z_sample = sample_Z(n_sample, Z_dim)#输入的噪声,尺寸为batchsize*noise维度y_sample = np.zeros(shape=[n_sample, y_dim])#输入的label,尺寸为batchsize*label维度y_sample[:, 7] = 1 #输出7samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})#G网络的输入fig = plot(samples)plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')#输出生成的图片i += 1plt.close(fig)#mb_size是网络训练时用的Batchsize,为100X_mb, y_mb = mnist.train.next_batch(mb_size)#Z_dim是noise的维度,为100Z_sample = sample_Z(mb_size, Z_dim)#交替最小化训练_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: Z_sample, y:y_mb})_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: Z_sample, y:y_mb})#输出训练时的参数if it % 1000 == 0:print('Iter: {}'.format(it))print('D loss: {:.4}'. format(D_loss_curr))print('G_loss: {:.4}'.format(G_loss_curr))print()

生成效果如下:

为了方便理解,本文只用了最简单的神经网络,有时间会使用CNN重写该网络。

CGAN原理及tensorflow代码相关推荐

  1. 边框检测原理与Tensorflow代码

    要学习目标检测算法吗?任何一个ML学习者都希望能够给图像中的目标物体圈个漂亮的框框,在这篇文章中我们将学习目标检测中的一个基本概念:边框回归/Bounding Box Regression.边框回归并 ...

  2. DeepFM原理及tensorflow代码实战

    目录 1.背景 2.引入 FM挑战 DNN的局限 3.组件介绍 FM Deep 4.代码解析 1.背景 之前说了wide&deepGoogle于 2016 年在DLRS上发表了一篇文章:201 ...

  3. python神经网络原理pdf_《深度学习原理与 TensorFlow实践》高清完整PDF版 下载

    1.封面介绍 2.出版时间 2019年7月 3.推荐理由 本书介绍了深度学习原理与TensorFlow实践.着重讲述了当前学术界和工业界的深度学习核心知识:机器学习概论.神经网络.深度学习.着重讲述了 ...

  4. 如何高效的学习TensorFlow代码?

    如何高效的学习TensorFlow代码? 如题,或者如何掌握TensorFlow,应用到任何领域? 添加评论分享 10 个回答 爱琳李,老李,明天就辍学了 8 人赞同 本来都忘了这个问题了,不过看到很 ...

  5. 生成对抗网络简介(包含TensorFlow代码示例)【翻译】

    判别模型 vs. 生成模型 示例:近似一维高斯分布 提高样本多样性 最后的思考 关于GAN的一些讨论 最近,大家对生成模型的兴趣又开始出现(OpenAI关于生成模型的案例).生成模型可以学习如何生成数 ...

  6. 强化学习教程(四):从PDG到DDPG的原理及tf代码实现详解

    强化学习教程(四):从PDG到DDPG的原理及tf代码实现详解 原创 lrhao 公众号:ChallengeHub 收录于话题 #强化学习教程 前言 在前面强化学习教程(三)中介绍了基于策略「PG」算 ...

  7. tensorflow63 《深度学习原理与TensorFlow实战》03 Hello TensorFlow

    00 基本信息 <深度学习原理与TensorFlow实战>书中涉及到的代码主要来源于: A:Tensorflow/TensorflowModel/TFLean的样例, B:https:// ...

  8. 深度强化学习系列(14): A3C算法原理及Tensorflow实现

    在DQN.DDPG算法中均用到了一个非常重要的思想经验回放,而使用经验回放的一个重要原因就是打乱数据之间的相关性,使得强化学习的序列满足独立同分布. 本文首先从Google于ICML2016顶会上发的 ...

  9. CNN卷积神经网络—LeNet原理以及tensorflow实现mnist手写体训练

    CNN卷积神经网络-LeNet原理以及tensorflow实现minst手写体训练 1. LeNet原理 2.tensorflow实现Mnist手写体识别 1.安装tensorflow 2.代码实现手 ...

最新文章

  1. C++字符串反转(C++11)
  2. linux vim分屏:水平和垂直分屏
  3. 以 vim 的方式来使用 chrome 浏览器(利用 vimium 插件)
  4. Gradle里Copy任务(task)的使用
  5. MAYA建模桌面一角_maya怎么建模逼真的学生书桌书桌桌面?
  6. 一道简单的sql语句题
  7. 超简单将Centos的yum源更换为国内的阿里云源
  8. Spring4.x()--注解通知的写法
  9. Spring Boot 高效入门实战
  10. @程序员,入行物联网的避坑指南!| 技术头条
  11. left join on or 优化_pandas中merge/join有什么区别?
  12. Linux操作系统原理
  13. 前后端交互过程、常见软件架构、服务器分类
  14. 一文极速读懂 KEGG 数据库
  15. cocoa和cocoa Touch的区别
  16. android.hardware.Camera 5.1之后操作照相机是不是不能用啦,我用小米手机(基本android 6.0)
  17. 红米手机4A怎么样刷入开发版获得ROOT权限
  18. SQL学习笔记——task4:集合运算与内连结
  19. <Linux> Ubuntu kernel 源码编译 替换
  20. JEECG 新手常见问题大全,入门必读

热门文章

  1. #数论#洛谷 3951 JZOJ 5473 小凯的疑惑
  2. WARNING: Ignoring invalid distribution -ip 解决方案
  3. 第2章 基础设施即服务(IaaS)-2-Docker
  4. 使用python爬取有道词典翻译
  5. 浅谈神经网络之链式法则与反向传播算法
  6. oracle 按旬统计并且每月小计 行转列 PIVOT函数 与分组小计 ROLLUP 函数
  7. Java 字节码技术:不积细流,无以成江河
  8. LikeLib区块链底层公链技术应用
  9. 【C++标准头文件】<string>
  10. 导出数据库的longblob