c语言贪吃蛇最简单代码_让我们跑一个最简单的GAN网络吧!(附Jupyter Notebook 代码)...
前言:最近在学习生成对抗网络(GAN, Generative Adversarial Networks),为了加深自己的理解,并帮助到想入门的同学,我特意写了这篇文章,教大家一步步搭建一个最简单原始的GAN网络 (Vanilla GAN)。代码后面会有详细(通俗易懂)的解释,大神请自动绕路~欢迎小白玩家围观~~ 查看本文jupyter notebook代码请点击这里。
首先,让我们简单回顾一下什么是GAN。
GAN最早由GoodFellow在2014年提出,查看原始论文请点击这里。GAN结构如图1所示,包含了一个生成器(Generator)和一个判别器 (Discriminator)。生成器的目的是生成以假乱真的图片,而判别器的目的是尽可能区分输入图片的真假。
举一个简单的例子,比如说假钞的流通。犯罪分子希望制作出逼真的假钞,可是警察的鉴定技术也在不断改良,双方互相博弈,互相提高,最终达到一种动态的平衡。讲到这里,是不是感觉很简单?
鉴于这是一个超级良心的教程~大家可以先跟着我一起,把代码实现。实现过程中,有不懂的先不要问(对,憋着),等跑完代码之后,看到酷炫的效果后,我再一步步解释为啥这么写。
好了,是时候放上代码了。来,先导入包。
from keras.datasets import mnist
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from google.colab import drive
然后,读取Keras自带的mnist数据集。在这里我们给出一个读取数据的函数load_data()。
# Load the dataset
def load_data():(x_train, y_train), (_, _) = mnist.load_data()x_train = (x_train.astype(np.float32) - 127.5)/127.5# Convert shape from (60000, 28, 28) to (60000, 784)x_train = x_train.reshape(60000, 784)return (x_train, y_train)X_train, y_train = load_data()
print(X_train.shape, y_train.shape)
由于本文我们旨在实现最原始的GAN网络,因此用最简单MLP全连接层来构建生成器(用卷积层当然更好,在这里先不考虑)
def build_generator():model = Sequential()model.add(Dense(units=256, input_dim=100))model.add(LeakyReLU(alpha=0.2))model.add(Dense(units=512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(units=1024))model.add(LeakyReLU(alpha=0.2))model.add(Dense(units=784, activation='tanh'))model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))return modelgenerator = build_generator()
generator.summary()
生成器结构如下图所示:
然后建一个判别器,也是一个MLP全连接神经网络:
def build_discriminator():model = Sequential()model.add(Dense(units=1024 ,input_dim=784))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.3))model.add(Dense(units=512))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.3))model.add(Dense(units=256))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.3))model.add(Dense(units=1, activation='sigmoid'))model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))return modeldiscriminator = build_discriminator()
discriminator.summary()
判别器结构如图所示:
然后,我们建立一个GAN网络,由discriminator和generator组成。
def build_GAN(discriminator, generator):discriminator.trainable=FalseGAN_input = Input(shape=(100,))x = generator(GAN_input)GAN_output= discriminator(x)GAN = Model(inputs=GAN_input, outputs=GAN_output)GAN.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))return GANGAN = build_GAN(discriminator, generator)
GAN.summary()
GAN结构如下图所示
然后我们给出绘制图像的函数,用于把generator生成的假图片画出来:
def draw_images(generator, epoch, examples=25, dim=(5,5), figsize=(10,10)):noise= np.random.normal(loc=0, scale=1, size=[examples, 100])generated_images = generator.predict(noise)generated_images = generated_images.reshape(25,28,28)plt.figure(figsize=figsize)for i in range(generated_images.shape[0]):plt.subplot(dim[0], dim[1], i+1)plt.imshow(generated_images[i], interpolation='nearest', cmap='Greys')plt.axis('off')plt.tight_layout()plt.savefig('Generated_images %d.png' %epoch)
OK, 最后一步,写一个train函数,来训练GAN网络。在这里我们设置最大迭代次数400,每次迭代生成128张假图片:
def train_GAN(epochs=1, batch_size=128):#Loading the dataX_train, y_train = load_data()# Creating GANgenerator= build_generator()discriminator= build_discriminator()GAN = build_GAN(discriminator, generator)for i in range(1, epochs+1):print("Epoch %d" %i)for _ in tqdm(range(batch_size)):# Generate fake images from random noisetnoise= np.random.normal(0,1, (batch_size, 100))fake_images = generator.predict(noise)# Select a random batch of real images from MNISTreal_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]# Labels for fake and real images label_fake = np.zeros(batch_size)label_real = np.ones(batch_size) # Concatenate fake and real images X = np.concatenate([fake_images, real_images])y = np.concatenate([label_fake, label_real])# Train the discriminatordiscriminator.trainable=Truediscriminator.train_on_batch(X, y)# Train the generator/chained GAN model (with frozen weights in discriminator) discriminator.trainable=FalseGAN.train_on_batch(noise, label_real)# Draw generated images every 15 epoches if i == 1 or i % 10 == 0:draw_images(generator, i)
train_GAN(epochs=400, batch_size=128)
我用了Google colab自带的GPU,训练400代大约用了十多分钟。如果用jupyter notebook在本机跑,会慢一些 (据说2分钟一代?)。
生成的图片如下图所示
大功告成,接下来我将一步步解释train_GAN()函数是怎么工作的。
首先,导入数据集,这个容易理解。
#Loading the dataX_train, y_train = load_data()
接下来,建立一个GAN网络,GAN由两个神经网络(generator, discriminator)连接而成。
# Creating GANgenerator= build_generator()discriminator= build_discriminator()GAN = build_GAN(discriminator, generator)
然后,建立一个循环(400次迭代)。tqdm用来动态显示每次迭代的进度。
for i in range(1, epochs+1):print("Epoch %d" %i)for _ in tqdm(range(batch_size)):
接着,我们生成呈高斯分布的噪声,利用generator,来生成batch_size(128张)图片。每张图片的输入就是一个1*100的噪声矩阵。
# Generate fake images from random noisetnoise= np.random.normal(0,1, (batch_size, 100))fake_images = generator.predict(noise)
同样的,我们从Mnist数据集中随机挑选128张真实图片。我们给真实图片标注1,给假图片标注0,然后将256张真假图片混合在一起。
# Select a random batch of real images from MNISTreal_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]# Labels for fake and real images label_fake = np.zeros(batch_size)label_real = np.ones(batch_size) # Concatenate fake and real images X = np.concatenate([fake_images, real_images])y = np.concatenate([label_fake, label_real])
此时,我们利用上文提到的256张带标签的真假图片,训练discriminator。训练完毕后,discriminator的weights得到了更新。(打个比方,警察通过研究市面上流通的假币,在一起开会讨论,努力研发出了新一代鉴定假钞的方法)。
# Train the discriminator
discriminator.trainable=True
discriminator.train_on_batch(X, y)
然后,我们冻结住discriminator的weights,让discriminator不再变化。然后就开始训练generator (chained GAN)。在GAN的训练中,我们输入一堆噪声,期待的输出是将假图片预测为真。在这个过程中,generator继续生成假图片,送到discriminator检验,得到检验结果,如果被鉴定为假,就不断更新自己的权重(假钞贩子不断改良造假技术),直到discriminator将加图片鉴定为真图片(直到当前鉴定假钞的技术无法识别出假钞)。
# Train the generator/chained GAN model (with frozen weights in discriminator) discriminator.trainable=FalseGAN.train_on_batch(noise, label_real)
OK,此时一次迭代进行完毕。接下来是第2, 3, ...次迭代。
现在,我们总结一下每次迭代发生了什么:
- Generator利用自己最新的权重,生成了一堆假图片。
- Discrminator根据真假图片的真实label,不断训练更新自己的权重,直到可以顺利鉴别真假图片。
- 此时discriminator权重被固定,不再发生变化。generator利用最新的discrimintor,苦苦思索,不断训练自己的权重,最终使discriminator将假图片鉴定为真图片。
换成印制假钞的例子,每次迭代发生了如下几件事:
- 假钞贩子根据最新造假技术,研发出一代假钞。
- 警察反复对比新型假钞和真币的区别,成功改良假钞鉴别方法,从而顺利鉴别出市面流通钞票的真伪。
- 假钞贩子生成假钞,马上被警察鉴别出来,痛定思痛,改良技术生成新的假钞。不成想,一上街又被警察识别了出来。日复一日,终于发明了新型假钞,当前的验钞技术已经无法成功检测出这种假钞。
然后通过每次迭代,discrimintor (警察的鉴定技术)和generator (假钞制作技术) 都越来越成熟...后来达到了动态平衡。
嗯,就这样,是不是挺简单的?
今天讲的是最原始的GAN网络,GAN发展到了如今已有许多变种,如将MLP结构换成CNN,Autoencoder,以及loss function的变化等等。我在github上找到一个超级全的用keras编写的各种花式GAN网络集合,有兴趣的小伙伴直接点击这里。本文的jupyter notebook代码请直接点击下面的小卡片~
https://nbviewer.jupyter.org/github/gaonanlee/Deep-Learning-Experiments/blob/master/Vanilla%20GAN_implementation.ipynbnbviewer.jupyter.org
如有理解不到位之处,欢迎批评指教。
参考文献
- https://github.com/eriklindernoren/Keras-GAN/blob/master/cgan/cgan.py
2. https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3
我的其他回答:
哪些 Python 库让你相见恨晚?
python如何画出漂亮的地图?
时间序列数据如何插补缺失值?
机器学习中的因果关系: 从辛普森悖论(常见的统计学谬误)谈起
c语言贪吃蛇最简单代码_让我们跑一个最简单的GAN网络吧!(附Jupyter Notebook 代码)...相关推荐
- cmd上写的java简单代码_用cmd编辑一个超级简单的小游戏,求代码
该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 贪吃蛇: import java.awt.*; import java.util.LinkedList; import java.util.Scanner ...
- python接水果游戏代码_使用Python开发一个超级简单的接水果小游戏,零基础也可以学会...
Pylash项目地址 创建项目 这样的话我们的项目就创建好了,然后只用往Main.py里填写代码运行即可. 编写Hello World小程序 编写游戏 有以上对pylash的小小了解,我们接下来可以开 ...
- c语言五子棋代码_基于控制台的C语言贪吃蛇
相信对很多人来说,学完C语言以后,都会找一些小程序来练练手.例如贪吃蛇.五子棋.俄罗斯方块等等. 今天给大家分享一个基于控制台的C语言贪吃蛇小程序. 基础知识要求:C语言基础. 知识点补充 这里写一些 ...
- c语言安卓贪吃蛇代码下载,C语言贪吃蛇代码
c语言编写贪吃蛇源代码,简单易懂,文件为VC源代码.如果你正在学习c语言,就来下载吧.很经典的 C语言贪吃蛇代码部分 #include #include #include#include #defin ...
- C语言贪吃蛇游戏代码,贪吃蛇C语言代码实现大全
一.C语言贪吃蛇代码实现前言 设计贪吃蛇游戏的主要目的是让大家夯实C语言基础,训练编程思维,培养解决问题的思路,领略多姿多彩的C语言. 贪吃蛇是非常经典的一款游戏,本次我们模拟在控制台实现贪吃蛇游戏, ...
- 贪吃蛇的c语言程序码,C语言贪吃蛇代码下载_C语言贪吃蛇代码官方下载-太平洋下载中心...
C语言编写贪吃蛇源代码,简单易懂,文件为VC源代码.如果你正在学习c语言,就来下载吧.很经典的. C语言贪吃蛇代码原理: 产生一个固定大小没有边界的游戏区域,蛇从区域的中心开始,由玩家通过键盘控制蛇的 ...
- 贪吃蛇统计分数的c语言代码,C/C++编程笔记:C语言贪吃蛇源代码控制台(二),分数和食物!...
接上文<C/C++编程笔记:C语言贪吃蛇源代码控制台(一),会动的那种哦!>如果你在学习C语言开发贪吃蛇的话,零基础建议从上一篇开始哦!接下来正式开始吧! 三.蛇的运动 上次我已经教大家画 ...
- c语言对抗程序代码,C语言贪吃蛇源程序代码双人对抗
C语言贪吃蛇源程序代码双人对抗 #include #include #include #include #include #include #include #define LEFT 100 #def ...
- 超简单的C语言贪吃蛇 不闪屏 双缓冲
C语言贪吃蛇 今天把以前自己写的贪吃蛇总结了一下,发到博客上,怕放在电脑上哪天丢失了都不知道, 有不当之处还望指教 (*・ω< ) ヾ(◍°∇°◍)ノ゙ 贪吃蛇中, 我们看到的蛇在不断的移动,其 ...
- 完整版C语言贪吃蛇代码
C语言贪吃蛇完整代码 #include <stdio.h> #include <stdlib.h> #include <Windows.h>//windows编程头 ...
最新文章
- k8s 使用Nginx Ingress实现灰度发布和蓝绿发布
- 免费教材丨第55期:Python机器学习实践指南、Tensorflow 实战Google深度学习框架
- 结对编程作业——四则运算GUI程序
- 特征工程(3):特征选择
- 直播 | AAAI 2021最佳论文:比Transformer更有效的长时间序列预测
- shiro的登录 subject.login(token)中执行逻辑和流程
- Jerry Wang在SAP社区上获得的徽章
- java jdk win10安装_Java 安装 JDK WIN10
- 【算法设计与分析】15 分治策略:芯片测试
- 【年终总结】可圈可点的2018年
- python玩跳一跳_python玩跳一跳
- 假期最后一天,出差赶到天津
- 使用Zoiper与freeSWITCH开视频会议
- 腾讯 IVWEB 团队:前端识别验证码思路分析
- 移动硬盘和电脑内置硬盘使用时的区别
- java常见单词汇总3(非常使用哦)
- 获得拼多多商品详情(商品主图、sku)
- jQuery之属性操作
- 秒杀系统的设计五大原则
- v模拟器(华为、H3C)点滴