从一个小白的方式理解GAN网络(生成对抗网络),可以认为是一个造假机器,造出来的东西跟真的一样,下面开始讲如何造假:(主要讲解GAN代码,代码很简单)

我们首先以造小狗的假图片为例。

首先需要一个生成小狗图片的模型,我们称之为generator,还有一个判断小狗图片是否是真假的判别模型discrimator,

首先输入一个1000维的噪声,然后送入生成器,生成器的具体结构如下所示(不看也可以,看完全篇回来再看也一样):

其实比较简单,代码如下所示:

def generator_model():model = Sequential()model.add(Dense(input_dim=1000, output_dim=1024))model.add(Activation('tanh'))model.add(Dense(128 * 8 * 8))model.add(BatchNormalization())model.add(Activation('tanh'))model.add(Reshape((8, 8, 128), input_shape=(8 * 8 * 128,)))model.add(UpSampling2D(size=(4, 4)))model.add(Conv2D(64, (5, 5), padding='same'))model.add(Activation('tanh'))model.add(UpSampling2D(size=(2, 2)))model.add(Conv2D(3, (5, 5), padding='same'))model.add(Activation('tanh'))return model

生成器接受一个1000维的随机生成的数组,然后输出一个64×64×3通道的图片数据。输出就是一个图片。不必太过深究,输入是1000个随机数字,输出是一张图片。

下面再看判别器代码与结构:

代码如下所示:


def discriminator_model():model = Sequential()model.add(Conv2D(64, (5, 5), padding='same', input_shape=(64, 64, 3)))model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(128, (5, 5)))model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Flatten())model.add(Dense(1024))model.add(Activation('tanh'))model.add(Dense(1))model.add(Activation('sigmoid'))return model

输入是64,64,3的图片,输出是一个数1或者0,代表图片是否是狗。

下面根据代码讲具体操作:

把真图与假图。进行拼接,然后打上标签,真图标签是1,假图标签是0,送入训练的网络。

# 随机生成的1000维的噪声
noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 1000))# X_train是训练的图片数据,这里取出一个batchsize的图片用于训练,这个是真图(64张)
image_batch = X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]# 这里是经过生成器生成的假图
generated_images = generator_model.predict(noise, verbose=0)# 将真图与假图进行拼接
X = np.concatenate((image_batch, generated_images))# 与X对应的标签,前64张图为真,标签是1,后64张图是假图,标签为0
y = [1] * BATCH_SIZE + [0] * BATCH_SIZE# 把真图与假图的拼接训练数据1送入判别器进行训练判别器的准确度
d_loss = discriminator_model.train_on_batch(X, y)

这里要是看不明白的话可以结合别人的讲解结合来看。

在这里训练好之后,判别器的精度会不断提高。

下面是重头戏了,也是GAN网络的核心:

def generator_containing_discriminator(g, d):model = Sequential()model.add(g)# 判别器参数不进行修改d.trainable = Falsemodel.add(d)return model

他的网络结构如下所示:

这个模型有生成器与判别器组成:看代码,这个模型上半部分是生成网络,下半部分是判别网络,生成网络首先生成假图,然后送入判别网络中进行判断,这里有一个d.trainable=False,意思是,只调整生成器,判别的的参数不做更改。简直巧妙。

然后我们来看如何训练生成网络,这一块也是核心区域:

        # 训练一个batchsize里面的数据for index in range(int(X_train.shape[0]/BATCH_SIZE)):# 产生随机噪声noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 1000))# 这里面都是真图片image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]# 这里产生假图片generated_images = g.predict(noise, verbose=0)# 将真图片与假图片拼接在一起X = np.concatenate((image_batch, generated_images))# 前64张图片标签为1,即真图,后64张照片为假图y = [1] * BATCH_SIZE + [0] * BATCH_SIZE# 对于判别器进行训练,不断提高判别器的识别精度d_loss = d.train_on_batch(X, y)# 再次产生随机噪声noise = np.random.uniform(-1, 1, (BATCH_SIZE, 1000))# 设置判别器的参数不可调整d.trainable = False# ××××××××××××××××××××××××××××××××××××××××××××××××××××××××××# 在此我们送入噪声,并认为这些噪声是真实的标签g_loss = generator_containing_discriminator.train_on_batch(noise, [1] * BATCH_SIZE)# ××××××××××××××××××××××××××××××××××××××××××××××××××××××××××# 此时设置判别器可以被训练,参数可以被修改d.trainable = True# 打印损失值print("batch %d d_loss : %s, g_loss : %f" % (index, d_loss, g_loss))

重点在于这句代码

g_loss = generator_containing_discriminator.train_on_batch(noise, [1] * BATCH_SIZE)

首先这个网络模型(定义在上面),先传入生成器中,然后生成器生成图片之后,把图片传入判别器中,标签此刻传入的是1,真实的图片,但实际上是假图,此刻判别器就会判断为假图,然后模型就会不断调整生成器参数,此刻的判别器的参数被设置为为不可调整,d.trainable=False,所以为了不断降低loss值,模型就会一直调整生成器的参数,直到判别器认为这是真图。此刻判别器与生成器达到了一个平衡。也就是说生成器产生的假图,判别器已经分辨不出来了。所以继续迭代,提高判别器精度,如此往复循环,直到生成连人都辨别不了的图片。

最后我训练了大概65轮,实际上生成比较真实的狗的图片我估计可能上千轮了,当然不同的网络结构,所需要的迭代次数也不一样。我这个因为太费时间,就跑了大概,可以看出大概有个狗模样。这个是训练了65轮之后的效果:

以上就是全部的内容了。

https://github.com/jensleeGit/Kaggle_self_use/tree/master/Generative%20Dog%20Images

GAN网络详解(从零入门)相关推荐

  1. 智安网络详解:零信任网络访问 (ZTNA)原理

    传统的基于边界的网络保护将普通用户和特权用户.不安全连接和安全连接,以及外部和内部基础设施部分结合在一起,创建了一个可信区域的假象,很多潜在的安全问题无法解决,越来越多的企业开始转向零信任网络访问来解 ...

  2. 对抗学习中GAN网络详解

    GAN GAN属于结构化学习,结构化学习的输出比较复杂,structured learning 需要有全局观 结构化学习的例子: 对抗网络(GAN)有两部分组成,一个是生成器(generator),一 ...

  3. MYSQL数据库详解-从零入门,一篇正式入门

    MYSQL 1,数据库相关概念 1.1 数据库 1.2 数据库管理系统 1.3 常见的数据库管理系统 1.4 SQL 2,MySQL 2.1 MySQL安装 2.1.1 下载 2.1.2 安装(解压) ...

  4. Python零基础速成班-第14讲-Python处理Excel和Word,使用openpyxl和docx包详解,图表入门

    Python零基础速成班-第14讲-Python处理Excel和Word,使用openpyxl和docx包详解,图表入门 学习目标 Python处理Excel(使用openpyxl包).图表入门\ P ...

  5. ResNet网络详解与keras实现

    ResNet网络详解与keras实现 ResNet网络详解与keras实现 Resnet网络的概览 Pascal_VOC数据集 第一层目录 第二层目录 第三层目录 梯度退化 Residual Lear ...

  6. IFM网络详解及torch复现

    文章目录 IFM网络详解 网络结构代码 训练代码 main IFM网络详解 https://mp.weixin.qq.com/s?__biz=Mzk0MzIzODM5MA==&mid=2247 ...

  7. 【GAN】二、原始GAN论文详解

    写在前面 在前面一篇文章:[GAN]一.利用keras实现DCGAN生成手写数字图像中我们利用keras实现了简单的DCGAN,并生成了手写数字图像.程序结果让我们领略了GAN的强大,接下来我们开始一 ...

  8. Keras深度学习实战(22)——生成对抗网络详解与实现

    Keras深度学习实战(22)--生成对抗网络详解与实现 0. 前言 1. 生成对抗网络原理 2. 模型分析 3. 利用生成对抗网络生成手写数字图像 小结 系列链接 0. 前言 生成对抗网络 (Gen ...

  9. 思科ei ccie认证体系最新内容下一代编址IPV6技术最全面的基础详解 从零到精通必读

    思科ei ccie认证体系最新内容下一代编址IPV6技术最全面的基础详解 从零到精通必读 IPv6(Internet Protocol Version 6,因特网协议版本6)是网络层协议的第二代标准协 ...

最新文章

  1. 关于正则表达式 \1 \2之类的问题
  2. PC端微信小程序wxapkg解密
  3. openSAP中国新平台的介绍
  4. 信仰的力量—海归毕业季的选择与入职后的蜕变记
  5. 大牛用SSM框架实现了支付宝的支付功能,满满干货指导
  6. java 原子量Atomic举例(AtomicReference)
  7. VLC支持的视频和音频文件扩展名
  8. 学习 vi —— “学习清单”式
  9. 软件测试服务方案ppt,测试方案(测试策略).ppt
  10. c语言json使用,cJSON使用(二)
  11. JMF介绍之媒体框架
  12. 数值分析:研究高次插值的龙格现象
  13. 油田系统三维布局可视化解决方案
  14. php 查询功能,php如何实现查询功能实现
  15. java实现图片验证码_JAVA实现图片验证码
  16. pjmedia系列之媒体设备pjmedia_snd_port
  17. 内网学习笔记 | SSH 隧道使用
  18. python写入文件没反应_python写入文本 如何用python将变量及其值写入文本文件?...
  19. WebGL 手撸3d贺卡+小草飘动滤镜
  20. 自动控制原理——概述

热门文章

  1. 2020年汽车驾驶员(高级)报名考试及汽车驾驶员(高级)在线考试
  2. GET请求里的body问题
  3. Unity中的热更新 - Lua和C#通信
  4. mc通用计算机,大神程序员标配:花365天在《我的世界》打造一台能运行的计算机...
  5. 关于EVAL()函数(一)
  6. UniApp文件上传
  7. 最近经常看到网上程序员被抓,如何避免面向监狱编程!?
  8. 【工具使用系列】关于 MATLAB Simulink 物理建模,你需要知道的事
  9. HEVC码率控制代码分析
  10. 想说说关于在刷题网站(牛客 、C语言网、力扣)上测试样例过了但是OJ判错这档子事