深度卷积生成对抗网络(DCGAN)原理与实现(采用Tensorflow2.x)

  • GAN直观理解
  • DCGAN网络结构
  • GAN训练目标
  • DCGAN实现
    • 数据加载
    • 网络
      • 鉴别网络
      • 生成网络
    • 网络训练
      • 定义损失函数
      • 实例化网络及优化器
      • 训练
    • 效果展示
      • 定义可视化函数
      • 可视化
  • 效果
  • 小问题
  • 后记

GAN直观理解

Ian Goodfellow 在首次提出GAN,使用了形象的比喻来介绍 GAN 模型:生成网络 G 的功能就是产生逼真的假钞试图欺骗鉴别器 D,鉴别器 D 通过学习真钞和生成器 G 生成的假钞来掌握钞票的鉴别方法。这两个网络在相互博弈中进行训练,直到生成器 G 产生的假钞使鉴别器 D 难以分辨。而DCGAN是使用卷积操作和反卷积操作来替代原始GAN中的全连接操作。

DCGAN网络结构

GAN包含生成网络(Generator, G )和判别网络(Discriminator, D ),其中 G 用于学习数据的真实分布, D 用于将 G 生成的数据与真实样本区分开。

生成网络G(z)G(z)G(z) G 从先验分布pz(⋅)p_z(\cdot )pz​(⋅)中采样潜变量z∼pz(⋅)z\sim p_z(\cdot)z∼pz​(⋅),通过 G 学习分布pg(x∣z)p_g (x|z)pg​(x∣z),获得生成样本x∼pg(x∣z)x\sim ~p_g (x|z)x∼ pg​(x∣z)。其中潜变量z的先验分布pz(⋅)p_z (\cdot)pz​(⋅)可以假设为常见的分布。
判别网络D(x)D(x)D(x) **D**是一个二分类网络,它判断采样自真实数据分布pr(⋅)p_r (\cdot)pr​(⋅)的数据xr∼pr(⋅)x_r\sim p_r(\cdot )xr​∼pr​(⋅)和采样自生成网络的生成的数据xf∼pg(x∣z)x_f\sim p_g (x|z)xf​∼pg​(x∣z),判别网络的训练数据集由xrx_rxr​和xfx_fxf​组成。真实样本xrx_rxr​的标签标为1,生成网络产生的样本xfx_fxf​标为0,通过最小化判别网络 D 的预测值与标签之间的误差来优化判别网络。

GAN训练目标

判别网络目标是分辨出真样本xrx_rxr​与假样本xfx_fxf​。它的目标是最小化预测值和真实值之间的交叉熵损失函数:
mθinL=CE(Dθ(xr),yr,Dθ(xf),yf)\underset{θ}min \mathcal L = CE(D_θ (x_r ), y_r,D_θ (x_f),y_f)θm​inL=CE(Dθ​(xr​),yr​,Dθ​(xf​),yf​)
CE表示交叉熵损失函数CrossEntropy:
L=−∑xr∼pr(⋅)logDθ(xr)−∑xf∼pg(⋅)log(1−Dθ(xf))\mathcal L = − \sum_{x_r \sim p_r (\cdot)}logD_θ (x_r ) −\sum_{x_f \sim p_g (\cdot) } log (1 − D_θ (x_f ))L=−xr​∼pr​(⋅)∑​logDθ​(xr​)−xf​∼pg​(⋅)∑​log(1−Dθ​(xf​))
判别网络 D 的优化目标是:
θ∗=aθrgmin−∑xr∼pr(⋅)logDθ(xr)−∑xf∼pg(⋅)log(1−Dθ(xf))θ^∗ = \underset{θ}argmin − \sum_{x_r \sim p_r (\cdot)}logD_θ (x_r ) −\sum_{x_f \sim p_g (\cdot) } log (1 − D_θ (x_f ))θ∗=θa​rgmin−xr​∼pr​(⋅)∑​logDθ​(xr​)−xf​∼pg​(⋅)∑​log(1−Dθ​(xf​))
把minLmin \mathcal LminL转换为max−Lmax −\mathcal Lmax−L:
θ∗=aθrgmaxExr∼pr(⋅)logDθ(xr)+Exf∼pg(⋅)log(1−Dθ(xf))θ^∗ = \underset{θ}argmax \mathbb E_{x_r \sim p_r (\cdot)}\ logD_θ (x_r ) +\mathbb E_{x_f \sim p_g (\cdot) } log (1 − D_θ (x_f ))θ∗=θa​rgmaxExr​∼pr​(⋅)​ logDθ​(xr​)+Exf​∼pg​(⋅)​log(1−Dθ​(xf​))
对于生成网络G(z)G(z)G(z),希望生成数据能够骗过判别网络 D,假样本xfx_fxf​在判别网络的输出越接近真实的标签越好。即在训练生成网络时,希望判别网络的输出D(G(z))D(G(z))D(G(z))越逼近 1 越好,最小化D(G(z))D(G(z))D(G(z))与 1 之间的交叉熵损失函数:
mφinL=CE(D(Gφ(z)),1)=−logD(Gφ(z))\underset{φ}min \mathcal L= CE (D (G_φ (z)) , 1) = −logD (G_φ (z))φm​inL=CE(D(Gφ​(z)),1)=−logD(Gφ​(z))
把minLmin \mathcal LminL转换为max−Lmax −\mathcal Lmax−L:
φ∗=aφrgminL=Ez∼pz(⋅)log[1−D(Gφ(z))]φ^∗ =\underset{φ} argmin \mathcal L = \mathbb E_{z\sim p_z(\cdot)}log[1 − D(G_φ(z))]φ∗=φa​rgminL=Ez∼pz​(⋅)​log[1−D(Gφ​(z))]
其中φφφ为生成网络 G 的参数。
在训练过程中迭代训练鉴别器和生成器

DCGAN实现

使用cifar10的训练集作为GAN训练集实现DCGAN。

数据加载

加载cifar10的训练集,并对数据进行预处理

import tensorflow as tf
from tensorflow import keras
import numpy as np#批大小
batch_size = 64
(train_x,_),_ = keras.datasets.cifar10.load_data()
#数据归一化
train_x = train_x / (255. / 2) - 1
print(train_x.shape)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)

网络

网络由鉴别网络与生成网络构成

鉴别网络

class Discriminator(keras.Model):def __init__(self):super(Discriminator,self).__init__()filters = 64self.conv1 = keras.layers.Conv2D(filters,4,2,'valid',use_bias=False)self.bn1 = keras.layers.BatchNormalization()self.conv2 = keras.layers.Conv2D(filters*2,4,2,'valid',use_bias=False)self.bn2 = keras.layers.BatchNormalization()self.conv3 = keras.layers.Conv2D(filters*4,3,1,'valid',use_bias=False)self.bn3 = keras.layers.BatchNormalization()self.conv4 = keras.layers.Conv2D(filters*8,3,1,'valid',use_bias=False)self.bn4 = keras.layers.BatchNormalization()#全局池化self.pool = keras.layers.GlobalAveragePooling2D()self.flatten = keras.layers.Flatten()self.fc = keras.layers.Dense(1)def call(self,inputs,training=True):x = inputsx = tf.nn.leaky_relu(self.bn1(self.conv1(x),training=training))x = tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))x = tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))x = tf.nn.leaky_relu(self.bn4(self.conv4(x),training=training))x = self.pool(x)x = self.flatten(x)logits = self.fc(x)return logits

生成网络

class Generator(keras.Model):def __init__(self):super(Generator,self).__init__()filters = 64self.conv1 = keras.layers.Conv2DTranspose(filters*4,4,1,'valid',use_bias=False)self.bn1 = keras.layers.BatchNormalization()self.conv2 = keras.layers.Conv2DTranspose(filters*3,4,2,'same',use_bias=False)self.bn2 = keras.layers.BatchNormalization()self.conv3 = keras.layers.Conv2DTranspose(filters*1,4,2,'same',use_bias=False)self.bn3 = keras.layers.BatchNormalization()self.conv4 = keras.layers.Conv2DTranspose(3,4,2,'same',use_bias=False)def call(self,inputs,training=False):x = inputsx = tf.reshape(x,(x.shape[0],1,1,x.shape[1]))x = tf.nn.relu(x)x = tf.nn.relu(self.bn1(self.conv1(x),training=training))x = tf.nn.relu(self.bn2(self.conv2(x),training=training))x = tf.nn.relu(self.bn3(self.conv3(x),training=training))x = self.conv4(x)x = tf.tanh(x)return x

网络训练

训练时可以训练鉴别器多次然后训练一次生成器

定义损失函数

def celoss_ones(logits):# 计算属于与标签为1的交叉熵y = tf.ones_like(logits)loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)return tf.reduce_mean(loss)def celoss_zeros(logits):# 计算属于与标签为0的交叉熵y = tf.zeros_like(logits)loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)return tf.reduce_mean(loss)def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):# 计算鉴别器的损失函数# 采样生成图片fake_image = generator(batch_z, is_training)# 判定生成图片d_fake_logits = discriminator(fake_image, is_training)# 判定真实图片d_real_logits = discriminator(batch_x, is_training)# 真实图片与1之间的误差d_loss_real = celoss_ones(d_real_logits)# 生成图片与0之间的误差d_loss_fake = celoss_zeros(d_fake_logits)# 合并误差loss = d_loss_fake + d_loss_realreturn lossdef g_loss_fn(generator, discriminator, batch_z, is_training):#计算生成器的损失函数# 采样生成图片fake_image = generator(batch_z, is_training)# 在训练生成网络时,需要迫使生成图片判定为真d_fake_logits = discriminator(fake_image, is_training)# 计算生成图片与1之间的误差loss = celoss_ones(d_fake_logits)return loss

实例化网络及优化器

#定义超参数
#潜变量维度
z_dim = 100
#epoch大小
epochs = 300
#批大小
batch_size = 64
#学习率
lr = 0.0002
is_training = True
#实例化网络
discriminator = Discriminator()
discriminator.build(input_shape=(4,32,32,3))
discriminator.summary()
generator = Generator()
generator.build(input_shape=(4,z_dim))
generator.summary()
#实例化优化器
g_optimizer = keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
d_optimizer = keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)

训练

#统计损失值
d_losses = []
g_losses = []
for epoch in range(epochs):for _,batch_x in enumerate(dataset):batch_z = tf.random.normal([batch_size,z_dim])with tf.GradientTape() as tape:d_loss = d_loss_fn(generator,discriminator,batch_z,batch_x,is_training)grads = tape.gradient(d_loss,discriminator.trainable_variables)d_optimizer.apply_gradients(zip(grads,discriminator.trainable_variables))with tf.GradientTape() as tape:g_loss = g_loss_fn(generator,discriminator,batch_z,is_training)grads = tape.gradient(g_loss,generator.trainable_variables)g_optimizer.apply_gradients(zip(grads,generator.trainable_variables))

效果展示

训练测试,可以通过调整超参数来获得更好的效果。

定义可视化函数

def save_result(val_out,val_block_size,image_path,color_mode):def preprocessing(img):img = ((img + 1.0)*(255./2)).astype(np.uint8)return imgpreprocessed = preprocessing(val_out)final_image = np.array([])single_row = np.array([])for b in range(val_out.shape[0]):# concat image into a rowif single_row.size == 0:single_row = preprocessed[b,:,:,:]else:single_row = np.concatenate((single_row,preprocessed[b,:,:,:]),axis=1)# concat image row to final_imageif (b+1) % val_block_size == 0:if final_image.size == 0:final_image = single_rowelse:final_image = np.concatenate((final_image, single_row), axis=0)# reset single rowsingle_row = np.array([])if final_image.shape[2] == 1:final_image = np.squeeze(final_image, axis=2)Image.fromarray(final_image).save(image_path)

可视化

在一定epoch后,保存生成结果

        if epoch % 2 == 0:batch_z = tf.random.normal([batch_size,z_dim])print(epoch,'d_loss:',float(d_loss),'g_loss:',float(g_loss))#可视化z = tf.random.normal([100,z_dim])fake_image = generator(z,training=False)img_path = r'gan-{}.png'.format(epoch)save_result(fake_image.numpy(),10,img_path,color_mode='P')d_losses.append(float(d_loss))g_losses.append(float(g_loss))

效果

训练26epoch的效果

小问题

tensorflow2.2训练报错

cuDNN launch failure : input shape ([64,4,4,512]) [Op:FusedBatchNormV3]

好像是批归一化层用于input层后有问题,升级tensorflow可以解决.

后记

Enjoy learning.

深度卷积生成对抗网络(DCGAN)原理与实现(采用Tensorflow2.x)相关推荐

  1. 深度卷积生成对抗网络DCGAN之实现动漫头像的生成(基于keras Tensorflow2.0实现)

    起飞目录 DCGAN简介 反卷积(上采样upsampling2D) 数据集 代码实战 数据导入和预处理 生成器G 判别器D 训练模块 完整代码 结果 2020 8/13补充 DCGAN简介 原始GAN ...

  2. 深度卷积生成对抗网络--DCGAN

    本问转自:https://ask.julyedu.com/question/7681,详情请查看原文 --前言:如何把CNN与GAN结合?DCGAN是这方面最好的尝试之一,DCGAN的原理和GAN是一 ...

  3. 深度卷积生成对抗网络DCGAN——生成手写数字图片

    前言 本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的. 本文用到 ...

  4. 对抗生成网络_深度卷积生成对抗网络

    本教程演示了如何使用深度卷积生成对抗网络(DCGAN)生成手写数字图片.该代码是使用 Keras Sequential API 与 tf.GradientTape 训练循环编写的. 什么是生成对抗网络 ...

  5. 生成对抗网络简介,深度卷积生成对抗网络(DCGAN)简介

    本博客是个人学习的笔记,讲述的是生成对抗网络(generate adversarial network ) 的一种架构:深度生成对抗网络 的简单介绍,下一节将使用 tensorflow 搭建 DCGA ...

  6. DCGAN——深度卷积生成对抗网络

    译文 | 让深度卷积网络对抗:DCGAN--深度卷积生成对抗网络 原文: https://arxiv.org/pdf/1511.06434.pdf -- 前言:如何把CNN与GAN结合?DCGAN是这 ...

  7. 深度卷积生成对抗网络

    深度卷积生成对抗网络 Deep Convolutional Generative Adversarial Networks GANs如何工作的基本思想.可以从一些简单的,易于抽样的分布,如均匀分布或正 ...

  8. 理解与学习深度卷积生成对抗网络

    一.GAN 引言:生成对抗网络GAN,是当今的一大热门研究方向.在2014年,被Goodfellow大神提出来,当时的G神还是蒙特利尔大学的博士生.据有关媒体统计:CVPR2018的论文里,有三分之一 ...

  9. GAN生成对抗网络-DCGAN原理与基本实现-深度卷积生成对抗网络03

    什么是DCGAN 实现代码 import tensorflow as tf from tensorflow import keras from tensorflow.keras import laye ...

最新文章

  1. JavaScript作用域学习笔记
  2. 编写java程序的常见问题_Java程序的编写与执行、Java新手常见的问题解决
  3. hive like 模糊匹配
  4. maven jetty/tomcat/wildfly plugin部署应用到本地容器
  5. ik分词器实现原理_SpringBoot整合Elasticsearch实现商品搜索
  6. 【PyTorch】eval() ==>主要是针对某些在train和predict两个阶段会有不同参数的层,比如Dropout层和BN层
  7. react antd confirm content list_react简单的项目架构搭建过程
  8. I,P,B帧和PTS,DTS的关系 转载
  9. pythonpandas无列名数据合并_python – Pandas:合并多个数据帧和控制列名?
  10. 高效开发Android App的10个建议
  11. 《Gradle实战》如何配置利用Maven本地仓库
  12. 二分类模型评价指标-总结
  13. apa引用要在文中吗_英文论文格式要求玩转APA
  14. Vue+elementUI导出xlsl表格,支持复杂表头,自动合拼单元格。xlsx+file-saver插件
  15. 写代码遇到的灵异事件
  16. API文档,已取消到该网页的导航
  17. request_threaded_irq
  18. c语言中v作用是什么意思,C语言里,\v是什么意思?
  19. 1134: 字符串转换 C语言
  20. 【Java.JMS】一个简单的JMS实例

热门文章

  1. require() 源码解读
  2. SpringAOP-基于@AspectJ的简单入门
  3. install intel c/c++ compiler
  4. Hacker News的全文输出RSS地址
  5. 【Assembly】Mixed mode dll unable to load in .net 4.0
  6. Silverlight 3正式版新鲜出炉
  7. 狼来了!中国房地产的实质--比喻太生动了
  8. [转载] python函数isdisjoint方法_Python中的isdisjoint()函数
  9. 第二周四则运算汇报及总结
  10. UCOSII学习笔记[开篇]