文章目录

  • 一、GAN原理
  • 二、项目实战
    • 2.1 项目背景
    • 2.2 网络描述
    • 2.3 项目实战

一、GAN原理

生成对抗网络简称GAN,是由两个网络组成的,一个生成器网络和一个判别器网络。这两个网络可以是神经网络(从卷积神经网络、循环神经网络到自编码器)。生成器从给定噪声中(一般是指均匀分布或者正态分布)产生合成数据,判别器分辨生成器的的输出和真实数据。前者试图产生更接近真实的数据,相应地,后者试图更完美地分辨真实数据与生成数据。由此,两个网络在对抗中进步,在进步后继续对抗,由生成式网络得的数据也就越来越完美,逼近真实数据,从而可以生成想要得到的数据(图片、序列、视频等)。网络结构如下:

GAN的公式:

更多理论知识请参考:生成对抗网络GAN详细推导。

需要注意的就是GAN模型中没有用到卷积,只是简单的多层神经网络。在网络中,也没有新的tensorflow函数,在经过了前边几篇学习笔记和理论知识学习之后,实现起来也比较容易实现。

二、项目实战

2.1 项目背景

基于MNIST数据集,利用GAN生成手写体数字。有关MNIST的简介,请参考:tensorflow学习笔记(四):利用BP手写体(MNIST)识别。

2.2 网络描述

利用Tensorflow搭建GAN网络,其中生成网络为两层全连接层,判别网络也为两层全连接层,并用MNIST训练,然后生成手写体数字。

2.3 项目实战

from __future__ import division, print_function, absolute_importimport matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf# 导入MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)# 训练参数
num_steps = 20000
batch_size = 128
learning_rate = 0.0002# 网络参数
image_dim = 784  # 28*28
gen_hidden_dim = 256
disc_hidden_dim = 256
noise_dim = 100 # Noise data points# A custom initialization (see Xavier Glorot init)
def glorot_init(shape):return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))# 保存隐藏层的权重和偏置
weights = {'gen_hidden1': tf.Variable(glorot_init([noise_dim, gen_hidden_dim])),'gen_out': tf.Variable(glorot_init([gen_hidden_dim, image_dim])),'disc_hidden1': tf.Variable(glorot_init([image_dim, disc_hidden_dim])),'disc_out': tf.Variable(glorot_init([disc_hidden_dim, 1])),
}
biases = {'gen_hidden1': tf.Variable(tf.zeros([gen_hidden_dim])),'gen_out': tf.Variable(tf.zeros([image_dim])),'disc_hidden1': tf.Variable(tf.zeros([disc_hidden_dim])),'disc_out': tf.Variable(tf.zeros([1])),
}# 生成网络
def generator(x):hidden_layer = tf.matmul(x, weights['gen_hidden1'])hidden_layer = tf.add(hidden_layer, biases['gen_hidden1'])hidden_layer = tf.nn.relu(hidden_layer)out_layer = tf.matmul(hidden_layer, weights['gen_out'])out_layer = tf.add(out_layer, biases['gen_out'])out_layer = tf.nn.sigmoid(out_layer)return out_layer# 判别网络
def discriminator(x):hidden_layer = tf.matmul(x, weights['disc_hidden1'])hidden_layer = tf.add(hidden_layer, biases['disc_hidden1'])hidden_layer = tf.nn.relu(hidden_layer)out_layer = tf.matmul(hidden_layer, weights['disc_out'])out_layer = tf.add(out_layer, biases['disc_out'])out_layer = tf.nn.sigmoid(out_layer)return out_layer##############创建网络
# 网络输入
gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise')
disc_input = tf.placeholder(tf.float32, shape=[None, image_dim], name='disc_input')# 创建生成网络
gen_sample = generator(gen_input)# 创建两个判别网络 (一个来自噪声输入, 一个来自生成的样本)
disc_real = discriminator(disc_input)
disc_fake = discriminator(gen_sample)# 定义损失函数
gen_loss = -tf.reduce_mean(tf.log(disc_fake))
disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake))# 定义优化器
optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate)
optimizer_disc = tf.train.AdamOptimizer(learning_rate=learning_rate)# 训练每个优化器的变量
# 生成网络变量
gen_vars = [weights['gen_hidden1'], weights['gen_out'],biases['gen_hidden1'], biases['gen_out']]
# 判别网络变量
disc_vars = [weights['disc_hidden1'], weights['disc_out'],biases['disc_hidden1'], biases['disc_out']]# 最小损失函数
train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)# 初始化变量
init = tf.global_variables_initializer()# 开始训练
with tf.Session() as sess:sess.run(init)for i in range(1, num_steps+1):# 准备数据batch_x, _ = mnist.train.next_batch(batch_size)# 产生噪声给生成网络z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])# 训练feed_dict = {disc_input: batch_x, gen_input: z}_, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],feed_dict=feed_dict)if i % 1000 == 0 or i == 1:print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))# 使用生成器网络从噪声生成图像f, a = plt.subplots(4, 10, figsize=(10, 4))for i in range(10):# 噪声输入.z = np.random.uniform(-1., 1., size=[4, noise_dim])g = sess.run([gen_sample], feed_dict={gen_input: z})g = np.reshape(g, newshape=(4, 28, 28, 1))# 将原来黑底白字转换成白底黑字,更好的显示g = -1 * (g - 1)for j in range(4):# 从噪音中生成图像。 扩展到3个通道,用于matplotlibimg = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),newshape=(28, 28, 3))a[j][i].imshow(img)f.show()plt.draw()plt.waitforbuttonpress()

训练20000次,输出结果:

Step 1: Generator Loss: 1.002087, Discriminator Loss: 1.212741
Step 1000: Generator Loss: 3.819249, Discriminator Loss: 0.063358
Step 2000: Generator Loss: 4.281909, Discriminator Loss: 0.040046
Step 3000: Generator Loss: 3.737413, Discriminator Loss: 0.072012
Step 4000: Generator Loss: 3.734505, Discriminator Loss: 0.121832
Step 5000: Generator Loss: 3.478826, Discriminator Loss: 0.155717
Step 6000: Generator Loss: 3.131607, Discriminator Loss: 0.167828
Step 7000: Generator Loss: 3.458174, Discriminator Loss: 0.176890
Step 8000: Generator Loss: 3.987390, Discriminator Loss: 0.132476
Step 9000: Generator Loss: 3.256813, Discriminator Loss: 0.246182
Step 10000: Generator Loss: 4.022185, Discriminator Loss: 0.106170
Step 11000: Generator Loss: 3.692181, Discriminator Loss: 0.229384
Step 12000: Generator Loss: 3.681010, Discriminator Loss: 0.221918
Step 13000: Generator Loss: 3.232910, Discriminator Loss: 0.276704
Step 14000: Generator Loss: 3.951521, Discriminator Loss: 0.223627
Step 15000: Generator Loss: 3.263102, Discriminator Loss: 0.262820
Step 16000: Generator Loss: 3.180792, Discriminator Loss: 0.326289
Step 17000: Generator Loss: 3.495943, Discriminator Loss: 0.350409
Step 18000: Generator Loss: 3.797458, Discriminator Loss: 0.174091
Step 19000: Generator Loss: 2.964710, Discriminator Loss: 0.286498
Step 20000: Generator Loss: 3.576961, Discriminator Loss: 0.336350

生成的图片:

tensorflow学习笔记(十):GAN生成手写体数字(MNIST)相关推荐

  1. TensorFlow学习笔记(二)手写体数字的识别——环境安装

    手写体数字的识别--环境安装 上一篇 Anaconda 的安装: 建立 TensorFlow 的 Anaconda 虚拟环境 1. 建立工作目录 2. 建立 Anaconda 虚拟环境 3. 启动 A ...

  2. 基于Keras的生成对抗网络(1)——利用Keras搭建简单GAN生成手写体数字

    目录 0.前言 一.GAN结构 二.函数代码 2.1 生成器Generator 2.2 判别器Discriminator 2.3 train函数 三.结果演示 四.完整代码 五.常见问题汇总 0.前言 ...

  3. TensorFlow学习笔记(二)把数字标签转化成onehot标签

    在MNIST手写字数据集中,我们导入的数据和标签都是预先处理好的,但是在实际的训练中,数据和标签往往需要自己进行处理. 以手写数字识别为例,我们需要将0-9共十个数字标签转化成onehot标签.例如: ...

  4. TensorFlow学习笔记(一):手写数字识别之softmax回归

    在Tensorflow中实现逻辑回归的步骤: 一般来讲,使用Tensorflow实现机器学习算法模型的步骤如下: 1.定义算法公式: 2.定义loss函数,选择优化器优化loss: 3.使用输入训练集 ...

  5. tensorflow学习笔记十7:tensorflow官方文档学习 How to Retrain Inception's Final Layer for New Categories

    现代物体识别模型有数以百万计的参数,可能需要数周才能完全训练.学习迁移是一个捷径,很多这样的工作,以充分的训练模式的一组类ImageNet技术,并从现有的权重进行新课.在这个例子中,我们将从头再训练最 ...

  6. 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)

    图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow) 文章目录 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网 ...

  7. GAN网络生成手写体数字图片

    Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的. 目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接 ...

  8. tensorflow学习笔记(七):CNN手写体(MNIST)识别

    文章目录 一.CNN简介 二.主要函数 三.CNN的手写体识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.CNN简介 一般的卷积神经网络由以下几个层组成:卷积层,池化层,非线性激活函数 ...

  9. tensorflow学习笔记(三十二):conv2d_transpose (解卷积)

    tensorflow学习笔记(三十二):conv2d_transpose ("解卷积") deconv解卷积,实际是叫做conv_transpose, conv_transpose ...

最新文章

  1. 10年,4600万台!树莓派,生日快乐
  2. python【力扣LeetCode算法题库】67-二进制求和
  3. 经典C语言能力测试题(值得一看)
  4. java基础系列:集合总结(7)
  5. Promise 学习心得
  6. C#接口的使用场合,接口应用
  7. Java 通过Executors创建线程池的种类
  8. servlet到mysql_在servlet中搜索代码到mysql?
  9. mockito 静态方法_Mockito模拟静态方法– PowerMock
  10. 第六章 jQuery选择器
  11. 慎用某60软件清理垃圾,导致三星SSD T5不识别了,折腾了一下午,终于弄好了
  12. Windows 10 VMware Workstation Server服务启动一段时间后自动异常关闭
  13. Labview温度采集系统
  14. favi.icon是什么?
  15. 【干货】从QQ群起家的情趣商城站长之路
  16. 蓝牙BQB认证 Profile测试
  17. 风潮唱片-远方的寂静;专辑
  18. Objective C范型
  19. 项目时间管理-知识领域
  20. Rust 数据内存布局

热门文章

  1. IDEA自动导入包和删除包的设置
  2. MapReduce编程框架
  3. 300英雄11月服务器维护,300英雄11月19日全区停机更新公告
  4. WEB 性能测试-介绍 学习笔记
  5. 高新区python培训机构
  6. 如何塑造品牌超级符号?
  7. 安卓实现类微信门户页面
  8. 热点展会2023第十五届上海国际物联网展览会
  9. 什么花可以代表父爱哪?用python分析百句父爱名言竟发现!!!
  10. centos 命令不能用 command not found