刚读完上一篇的GAN原理,是不是觉得GAN特别的有意思,已经迫不及待了,那赶紧趁热打铁上手实战吧!

1.我的环境:

  • win10
  • Pycharm
  • Python3.5
  • tensorflow-gpu-1.4.0
  • matplotlib-3.0.3
  • numpy-1.16.2

2.数据集:
手写数字识别mnist数据集
下载完数据之后,所有数据为压缩包形式,我们需要对train数据进行解压:

3.代码:
网上代码真的是太多了,首先感谢这位大神在Github上总结的代码,直接就能成功运行了:

import tensorflow as tf #导入tensorflow
from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集
import numpy as np #导入numpy
import matplotlib.pyplot as plt #plt是绘图工具,在训练过程中用于输出可视化结果
import matplotlib.gridspec as gridspec #gridspec是图片排列工具,在训练过程中用于输出可视化结果
import os #导入osdef save(saver, sess, logdir, step): #保存模型的save函数model_name = 'model' #模型名前缀checkpoint_path = os.path.join(logdir, model_name) #保存路径saver.save(sess, checkpoint_path, global_step=step) #保存模型print('The checkpoint has been created.')def xavier_init(size): #初始化参数时使用的xavier_init函数in_dim = size[0] xavier_stddev = 1. / tf.sqrt(in_dim / 2.) #初始化标准差return tf.random_normal(shape=size, stddev=xavier_stddev) #返回初始化的结果X = tf.placeholder(tf.float32, shape=[None, 784]) #X表示真的样本(即真实的手写数字)D_W1 = tf.Variable(xavier_init([784, 128])) #表示使用xavier方式初始化的判别器的D_W1参数,是一个784行128列的矩阵
D_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的判别器的D_1参数,是一个长度为128的向量D_W2 = tf.Variable(xavier_init([128, 1])) #表示使用xavier方式初始化的判别器的D_W2参数,是一个128行1列的矩阵
D_b2 = tf.Variable(tf.zeros(shape=[1])) ##表示全零方式初始化的判别器的D_1参数,是一个长度为1的向量theta_D = [D_W1, D_W2, D_b1, D_b2] #theta_D表示判别器的可训练参数集合Z = tf.placeholder(tf.float32, shape=[None, 100]) #Z表示生成器的输入(在这里是噪声),是一个N列100行的矩阵G_W1 = tf.Variable(xavier_init([100, 128])) #表示使用xavier方式初始化的生成器的G_W1参数,是一个100行128列的矩阵
G_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的生成器的G_b1参数,是一个长度为128的向量G_W2 = tf.Variable(xavier_init([128, 784])) #表示使用xavier方式初始化的生成器的G_W2参数,是一个128行784列的矩阵
G_b2 = tf.Variable(tf.zeros(shape=[784])) #表示全零方式初始化的生成器的G_b2参数,是一个长度为784的向量theta_G = [G_W1, G_W2, G_b1, G_b2] #theta_G表示生成器的可训练参数集合def sample_Z(m, n): #生成维度为[m, n]的随机噪声作为生成器G的输入return np.random.uniform(-1., 1., size=[m, n])def generator(z): #生成器,z的维度为[N, 100]G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) #输入的随机噪声乘以G_W1矩阵加上偏置G_b1,G_h1维度为[N, 128]G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 #G_h1乘以G_W2矩阵加上偏置G_b2,G_log_prob维度为[N, 784]G_prob = tf.nn.sigmoid(G_log_prob) #G_log_prob经过一个sigmoid函数,G_prob维度为[N, 784]return G_prob #返回G_probdef discriminator(x): #判别器,x的维度为[N, 784]D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) #输入乘以D_W1矩阵加上偏置D_b1,D_h1维度为[N, 128]D_logit = tf.matmul(D_h1, D_W2) + D_b2 #D_h1乘以D_W2矩阵加上偏置D_b2,D_logit维度为[N, 1]D_prob = tf.nn.sigmoid(D_logit) #D_logit经过一个sigmoid函数,D_prob维度为[N, 1]return D_prob, D_logit #返回D_prob, D_logitdef plot(samples): #保存图片时使用的plot函数fig = plt.figure(figsize=(4, 4)) #初始化一个4行4列包含16张子图像的图片gs = gridspec.GridSpec(4, 4) #调整子图的位置gs.update(wspace=0.05, hspace=0.05) #置子图间的间距for i, sample in enumerate(samples): #依次将16张子图填充进需要保存的图像ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return figG_sample = generator(Z) #取得生成器的生成结果
D_real, D_logit_real = discriminator(X) #取得判别器判别的真实手写数字的结果
D_fake, D_logit_fake = discriminator(G_sample) #取得判别器判别的生成的手写数字的结果D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real))) #对判别器对真实样本的判别结果计算误差(将结果与1比较)
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake))) #对判别器对虚假样本(即生成器生成的手写数字)的判别结果计算误差(将结果与0比较)
D_loss = D_loss_real + D_loss_fake #判别器的误差
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake))) #生成器的误差(将判别器返回的对虚假样本的判别结果与1比较)dreal_loss_sum = tf.summary.scalar("dreal_loss", D_loss_real) #记录判别器判别真实样本的误差
dfake_loss_sum = tf.summary.scalar("dfake_loss", D_loss_fake) #记录判别器判别虚假样本的误差
d_loss_sum = tf.summary.scalar("d_loss", D_loss) #记录判别器的误差
g_loss_sum = tf.summary.scalar("g_loss", G_loss) #记录生成器的误差summary_writer = tf.summary.FileWriter('snapshots/', graph=tf.get_default_graph()) #日志记录器D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) #判别器的训练器
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) #生成器的训练器mb_size = 128 #训练的batch_size
Z_dim = 100 #生成器输入的随机噪声的列的维度mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集sess = tf.Session() #会话层
sess.run(tf.global_variables_initializer()) #初始化所有可训练参数if not os.path.exists('out/'): #初始化训练过程中的可视化结果的输出文件夹os.makedirs('out/')if not os.path.exists('snapshots/'): #初始化训练过程中的模型保存文件夹os.makedirs('snapshots/')saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=50) #模型的保存器i = 0 #训练过程中保存的可视化结果的索引for it in range(1000000): #训练100万次if it % 1000 == 0: #每训练1000次就保存一下结果samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})fig = plot(samples) #通过plot函数生成可视化结果plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight') #保存可视化结果i += 1plt.close(fig)X_mb, _ = mnist.train.next_batch(mb_size) #得到训练一个batch所需的真实手写数字(作为判别器的输入)#下面是得到训练一次的结果,通过sess来run出来_, D_loss_curr, dreal_loss_sum_value, dfake_loss_sum_value, d_loss_sum_value = sess.run([D_solver, D_loss, dreal_loss_sum, dfake_loss_sum, d_loss_sum], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})_, G_loss_curr, g_loss_sum_value = sess.run([G_solver, G_loss, g_loss_sum], feed_dict={Z: sample_Z(mb_size, Z_dim)})if it%100 ==0: #每过100次记录一下日志,可以通过tensorboard查看summary_writer.add_summary(dreal_loss_sum_value, it)summary_writer.add_summary(dfake_loss_sum_value, it)summary_writer.add_summary(d_loss_sum_value, it)summary_writer.add_summary(g_loss_sum_value, it)if it % 1000 == 0: #每训练1000次输出一下结果save(saver, sess, 'snapshots/', it)print('Iter: {}'.format(it))print('D loss: {:.4}'. format(D_loss_curr))print('G_loss: {:.4}'.format(G_loss_curr))print()

在上面的代码中,各位读者朋友可以看到,生成器与判别器都是使用多层感知机实现的(没有使用卷积神经网络)。生成器的输入是随机噪声,生成的是手写数字,生成器与判别器均使用Adam优化器进行训练并训练100w次。

4.结果:
然后我们的结果在out文件夹里:

5.分析:
1.进行GAN训练时,是将二维MNSIT数据拉伸成了一维数据进行处理,且GAN模型中没有用到卷积,只是多层神经网络的叠加,实现相对容易。
2.GAN的生成器和判别器中使用不同的激活函数。
3.有时候训练到最后所有的数字都变成了1,为什么会这样呢?因为失衡造成生成器崩溃了。

系列传送门:
初窥门径__生成对抗网络(GAN)(一)
融会贯通__条件生成对抗网络(cGAN)(三)
炉火纯青__深度卷积生成对抗网络(DCGAN)(四)
登堂入室__生成对抗网络的信息论扩展(infoGAN)(五)
渐入佳境__距离生成对抗网络(WGAN)(六)
登峰造极__边界均衡生成对抗网络(BEGAN)(七)
一代宗师__循环一致性生成对抗网络(CycleGAN)(八)

小试牛刀__GAN实战项目之mnist数据集(二)相关推荐

  1. Java实战项目(三)——二十一点游戏

    一.项目目标 利用Java swing技术能够实现玩家与电脑进行二十一点游戏.要求如下: 纸牌数:共52张纸牌,除去大小王两张纸牌. 花色:红桃.黑桃.方块.梅花. 纸牌的面值:A到10的纸牌面值按照 ...

  2. Hyperledger Fabric 超级账本实战项目(一、二)

    p1基础介绍 Fabric需要配置环境 hash概念:输入任何数据可以得到与其对应的hash值.发现不管data的长度是多少,它对应的hash长度是不变的:data也可以传空值,即空值也会对应一个ha ...

  3. 【人工智能项目】MNIST手写体识别实验及分析

    [人工智能项目]MNIST数据集实验报告 这是之前接的小作业,现在分享出来,给大家以学习!!! [人工智能项目]MNIST手写体识别实验及分析 1.实验内容简述 1.1 实验环境 本实验采用的软硬件实 ...

  4. 机器学习之sklearn使用下载MNIST数据集进行分类识别

    机器学习之sklearn使用下载MNIST数据集进行分类识别 一.MNIST数据集 1.MNIST数据集简介 2.获取MNIST数据集 二.训练一个二分类器 1.随机梯度下降(SGD)分类器 2.分类 ...

  5. (!详解 Pytorch实战:①)kaggle猫狗数据集二分类:加载(集成/自定义)数据集

    这系列的文章是我对Pytorch入门之后的一个总结,特别是对数据集生成加载这一块加强学习 另外,这里有一些比较常用的数据集,大家可以进行下载: 需要注意的是,本篇文章使用的PyTorch的版本是v0. ...

  6. 机器学习实战10-Artificial Neural Networks人工神经网络简介(mnist数据集)

    目录 一.感知器 1.1.单层感知器 1.2.多层感知器MLP与反向传播 二.用 TensorFlow 高级 API 训练 MLP DNNClassifier(深度神经网络分类器) 2.1.初始化: ...

  7. .NET Core实战项目之CMS 第十二章 开发篇-Dapper封装CURD及仓储代码生成器实现...

    本篇我将带着大家一起来对Dapper进行下封装并实现基本的增删改查.分页操作的同步异步方法的实现(已实现MSSQL,MySql,PgSQL).同时我们再实现一下仓储层的代码生成器,这样的话,我们只需要 ...

  8. TensorFlow:实战Google深度学习框架(四)MNIST数据集识别问题

    第5章 MNIST数字识别问题 5.1 MNIST数据处理 5.2 神经网络的训练以及不同模型结果的对比 5.2.1 TensorFlow训练神经网络 5.2.2 使用验证数据集判断模型的效果 5.2 ...

  9. 二隐层的神经网络实现MNIST数据集分类

    二隐层的神经网络实现MNIST数据集分类 传统的人工神经网络包含三部分,输入层.隐藏层和输出层.对于一个神经网络模型的确定需要考虑以下几个方面: 隐藏层的层数以及各层的神经元数量 各层激活函数的选择 ...

  10. 鸿蒙开发|呼吸训练实战项目(二)

    文章目录 鸿蒙开发|呼吸训练实战项目(二) 实现训练页面与主页面之间相互跳转 运行效果 实现思路 代码详解 验证应用和每个页面的生命周期时间 运行效果 在主界面中显示logo和两个选择器 实现思路 代 ...

最新文章

  1. 直播 | 小爱通用理解团队负责人雷宗:小爱同学中控意图理解
  2. 如何保证MongoDB的安全性? 1
  3. 使用txt文件导入数据库内容
  4. 房贷断供了,房子就要被收走,首付款怎么办?
  5. matlab中回归系数,最小一乘回归系数估计及其MATLAB实现
  6. SpringMVC 搭建遇到的坑
  7. windows 64 搭建RabbitMQ环境
  8. 【OOM】GC overhead limit exceeded
  9. Java-注解第一篇认识Annotation
  10. 【写博客常用】Word文档中怎么插入分隔线
  11. 基于共振解调的轴承故障诊断方法总结(一)
  12. 弹性地基梁板实用计算_YJK软件前处理之计算参数的设置(上篇)
  13. Beyond 比对工具
  14. java实现在线预览的功能(一)word转html
  15. 耕、林、园地分类搞不定?PIE-Engine机器学习带你攻克难题
  16. 自动生成图片及修改图片尺寸
  17. 如何在debian上安装google pingyin
  18. 不重视技术,何谈掌握核心技术?
  19. Leetcode-01-Tow SUM
  20. 大天使黎明服务器维护,亲爱的玩家: 您好,为保证服务

热门文章

  1. 【TiDB 4.0 新特性前瞻】DBA 减负捷径:拍个 CT 诊断集群热点问题
  2. 计算机网络放大器的作用,运算放大器工作原理是什么?
  3. Linux环境:可变剪切分析软件rMATS安装、使用与解读
  4. 关于累加偶数奇数的c语言程序,c语言 在1-100之间,求所有奇数和偶数的个数和所有奇数和偶数的和(写到一个里面)...
  5. discuz手机端默forum.php,discuz手机wap版模板开发方式简述
  6. HFSS 3D LAYOUT TDR仿真
  7. 最后一本书 上机5(翻书)
  8. 计算机视觉——图像拼接
  9. php制作奥运五环颜色代表的洲,php趣味 - php 奥运五环
  10. 决策树算法--C4.5算法