【GANs】Deep Convolution Generative Adversarial Network

  • 3 DCGAN
    • 3.1 简介
    • 3.2 DGGAN实现

3 DCGAN

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

3.1 简介


深度卷积神经网络(Deep Convolution Generative Adversarial Nets)的生成网络

在DCGAN中,

  • 判别网络是一个传统的深度卷积神经网络,但使用了带步长的卷积来实现下采样操作,不用 m a x p o o l i n g maxpooling maxpooling操作;
  • 生成网络使用了一个特殊的深度卷积网络来实现。如上图,使用微步卷积来生成 64 × 64 64×64 64×64大小的图像。第一层是全连接层,输入是从均匀分布中随机采样的100维向量 z z z,输出是 4 × 4 × 1024 4×4×1024 4×4×1024的向量,后面是四层微步卷积,没有汇聚层。

DCGAN的主要优点是通过一些经验性的网络结构设计使得对抗训练更加稳定。比如:

  • 使用带步长的卷积(在判别网络中)和微步卷积(在生成网络中)来代替汇聚操作,以免损失信息;
  • 使用批量归一化;
  • 去除卷积层之后的全连接层;
  • 在生成网络中,除了最后一层使用了 T a n h Tanh Tanh激活函数以外,其余层都使用 R e L U ReLU ReLU函数;
  • 在判别网络中,都是用 L e a k y R e L U LeakyReLU LeakyReLU函数。

3.2 DGGAN实现

# DCGAN_2016.py
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Input
from tensorflow.keras.layers import UpSampling2D, Conv2D, Activation, ZeroPadding2D, GlobalAveragePooling2D
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adamclass DCGAN():def __init__(self):# 输入shapeself.img_rows = 28self.img_cols = 28self.channels = 1self.img_shape = (self.img_rows, self.img_cols, self.channels)# 分十类self.num_classes = 10self.latent_dim = 100# adam优化器optimizer = Adam(0.0002, 0.5)# 判别模型self.discriminator = self.build_discriminator()self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# 生成模型self.generator = self.build_generator()# conbine是生成模型和判别模型的结合# 判别模型的trainable为False# 用于训练生成模型z = Input(shape=(self.latent_dim,))img = self.generator(z)self.discriminator.trainable = Falsevalid = self.discriminator(img)self.combined = Model(z, valid)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)def build_generator(self):model = Sequential()# 先全连接到64*7*7的维度上model.add(Dense(32 * 7 * 7, activation="relu", input_dim=self.latent_dim))# reshape成特征层的样式model.add(Reshape((7, 7, 32)))# 7, 7, 64model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))# 上采样# 7, 7, 64 -> 14, 14, 64model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))# 上采样# 14, 14, 128 -> 28, 28, 64model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))# 上采样# 28, 28, 64 -> 28, 28, 1model.add(Conv2D(self.channels, kernel_size=3, padding="same"))model.add(Activation("tanh"))model.summary()noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):model = Sequential()# 28, 28, 1 -> 14, 14, 32model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))# 14, 14, 32 -> 7, 7, 64model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))# 7, 7, 64 -> 4, 4, 128model.add(ZeroPadding2D(((0, 1), (0, 1))))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(GlobalAveragePooling2D())# 全连接model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)def train(self, epochs, batch_size=128, save_interval=50):(X_train, _), (_, _) = mnist.load_data()       # 载入数据X_train = X_train / 127.5 - 1.  # 归一化X_train = np.expand_dims(X_train, axis=3)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# --------------------- ##  训练判别模型# --------------------- #idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)# 训练并计算lossd_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  训练生成模型# ---------------------g_loss = self.combined.train_on_batch(noise, valid)print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))if epoch % save_interval == 0:self.save_imgs(epoch)def save_imgs(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1fig.savefig("images/mnist_%d.png" % epoch)plt.close()if __name__ == '__main__':# if not os.path.exists("./images"):#     os.makedirs("./images")dcgan = DCGAN()dcgan.train(epochs=20000, batch_size=256, save_interval=500)

【GANs】Deep Convolution Generative Adversarial Network相关推荐

  1. 【GAN】《ENERGY-BASED GENERATIVE ADVERSARIAL NETWORKS》 ICLR‘17

    <ENERGY-BASED GENERATIVE ADVERSARIAL NETWORKS> 先介绍EBGAN,再详细解读为什么这样做. Auto-encoder Discriminato ...

  2. Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network论文翻译——中英文对照

    文章作者:Tyan 博客:noahsnail.com  |  CSDN  |  简书 声明:作者翻译论文仅为学习,如有侵权请联系作者删除博文,谢谢! 翻译论文汇总:https://github.com ...

  3. 【论文翻译】中英对照翻译--(Attentive Generative Adversarial Network for Raindrop Removal from A Single Image)

    [开始时间]2018.10.08 [完成时间]2018.10.09 [论文翻译]Attentive GAN论文中英对照翻译--(Attentive Generative Adversarial Net ...

  4. 【GANs】Generative Adversarial Nets

    [GANs]Generative Adversarial Nets 1 GAN 1.1 GANs的简介 1.2 思想与目标函数 1.3 GAN代码 1.4 全局最优推导 1.5 GANs方向展望 1 ...

  5. 【GANs】Conditional Generative Adversarial Nets

    [GANs]Conditional Generative Adversarial Nets 2 CGAN 2.1 CGAN简介 前言 流程图 目标函数 2.2 CGAN代码 2 CGAN 2.1 CG ...

  6. 【论文笔记】DOA-GAN: Dual-Order Attentive Generative Adversarial Network for Image Copy-move Forgery Detec

    DOA-GAN: Dual-Order Attentive Generative Adversarial Network for Image Copy-move Forgery Detection a ...

  7. 【深度学习】Generative Adversarial Network 生成式对抗网络(GAN)

    文章目录 一.神经网络作为生成器 1.1 什么是生成器? 1.2 为什么需要输出一个分布? 1.3 什么时候需要生成器? 二.Generative Adversarial Network 生成式对抗网 ...

  8. GANs学习系列(8):Deep Convolutional Generative Adversarial Nerworks,DCGAN

    [前言]      本文首先介绍生成式模型,然后着重梳理生成式模型(Generative Models)中生成对抗网络(Generative Adversarial Network)的研究与发展.作者 ...

  9. Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

    Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network Ledig C, Theis ...

最新文章

  1. 1071 mysql_mysql 出现1071错误怎么办
  2. 大叔手记(16):分析URL Routing和URL Rewriting两者之间的不同
  3. SQL查询库、表,列等的一些操作
  4. docker安装redis提示没有日记写入权限_浅析Linux下Redis的攻击面(一)
  5. MySQL事务隔离级别详解
  6. Java一种错误的实例化方法:在默认无参构造函数中进行实例化
  7. 主存和cache每一块相等_CPU中的Cache实现原理
  8. centos7.5 安装配置supervisor管理python进程(也就是服务)
  9. MS SQL 表字段增加,删除,修改
  10. VIM编辑器的常用命令
  11. 一个简单的小程序demo
  12. 第一次结对作业:原型设计
  13. 一看就会的ios配置证书及描述文件
  14. 数据库发展的三个阶段及特点
  15. 安卓机如果相册不选图片就退出_微商相册如何在「多台设备登录」?
  16. linux 内存告警门限,H3C LA系列无线网关 配置指导(V7)-R0304-6W100_基础配置指导_设备管理配置-新华三集团-H3C...
  17. 计算机图形学之二维平移旋转缩放代码
  18. Youtube内容正在失控
  19. Windows8 照片查看器,图片发黄解决方法~
  20. 【MPLAB X IPE】:XIPE烧写教程

热门文章

  1. DynamicModuleUtility对象在.net不同版本下的兼容性问题
  2. UX术语详解:任务流,用户流,流程图以及其它全新术语
  3. 关于srand()与rand()函数的理解-----必看系列
  4. 【信号与系统】(六)连续系统的时域分析 ——冲激响应与阶跃响应
  5. 常见基本编程练习与思考
  6. 计算机网络—数据链路层
  7. 办北京居住证,定制社保缴费记录,个人权益记录最近6个月的查询与打印,社保,北京市社会保险,北京市社会保险网上服务平台,北京市社会保险网上申报查询系统
  8. MATLAB代码实现三次样条插值
  9. 如何看linux版本
  10. MATLAB处理EXCEL文件