前言:最近在学习生成对抗网络(GAN, Generative Adversarial Networks),为了加深自己的理解,并帮助到想入门的同学,我特意写了这篇文章,教大家一步步搭建一个最简单原始的GAN网络 (Vanilla GAN)。代码后面会有详细(通俗易懂)的解释,大神请自动绕路~欢迎小白玩家围观~~ 查看本文jupyter notebook代码请点击这里。

利用GAN生成Mnist手写体图像(第1, 5, 10, 50, 400次迭代结果)

首先,让我们简单回顾一下什么是GAN

图1. GAN网络结构 (来自灵魂画手:我)

GAN最早由GoodFellow在2014年提出,查看原始论文请点击这里。GAN结构如图1所示,包含了一个生成器(Generator)和一个判别器 (Discriminator)。生成器的目的是生成以假乱真的图片,而判别器的目的是尽可能区分输入图片的真假。

举一个简单的例子,比如说假钞的流通。犯罪分子希望制作出逼真的假钞,可是警察的鉴定技术也在不断改良,双方互相博弈,互相提高,最终达到一种动态的平衡。讲到这里,是不是感觉很简单?

鉴于这是一个超级良心的教程~大家可以先跟着我一起,把代码实现。实现过程中,有不懂的先不要问(对,憋着),等跑完代码之后,看到酷炫的效果后,我再一步步解释为啥这么写。

好了,是时候放上代码了。来,先导入包。

from keras.datasets import mnist
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from google.colab import drive

然后,读取Keras自带的mnist数据集。在这里我们给出一个读取数据的函数load_data()。

# Load the dataset
def load_data():(x_train, y_train), (_, _) = mnist.load_data()x_train = (x_train.astype(np.float32) - 127.5)/127.5# Convert shape from (60000, 28, 28) to (60000, 784)x_train = x_train.reshape(60000, 784)return (x_train, y_train)X_train, y_train = load_data()
print(X_train.shape, y_train.shape)

输出X_train.shape, y_train.shape

由于本文我们旨在实现最原始的GAN网络,因此用最简单MLP全连接层来构建生成器(用卷积层当然更好,在这里先不考虑)

def build_generator():model = Sequential()model.add(Dense(units=256, input_dim=100))model.add(LeakyReLU(alpha=0.2))model.add(Dense(units=512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(units=1024))model.add(LeakyReLU(alpha=0.2))model.add(Dense(units=784, activation='tanh'))model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))return modelgenerator = build_generator()
generator.summary()

生成器结构如下图所示:

然后建一个判别器,也是一个MLP全连接神经网络:

def build_discriminator():model = Sequential()model.add(Dense(units=1024 ,input_dim=784))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.3))model.add(Dense(units=512))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.3))model.add(Dense(units=256))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.3))model.add(Dense(units=1, activation='sigmoid'))model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))return modeldiscriminator = build_discriminator()
discriminator.summary()

判别器结构如图所示:

然后,我们建立一个GAN网络,由discriminator和generator组成。

def build_GAN(discriminator, generator):discriminator.trainable=FalseGAN_input = Input(shape=(100,))x = generator(GAN_input)GAN_output= discriminator(x)GAN = Model(inputs=GAN_input, outputs=GAN_output)GAN.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))return GANGAN = build_GAN(discriminator, generator)
GAN.summary()

GAN结构如下图所示

然后我们给出绘制图像的函数,用于把generator生成的假图片画出来:

def draw_images(generator, epoch, examples=25, dim=(5,5), figsize=(10,10)):noise= np.random.normal(loc=0, scale=1, size=[examples, 100])generated_images = generator.predict(noise)generated_images = generated_images.reshape(25,28,28)plt.figure(figsize=figsize)for i in range(generated_images.shape[0]):plt.subplot(dim[0], dim[1], i+1)plt.imshow(generated_images[i], interpolation='nearest', cmap='Greys')plt.axis('off')plt.tight_layout()plt.savefig('Generated_images %d.png' %epoch)

OK, 最后一步,写一个train函数,来训练GAN网络。在这里我们设置最大迭代次数400,每次迭代生成128张假图片:

def train_GAN(epochs=1, batch_size=128):#Loading the dataX_train, y_train = load_data()# Creating GANgenerator= build_generator()discriminator= build_discriminator()GAN = build_GAN(discriminator, generator)for i in range(1, epochs+1):print("Epoch %d" %i)for _ in tqdm(range(batch_size)):# Generate fake images from random noisetnoise= np.random.normal(0,1, (batch_size, 100))fake_images = generator.predict(noise)# Select a random batch of real images from MNISTreal_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]# Labels for fake and real images           label_fake = np.zeros(batch_size)label_real = np.ones(batch_size) # Concatenate fake and real images X = np.concatenate([fake_images, real_images])y = np.concatenate([label_fake, label_real])# Train the discriminatordiscriminator.trainable=Truediscriminator.train_on_batch(X, y)# Train the generator/chained GAN model (with frozen weights in discriminator) discriminator.trainable=FalseGAN.train_on_batch(noise, label_real)# Draw generated images every 15 epoches     if i == 1 or i % 10 == 0:draw_images(generator, i)
train_GAN(epochs=400, batch_size=128)

我用了Google colab自带的GPU,训练400代大约用了十多分钟。如果用jupyter notebook在本机跑,会慢一些 (据说2分钟一代?)。

生成的图片如下图所示

第1次迭代
第10次迭代
第50次迭代
第400次迭代

大功告成,接下来我将一步步解释train_GAN()函数是怎么工作的。

首先,导入数据集,这个容易理解。

  #Loading the dataX_train, y_train = load_data()

接下来,建立一个GAN网络,GAN由两个神经网络(generator, discriminator)连接而成。

  # Creating GANgenerator= build_generator()discriminator= build_discriminator()GAN = build_GAN(discriminator, generator)

然后,建立一个循环(400次迭代)。tqdm用来动态显示每次迭代的进度。

 for i in range(1, epochs+1):print("Epoch %d" %i)for _ in tqdm(range(batch_size)):

接着,我们生成呈高斯分布的噪声,利用generator,来生成batch_size(128张)图片。每张图片的输入就是一个1*100的噪声矩阵。

  # Generate fake images from random noisetnoise= np.random.normal(0,1, (batch_size, 100))fake_images = generator.predict(noise)

同样的,我们从Mnist数据集中随机挑选128张真实图片。我们给真实图片标注1,给假图片标注0,然后将256张真假图片混合在一起。

 # Select a random batch of real images from MNISTreal_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]# Labels for fake and real images           label_fake = np.zeros(batch_size)label_real = np.ones(batch_size) # Concatenate fake and real images X = np.concatenate([fake_images, real_images])y = np.concatenate([label_fake, label_real])

此时,我们利用上文提到的256张带标签的真假图片,训练discriminator。训练完毕后,discriminator的weights得到了更新。(打个比方,警察通过研究市面上流通的假币,在一起开会讨论,努力研发出了新一代鉴定假钞的方法)。

# Train the discriminator
discriminator.trainable=True
discriminator.train_on_batch(X, y)

然后,我们冻结住discriminator的weights,让discriminator不再变化。然后就开始训练generator (chained GAN)。在GAN的训练中,我们输入一堆噪声,期待的输出是将假图片预测为真。在这个过程中,generator继续生成假图片,送到discriminator检验,得到检验结果,如果被鉴定为假,就不断更新自己的权重(假钞贩子不断改良造假技术),直到discriminator将加图片鉴定为真图片(直到当前鉴定假钞的技术无法识别出假钞)。

 # Train the generator/chained GAN model (with frozen weights in discriminator) discriminator.trainable=FalseGAN.train_on_batch(noise, label_real)

OK,此时一次迭代进行完毕。接下来是第2, 3, ...次迭代。

现在,我们总结一下每次迭代发生了什么:

  1. Generator利用自己最新的权重,生成了一堆假图片。
  2. Discrminator根据真假图片的真实label,不断训练更新自己的权重,直到可以顺利鉴别真假图片。
  3. 此时discriminator权重被固定,不再发生变化。generator利用最新的discrimintor,苦苦思索,不断训练自己的权重,最终使discriminator将假图片鉴定为真图片。

换成印制假钞的例子,每次迭代发生了如下几件事:

  1. 假钞贩子根据最新造假技术,研发出一代假钞。
  2. 警察反复对比新型假钞和真币的区别,成功改良假钞鉴别方法,从而顺利鉴别出市面流通钞票的真伪。
  3. 假钞贩子生成假钞,马上被警察鉴别出来,痛定思痛,改良技术生成新的假钞。不成想,一上街又被警察识别了出来。日复一日,终于发明了新型假钞,当前的验钞技术已经无法成功检测出这种假钞。

然后通过每次迭代,discrimintor (警察的鉴定技术)和generator (假钞制作技术) 都越来越成熟...后来达到了动态平衡。

嗯,就这样,是不是挺简单的?

今天讲的是最原始的GAN网络,GAN发展到了如今已有许多变种,如将MLP结构换成CNN,Autoencoder,以及loss function的变化等等。我在github上找到一个超级全的用keras编写的各种花式GAN网络集合,有兴趣的小伙伴直接点击这里。本文的jupyter notebook代码请直接点击下面的小卡片~

https://nbviewer.jupyter.org/github/gaonanlee/Deep-Learning-Experiments/blob/master/Vanilla%20GAN_implementation.ipynb​nbviewer.jupyter.org

如有理解不到位之处,欢迎批评指教。

参考文献

  1. https://github.com/eriklindernoren/Keras-GAN/blob/master/cgan/cgan.py

2. https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3


我的其他回答:

哪些 Python 库让你相见恨晚?

python如何画出漂亮的地图?

时间序列数据如何插补缺失值?

机器学习中的因果关系: 从辛普森悖论(常见的统计学谬误)谈起

c语言贪吃蛇最简单代码_让我们跑一个最简单的GAN网络吧!(附Jupyter Notebook 代码)...相关推荐

  1. cmd上写的java简单代码_用cmd编辑一个超级简单的小游戏,求代码

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 贪吃蛇: import java.awt.*; import java.util.LinkedList; import java.util.Scanner ...

  2. python接水果游戏代码_使用Python开发一个超级简单的接水果小游戏,零基础也可以学会...

    Pylash项目地址 创建项目 这样的话我们的项目就创建好了,然后只用往Main.py里填写代码运行即可. 编写Hello World小程序 编写游戏 有以上对pylash的小小了解,我们接下来可以开 ...

  3. c语言五子棋代码_基于控制台的C语言贪吃蛇

    相信对很多人来说,学完C语言以后,都会找一些小程序来练练手.例如贪吃蛇.五子棋.俄罗斯方块等等. 今天给大家分享一个基于控制台的C语言贪吃蛇小程序. 基础知识要求:C语言基础. 知识点补充 这里写一些 ...

  4. c语言安卓贪吃蛇代码下载,C语言贪吃蛇代码

    c语言编写贪吃蛇源代码,简单易懂,文件为VC源代码.如果你正在学习c语言,就来下载吧.很经典的 C语言贪吃蛇代码部分 #include #include #include#include #defin ...

  5. C语言贪吃蛇游戏代码,贪吃蛇C语言代码实现大全

    一.C语言贪吃蛇代码实现前言 设计贪吃蛇游戏的主要目的是让大家夯实C语言基础,训练编程思维,培养解决问题的思路,领略多姿多彩的C语言. 贪吃蛇是非常经典的一款游戏,本次我们模拟在控制台实现贪吃蛇游戏, ...

  6. 贪吃蛇的c语言程序码,C语言贪吃蛇代码下载_C语言贪吃蛇代码官方下载-太平洋下载中心...

    C语言编写贪吃蛇源代码,简单易懂,文件为VC源代码.如果你正在学习c语言,就来下载吧.很经典的. C语言贪吃蛇代码原理: 产生一个固定大小没有边界的游戏区域,蛇从区域的中心开始,由玩家通过键盘控制蛇的 ...

  7. 贪吃蛇统计分数的c语言代码,C/C++编程笔记:C语言贪吃蛇源代码控制台(二),分数和食物!...

    接上文<C/C++编程笔记:C语言贪吃蛇源代码控制台(一),会动的那种哦!>如果你在学习C语言开发贪吃蛇的话,零基础建议从上一篇开始哦!接下来正式开始吧! 三.蛇的运动 上次我已经教大家画 ...

  8. c语言对抗程序代码,C语言贪吃蛇源程序代码双人对抗

    C语言贪吃蛇源程序代码双人对抗 #include #include #include #include #include #include #include #define LEFT 100 #def ...

  9. 超简单的C语言贪吃蛇 不闪屏 双缓冲

    C语言贪吃蛇 今天把以前自己写的贪吃蛇总结了一下,发到博客上,怕放在电脑上哪天丢失了都不知道, 有不当之处还望指教 (*・ω< ) ヾ(◍°∇°◍)ノ゙ 贪吃蛇中, 我们看到的蛇在不断的移动,其 ...

  10. 完整版C语言贪吃蛇代码

    C语言贪吃蛇完整代码 #include <stdio.h> #include <stdlib.h> #include <Windows.h>//windows编程头 ...

最新文章

  1. k8s 使用Nginx Ingress实现灰度发布和蓝绿发布
  2. 免费教材丨第55期:Python机器学习实践指南、Tensorflow 实战Google深度学习框架
  3. 结对编程作业——四则运算GUI程序
  4. 特征工程(3):特征选择
  5. 直播 | AAAI 2021最佳论文:比Transformer更有效的长时间序列预测
  6. shiro的登录 subject.login(token)中执行逻辑和流程
  7. Jerry Wang在SAP社区上获得的徽章
  8. java jdk win10安装_Java 安装 JDK WIN10
  9. 【算法设计与分析】15 分治策略:芯片测试
  10. 【年终总结】可圈可点的2018年
  11. python玩跳一跳_python玩跳一跳
  12. 假期最后一天,出差赶到天津
  13. 使用Zoiper与freeSWITCH开视频会议
  14. 腾讯 IVWEB 团队:前端识别验证码思路分析
  15. 移动硬盘和电脑内置硬盘使用时的区别
  16. java常见单词汇总3(非常使用哦)
  17. 获得拼多多商品详情(商品主图、sku)
  18. jQuery之属性操作
  19. 秒杀系统的设计五大原则
  20. v模拟器(华为、H3C)点滴

热门文章

  1. java技术类网站收录
  2. 缓存与IO(很经典)
  3. 【鱼眼镜头6】[鱼眼畸变模型]:统一相机模型标定
  4. numpy 库使用说明
  5. 进程切换与线程切换的区别
  6. Java ConcurrentHashMap
  7. 素数筛(快速筛)-爱拉托斯特尼筛法+欧拉筛
  8. fatal: Not a git repository (or any of the parent directories): .git的解决办法
  9. android 进程(复习)
  10. 分享:国外著名代码管理网站GitHub访问方式