前言

本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。

本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:

# 用于生成 GIF 图片
pip install -q imageio

目录

前言

一、什么是生成对抗网络?

二、加载数据集

三、创建模型

3.1 生成器

3.1 判别器

四、定义损失函数和优化器

4.1 生成器的损失和优化器

4.2 判别器的损失和优化器

五、训练模型

5.1 保存检查点

5.2 定义训练过程

5.3 训练模型

六、评估模型


一、什么是生成对抗网络?

生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。

生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。

判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。

训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。

当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。

本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。

二、加载数据集

使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内BUFFER_SIZE = 60000
BATCH_SIZE = 256# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

三、创建模型

主要创建两个模型,一个是生成器,另一个是判别器

3.1 生成器

生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。

然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。

后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。

def make_generator_model():model = tf.keras.Sequential()model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Reshape((7, 7, 256)))assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))assert model.output_shape == (None, 7, 7, 128)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))assert model.output_shape == (None, 14, 14, 64)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))assert model.output_shape == (None, 28, 28, 1)return model

用tf.keras.utils.plot_model( ),看一下模型结构

用summary(),看一下模型结构和参数

使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。

generator = make_generator_model()noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)plt.imshow(generated_image[0, :, :, 0], cmap='gray')

3.1 判别器

判别器是基于 CNN卷积神经网络 的图片分类器。

def make_discriminator_model():model = tf.keras.Sequential()model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',input_shape=[28, 28, 1]))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Flatten())model.add(layers.Dense(1))return model

用tf.keras.utils.plot_model( ),看一下模型结构

用summary(),看一下模型结构和参数

四、定义损失函数和优化器

由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。

首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。

# 该方法返回计算交叉熵损失的辅助函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

4.1 生成器的损失和优化器

1)生成器损失

生成器损失是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。

这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。

def generator_loss(fake_output):return cross_entropy(tf.ones_like(fake_output), fake_output)

2)生成器优化器

generator_optimizer = tf.keras.optimizers.Adam(1e-4)

4.2 判别器的损失和优化器

1)判别器损失

判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。

def discriminator_loss(real_output, fake_output):real_loss = cross_entropy(tf.ones_like(real_output), real_output)fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)total_loss = real_loss + fake_lossreturn total_loss

2)判别器优化器

discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

五、训练模型

5.1 保存检查点

保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,discriminator_optimizer=discriminator_optimizer,generator=generator,discriminator=discriminator)

5.2 定义训练过程

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。

判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。

两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。

# 注意 `tf.function` 的使用
# 该注解使函数被“编译”
@tf.function
def train_step(images):noise = tf.random.normal([BATCH_SIZE, noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:generated_images = generator(noise, training=True)real_output = discriminator(images, training=True)fake_output = discriminator(generated_images, training=True)gen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))def train(dataset, epochs):for epoch in range(epochs):start = time.time()for image_batch in dataset:train_step(image_batch)# 继续进行时为 GIF 生成图像display.clear_output(wait=True)generate_and_save_images(generator,epoch + 1,seed)# 每 15 个 epoch 保存一次模型if (epoch + 1) % 15 == 0:checkpoint.save(file_prefix = checkpoint_prefix)print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))# 最后一个 epoch 结束后生成图片display.clear_output(wait=True)generate_and_save_images(generator,epochs,seed)# 生成与保存图片
def generate_and_save_images(model, epoch, test_input):# 注意 training` 设定为 False# 因此,所有层都在推理模式下运行(batchnorm)。predictions = model(test_input, training=False)fig = plt.figure(figsize=(4,4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))plt.show()

5.3 训练模型

调用上面定义的train()函数,来同时训练生成器和判别器。

注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。

%%time
train(train_dataset, EPOCHS)

在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。

训练了15轮的效果:

训练了30轮的效果:

训练过程:

恢复最新的检查点

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

六、评估模型

这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。

# 使用 epoch 数生成单张图片
def display_image(epoch_no):return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))display_image(EPOCHS)
anim_file = 'dcgan.gif'with imageio.get_writer(anim_file, mode='I') as writer:filenames = glob.glob('image*.png')filenames = sorted(filenames)last = -1for i,filename in enumerate(filenames):frame = 2*(i**0.5)if round(frame) > round(last):last = frameelse:continueimage = imageio.imread(filename)writer.append_data(image)image = imageio.imread(filename)writer.append_data(image)import IPython
if IPython.version_info > (6,2,0,''):display.Image(filename=anim_file)

完整代码:

import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import timefrom IPython import display(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内BUFFER_SIZE = 60000
BATCH_SIZE = 256# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)# 创建模型--生成器
def make_generator_model():model = tf.keras.Sequential()model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Reshape((7, 7, 256)))assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))assert model.output_shape == (None, 7, 7, 128)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))assert model.output_shape == (None, 14, 14, 64)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))assert model.output_shape == (None, 28, 28, 1)return model# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
generator = make_generator_model()noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)plt.imshow(generated_image[0, :, :, 0], cmap='gray')
tf.keras.utils.plot_model(generator)# 判别器
def make_discriminator_model():model = tf.keras.Sequential()model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',input_shape=[28, 28, 1]))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Flatten())model.add(layers.Dense(1))return model# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)# 生成器的损失和优化器
def generator_loss(fake_output):return cross_entropy(tf.ones_like(fake_output), fake_output)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)# 判别器的损失和优化器
def discriminator_loss(real_output, fake_output):real_loss = cross_entropy(tf.ones_like(real_output), real_output)fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)total_loss = real_loss + fake_lossreturn total_loss
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)# 保存检查点
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,discriminator_optimizer=discriminator_optimizer,generator=generator,discriminator=discriminator)# 定义训练过程
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])# 注意 `tf.function` 的使用
# 该注解使函数被“编译”
@tf.function
def train_step(images):noise = tf.random.normal([BATCH_SIZE, noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:generated_images = generator(noise, training=True)real_output = discriminator(images, training=True)fake_output = discriminator(generated_images, training=True)gen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))def train(dataset, epochs):for epoch in range(epochs):start = time.time()for image_batch in dataset:train_step(image_batch)# 继续进行时为 GIF 生成图像display.clear_output(wait=True)generate_and_save_images(generator,epoch + 1,seed)# 每 15 个 epoch 保存一次模型if (epoch + 1) % 15 == 0:checkpoint.save(file_prefix = checkpoint_prefix)print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))# 最后一个 epoch 结束后生成图片display.clear_output(wait=True)generate_and_save_images(generator,epochs,seed)# 生成与保存图片
def generate_and_save_images(model, epoch, test_input):# 注意 training` 设定为 False# 因此,所有层都在推理模式下运行(batchnorm)。predictions = model(test_input, training=False)fig = plt.figure(figsize=(4,4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))plt.show()# 训练模型
train(train_dataset, EPOCHS)# 恢复最新的检查点
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))# 评估模型
# 使用 epoch 数生成单张图片
def display_image(epoch_no):return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))display_image(EPOCHS)anim_file = 'dcgan.gif'with imageio.get_writer(anim_file, mode='I') as writer:filenames = glob.glob('image*.png')filenames = sorted(filenames)last = -1for i,filename in enumerate(filenames):frame = 2*(i**0.5)if round(frame) > round(last):last = frameelse:continueimage = imageio.imread(filename)writer.append_data(image)image = imageio.imread(filename)writer.append_data(image)import IPython
if IPython.version_info > (6,2,0,''):display.Image(filename=anim_file)

参考:https://www.tensorflow.org/tutorials/generative/dcgan

一篇文章“简单”认识《生成对抗网络》(GAN)

深度卷积生成对抗网络DCGAN——生成手写数字图片相关推荐

  1. 生成对抗网络(GAN)——MNIST手写数字生成

    前言 正文 一.什么是GAN 二.GAN的应用 三.GAN的网络模型 对抗生成手写数字 一.引入必要的库 一.引入必要的库 二.进行准备工作 三.定义生成器和判别器模型 四.设置损失函数和优化器,以及 ...

  2. Pytorch:GAN生成对抗网络实现MNIST手写数字的生成

    github:https://github.com/SPECTRELWF/pytorch-GAN-study 个人主页:liuweifeng.top:8090 网络结构 最近在疯狂补深度学习一些基本架 ...

  3. 深度学习故障诊断之-使用条件生成对抗网络CGAN生成泵流量信号

    开始填坑 MATLAB统计机器学习,深度学习,计算机视觉 - 哥廷根数学学派的文章 - 知乎 MATLAB统计机器学习,深度学习,计算机视觉 - 知乎 之前写过在使用深度学习对机械系统或电气系统进行故 ...

  4. GAN (生成对抗网络) 手写数字图片生成

    GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...

  5. 使用Keras训练Lenet网络来进行手写数字识别

    使用Keras训练Lenet网络来进行手写数字识别 这篇博客将介绍如何使用Keras训练Lenet网络来进行手写数字识别. LeNet架构是深度学习中的一项开创性工作,演示了如何训练神经网络以端到端的 ...

  6. 02:一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    标签(空格分隔): 王小草Tensorflow笔记 笔记整理者:王小草 笔记整理时间2017年2月24日 Tensorflow官方英文文档地址:https://www.tensorflow.org/g ...

  7. 卷积神经网络与循环神经网络实战 --- 手写数字识别及诗词创作

    卷积神经网络与循环神经网络实战 - 手写数字识别及诗词创作 文章目录 卷积神经网络与循环神经网络实战 --- 手写数字识别及诗词创作 一.神经网络相关知识 1. 深度学习 2. 人工神经网络回顾 3. ...

  8. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...

  9. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...

最新文章

  1. Ubuntu iso镜像文件写入U盘
  2. 华南理工大学 高级程序设计语言 c++ ,2017华南理工大学网络教育《高级语言程序设计C++》平时作业...
  3. eureka架构图原理
  4. PHP读取TXT UTF-8,2)PHP中把读取.txt中内容并转为UTF-8格式
  5. 2017西安交大ACM小学期数论 [等差数列]
  6. python编程技巧1002python编程技巧_总结Python编程中三条常用的技巧
  7. 我眼中的Web2.0
  8. spring读取jdbc(file方式)
  9. 【元胞自动机】基于matlab元胞自动机车流密度不变下的双向两车道仿真(T 字形路口)【含Matlab源码 1290期】
  10. Linux:libxml2的安装及使用示例(C语言)
  11. 求职必看~10分钟读懂国企、事业单位和公务员的区别
  12. 网站域名假墙处理方法 内含cloudflare API自动更换IP的php脚本
  13. 在进化计算中,软件进行元基编码的新陈代谢方式 V0. 1. 0
  14. KILE生成S19或者BIN文件
  15. 《Java程序小作业之自动贩卖机》#谭子
  16. 吴彩强:从表征到行动---意向性的自然主义进路
  17. java: 关于反射
  18. oracle分区 varchar2,oracle分区怎么使用
  19. DNS(从域名到IP地址的对应)
  20. 雷达编程实战之恒虚警率(CFAR)检测

热门文章

  1. bcnf分解算法_数据库规范化:模式分解算法(3NF,BCNF分解,附带口诀,通俗易懂)...
  2. 不讲武德放大招 云上安全桥头堡
  3. android手机打开java文件_Android Studio打开手机权限
  4. 51单片机--数字电子时钟(单片机基础应用)
  5. java perf_系统级性能分析工具perf的介绍与使用
  6. c语言分母多项乘积怎么算,C++编程 用梯形求积公式求解定积分∫3lnxdx积分区间为(1,2, C语言,用梯形法编程求定积分x^3+x/2+1的值...
  7. 从苏宁电器到卡巴斯基第12篇:我在苏宁电器当营业员 IV
  8. Centos7.6 安装devstack
  9. 春节天猫快递照常送,菜鸟给全国的值守快递员发了团圆基金!
  10. Matplotlib格式化轴