深度学习笔记(三)——GAN入门实现MNIST数据集

文章目录

  • 深度学习笔记(三)——GAN入门实现MNIST数据集
    • 闲聊
    • 什么是GAN?
    • Generator
      • Discriminator
    • GAN的特点
      • 为什么GAN不适合处理文本数据
      • GAN的优化器为什么不常用SGD
      • 训练技巧
    • 具体实现
      • Generator
      • Discriminator
      • 训练
      • 全部代码
    • 结语
      • 参考:
      • 训练结果

闲聊

这周的任务完成了哇,但是本着好好学习,天天向上的宏伟计划(PS:一直感觉GAN蛮好玩的,打算玩玩),所以打算尝试学习下GAN。

什么是GAN?

论文
生成式对抗网络(GAN)是近年来复杂分布上无监督学习最具前景的方法之一,GAN的主要灵感来源于博弈论中零和博弈的思想,应用到深度学习中来说是通过生成网络G(Generator)和判别网络D(Discriminator)不断博弈,进而使G学习到数据的分布 ,如果是在图片生成上进行应用,就是

  • G是一个生成式的网络,它接收一个随机的噪声z(随机数),通过这个噪声生成图像

  • D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片

  • G的目的是生成让判别的模型无法判断真伪的输出

  • D的目的是判断这个是否真实

Generator

什么是生成(generation)?就是模型通过学习一些数据,然后生成类似的数据。让机器看一些图片,然后自己来产生的图片,这就是生成。

以前就有很多生成的技术,比如auto-encoder(自编码器):

训练一个encoder(编码器),把image input转换成code,然后训练一个decoder解码器,将code转换成一个Image,然后计算得到的image和image input之间的MSE(mean square error)

训练完模型后,取出后半部分的decoder,输入一个随机的code,就能通过Generator生成一个image

Discriminator

上面讲到了Generator可以根据一个随机的code生成图片,但是这里涉及到了一个问题,生成出来的图片,到底让我们看是怎么样的呢?这个时候就需要Discriminator对Generator生成的图片进行判别

如果非要简单地形容一下就是自己和自己下围棋?

GAN的特点

  • GAN是一种生成式模型,相比较其他生成模型(玻尔兹曼机和GSNs)只用到了反向传播,而不需要复杂的马尔科夫链
  • GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域
  • GAN不适合处理离散形式的数据,比如文本
  • GAN中的G的梯度更新信息来自于D,而不是样本数据
  • 它存在两个不同的网络,而不是传统的单一网络

为什么GAN不适合处理文本数据

对于文本来说,通常需要将一个词映射为一个高维向量,最终预测输出的是一个one-hot向量,假设softmax的输出(0.2,0.3,0.1,0.2,0.15,0.05)那么变为onehot是(0,1,0,0,0,0),如果softmax输出是(0.2,0.25,0.2,0.1,0.15,0.1),而onehot仍然是(0,1,0,0,0,0)。对于G来说,输出了不同的结果,但是D给出了相同的判别结果,并不能梯度更新信息很好的传递到G

GAN的优化器为什么不常用SGD

  • SGD容易震荡,容易使GAN训练不稳定
  • GAN的目的是在高维非凸的参数空间中找到纳什均衡点,GAN的均衡点是一个鞍点,但是SGD只会找到局部极小值,因为SGD解决的是求最小值问题,GAN是一个博弈问题

训练技巧

  • 输入规范化到(-1,1)之间,最后一层的激活函数使用tanh(BEGAN除外)
  • 使用wassertein GAN的损失函数,
  • 如果有标签数据的话,尽量使用标签,也有人提出使用反转标签效果很好,另外使用标签平滑,单边标签平滑或者双边标签平滑
  • 使用mini-batch norm, 如果不用batch norm可以使用instance norm或者weight norm
  • 避免使用RELUpooling层,减少稀疏梯度的可能性,可以使用leakrelu激活函数
  • 优化器尽量选择ADAM,学习率不要设置太大,初始1e-4可以参考,另外可以随着训练进行不断缩小学习率,
  • 给D的网络层增加高斯噪声,相当于是一种正则

具体实现

代码参考

Generator

生成网络的目标是输入一行正太分布的随机数,生成mnist手写体图片,输出一个28x28x1的图片

全连接层:256——>512——>1024——>784-变形->28x28

def build_generator(self):# --------------------------------- ##   生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)

Discriminator

判别模型的目的是根据输入的图片判断出真伪。因此它的输入一个28,28,1维的图片,输出是0到1之间的数,1代表判断这个图片是真的,0代表判断这个图片是假的。

全连接层:28x28-打平->784——>512——>256——>1(0或1真伪)

def build_discriminator(self):# ----------------------------------- ##   评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)

训练

GAN的训练分为如下几个步骤:

  • 随机选取batch_size个真实的图片
  • 随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片
  • 真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练
  • 将虚假图片的Discriminator预测结果与1的对比作为loss对Generator进行训练(与1对比的意思是,如果Discriminator将虚假图片判断为1,说明这个生成的图片很“真实”)

全部代码

from __future__ import print_function, divisionfrom tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
# from tensorflow.keras.layers.advanced_activations import LeakyReLU
from tensorflow.keras.layers import LeakyReLU
# from tensorflow.keras.layers.convolutional import UpSampling2D, Conv2D
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as pltimport sys
import os
import numpy as npclass GAN():def __init__(self):# --------------------------------- ##   行28,列28,也就是mnist的shape# --------------------------------- #self.img_rows = 28self.img_cols = 28self.channels = 1# 28,28,1self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 100# adam优化器optimizer = Adam(0.0002, 0.5)self.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])self.generator = self.build_generator()gan_input = Input(shape=(self.latent_dim,))img = self.generator(gan_input)# 在训练generate的时候不训练discriminatorself.discriminator.trainable = False# 对生成的假图片进行预测validity = self.discriminator(img)self.combined = Model(gan_input, validity)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)def build_generator(self):# --------------------------------- ##   生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):# ----------------------------------- ##   评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)def train(self, epochs, batch_size=128, sample_interval=50):# 获得数据(X_train, _), (_, _) = mnist.load_data()# 进行标准化X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)# 创建标签valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# --------------------------- ##   随机选取batch_size个图片#   对discriminator进行训练# --------------------------- #idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# --------------------------- ##  训练generator# --------------------------- #noise = np.random.normal(0, 1, (batch_size, self.latent_dim))g_loss = self.combined.train_on_batch(noise, valid)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))if epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()if __name__ == '__main__':if not os.path.exists("./images"):os.makedirs("./images")gan = GAN()gan.train(epochs=30000, batch_size=256, sample_interval=200)

结语

终于跑了一次GAN了,还是蛮好玩的

参考:

深度学习----GAN(生成对抗神经网络)原理解析
好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体

训练结果





深度学习笔记(三)——GAN入门实现MNIST数据集相关推荐

  1. 深度学习笔记:Tensorflow手写mnist数字识别

    文章出处:深度学习笔记11:利用numpy搭建一个卷积神经网络 免费视频课程:Hellobi Live | 从数据分析师到机器学习(深度学习)工程师的进阶之路 上一讲笔者和大家一起学习了如何使用 Te ...

  2. 深度学习笔记三:反向传播(backpropagation)算法

    接上一篇的最后,我们要训练多层网络的时候,最后关键的部分就是求梯度啦.纯数学方法几乎是不可能的,那么反向传播算法就是用来求梯度的,用了一个很巧妙的方法. 反向传播算法应该是神经网络最基本最需要弄懂的方 ...

  3. Python深度学习之分类模型示例,MNIST数据集手写数字识别

    MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片. 我们把60000个训练样本分成两部分,前 ...

  4. 深度学习笔记三:Softmax Regression

    Softmax回归模型 整体理解 回归与分类 借用网上一个帖子的回复:分类问题和回归问题都要根据训练样本找到一个实值函数g(x). 回归问题的要求是:给定一个新的模式,根据训练集推断它所对应的输出y( ...

  5. 深度学习笔记:优化方法总结(BGD,SGD,Momentum,AdaGrad,RMSProp,Adam)

    深度学习笔记(一):logistic分类  深度学习笔记(二):简单神经网络,后向传播算法及实现  深度学习笔记(三):激活函数和损失函数  深度学习笔记:优化方法总结  深度学习笔记(四):循环神经 ...

  6. 生成对抗网络入门详解及TensorFlow源码实现--深度学习笔记

    生成对抗网络入门详解及TensorFlow源码实现–深度学习笔记 一.生成对抗网络(GANs) 生成对抗网络是一种生成模型(Generative Model),其背后最基本的思想就是从训练库里获取很多 ...

  7. AI Studio 飞桨 零基础入门深度学习笔记6.3-手写数字识别之数据处理

    AI Studio 飞桨 零基础入门深度学习笔记6.3-手写数字识别之数据处理) 概述 前提条件 读入数据并划分数据集 扩展阅读:为什么学术界的模型总在不断精进呢? 训练样本乱序.生成批次数据 校验数 ...

  8. 【深度学习】Tensorflow2.x入门(一)建立模型的三种模式

    前言 最近做实验比较焦虑,因此准备结合推荐算法梳理下Tensorflow2.x的知识.介绍Tensorflow2.x的文章有很多,但本文(系列)是按照作者构建模型的思路来展开的,因此不会从Eager ...

  9. 深度学习者的入门福利-Keras深度学习笔记

    Keras深度学习笔记 最近本人在github上发现一个不错的资源,是利用keras来学习深度学习的笔记,笔记内容充实,数据完善,本人亲自实操了里面的所有例子,深感收获颇丰,今天特意推荐给大家,希望能 ...

  10. [深度学习-实践]GAN基于手写体Mnist数据集生成新图片

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之基于CIFAR10数据集的例子 深度学习GAN(三)之基于手写体Mnist数据集的例子 深度学习GAN(四)之PIX2PIX GAN ...

最新文章

  1. Nagios插件NDOUtils安装
  2. 2月书讯 | 冬奥结束看什么?看看“天花板”级别新作!
  3. 兰山天书(贺兰山岩画)
  4. centos7无法使用epel的解决方法
  5. sql server中截取字符串的常用函数(自己经常到用的时候想不起来所以拿到这里)...
  6. SQL Prompt 没激活联网后突然无法使用 解决办法
  7. bat计算机清理原理,电脑清理系统垃圾bat的操作步骤
  8. 论文导读 | 图上的可达性问题
  9. MyBatis条件查询
  10. 油罐清洗抽吸系统设计
  11. 解决ffmpeg合并视频后播放条拖不动,画面出错的问题
  12. 解决jQuery(e).addclass(‘xxx‘)始终不生效的问题 - $(...).addclass is not a function
  13. mininet和ryu简单实现自定义topo
  14. 【笔记本电脑连接真无线 jbl flash x耳机】pin 是 000000
  15. hive的一些常用命令
  16. PMP快速通过经验分享
  17. 自定义ironic-python-agent镜像 ipa ramdisk and kernel
  18. 彩票小贩潜伏50天惊人绽放携12人合买中52万
  19. 时间运算函数 CATT_ADD_TO_TIME
  20. 腾讯云TCA认证证书含金量如何?考试通知哪里获取?

热门文章

  1. Java线程池 与Lambda
  2. PHP Windows系统下调用OpenOffice
  3. 寄售Consignment和VMI有什么区别?
  4. JAVA随机数生成 | Math.random()方法 | 随机生成int、double类型
  5. Windws Server 2012 Server Backup(备份与还原)
  6. 经典排序算法(十三)--奇偶排序Odd-even Sort
  7. C++11 chrono库
  8. 10. 了解分配子(allocateor)的约定和限制
  9. 微服务学习之Eureka注册中心集群环境构建【Hoxton.SR1版】
  10. c# 存储图片到oracle,c# winform 读取oracle中blob字段的图片并且显示到pictureBox里 保存进库...