本文主要是个简单的笔记,参考资料来自下面三部分

  1. Tutorial_HYLee_GAN
  2. Renu Khandelwal 的博客
  3. Jason 的博客

神经网络一览

各种神经网络(全连接前向网络、卷积神经网络、循环神经网络)的区别在于具有不同的输入/输出形式,比如可以是向量、矩阵或者是向量序列等。

GAN的基本思想

GAN由生成器和判别器组成:

生成器的本质也是一个神经网络,或者说是一个函数

如果给定一个向量可以生成一张漫画图片,向量的每一个维度具有不同含义

判别器的本质也是一个神经网络

如果给定一张图片,判别器就会告诉你这是不是真实图片

所以GAN的训练本质就是训练两个神经网络

GAN的工作原理

生成器的目标是产生和训练数据相似的数据(以假乱真的图片),而判别器的目标是辨别真假。

生成器的输入通常为随机噪声,判别器有两个输入,一个来自训练数据中的真图片,一个来自生成器生成的假图片。

GAN的流程如下图所示

每一次迭代过程中:

  1. 更新判别器的网络参数。即给定假图片以及假图片的标签(上图中的generated example)、真图片以及真图片的标签(上图中的real example),让判别器能够区别出真假图片,也就是训练一个尽可能准确的二分类器。
  2. 固定判别器网络参数, 更新生成器网络。即给定假图片以及假标签(让判别器以为假图片是真的),从而误差反向传播来更新生成器,使得生成器生成更加逼真的照片。

GAN训练的目标函数如下所示

  • 判别器想要最大化目标函数使得对于真实数据 D(x) 接近 1,对于假数据 D(G(z)) 接近 0
  • 生成器想要最小化目标函数使得 D(G(z)) 接近 1,也就是欺骗判别器让它认为假数据为真

GAN的实现

这里采用 MNIST 数据集作为实验数据,最后我们会看到生成器能够产生看起来像真的数字!

导入需要用到的库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import keras
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.datasets import mnist
from tqdm import tqdm
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam

导入数据

def load_data():(x_train, y_train), (x_test, y_test) = mnist.load_data()x_train = (x_train.astype(np.float32) - 127.5)/127.5# 将图片转为向量 x_train from (60000, 28, 28) to (60000, 784) # 每一行 784 个元素x_train = x_train.reshape(60000, 784)return (x_train, y_train, x_test, y_test)
(X_train, y_train,X_test, y_test)=load_data()
print(X_train.shape)

定义优化器

def adam_optimizer():return Adam(lr=0.0002, beta_1=0.5)

这里要采用的生成对抗网络的结构如下图所示

定义生成器:输入是 100 维,经过三层隐藏层,输出 784 维的向量(造假的图片)

def create_generator():generator=Sequential()generator.add(Dense(units=256,input_dim=100))generator.add(LeakyReLU(0.2))generator.add(Dense(units=512))generator.add(LeakyReLU(0.2))generator.add(Dense(units=1024))generator.add(LeakyReLU(0.2))generator.add(Dense(units=784, activation='tanh'))generator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())return generator
g=create_generator()
g.summary()

定义判别器:判别器的输入为真实图片或者由生成器造出来的假图片(784维),经过三层隐藏层,输出类别(1 维)

def create_discriminator():discriminator=Sequential()discriminator.add(Dense(units=1024,input_dim=784))discriminator.add(LeakyReLU(0.2))discriminator.add(Dropout(0.3))discriminator.add(Dense(units=512))discriminator.add(LeakyReLU(0.2))discriminator.add(Dropout(0.3))discriminator.add(Dense(units=256))discriminator.add(LeakyReLU(0.2))discriminator.add(Dense(units=1, activation='sigmoid'))discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())return discriminator
d =create_discriminator()
d.summary()

定义生成对抗网络

def create_gan(discriminator, generator):discriminator.trainable=False# 这是一个链式模型:输入经过生成器、判别器得到输出gan_input = Input(shape=(100,))x = generator(gan_input)gan_output= discriminator(x)gan= Model(inputs=gan_input, outputs=gan_output)gan.compile(loss='binary_crossentropy', optimizer='adam')return gan
gan = create_gan(d,g)
gan.summary()

定义画图函数来可视化图片的生成

def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):noise= np.random.normal(loc=0, scale=1, size=[examples, 100])generated_images = generator.predict(noise)generated_images = generated_images.reshape(100,28,28)plt.figure(figsize=figsize)for i in range(generated_images.shape[0]):plt.subplot(dim[0], dim[1], i+1)plt.imshow(generated_images[i], interpolation='nearest')plt.axis('off')plt.tight_layout()plt.savefig('gan_generated_image %d.png' %epoch)

生成对抗网络的训练函数

def training(epochs=1, batch_size=128):#导入数据(X_train, y_train, X_test, y_test) = load_data()batch_count = X_train.shape[0] / batch_size# 定义生成器、判别器和GAN网络generator= create_generator()discriminator= create_discriminator()gan = create_gan(discriminator, generator)for e in range(1,epochs+1 ):print("Epoch %d" %e)for _ in tqdm(range(int(batch_count))):#产生噪声喂给生成器noise= np.random.normal(0,1, [batch_size, 100])# 产生假图片generated_images = generator.predict(noise)# 一组随机真图片image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]# 真假图片拼接 X= np.concatenate([image_batch, generated_images])# 生成数据和真实数据的标签y_dis=np.zeros(2*batch_size)y_dis[:batch_size]=0.9# 预训练,判别器区分真假discriminator.trainable=Truediscriminator.train_on_batch(X, y_dis)# 欺骗判别器 生成的图片为真的图片noise= np.random.normal(0,1, [batch_size, 100])y_gen = np.ones(batch_size)# GAN的训练过程中判别器的权重需要固定 discriminator.trainable=False# GAN的训练过程为交替“训练判别器”和“固定判别器权重训练链式模型”gan.train_on_batch(noise, y_gen)if e == 1 or e % 50 == 0:# 画图 看一下生成器能生成什么plot_generated_images(e, generator)
training(400,256)

经过训练后生成的图片

一个epoch后生成器还是个小学生

100个epoch后生成器已经有点样子了

400个epoch后生成器可以出师了

是不是已经学得像模像样了,这样就能够利用噪声通过生成器来生成以假乱真的图片了。

plt生成固定的colormap_白话生成对抗网络GAN及代码实现相关推荐

  1. 简述生成式对抗网络 GAN

    本文主要阐述了对生成式对抗网络的理解,首先谈到了什么是对抗样本,以及它与对抗网络的关系,然后解释了对抗网络的每个组成部分,再结合算法流程和代码实现来解释具体是如何实现并执行这个算法的,最后通过给出一个 ...

  2. 利用Tensorflow构建生成对抗网络GAN以生成数据

    使用生成对抗网络(GAN)生成数据 本文主要内容 介绍了自动编码器的基本原理 比较了生成模型与自动编码器的区别 描述了GAN模型的网络结构 分析了GAN模型的目标核函数以及训练过程 介绍了利用Goog ...

  3. [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  4. 用MXNet实现mnist的生成对抗网络(GAN)

    用MXNet实现mnist的生成对抗网络(GAN) 生成式对抗网络(Generative Adversarial Network,简称GAN)由一个生成网络与一个判别网络组成.生成网络从潜在空间(la ...

  5. (五)使用生成对抗网络 (GAN)生成新的时装设计

    目录 介绍 预测新时尚形象的力量 构建GAN 初始化GAN参数和加载数据 从头开始构建生成器 从头开始构建鉴别器 初始化GAN的损失和优化器 下一步 下载源 - 120.7 MB 介绍 DeepFas ...

  6. 万字详解什么是生成对抗网络GAN

    摘要:这篇文章将详细介绍生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN).发展历程.预备知识,并通过Keras搭建最简答的手写数字图片生成案. ...

  7. 『一起学AI』生成对抗网络(GAN)原理学习及实战开发

     参考并翻译教程:https://d2l.ai/chapter_generative-adversarial-networks/gan.html,加入笔者的理解和心得 1.生成对抗网络原理 在Col ...

  8. ECCV2022 | 生成对抗网络GAN论文汇总(图像转换-图像编辑-图像修复-少样本生成-3D等)...

    图像转换/图像可控编辑 视频生成 少样本生成 图像外修复/结合transformer GAN改进 新数据集 图像增强 3D 图像来源归属分析 一.图像转换/图像可控编辑 1.VecGAN: Image ...

  9. 一文看懂「生成对抗网络 - GAN」基本原理+10种典型算法+13种应用

    生成对抗网络 – Generative Adversarial Networks | GAN 文章目录 GAN的设计初衷 生成对抗网络 GAN 的基本原理 GAN的优缺点 10大典型的GAN算法 GA ...

最新文章

  1. 解决linux系统CentOS下调整home和根分区大小
  2. iOS之Block总结以及内存管理
  3. 一个express老系统csrf漏洞修复
  4. POJ 2299 Ultra-QuickSort(树状数组 + 离散)
  5. 磨刀不误砍柴工-git新手教程
  6. Google Chrome保存插件方法
  7. mysql递归查询 缓存_MySQL-递归查询方法解析
  8. Linux下的磁盘空间管理
  9. java 两个页面传递数据,请问Cookie怎么在两个页面间传递数据?
  10. 计算机硬盘标称容量怎么看,电脑硬盘标称容量、分区大小与实际容量之间的差异...
  11. c语言综合项目实践 结构体及应用,C51单片机应用与C语言程序设计(第3版) 基于机器人工程对象的项目实践简介,目录书摘...
  12. markdown/LaTeX中在字母下方输入圆点的方法
  13. 如何创建一个最简单的Windows桌面应用程序 (C++)
  14. 第二章:计算思维——知识点整理
  15. 利用Apple Developer申请苹果开发者账号(支付宝微信付款)
  16. ORA-15064 ORA-03113 - 测试库案例
  17. docker push 时 tag does not exist
  18. 幻读(phantom read)详解
  19. python 70行完成requests抓取csdn阅读量.
  20. TCL:不断扩张的业务,不断下跌的股价

热门文章

  1. Struts DispatchAction
  2. php 判断http还是https,以及获得当前url的方法
  3. 通过管道传输快速将MySQL的数据导入Redis(自己做过测试)
  4. JQUERY获取DOM
  5. Yii的GridView
  6. team网卡配置_Windows下的网卡Teaming 配置教程(图文)
  7. java capacity_关于Java中StringBuffer的capacity问题
  8. 代理服务器ip地址如何获得_详细教程:如何使用代理服务器进行网页抓取?
  9. android平板安装python_Notepad++配置Python开发环境
  10. android p版本 字符串常量池,Android OOM 问题