学习目标

  • 目标

    • 了解GAN的作用
    • 说明GAN的训练过程
    • 知道DCGAN的结构
  • 应用
    • 应用DCGAN模型实现手写数字的生成

5.1.1 GAN能做什么

GAN是非监督式学习的一种方法,在2014年被提出。

GAN主要用途:

  • 生成以假乱真的图片

  • 生成视频、模型

5.1.2 什么GAN

5.1.2.1 定义

生成对抗网络(Generative Adversarial Network,简称GAN),主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)。

  • 生成器(Generator),能够输入一个向量,输出需要生成固定大小的像素图像
  • 判别器(Discriminator),用来判别图片是真的还是假的,输入图片(训练的数据或者生成的数据),输出为判别图片的标签

5.1.2.2 理解

  • 思想:从训练库里获取很多训练样本,从而学习这些训练案例生成的概率分布

  • 黑色虚线:真是样本的分布
  • 绿色实线:生成样本的分布
  • 蓝色虚线:判别器判断的概率分布
  • zz表示噪声,zz到xx表示生成器生成的分布映射

过程分析:

  • 1、定义GAN结构生成数据

    • (a)(a)状态处于最初始的状态,生成器生成的分布和真实分布区别较大,并且判别器判别出样本的概率不稳定
  • 2、在真实数据上训练 n epochs判别器,产生fake(假数据)并训练判别器识别为假
    • 通过多次训练判别器来达到(b)(b)样本状态,此时判别样本区分得非常显著
  • 3、训练生成器达到欺骗判别器的效果
    • 训练生成器之后达到(c)(c)样本状态,此时生成器分布相比之前,逼近了真实样本分布。经过多次反复训练迭代之后。
    • 最终希望能够达到(d)(d)状态,生成样本分布拟合于真实样本分布,并且判别器分辨不出样本是生成的还是真实的。

5.1.2.3 训练损失

  • V(G, D)V(G,D):表示P_ x和 P_z 的差异程度。
  • \max \limits_DV(D, G)​D​max​​V(D,G) :固定生成器G, 尽可能地让判别器能够最大化地判别出样本来自于真实数据还是生成的数据
  • \min \limits_G L​G​min​​L:固定判别器D的条件下得到生成器G,能够最小化真实样本与生成样本的差异。

整个优化我们其实只看做一个部分:

  • 判别器:相当于一个分类器,判断图片的真伪,二分类问题,使用交叉熵损失

对于真实样本:对数预测概率损失,提高预测的概率

对于生成样本:对数预测概率损失,降低预测概率

最终可以这样:

5.1.2.4 G、D结构

G、D结构是两个网络,特点是能够反向传播可导计算要介绍G、D结构,需要区分不同版本的GAN。

  • 2014年最开始的模型:

    • G、D都是multilayer perceptron(MLP)
    • 缺点:实践证明训练难度大,效果不行
  • 2015:使用卷积神经网络+GAN(DCGAN(Deep Convolutional GAN))
    • 改进:

      • 1、判别器D中取出pooling,全部变成卷积、生成器G中使用反卷积(下图)
      • 2、D、G中都增加了BN层
      • 3、去除了所有的全连接层
      • 4、判别器D中全部使用Leaky ReLU,生成器除了最后输出层使用tanh其它层全换成ReLU

5.1.3 案例:GAN生成手写数字图像

5.1.3.1 案例演示与结果显示

  • 迭代不同次数生成的图片效果
  • 1次

  • 50

  • 2000次

5.1.3.2 代码步骤流程

  • 初始化GAN模型结构

    • init_model(self)
    • 判别器:CNN,build_discriminator
    • 生成器:CNN,build_generator
  • 训练过程:train(self, epochs, batch_size=128, save_interval=50)
    • 训练判别器
    • 训练生成器
    • 生成图片保存

5.1.3.3 代码编写

  • 1、模型类
class DCGAN():def __init__(self):# 输入图片的形状self.img_rows = 28self.img_cols = 28self.channels = 1self.img_shape = (self.img_rows, self.img_cols, self.channels)
  • 2、初始化GAN模型结构

    • 建立D判别器CNN结构,初始化判别器训练优化参数
    • 联合建立G生成器CNN结构,初始化生成器训练优化参数
      • 输入噪点数据,输出预测的类别概率
      • 注意生成器训练时,判别器不进行训练
    • from keras.optimizers import Adam
       def init_model(self):# 生成原始噪点数据大小self.latent_dim = 100optimizer = Adam(0.0002, 0.5)# 1、建立判别器训练参数# 选择损失,优化器,以及衡量准确率self.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# 2、联合建立生成器训练参数,指定生成器损失self.generator = self.build_generator()z = Input(shape=(self.latent_dim,))img = self.generator(z)# 合并模型的损失,并且之后只训练生成器,判别器不训练self.discriminator.trainable = Falsevalid = self.discriminator(img)# 训练生成器欺骗判别器self.combined = Model(z, valid)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
  • 定义模型的判别器

    • from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model
    def build_discriminator(self):model = Sequential()model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)
  • 定义模型的生成器

    • CNN结构
    def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(Conv2D(self.channels, kernel_size=3, padding="same"))model.add(Activation("tanh"))model.summary()noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)
  • 3、训练模型代码

    • from keras.datasets import mnist
    • import matplotlib.pyplot as plt
    • import numpy as np
    • model:train_on_batch(feature, target)
    def train(self, epochs, batch_size=32):# 加载手写数字(X_train, _), (_, _) = mnist.load_data()# 进行归一化X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)# 正负样本的目标值建立valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# 1、训练判别器# 选择随机的一些真实样本idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# 生成器产生假样本noise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)# 训练判别器过程d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# 计算平均两部分损失d_loss = np.add(d_loss_real, d_loss_fake) / 2# 2、训练生成器,停止判别器#  合并训练,并停止训练判别器# 用目标值为1去训练,目的使得生成器生成的样本越来越接近真是样本g_loss = self.combined.train_on_batch(noise, valid)# 画出结果print("迭代次数:%d [D 损失: %f, 准确率: %.2f%], [G 损失: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))# 保存生成的图片if epoch % 50 == 0:self.save_imgs(epoch)
  • 保存生成的图片
    def save_imgs(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)# Rescale images 0 - 1gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1fig.savefig("./images/mnist_%d.png" % epoch)plt.close()

5.1.4 总结

  • 掌握GAN模型的原理过程
  • 掌握GAN手写数字的训练过程

生成对抗网络(GAN)相关推荐

  1. 简述一下生成对抗网络GAN(Generative adversarial nets)模型?

    简述一下生成对抗网络GAN(Generative adversarial nets)模型? 生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow在2014年提出的机器学习架构. 要全面理解生 ...

  2. 生成对抗网络gan原理_中国首个“芯片大学”即将落地;生成对抗网络(GAN)的数学原理全解...

    开发者社区技术周刊又和大家见面了,萌妹子主播为您带来第三期"开发者技术联播".让我们一起听听,过去一周有哪些值得我们开发者关注的重要新闻吧. 中国首个芯片大学,南京集成电路大学即将 ...

  3. 必读!TOP10生成对抗网络GAN论文(附链接)

    来源:新智元 本文约2200字,建议阅读7分钟. 本文所选论文提供了一个易读的对GAN的介绍,帮助你理解GAN技术的基础. [ 导读 ]生成对抗网络 (GAN) 是深度学习中最有趣.最受欢迎的应用之一 ...

  4. 生成对抗网络GAN综述

    题目:生成对抗网络GAN综述 系别:工程物理系 姓名:王雨阳 简 介: 生成对抗网络(GAN)是目前深度学习中应用较为广泛的一种网络.在我今后的研究中,可能会用到GAN,并且我也想了解一下GAN,因此 ...

  5. 权重对生成对抗网络GAN性能的影响

    本文制作了一个生成对抗网络GAN网络,并通过调节权重的初始化方法来观察权重对网络性能的影响. 生成网络的结构是784*300*784,对抗网络的结构是784*300*1.生成网络的输入是一个28*28 ...

  6. DL之GAN:生成对抗网络GAN的简介、应用、经典案例之详细攻略

    DL之GAN:生成对抗网络GAN的简介.应用.经典案例之详细攻略 目录 生成对抗网络GAN的简介 1.生成对抗网络的重要进展 1.1.1986年的RBM→2006年的DBN

  7. 【GAN优化】长文综述解读如何定量评价生成对抗网络(GAN)

    欢迎大家来到<GAN优化>专栏,这里将讨论GAN优化相关的内容,本次将和大家一起讨论GAN的评价指标. 作者&编辑 | 小米粥 编辑 | 言有三 在判别模型中,训练完成的模型要在测 ...

  8. 生成对抗网络(GAN)相比传统训练方法有什么优势?(一)

    作者:元峰 链接:https://www.zhihu.com/question/56171002/answer/148593584 来源:知乎 著作权归作者所有.商业转载请联系作者获得授权,非商业转载 ...

  9. [Python图像识别] 四十九.图像生成之什么是生成对抗网络GAN?基础原理和代码普及

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  10. [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

最新文章

  1. 【408】C函数中的ADT
  2. ASP.NET MVC 5 学习教程:添加控制器
  3. 学习java 的30个目标
  4. python持久层框架_想设计一个Python的持久层框架
  5. Saltstack-6:模块
  6. python 类、模块、包的区别
  7. 谈谈怎样提高炼丹手速
  8. 饶军:Apache Kafka的过去,现在,和未来
  9. 注意力机制可视化_Attention isn’t all you need!BERT的力量之源远不止注意力
  10. Spring @Bean @Scope @Qualifier
  11. Visual Studio From DataBase (1)
  12. Oracle12C用户创建、授权、登录
  13. [转]WebService压缩
  14. paxos算法java实现_Paxos算法——前世
  15. PLM与PDM的概念与区别
  16. Android Studio模拟器AndroidWifi连接成功但无法访问网络问题
  17. eovs实训报告总结心得_实训报告心得体会范文大全
  18. 清华大学计算机国际大赛,动态 | 清华大学超算团队摘得 SC 2018 总冠军,包揽三大国际大学生超算竞赛总冠军...
  19. 知其然,知其所以然之Java基础系列(一)
  20. 人与人的区别在于八小时之外如何运用

热门文章

  1. 2022-2028年中国安防视频行业市场前景分析预测报告
  2. JVM年轻代,老年代,永久代详解​​​​​​​
  3. 2022-2028年中国硅质原料行业全景调研及投资前景展望报告
  4. 2022-2028年中国内衣行业研究及前瞻分析报告
  5. python装饰器学习
  6. SpringBoot (四) :SpringBoot整合使用JdbcTemplate
  7. pandas以前笔记
  8. view(*args)改变张量的大小和形状_pytorch reshape numpy
  9. 高精地图与自动驾驶(上)
  10. GPU加速:宽深度推理