朋友们,如需转载请标明出处:https://blog.csdn.net/jiangjunshow

前面讲解了那么多GAN的基础知识,我们已经比较深入地了解GAN了,但如果不动手将上面的理论知识融入到实战中,你依旧无法内化上面的内容,所以接着就通过TensorFlow来实现一个朴素GAN。(文章中使用的是Tensorflow 1.x版本的语法)

我们主要是创建一个最简单的GAN,然后训练它,使它可以生成与真实图片一样的手写数字图片。下面直接进行代码的编写。

(1)导入第三方库。

 import tensorflow as tfimport numpy as npimport pickleimport matplotlib.pyplot as plt

我们使用TensorFlow来实现GAN的网络架构,并对构建的GAN进行训练;使用numpy来生成随机噪声,用于给生成器生成输入数据;使用pickle来持久化地保存变量;最后使用matplotlib来可视化GAN训练时两个网络结构损失的变化以及GAN生成的图片。

(2)因为是要训练GAN生成MNIST手写数据集中的图片,需要读入MNIST数据集中的真实图片作为训练判别器D的真实数据,TensorFlow提供了处理MNIST的方法,可以使用它读入MNIST数据。

 from tensorflow.examples.tutorials.mnist import input_data# 读入MNIST数据mnist = input_data.read_data_sets('./data/MNIST_data')img = mnist.train.images[500]   #以灰度图的形式读入plt.imshow(img.reshape((28, 28)), cmap='Greys_r')plt.show()

读入MNIST图片后,每一张图片都由一个一维矩阵表示。

  print(type(img))print(img.shape)
输出如下。<class 'numpy.ndarray'>(784,)

PS:TensorFlow在1.9版本后,input_data.read_data_sets方法不会自动下载,如果本地没有MNIST数据集,就会报错,所以我们必须事先将它下载好。

接着定义用于接收输入的方法,使用TensorFlow的placeholder占位符来获得输入的数据。

  def get_inputs(real_size, noise_size):real_img = tf.placeholder(tf.float32, [None, real_size], name='real_img')noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img')return real_img, noise_img

然后就可以实现生成器和判别器了,先来看生成器,代码如下。

  def generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):'''生成器:paramnoise_img: 生成器生成的噪声图片:paramn_units: 隐藏层单元数:paramout_dim: 生成器输出的tensor的size,应该是32×32=784:param reuse: 是否重用空间:param alpha: leakeyReLU系数:return:'''with tf.variable_scope("generator", reuse=reuse):#全连接hidden1 = tf.layers.dense(noise_img, n_units)#返回最大值hidden1 = tf.maximum(alpha * hidden1, hidden1)hidden1 = tf.layers.dropout(hidden1, rate=0.2, training=True)#dense:全连接logits = tf.layers.dense(hidden1, out_dim)outputs = tf.tanh(logits)return logits, outputs

可以发现生成器的网络结构非常简单,只是一个具有单隐藏层的神经网络,其整体结构为输入层→隐藏层→输出层,一开始只是编写最简单的GAN,在后面的高级内容中,生成器和判别器的结构会复杂一些。

简单解释一下上面的代码,首先使用tf.variable_scope创建了一个名为generator的空间,主要目的是实现在该空间中,变量可以被重复使用且方便区分不同卷积层之间的组件。

接着使用tf.layers下的dense方法将输入层和隐藏层进行全连接。tf.layers模块提供了很多封装层次较高的方法,使用这些方法,我们可以更加轻松地构建相应的神经网络结构。这里使用dense方法,其作用就是实现全连接。

我们选择Leaky ReLU作为隐藏层的激活函数,使用tf.maximum方法返回通过Leaky ReLU激活后较大的值。

然后使用tf.layers的dropout方法,其做法就是按一定的概率随机弃用神经网络中的网络单元(即将该网络单元的参数置0),防止发生过拟合现象,dropout只能在训练时使用,在测试时不能使用。最后再通过dense方法,实现隐藏层与输出层全连接,并使用Tanh作为输出层的激活函数(试验中用Tanh作为激活函数生成器效果更好),Tanh函数的输出范围是−1~1,即表示生成图片的像素范围是−1~1,但MNIST数据集中真实图片的像素范围是0~1,所以在训练时,要调整真实图片的像素范围,让其与生成图片一致。

Leakey ReLU函数是ReLU函数的变种,与ReLU函数的不同之处在于,ReLU将所有的负值都设为零,而Leakey ReLU则给负值乘以一个斜率。

接着看判别器的代码。

  def discirminator(img, n_units, reuse=False, alpha=0.01):'''判别器:paramimg: 图片(真实图片/生成图片):paramn_units::param reuse::param alpha::return:'''with tf.variable_scope('discriminator', reuse=reuse):hidden1 = tf.layers.dense(img, n_units)hidden1 = tf.maximum(alpha * hidden1, hidden1)logits = tf.layers.dense(hidden1, 1)outputs = tf.sigmoid(logits)return logits, outputs

判别器的实现代码与生成器没有太大差别,稍有不同的地方就是,判别器的输出层只有一个网络单元且使用sigmoid作为输出层的激活函数,sigmoid函数输出值的范围是0~1。

生成器和判别器编写完成后,接着就来编写具体的计算图,首先做一些初始化工作,如定义需要的变量、清空default graph计算图。

img_size = mnist.train.images[0].shape[0]#真实图片大小noise_size = 100 #噪声,Generator的初始输入g_units = 128#生成器隐藏层参数d_units = 128alpha = 0.01 #leaky ReLU参数learning_rate = 0.001 #学习速率smooth = 0.1 #标签平滑# 重置default graph计算图以及nodes节点tf.reset_default_graph()

然后我们通过get_inputs方法获得真实图片的输入和噪声输入,并传入生成器和判别器进行训练,当然,现在只是构建GAN整个网络的训练结构。


#生成器g_logits, g_outputs = generator(noise_img, g_units, img_size)#判别器d_logits_real, d_outputs_real = discirminator(real_img, d_units)# 传入生成图片,为其打分d_logits_fake, d_outputs_fake = discirminator(g_outputs, d_units, reuse=True)

上面的代码将噪声、生成器隐藏层节点数、真实图片大小传入生成器,传入真实图片的大小是因为要求生成器可以生成与真实图片大小一样的图片。

判别器一开始先传入真实图片和判别器隐藏层节点,为真实图片打分,接着再用相同的参数训练生成图片,为生成图片打分。

训练逻辑构建完成,接着就定义生成器和判别器的损失。先回忆一下前面对损失的描述,判别器的损失由判别器给真实图片打分与其期望分数的差距、判别器给生成图片打分与其期望分数的差距两部分构成。这里定义最高分为1、最低分为0,也就是说判别器希望给真实图片打1分,给生成图片打0分。生成器的损失实质上是生成图片与真实图片概率分布上的差距,这里将其转换为,生成器期望判别器给自己的生成图片打多少分与实际上判别器给生成图片打多少分的差距。

  d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real))*(1-smooth))d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)))#判别器总损失d_loss = tf.add(d_loss_real, d_loss_fake)g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake))*(1-smooth))

计算损失时使用tf.nn.sigmoid_cross_entropy_with_logits方法,它对传入的logits参数先使用sigmoid函数计算,然后再计算它们的cross entropy交叉熵损失,同时该方法优化了cross entropy的计算方式,使得结果不会溢出。从方法的名字就可以直观地看出它的作用。

损失定义好后,要做的就是最小化这个损失。

  # generator中的tensorg_vars = [var for var in train_vars if var.name.startswith("generator")]# discriminator中的tensord_vars = [var for var in train_vars if var.name.startswith("discriminator")]#AdamOptimizer优化损失d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

要最小化损失,先要获得对应网络结构中的参数,也就是生成器和判别器的变量,这是最小化损失时要修改的对象。这里使用AdamOptimizer方法来最小化损失,其内部实现了Adam算法,该算法基于梯度下降算法,但它可以动态地调整每个参数的学习速率。

至此整个计算结果大致定义完成,接着开始实现具体的训练逻辑,先初始化一些与训练有关的变量。

  batch_size = 64 #每一轮训练数量epochs = 500 #训练迭代轮数n_sample = 25 #抽取样本数samples = [] #存储测试样例losses = [] #存储loss#保存生成器变量saver = tf.train.Saver(var_list=g_vars)

编写训练具体代码。

  with tf.Session() as sess:# 初始化模型的参数sess.run(tf.global_variables_initializer())for e in range(epochs):for batch_i in range(mnist.train.num_examples // batch_size):batch = mnist.train.next_batch(batch_size)#28 × 28 = 784batch_images = batch[0].reshape((batch_size, 784))# 对图像像素进行scale,这是因为Tanh输出的结果介于(-1,1)之间,real和fake图片共享discriminator的参数batch_images = batch_images * 2 -1#生成噪声图片batch_noise = np.random.uniform(-1,1,size=(batch_size, noise_size))#先训练判别器,再训练生成器_= sess.run(d_train_opt, feed_dict={real_img: batch_images, noise_img:batch_noise})_= sess.run(g_train_opt, feed_dict={noise_img:batch_noise})#每一轮训练完后,都计算一下losstrain_loss_d = sess.run(d_loss, feed_dict={real_img:batch_images, noise_img:batch_noise})# 判别器训练时真实图片的损失train_loss_d_real = sess.run(d_loss_real, feed_dict={real_img:batch_images,noise_img:batch_noise})# 判别器训练时生成图片的损失train_loss_d_fake = sess.run(d_loss_fake, feed_dict={real_img:batch_images,noise_img:batch_noise})# 生成器损失train_loss_g = sess.run(g_loss, feed_dict= {noise_img: batch_noise})print("训练轮数 {}/{}...".format(e + 1, epochs),"判别器总损失: {:.4f}(真实图片损失: {:.4f} + 虚假图片损失: {:.4f})...".format(train_loss_d,train_loss_d_real,train_loss_d_fake),"生成器损失: {:.4f}".format(train_loss_g))# 记录各类loss值losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))# 抽取样本后期进行观察sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))#生成样本,保存起来后期观察gen_samples = sess.run(generator(noise_img, g_units, img_size, reuse=True),feed_dict={noise_img:sample_noise})samples.append(gen_samples)# 存储checkpointssaver.save(sess, './data/generator.ckpt')with open('./data/train_samples.pkl', 'wb') as f:pickle.dump(samples,f)

一开始是创建Session对象,然后使用双层for循环进行GAN的训练,第一层表示要训练多少轮,第二层表示每一轮训练时,要取的样本量,因为一口气训练完所有的真实图片效率会比较低,一般的做法是将其分割成多组,然后进行多轮训练,这里64张为一组。

接着就是读入一组真实数据,因为生成器使用Tanh作为输出层的激活函数,导致生成图片的像素范围是−1~1,所以这里也简单调整一下真实图片的像素访问,将其从0~1变为−1~1,然后使用numpy的uniform方法生成−1~1之间的随机噪声。准备好真实数据和噪声数据后,就可以丢入生成器和判别器了,数据会按我们之前设计好的计算图运行,值得注意的是,要先训练判别器,再训练生成器。

当本轮将所有的真实图片都训练了一遍后,计算一下本轮生成器和判别器的损失,并将损失记录起来,方便后面可视化GAN训练过程中损失的变化。为了直观地感受GAN训练时生成器的变化,每一轮GAN训练完都用此时的生成器生成一组图片并保存起来。训练逻辑编写完后,就可以让训练代码运行起来,输出如下内容。

训练轮数 1/500… 判别器总损失: 0.0190(真实图片损失: 0.0017 + 虚假图片损失: 0.0173)…

生成器损失: 4.1502

训练轮数 2/500… 判别器总损失: 1.0480(真实图片损失: 0.3772 + 虚假图片损失: 0.6708)…

生成器损失: 3.1548

训练轮数 3/500… 判别器总损失: 0.5315(真实图片损失: 0.3580 + 虚假图片损失: 0.1736)…

生成器损失: 2.8828

训练轮数 4/500… 判别器总损失: 2.9703(真实图片损失: 1.5434 + 虚假图片损失: 1.4268)…

生成器损失: 0.7844

训练轮数 5/500… 判别器总损失: 1.0076(真实图片损失: 0.5763 + 虚假图片损失: 0.4314)…

生成器损失: 1.8176

训练轮数 6/500… 判别器总损失: 0.7265(真实图片损失: 0.4558 + 虚假图片损失: 0.2707)…

生成器损失: 2.9691

训练轮数 7/500… 判别器总损失: 1.5635(真实图片损失: 0.8336 + 虚假图片损失: 0.7299)…

生成器损失: 2.1342

整个训练过程会花费30~40分钟。

教你编写第一个生成式对抗网络GAN相关推荐

  1. 王飞跃教授:生成式对抗网络GAN的研究进展与展望

    本次汇报的主要内容包括GAN的提出背景.GAN的理论与实现模型.发展以及我们所做的工作,即GAN与平行智能.  生成式对抗网络GAN GAN是Goodfellow在2014年提出来的一种思想,是一种比 ...

  2. 如何用 TensorFlow 实现生成式对抗网络(GAN)

    我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodfellow 在 14 年发表了 论文 Generative Adversarial Nets 以 ...

  3. 简述生成式对抗网络 GAN

    本文主要阐述了对生成式对抗网络的理解,首先谈到了什么是对抗样本,以及它与对抗网络的关系,然后解释了对抗网络的每个组成部分,再结合算法流程和代码实现来解释具体是如何实现并执行这个算法的,最后通过给出一个 ...

  4. 深度学习之生成式对抗网络 GAN(Generative Adversarial Networks)

    一.GAN介绍 生成式对抗网络GAN(Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.它源于2014年发表的论文:& ...

  5. 生成式对抗网络GAN(一)—基于python实现

    基于python实现生成式对抗网络GAN 构建和训练一个生成对抗网络(GAN) ,使其可以生成数字(0-9)的手写图像. 学习目标 从零开始构建GAN的生成器和判别器. 创建GAN的生成器和判别器的损 ...

  6. 《生成式对抗网络GAN的研究进展与展望》论文笔记

    本文主要是对论文:王坤峰, 苟超, 段艳杰, 林懿伦, 郑心湖, 王飞跃. 生成式对抗网络GAN的研究进展与展望. 自动化学报, 2017, 43(3): 321-332. 进行总结. 相关博客地址: ...

  7. 深度学习之生成式对抗网络GAN

    一.GAN介绍 生成式对抗网络GAN(Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块 ...

  8. 生成式对抗网络(GAN, Generaitive Adversarial Networks)总结

    最近要做有关图像生成的工作-也是小白,今天简单学习一些有关GAN的基础知识,很浅,入个门,大神勿喷. GAN目前确实是在深度学习领域最热门,最有前景的方向之一.近几年有关于GAN的论文非常非常之多,从 ...

  9. 生成式对抗网络GAN模型搭建

    生成式对抗网络GAN模型搭建 目录 一.理论部分 1.GAN基本原理介绍 2.对KL散度的理解 3.模块导入命令 二.编程实现 1.加载所需要的模块和库,设定展示图片函数以及其他对图像预处理函数 1) ...

最新文章

  1. 百度编辑器ueditor每次编辑后多一个空行的解决办法
  2. 在openstack上创建第一个虚拟机
  3. titanium开发教程-04-11其他属性和方法
  4. 图像中添加二项式分布噪声
  5. 高级数据结构与算法 | LFU缓存机制(Least Frequently Used)
  6. java 方法 示例_Java ArrayDeque带有示例的removeFirstOccurrence()方法
  7. 云计算入门科普系列:小型云计算平台怎么搭建?
  8. 安装erlang没有bin文件夹_RabbitMQ安装教程
  9. 老李分享云计算基本概念 2
  10. matplotlib绘制李萨如图(二) animation实现动态2D李萨如图
  11. 算法设计与分析 期末考试试卷
  12. Ubuntu 20.04安装微信,QQ,TIM
  13. mouseenter、mouseleave、mouseover和mouseout的区别
  14. java 操作主机,告诉你java怎么实现键盘操作
  15. iOS UITableView 指定组头悬停位置
  16. python实现推广小项目
  17. c#WinForm使用OpencvSharp4实现简易抓边
  18. 用USART来替代SPI,效果非常不错
  19. 仿网易云音乐日推界面(监听AppBarLayout滑动+动态高斯模糊)
  20. 一个参数 nls_date_language

热门文章

  1. 王者荣耀高清壁纸脚本Python文件
  2. VS语音信号处理(2) C语言分段读取WAV语音文件语音数据
  3. PLC的编程语言跟CNC的编程语言有什么区别?
  4. 各博客平台编辑器使用测评
  5. oracle11g連不上em,oracle11gem重建失败的几点解决办法.doc
  6. UG二次开发GRIP修改属性
  7. 基于网络爬虫的新闻实时监测分析可视化系统(Java+MySQL+Web+Eclipse)
  8. ios7新特性--4
  9. 解决win10 文件属性没有数字签名及详细信息等选项卡
  10. 【面试】764高频前端开发面试问题及答案整理