1、结构图

2、知识点

生成器(G):将噪音数据生成一个想要的数据
判别器(D):将生成器的结果进行判别,

3、代码及案例

# coding: utf-8# ## 对抗生成网络案例 ##
#
#
# <img src="jpg/3.png" alt="FAO" width="590" ># - 判别器 : 火眼金睛,分辨出生成和真实的 <br />
# <br />
# - 生成器 : 瞒天过海,骗过判别器 <br />
# <br />
# - 损失函数定义 : 一方面要让判别器分辨能力更强,另一方面要让生成器更真 <br />
# <br />
#
# <img src="jpg/1.jpg" alt="FAO" width="590" ># In[1]:import tensorflow as tf
import numpy as np
import pickle
import matplotlib.pyplot as pltget_ipython().run_line_magic('matplotlib', 'inline')# # 导入数据# In[2]:from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/data')# ## 网络架构
#
# ### 输入层 :待生成图像(噪音)和真实数据
#
# ### 生成网络:将噪音图像进行生成
#
# ### 判别网络:
# - (1)判断真实图像输出结果
# - (2)判断生成图像输出结果
#
# ### 目标函数:
# - (1)对于生成网络要使得生成结果通过判别网络为真
# - (2)对于判别网络要使得输入为真实图像时判别为真 输入为生成图像时判别为假
#
# <img src="jpg/2.png" alt="FAO" width="590" ># ## Inputs# In[3]:#真实数据和噪音数据
def get_inputs(real_size, noise_size):real_img = tf.placeholder(tf.float32, [None, real_size])noise_img = tf.placeholder(tf.float32, [None, noise_size])return real_img, noise_img# ## 生成器
# * noise_img: 产生的噪音输入
# * n_units: 隐层单元个数
# * out_dim: 输出的大小(28 * 28 * 1)# In[4]:def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):with tf.variable_scope("generator", reuse=reuse):# hidden layerhidden1 = tf.layers.dense(noise_img, n_units)# leaky ReLUhidden1 = tf.maximum(alpha * hidden1, hidden1)# dropouthidden1 = tf.layers.dropout(hidden1, rate=0.2)# logits & outputslogits = tf.layers.dense(hidden1, out_dim)outputs = tf.tanh(logits)return logits, outputs# ## 判别器
# * img:输入
# * n_units:隐层单元数量
# * reuse:由于要使用两次# In[5]:def get_discriminator(img, n_units, reuse=False, alpha=0.01):with tf.variable_scope("discriminator", reuse=reuse):# hidden layerhidden1 = tf.layers.dense(img, n_units)hidden1 = tf.maximum(alpha * hidden1, hidden1)# logits & outputslogits = tf.layers.dense(hidden1, 1)outputs = tf.sigmoid(logits)return logits, outputs# ## 网络参数定义
# * img_size:输入大小
# * noise_size:噪音图像大小
# * g_units:生成器隐层参数
# * d_units:判别器隐层参数
# * learning_rate:学习率# In[6]:
img_size = mnist.train.images[0].shape[0]noise_size = 100g_units = 128d_units = 128learning_rate = 0.001alpha = 0.01# ## 构建网络# In[7]:
tf.reset_default_graph()real_img, noise_img = get_inputs(img_size, noise_size)# generator
g_logits, g_outputs = get_generator(noise_img, g_units, img_size)# discriminator
d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)# ### 目标函数:
# - (1)对于生成网络要使得生成结果通过判别网络为真
# - (2)对于判别网络要使得输入为真实图像时判别为真 输入为生成图像时判别为假
#
# <img src="jpg/2.png" alt="FAO" width="590" ># In[8]:# discriminator的loss
# 识别真实图片
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real)))
# 识别生成的图片
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)))
# 总体loss
d_loss = tf.add(d_loss_real, d_loss_fake)# generator的loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.ones_like(d_logits_fake)))# ## 优化器# In[9]:
train_vars = tf.trainable_variables()# generator
g_vars = [var for var in train_vars if var.name.startswith("generator")]
# discriminator
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]# optimizer
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)# # 训练# In[10]:# batch_size
batch_size = 64
# 训练迭代轮数
epochs = 300
# 抽取样本数
n_sample = 25# 存储测试样例
samples = []
# 存储loss
losses = []
# 保存生成器变量
saver = tf.train.Saver(var_list = g_vars)
# 开始训练
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for e in range(epochs):for batch_i in range(mnist.train.num_examples//batch_size):batch = mnist.train.next_batch(batch_size)batch_images = batch[0].reshape((batch_size, 784))# 对图像像素进行scale,这是因为tanh输出的结果介于(-1,1),real和fake图片共享discriminator的参数batch_images = batch_images*2 - 1# generator的输入噪声batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))# Run optimizers_ = sess.run(d_train_opt, feed_dict={real_img: batch_images, noise_img: batch_noise})_ = sess.run(g_train_opt, feed_dict={noise_img: batch_noise})# 每一轮结束计算losstrain_loss_d = sess.run(d_loss, feed_dict = {real_img: batch_images, noise_img: batch_noise})# real img losstrain_loss_d_real = sess.run(d_loss_real, feed_dict = {real_img: batch_images, noise_img: batch_noise})# fake img losstrain_loss_d_fake = sess.run(d_loss_fake, feed_dict = {real_img: batch_images, noise_img: batch_noise})# generator losstrain_loss_g = sess.run(g_loss, feed_dict = {noise_img: batch_noise})print("Epoch {}/{}...".format(e+1, epochs),"判别器损失: {:.4f}(判别真实的: {:.4f} + 判别生成的: {:.4f})...".format(train_loss_d, train_loss_d_real, train_loss_d_fake),"生成器损失: {:.4f}".format(train_loss_g))    losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))# 保存样本sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True),feed_dict={noise_img: sample_noise})samples.append(gen_samples)saver.save(sess, './checkpoints/generator.ckpt')# 保存到本地
with open('train_samples.pkl', 'wb') as f:pickle.dump(samples, f)# # loss迭代曲线# In[11]:
fig, ax = plt.subplots(figsize=(20,7))
losses = np.array(losses)
plt.plot(losses.T[0], label='判别器总损失')
plt.plot(losses.T[1], label='判别真实损失')
plt.plot(losses.T[2], label='判别生成损失')
plt.plot(losses.T[3], label='生成器损失')
plt.title("对抗生成网络")
ax.set_xlabel('epoch')
plt.legend()# # 生成结果# In[12]:# Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:samples = pickle.load(f)# In[13]:#samples是保存的结果 epoch是第多少次迭代
def view_samples(epoch, samples):fig, axes = plt.subplots(figsize=(7,7), nrows=5, ncols=5, sharey=True, sharex=True)for ax, img in zip(axes.flatten(), samples[epoch][1]): # 这里samples[epoch][1]代表生成的图像结果,而[0]代表对应的logits
        ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')return fig, axes# In[14]:
_ = view_samples(-1, samples) # 显示最终的生成结果# # 显示整个生成过程图片# In[15]:# 指定要查看的轮次
epoch_idx = [10, 30, 60, 90, 120, 150, 180, 210, 240, 290]
show_imgs = []
for i in epoch_idx:show_imgs.append(samples[i][1])# In[16]:# 指定图片形状
rows, cols = 10, 25
fig, axes = plt.subplots(figsize=(30,12), nrows=rows, ncols=cols, sharex=True, sharey=True)idx = range(0, epochs, int(epochs/rows))for sample, ax_row in zip(show_imgs, axes):for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):ax.imshow(img.reshape((28,28)), cmap='Greys_r')ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)# # 生成新的图片# In[17]:# 加载我们的生成器变量
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))sample_noise = np.random.uniform(-1, 1, size=(25, noise_size))gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True),feed_dict={noise_img: sample_noise})# In[18]:
_ = view_samples(0, [gen_samples])

View Code

4、优化目标

转载于:https://www.cnblogs.com/ywjfx/p/11042162.html

深度学习之GAN对抗神经网络相关推荐

  1. 深度学习~生成式对抗神经网络GAN

    目录 出现背景(why?) 概念 出现背景(why?) 在分类任务中,训练机器学习和深度学习模块需要大量的真实世界数据,并且在某些情况下,获取足够数量的真实数据存在局限性,或者仅仅是时间和人力资源的投 ...

  2. 深度学习之生成对抗网络(6)GAN训练难题

    深度学习之生成对抗网络(6)GAN训练难题 1. 超参数敏感 2. 模式崩塌  尽管从理论层面分析了GAN网络能够学习到数据的真实分布,但是在工程实现中,常常出现GAN网络训练困难的问题,主要体现在G ...

  3. 深度学习之生成对抗网络(4)GAN变种

    深度学习之生成对抗网络(4)GAN变种 1. DCGAN 2. InfoGAN 3. CycleGAN 4. WGAN 5. Equal GAN 6. Self-Attention GAN 7. Bi ...

  4. 深度学习之生成对抗网络(2)GAN原理

    深度学习之生成对抗网络(2)GAN原理 1. 网络结构 生成网络G(z)\text{G}(\boldsymbol z)G(z) 判别网络D(x)\text{D}(\boldsymbol x)D(x) ...

  5. 【深度学习】生成对抗网络(GAN)的tensorflow实现

    [深度学习]生成对抗网络(GAN)的tensorflow实现 一.GAN原理 二.GAN的应用 三.GAN的tensorflow实现 参考资料 GAN( Generative Adversarial ...

  6. 深度学习之生成对抗网络(1)博弈学习实例

    深度学习之生成对抗网络(1)博弈学习实例 博弈学习实例  在 生成对抗网络(Generative Adversarial Network,简称GAN)发明之前,变分自编码器被认为是理论完备,实现简单, ...

  7. 深度学习、生成对抗、Pytorch优秀教材推荐

    推荐一批深度学习.生成对抗.Pytorch优秀教材推荐,绝对不容错过. <GANs in Action> Jakub Langr and Vladimir Bok 本书简介: <GA ...

  8. 深度学习笔记:卷积神经网络的可视化--卷积核本征模式

    目录 1. 前言 2. 代码实验 2.1 加载模型 2.2 构造返回中间层激活输出的模型 2.3 目标函数 2.4 通过随机梯度上升最大化损失 2.5 生成滤波器模式可视化图像 2.6 将多维数组变换 ...

  9. 【转载】Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule

    Hinton"深度学习之父"和"神经网络先驱",新论文Capsule将推翻自己积累了30年的学术成果时 Hinton"深度学习之父"和&qu ...

  10. Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时...

    Hinton"深度学习之父"和"神经网络先驱",新论文Capsule将推翻自己积累了30年的学术成果时 在论文中,Capsule被Hinton大神定义为这样一组 ...

最新文章

  1. JSTL fmt标签格式化日期时分秒显示为00:00:00和12:00:00问题
  2. 蓝牙MESH网关_水哥智能教学视频一米家蓝牙mesh设备如何升级固件
  3. 合肥学院计算机对口升学2019,15高校招生4340人!2019安徽省对口升学本科招生计划出炉!...
  4. 先有鸡还是先有蛋?--IT公司用人困惑
  5. mysql between and 包含边界吗_10分钟让你明白MySQL是如何利用索引的
  6. boost::program_options模块实现处理选项组的测试程序
  7. 使用Nodejs发送邮件
  8. Centos7 查看/关闭/启动防火墙
  9. Android开发教程:手机震动控制浅析
  10. c语言洗牌发牌结构体,C语言程序设计课程设计多功能计算器、洗牌发牌、学生文件处理、链表处理.doc...
  11. Java 使用阿里巴巴 Dns Cache Manipulator
  12. Python批量生成垃圾邮件内容
  13. 7-3 DAG图优化-A (15 分)(更新版)
  14. 全球完美打通元宇宙、DeFi、NFT的区块链游戏平台
  15. Unity随机创造敌人
  16. 管理成长计划(二):定目标--战略制定承接
  17. 台式机dp接口_了解笔记本电脑的各种视频接口
  18. 【Python】经典问题创建一个矩形类,定义方法 属性 初始化
  19. 终于完美解决OneNote无法同步的问题!如此简单!
  20. SQA,SQC是什么意思,有什么区别?

热门文章

  1. DEDECMS 另类***
  2. 在c语言Windows窗口添加按钮,C语言用windows.h创建按钮的问题
  3. python 阿里云短信接口_阿里云短信接口怎么使用
  4. IEquatable「T」和Equal详解
  5. springboot的使用html页面及css、js路径的配置
  6. android webview 选择图片上传,Android webview打开本地图片上传实现代码
  7. kafka依赖_Kafka集群搭建及必知必会
  8. matlab做三次拉格朗日插值多项式_买菜必用的MATLAB拉格朗日插值函符号解输出
  9. 【渝粤教育】国家开放大学2018年春季 0266-21T设计构成 参考试题
  10. 【渝粤教育】国家开放大学2018年春季 0680-22T会计基础知识 参考试题