训练数据集:手写数字识别

下载链接:https://pan.baidu.com/s/1d9jX5xLHd1x3DFChVCe3LQ 密码:ws28

在本篇博客中,笔者将逐行解析一下NIPS 2014的Generative Adversarial Networks(生成对抗网络,简称GAN)代码,该篇文章作为GAN系列的开山之作,在近3年吸引了无数学者的目光。在2017-2018年,各大计算机顶会中也都能看到各种GAN的身影。因此,本篇博客就来逐行解析一下使用GAN生成手写数字的代码。

在正式开始之前,笔者想说的是,如果要使得本篇博客对各位读者朋友的学习有帮助,请各位读者朋友们先熟悉生成对抗网络的基本原理。由于对于生成对抗网络的原理详解网络上的资源比较多,在本篇博客中笔者就不再对生成对抗网络的原理进行解释,而是给大家推荐一些对生成对抗网络原理进行了解的链接:

1. 直接进行论文阅读:https://arxiv.org/abs/1406.2661

2. 一篇通俗易懂,形象的GAN原理解释:一文看懂生成式对抗网络GANs:介绍指南及前景展望

3. 一篇比较详细的CSDN博文:生成式对抗网络GAN研究进展(二)——原始GAN

4. 知乎专栏上的文章:GAN原理学习笔记

如果对生成对抗网络原理已经熟稔的读者朋友,请自动忽略以上链接。并且,笔者以下放出的代码注释是参考了github上面的代码,链接https://github.com/wiseodd/generative-models

在这里笔者也想衷心感谢一下这位wiseodd大神,在他的generative-models下面的关于生成模型的代码非常全面,本文解析的代码路径是该工程下面的GAN/vanilla_gan/gan_tensorflow.py文件。笔者沿用了作者的代码,只是增加了模型保存与summary记录的少量代码,下面放出代码及注释:import tensorflow as tf #导入tensorflow
from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集
import numpy as np #导入numpy
import matplotlib.pyplot as plt #plt是绘图工具,在训练过程中用于输出可视化结果
import matplotlib.gridspec as gridspec #gridspec是图片排列工具,在训练过程中用于输出可视化结果
import os #导入os
 
def save(saver, sess, logdir, step): #保存模型的save函数
   model_name = 'model' #模型名前缀
   checkpoint_path = os.path.join(logdir, model_name) #保存路径
   saver.save(sess, checkpoint_path, global_step=step) #保存模型
   print('The checkpoint has been created.')
 
def xavier_init(size): #初始化参数时使用的xavier_init函数
    in_dim = size[0] 
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.) #初始化标准差
    return tf.random_normal(shape=size, stddev=xavier_stddev) #返回初始化的结果
 
X = tf.placeholder(tf.float32, shape=[None, 784]) #X表示真的样本(即真实的手写数字)
 
D_W1 = tf.Variable(xavier_init([784, 128])) #表示使用xavier方式初始化的判别器的D_W1参数,是一个784行128列的矩阵
D_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的判别器的D_1参数,是一个长度为128的向量
 
D_W2 = tf.Variable(xavier_init([128, 1])) #表示使用xavier方式初始化的判别器的D_W2参数,是一个128行1列的矩阵
D_b2 = tf.Variable(tf.zeros(shape=[1])) ##表示全零方式初始化的判别器的D_1参数,是一个长度为1的向量
 
theta_D = [D_W1, D_W2, D_b1, D_b2] #theta_D表示判别器的可训练参数集合
 
 
Z = tf.placeholder(tf.float32, shape=[None, 100]) #Z表示生成器的输入(在这里是噪声),是一个N列100行的矩阵
 
G_W1 = tf.Variable(xavier_init([100, 128])) #表示使用xavier方式初始化的生成器的G_W1参数,是一个100行128列的矩阵
G_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的生成器的G_b1参数,是一个长度为128的向量
 
G_W2 = tf.Variable(xavier_init([128, 784])) #表示使用xavier方式初始化的生成器的G_W2参数,是一个128行784列的矩阵
G_b2 = tf.Variable(tf.zeros(shape=[784])) #表示全零方式初始化的生成器的G_b2参数,是一个长度为784的向量
 
theta_G = [G_W1, G_W2, G_b1, G_b2] #theta_G表示生成器的可训练参数集合
 
 
def sample_Z(m, n): #生成维度为[m, n]的随机噪声作为生成器G的输入
    return np.random.uniform(-1., 1., size=[m, n])
 
 
def generator(z): #生成器,z的维度为[N, 100]
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) #输入的随机噪声乘以G_W1矩阵加上偏置G_b1,G_h1维度为[N, 128]
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 #G_h1乘以G_W2矩阵加上偏置G_b2,G_log_prob维度为[N, 784]
    G_prob = tf.nn.sigmoid(G_log_prob) #G_log_prob经过一个sigmoid函数,G_prob维度为[N, 784]
 
    return G_prob #返回G_prob
 
 
def discriminator(x): #判别器,x的维度为[N, 784]
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) #输入乘以D_W1矩阵加上偏置D_b1,D_h1维度为[N, 128]
    D_logit = tf.matmul(D_h1, D_W2) + D_b2 #D_h1乘以D_W2矩阵加上偏置D_b2,D_logit维度为[N, 1]
    D_prob = tf.nn.sigmoid(D_logit) #D_logit经过一个sigmoid函数,D_prob维度为[N, 1]
 
    return D_prob, D_logit #返回D_prob, D_logit
 
 
def plot(samples): #保存图片时使用的plot函数
    fig = plt.figure(figsize=(4, 4)) #初始化一个4行4列包含16张子图像的图片
    gs = gridspec.GridSpec(4, 4) #调整子图的位置
    gs.update(wspace=0.05, hspace=0.05) #置子图间的间距
 
    for i, sample in enumerate(samples): #依次将16张子图填充进需要保存的图像
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
 
    return fig
 
 
G_sample = generator(Z) #取得生成器的生成结果
D_real, D_logit_real = discriminator(X) #取得判别器判别的真实手写数字的结果
D_fake, D_logit_fake = discriminator(G_sample) #取得判别器判别的生成的手写数字的结果
 
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real))) #对判别器对真实样本的判别结果计算误差(将结果与1比较)
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake))) #对判别器对虚假样本(即生成器生成的手写数字)的判别结果计算误差(将结果与0比较)
D_loss = D_loss_real + D_loss_fake #判别器的误差
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake))) #生成器的误差(将判别器返回的对虚假样本的判别结果与1比较)
 
dreal_loss_sum = tf.summary.scalar("dreal_loss", D_loss_real) #记录判别器判别真实样本的误差
dfake_loss_sum = tf.summary.scalar("dfake_loss", D_loss_fake) #记录判别器判别虚假样本的误差
d_loss_sum = tf.summary.scalar("d_loss", D_loss) #记录判别器的误差
g_loss_sum = tf.summary.scalar("g_loss", G_loss) #记录生成器的误差
 
summary_writer = tf.summary.FileWriter('snapshots/', graph=tf.get_default_graph()) #日志记录器
 
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) #判别器的训练器
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) #生成器的训练器
 
mb_size = 128 #训练的batch_size
Z_dim = 100 #生成器输入的随机噪声的列的维度
 
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集
 
sess = tf.Session() #会话层
sess.run(tf.global_variables_initializer()) #初始化所有可训练参数
 
if not os.path.exists('out/'): #初始化训练过程中的可视化结果的输出文件夹
    os.makedirs('out/')
 
if not os.path.exists('snapshots/'): #初始化训练过程中的模型保存文件夹
    os.makedirs('snapshots/')
 
saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=50) #模型的保存器
 
i = 0 #训练过程中保存的可视化结果的索引
 
for it in range(1000000): #训练100万次
    if it % 1000 == 0: #每训练1000次就保存一下结果
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})
 
        fig = plot(samples) #通过plot函数生成可视化结果
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight') #保存可视化结果
        i += 1
        plt.close(fig)
 
    X_mb, _ = mnist.train.next_batch(mb_size) #得到训练一个batch所需的真实手写数字(作为判别器的输入)
 
    #下面是得到训练一次的结果,通过sess来run出来
    _, D_loss_curr, dreal_loss_sum_value, dfake_loss_sum_value, d_loss_sum_value = sess.run([D_solver, D_loss, dreal_loss_sum, dfake_loss_sum, d_loss_sum], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
    _, G_loss_curr, g_loss_sum_value = sess.run([G_solver, G_loss, g_loss_sum], feed_dict={Z: sample_Z(mb_size, Z_dim)})
 
    if it%100 ==0: #每过100次记录一下日志,可以通过tensorboard查看
        summary_writer.add_summary(dreal_loss_sum_value, it)
        summary_writer.add_summary(dfake_loss_sum_value, it)
        summary_writer.add_summary(d_loss_sum_value, it)
        summary_writer.add_summary(g_loss_sum_value, it)
 
    if it % 1000 == 0: #每训练1000次输出一下结果
        save(saver, sess, 'snapshots/', it)
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

在上面的代码中,各位读者朋友可以看到,生成器与判别器都是使用多层感知机实现的(没有使用卷积神经网络)。生成器的输入是随机噪声,生成的是手写数字,生成器与判别器均使用Adam优化器进行训练并训练100w次。

在上面的代码中,笔者添加了各种summary保存了训练中的误差,结果如下所示。

判别器判别真实样本的误差变化:

判别器判别虚假样本(即生成器G生成的手写数字)的误差变化:

判别器的误差变化(上面两者之和):

生成器的误差变化:

下面是训练过程中输出的可视化结果,笔者选择了一些,大家可以看到,生成器输出结果最开始非常糟糕,但是随着训练的进行到训练中期输出效果越来越好:

训练2k次的输出:

训练6k次的输出:

训练4.2w次的输出

训练14.4w次的输出:

训练24.4w次的输出:

训练31.6w次的输出:

在训练的后期(训练80w次之后),大家从生成器的误差曲线可以看出,生成器的误差陡增,生成效果也相应变差了(如下图所示),这是生成器与判别器失衡的结果。

训练85.7w次的输出:

训练93.6w次的输出:

训练97.2w次的输出:

到这里,生成对抗网络的代码讲解就接近尾声了,衷心希望笔者的本篇博客对大家有帮助!

欢迎阅读笔者后续博客,各位读者朋友的支持与鼓励是我最大的动力!

written by jiong

日出入安穷?时世不与人同。

故春非我春,夏非我夏,秋非我秋,

冬,亦非我冬。
--------------------- 
作者:jiongnima 
来源:CSDN 
原文:https://blog.csdn.net/jiongnima/article/details/80033169 
版权声明:本文为博主原创文章,转载请附上博文链接!

【转】详解GAN代码之逐行解析GAN代码相关推荐

  1. [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  2. 万字详解什么是生成对抗网络GAN

    摘要:这篇文章将详细介绍生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN).发展历程.预备知识,并通过Keras搭建最简答的手写数字图片生成案. ...

  3. Android中measure过程、WRAP_CONTENT详解以及xml布局文件解析流程浅析(下)

       本文原创, 转载请注明出处:http://blog.csdn.net/qinjuning 上篇文章<<Android中measure过程.WRAP_CONTENT详解以及xml布局文 ...

  4. Mask_RCNN翻译和详解笔记一(原文翻译+源代码+代码使用说明)

    Mask_RCNN翻译和详解笔记一(原文翻译+源代码+代码使用说明) 2018年06月01日 23:45:47 阅读数:332 原文:https://github.com/matterport/Mas ...

  5. java斐波那契查找_详解Java Fibonacci Search斐波那契搜索算法代码实现

    一, 斐波那契搜索算法简述 斐波那契搜索(Fibonacci search) ,又称斐波那契查找,是区间中单峰函数的搜索技术. 斐波那契搜索采用分而治之的方法,其中我们按照斐波那契数列对元素进行不均等 ...

  6. 详解DNS服务、DNS解析、DNS劫持和污染

    简介 DNS(全称:Domain Name System,中文:域名系统)是互联网的一项服务.它作为将域名和 IP 地址相互映射的一个分布式数据库,能够使人更方便地访问互联网.1 前言 要想弄清楚 D ...

  7. Three.js实例详解___旋转的精灵女孩(附完整代码和资源)(一)

    Three.js实例详解___旋转的精灵女孩(附完整代码和资源)(一) 本文目录: 一.[旋转的精灵女孩]案例运行效果 二.Three.js简介 三.Three.js代码正常运行显示条件 (1)不载入 ...

  8. Three.js实例详解___旋转的精灵女孩(附完整代码和资源)(三)

    Three.js实例详解___旋转的精灵女孩(附完整代码和资源)(三) 本篇目录: 六.完整构建整个[旋转的精灵女孩]实例 (1).新建.启动webGL工程空间 (2).构建项目的目录层次结构 (2. ...

  9. Three.js实例详解___旋转的精灵女孩(附完整代码和资源)(二)

    Three.js实例详解___旋转的精灵女孩(附完整代码和资源)(二) 本篇目录: 五.实例中所使用的代码语法详细解释 (1).构建一个三维空间场景 (2).选择一个透视投影相机作为观察点 (a).创 ...

  10. Android中measure过程、WRAP_CONTENT详解以及 xml布局文件解析流程浅析

    转自:http://www.uml.org.cn/mobiledev/201211221.asp 今天,我着重讲解下如下三个内容: measure过程 WRAP_CONTENT.MATCH_PAREN ...

最新文章

  1. 研究人员吐槽当前AI训练效率过于低下
  2. computed vue 不 触发_vuejs render何时执行?以及使用vue.$refs遇到的坑。
  3. how can you understand the world
  4. eclipse设置和启动优化(转)
  5. MySQL修改数据类型语句
  6. js判断对象还是数组
  7. 计算机考试打字对齐,2010年职称计算机考试:对齐方式
  8. dst发育筛查有意义吗_儿童视力筛查,都筛些啥?
  9. sqlserver text最大长度_1156. 单字符重复子串的最大长度
  10. android自定义对话框_Android自定义提醒对话框
  11. codeforces-constructive algorithms(构造算法.)
  12. Atcoder Grand Contest 036 D - Negative Cycle
  13. 空间直线与球面相交算法
  14. 2.5数字传输系统2.6宽带接入技术
  15. WGCNA:(加权共表达网络分析)
  16. 计算机网络实验一、验证性实验
  17. 企业微信可以取消实名认证吗?如何操作?
  18. JSP九大内置对象是什么?
  19. 955.WLB 红包封面来啦!送给希望不加班的你~
  20. 性能优化专题 - JVM 性能优化 - 04 - GC算法与调优

热门文章

  1. Asp.net使用HttpResponse.Filter 实现简繁/繁简转换
  2. 802.1D生成树STP协议
  3. VLAN的分类与实验
  4. MySQL集群Cluste详解(二)——配置实战
  5. 在OSPF网络中,如何判断LSA的新旧?
  6. Arts 第十二周(6/3 ~ 6/9)
  7. Android开发之跟踪应用更新大小
  8. 一些不好理解的名词解释
  9. 为什么不要把ZooKeeper用于服务发现
  10. (四)伪分布式下jdk1.6+Hadoop1.2.1+HBase0.94+Eclipse下运行wordCount例子