一、背景

SAGAN全称为Self-Attention Generative Adversarial Networks,是由Han Zhang等人[1]于18年5月提出的一种模型。文章中作者解释到,传统的GAN模型都是在低分辨率特征图的空间局部点上来生成高分辨率的细节,而SAGAN是可以从所有的特征处生成细节,并且SAGAN的判别器可以判别两幅具有明显差异的图像是否具有一致的高度精细特征。SAGAN目前是取得了非常好的效果。

本文以CelebA数据集为例,用SAGAN生成更为精细的人脸图像,主要参考代码为[2]。

[1]文章链接:https://arxiv.org/pdf/1805.08318.pdf

[2]参考代码:https://github.com/taki0112/Self-Attention-GAN-Tensorflow

二、SAGAN原理

感觉入门系列的GAN文章网上的介绍还挺多,越新的文章解读越少,这里简单推荐一篇:

[3]SA-GAN - Self-Attention Generative Adversarial Networks 论文解读(附代码)

下面是自己对于文献的一些理解和介绍。

在我的上一篇文章GAN系列文章中:对抗神经网络学习(十)——attentiveGAN实现影像去雨滴的过程(tensorflow实现),初次了解到了attentiveNet引入GAN中的优势,引入attentiveNet来生成attentive map,能够让网络快速准确的定位到图像中的重点关注区域,当时就隐约觉得可以用这个思路来进一步优化GAN的模型结构,后来就看到了SAGAN采用了这个方法。

首先,作者关注GAN目前存在的问题:当我们训练多类别数据集时,GAN在某些图像类别上很难建模。通俗来说,GAN容易捕捉纹理特征但很难捕捉几何结构特征。

However, by carefully examining the generated samples from these models, we can observe that convolutional GANs have much more difficulty modeling some image classes than others when trained on multi-class datasets. For example, while the state-of-the-art ImageNet GAN model excels at synthesizing image classes with few structural constraints (e.g. ocean, sky and landscape classes, which are distinguished more by texture than by geometry), it fails to capture geometric or structural patterns that occur consistently in some classes (for example, dogs are often drawn with realistic fur texture but without clearly defined separate feet).

原因就在于这类模型依靠卷积来建立不同图像区域之间的依赖关系,而依赖关系的传递只能通过大范围的多个卷积层来实现。随着卷积大小的增加,网络的真实容量也在增加,但却损失了计算效率。而self-attentive,却能够做到依赖性和计算效率的平衡,因此文章引入self-attention机制。

作者主要的贡献在于:

In this work, we propose Self-Attention Generative Adversarial Networks (SAGANs), which introduce a self -attention mechanism into convolutional GANs. The self-attention module is complementary to convolutions and helps with modeling long range, multi-level dependencies across image regions. Armed with self-attention, the generator can draw images in which fine details at every location are carefully coordinated with fine details in distant portions of the image. Moreover, the discriminator can also more accurately enforce complicated geometric constraints on the global image structure. (作者引入self-attention机制,提出SAGAN。引入该机制后,生成器能够精细的细节,判别器能够实行几何限制。)

作者以一幅图来简单介绍self-attention:

简单来理解,最左边的图中的5个点是作者放置的5个点。而右侧的每张图对应一个点,内容则是这个点具有类似特征的区域,也就是上面说的most-attened regions。

之后,作者介绍了self-attention机制:

将self-attention机制引入到GAN中,可以得到SAGAN的loss函数:

为了使得SAGAN的训练稳定,作者还使用一点trick:①在生成器和判别器中都使用了spectral normalization
。②生成器和判别器的非平衡学习率(Imbalanced learning rate)。

最终SAGAN的表现也很好:

SAGAN significantly outperforms the state of the art in image synthesis by boosting the best reported Inception score from 36.8 to 52.52 and reducing Fréchet Inception distance from 27.62 to 18.65.

关于文章中的公式推导和参数的设置,这里不再多说,最后只展示一下作者的效果图,看起来还是很惊艳的:

三、SAGAN实现过程

1. 所有文件结构

SAGAN的文件比较少,所有文件的结构为:

-- dataset                               # 数据集文件,需要自己下载|------ CelebA|------ image1.jpg|------ image2.jpg|------ ...
-- ops.py                                # 图层文件
-- utils.py                              # 操作文件
-- SAGAN.py                              # 模型文件
-- main.py                               # 主函数文件

2. 数据准备

这次用的数据集依旧是CelebA人脸数据集,之前的文章对抗神经网络学习(六)——BEGAN实现不同人脸的生成(tensorflow实现)中有介绍这个数据集,因此这里不再多介绍,只给出这个数据集的官方下载地址:https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA。

下载好数据集之后,将其解压放到'dataset/CelebA/'路径下即可。

3. 操作文件utils.py

utils.py的所有内容如下:

import scipy.misc
import numpy as np
import os
from glob import globimport tensorflow as tf
import tensorflow.contrib.slim as slimclass ImageData:def __init__(self, load_size, channels):self.load_size = load_sizeself.channels = channelsdef image_processing(self, filename):x = tf.read_file(filename)x_decode = tf.image.decode_jpeg(x, channels=self.channels)img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])img = tf.cast(img, tf.float32) / 127.5 - 1return imgdef load_data(dataset_name, size=64) :x = glob(os.path.join("./dataset", dataset_name, '*.*'))return xdef preprocessing(x, size):x = scipy.misc.imread(x, mode='RGB')x = scipy.misc.imresize(x, [size, size])x = normalize(x)return xdef normalize(x):return x/127.5 - 1def save_images(images, size, image_path):return imsave(inverse_transform(images), size, image_path)def merge(images, size):h, w = images.shape[1], images.shape[2]if images.shape[3] in (3, 4):c = images.shape[3]img = np.zeros((h * size[0], w * size[1], c))for idx, image in enumerate(images):i = idx % size[1]j = idx // size[1]img[j * h:j * h + h, i * w:i * w + w, :] = imagereturn imgelif images.shape[3] == 1:img = np.zeros((h * size[0], w * size[1]))for idx, image in enumerate(images):i = idx % size[1]j = idx // size[1]img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]return imgelse:raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')def imsave(images, size, path):# image = np.squeeze(merge(images, size)) # 채널이 1인거 제거 ?return scipy.misc.imsave(path, merge(images, size))def inverse_transform(images):return (images+1.)/2.def check_folder(log_dir):if not os.path.exists(log_dir):os.makedirs(log_dir)return log_dirdef show_all_variables():model_vars = tf.trainable_variables()slim.model_analyzer.analyze_vars(model_vars, print_info=True)def str2bool(x):return x.lower() in ('true')

4. 图层文件ops.py

图层文件ops.py的主要内容为:

import tensorflow as tf
import tensorflow.contrib as tf_contrib# Xavier : tf_contrib.layers.xavier_initializer()
# He : tf_contrib.layers.variance_scaling_initializer()
# Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
# l2_decay : tf_contrib.layers.l2_regularizer(0.0001)weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = None##################################################################################
# Layer
##################################################################################def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):with tf.variable_scope(scope):if pad_type == 'zero' :x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])if pad_type == 'reflect' :x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT')if sn :w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,regularizer=weight_regularizer)x = tf.nn.conv2d(input=x, filter=spectral_norm(w),strides=[1, stride, stride, 1], padding='VALID')if use_bias :bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))x = tf.nn.bias_add(x, bias)else :x = tf.layers.conv2d(inputs=x, filters=channels,kernel_size=kernel, kernel_initializer=weight_init,kernel_regularizer=weight_regularizer,strides=stride, use_bias=use_bias)return xdef deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'):with tf.variable_scope(scope):x_shape = x.get_shape().as_list()if padding == 'SAME':output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]else:output_shape =[x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), x_shape[2] * stride + max(kernel - stride, 0), channels]if sn :w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer)x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding)if use_bias :bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))x = tf.nn.bias_add(x, bias)else :x = tf.layers.conv2d_transpose(inputs=x, filters=channels,kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer,strides=stride, padding=padding, use_bias=use_bias)return xdef fully_conneted(x, units, use_bias=True, sn=False, scope='fully_0'):with tf.variable_scope(scope):x = flatten(x)shape = x.get_shape().as_list()channels = shape[-1]if sn :w = tf.get_variable("kernel", [channels, units], tf.float32,initializer=weight_init, regularizer=weight_regularizer)if use_bias :bias = tf.get_variable("bias", [units],initializer=tf.constant_initializer(0.0))x = tf.matmul(x, spectral_norm(w)) + biaselse :x = tf.matmul(x, spectral_norm(w))else :x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)return xdef flatten(x) :return tf.layers.flatten(x)def hw_flatten(x) :return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])##################################################################################
# Residual-block
##################################################################################def resblock(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock'):with tf.variable_scope(scope):with tf.variable_scope('res1'):x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)x = batch_norm(x, is_training)x = relu(x)with tf.variable_scope('res2'):x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)x = batch_norm(x, is_training)return x + x_init##################################################################################
# Sampling
##################################################################################def global_avg_pooling(x):gap = tf.reduce_mean(x, axis=[1, 2])return gapdef up_sample(x, scale_factor=2):_, h, w, _ = x.get_shape().as_list()new_size = [h * scale_factor, w * scale_factor]return tf.image.resize_nearest_neighbor(x, size=new_size)##################################################################################
# Activation function
##################################################################################def lrelu(x, alpha=0.2):return tf.nn.leaky_relu(x, alpha)def relu(x):return tf.nn.relu(x)def tanh(x):return tf.tanh(x)##################################################################################
# Normalization function
##################################################################################def batch_norm(x, is_training=True, scope='batch_norm'):return tf_contrib.layers.batch_norm(x,decay=0.9, epsilon=1e-05,center=True, scale=True, updates_collections=None,is_training=is_training, scope=scope)def spectral_norm(w, iteration=1):w_shape = w.shape.as_list()w = tf.reshape(w, [-1, w_shape[-1]])u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)u_hat = uv_hat = Nonefor i in range(iteration):"""power iterationUsually iteration = 1 will be enough"""v_ = tf.matmul(u_hat, tf.transpose(w))v_hat = l2_norm(v_)u_ = tf.matmul(v_hat, w)u_hat = l2_norm(u_)sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))w_norm = w / sigmawith tf.control_dependencies([u.assign(u_hat)]):w_norm = tf.reshape(w_norm, w_shape)return w_normdef l2_norm(v, eps=1e-12):return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)##################################################################################
# Loss function
##################################################################################def discriminator_loss(loss_func, real, fake):real_loss = 0fake_loss = 0if loss_func == 'lsgan' :real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0))fake_loss = tf.reduce_mean(tf.square(fake))if loss_func == 'gan' or loss_func == 'dragan' :real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))if loss_func == 'hinge' :real_loss = tf.reduce_mean(relu(1.0 - real))fake_loss = tf.reduce_mean(relu(1.0 + fake))loss = real_loss + fake_lossreturn lossdef generator_loss(loss_func, fake):fake_loss = 0if loss_func == 'lsgan' :fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0))if loss_func == 'gan' or loss_func == 'dragan' :fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))if loss_func == 'hinge' :fake_loss = -tf.reduce_mean(fake)loss = fake_lossreturn loss

5. 模型文件SAGAN.py

SAGAN.py文件的主要内容为:

import time
from ops import *
from utils import *
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batchclass SAGAN(object):def __init__(self, sess, args):self.model_name = "SAGAN"  # name for checkpointself.sess = sessself.dataset_name = args.datasetself.checkpoint_dir = args.checkpoint_dirself.sample_dir = args.sample_dirself.result_dir = args.result_dirself.log_dir = args.log_dirself.epoch = args.epochself.iteration = args.iterationself.batch_size = args.batch_sizeself.print_freq = args.print_freqself.save_freq = args.save_freqself.img_size = args.img_size""" Generator """self.layer_num = int(np.log2(self.img_size)) - 3self.z_dim = args.z_dim  # dimension of noise-vectorself.up_sample = args.up_sampleself.gan_type = args.gan_type""" Discriminator """self.n_critic = args.n_criticself.sn = args.snself.ld = args.ldself.sample_num = args.sample_num  # number of generated images to be savedself.test_num = args.test_num# trainself.g_learning_rate = args.g_lrself.d_learning_rate = args.d_lrself.beta1 = args.beta1self.beta2 = args.beta2self.custom_dataset = Falseif self.dataset_name == 'mnist' :self.c_dim = 1self.data = load_mnist(size=self.img_size)elif self.dataset_name == 'cifar10' :self.c_dim = 3self.data = load_cifar10(size=self.img_size)else :self.c_dim = 3self.data = load_data(dataset_name=self.dataset_name, size=self.img_size)self.custom_dataset = Trueself.dataset_num = len(self.data)self.sample_dir = os.path.join(self.sample_dir, self.model_dir)check_folder(self.sample_dir)print()print("##### Information #####")print("# gan type : ", self.gan_type)print("# dataset : ", self.dataset_name)print("# dataset number : ", self.dataset_num)print("# batch_size : ", self.batch_size)print("# epoch : ", self.epoch)print("# iteration per epoch : ", self.iteration)print()print("##### Generator #####")print("# generator layer : ", self.layer_num)print("# upsample conv : ", self.up_sample)print()print("##### Discriminator #####")print("# discriminator layer : ", self.layer_num)print("# the number of critic : ", self.n_critic)print("# spectral normalization : ", self.sn)################################################################################### Generator##################################################################################def generator(self, z, is_training=True, reuse=False):with tf.variable_scope("generator", reuse=reuse):ch = 1024x = deconv(z, channels=ch, kernel=4, stride=1, padding='VALID', use_bias=False, sn=self.sn, scope='deconv')x = batch_norm(x, is_training, scope='batch_norm')x = relu(x)for i in range(self.layer_num // 2):if self.up_sample:x = up_sample(x, scale_factor=2)x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm_' + str(i))x = relu(x)else:x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm_' + str(i))x = relu(x)ch = ch // 2# Self Attentionx = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)for i in range(self.layer_num // 2, self.layer_num):if self.up_sample:x = up_sample(x, scale_factor=2)x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm_' + str(i))x = relu(x)else:x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm_' + str(i))x = relu(x)ch = ch // 2if self.up_sample:x = up_sample(x, scale_factor=2)x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, sn=self.sn, scope='G_conv_logit')x = tanh(x)else:x = deconv(x, channels=self.c_dim, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='G_deconv_logit')x = tanh(x)return x################################################################################### Discriminator##################################################################################def discriminator(self, x, is_training=True, reuse=False):with tf.variable_scope("discriminator", reuse=reuse):ch = 64x = conv(x, channels=ch, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv')x = lrelu(x, 0.2)for i in range(self.layer_num // 2):x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm' + str(i))x = lrelu(x, 0.2)ch = ch * 2# Self Attentionx = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)for i in range(self.layer_num // 2, self.layer_num):x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))x = batch_norm(x, is_training, scope='batch_norm' + str(i))x = lrelu(x, 0.2)ch = ch * 2x = conv(x, channels=4, stride=1, sn=self.sn, use_bias=False, scope='D_logit')return xdef attention(self, x, ch, sn=False, scope='attention', reuse=False):with tf.variable_scope(scope, reuse=reuse):f = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c']g = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c']h = conv(x, ch, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c]# N = h * ws = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]beta = tf.nn.softmax(s)  # attention mapo = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]x = gamma * o + xreturn xdef gradient_penalty(self, real, fake):if self.gan_type == 'dragan' :shape = tf.shape(real)eps = tf.random_uniform(shape=shape, minval=0., maxval=1.)x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])x_std = tf.sqrt(x_var)  # magnitude of noise decides the size of local regionnoise = 0.5 * x_std * eps  # delta in paper# Author suggested U[0,1] in original paper, but he admitted it is bug in github# (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided.alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.)interpolated = tf.clip_by_value(real + alpha * noise, -1., 1.)  # x_hat should be in the space of Xelse :alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)interpolated = alpha*real + (1. - alpha)*fakelogit = self.discriminator(interpolated, reuse=True)grad = tf.gradients(logit, interpolated)[0]  # gradient of D(interpolated)grad_norm = tf.norm(flatten(grad), axis=1)  # l2 normGP = 0# WGAN - LPif self.gan_type == 'wgan-lp':GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))return GP################################################################################### Model##################################################################################def build_model(self):""" Graph Input """# imagesif self.custom_dataset :Image_Data_Class = ImageData(self.img_size, self.c_dim)inputs = tf.data.Dataset.from_tensor_slices(self.data)gpu_device = '/gpu:0'inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))inputs_iterator = inputs.make_one_shot_iterator()self.inputs = inputs_iterator.get_next()else :self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images')# noisesself.z = tf.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z')""" Loss Function """# output of D for real imagesreal_logits = self.discriminator(self.inputs)# output of D for fake imagesfake_images = self.generator(self.z)fake_logits = self.discriminator(fake_images, reuse=True)if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :GP = self.gradient_penalty(real=self.inputs, fake=fake_images)else :GP = 0# get loss for discriminatorself.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP# get loss for generatorself.g_loss = generator_loss(self.gan_type, fake=fake_logits)""" Training """# divide trainable variables into a group for D and a group for Gt_vars = tf.trainable_variables()d_vars = [var for var in t_vars if 'discriminator' in var.name]g_vars = [var for var in t_vars if 'generator' in var.name]# optimizersself.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)"""" Testing """# for testself.fake_images = self.generator(self.z, is_training=False, reuse=True)""" Summary """self.d_sum = tf.summary.scalar("d_loss", self.d_loss)self.g_sum = tf.summary.scalar("g_loss", self.g_loss)################################################################################### Train##################################################################################def train(self):# initialize all variablestf.global_variables_initializer().run()# graph inputs for visualize training resultsself.sample_z = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))# saver to save modelself.saver = tf.train.Saver()# summary writerself.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)# restore check-point if it exitscould_load, checkpoint_counter = self.load(self.checkpoint_dir)if could_load:start_epoch = (int)(checkpoint_counter / self.iteration)start_batch_id = checkpoint_counter - start_epoch * self.iterationcounter = checkpoint_counterprint(" [*] Load SUCCESS")else:start_epoch = 0start_batch_id = 0counter = 1print(" [!] Load failed...")# loop for epochstart_time = time.time()past_g_loss = -1.for epoch in range(start_epoch, self.epoch):# get batch datafor idx in range(start_batch_id, self.iteration):batch_z = np.random.uniform(-1, 1, [self.batch_size, 1, 1, self.z_dim])if self.custom_dataset :train_feed_dict = {self.z: batch_z}else :random_index = np.random.choice(self.dataset_num, size=self.batch_size, replace=False)# batch_images = self.data[idx*self.batch_size : (idx+1)*self.batch_size]batch_images = self.data[random_index]train_feed_dict = {self.inputs : batch_images,self.z : batch_z}# update D network_, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], feed_dict=train_feed_dict)self.writer.add_summary(summary_str, counter)# update G networkg_loss = Noneif (counter - 1) % self.n_critic == 0:_, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict=train_feed_dict)self.writer.add_summary(summary_str, counter)past_g_loss = g_loss# display training statuscounter += 1if g_loss == None :g_loss = past_g_lossprint("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \% (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))# save training results for every 300 stepsif np.mod(idx+1, self.print_freq) == 0:samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z})tot_num_samples = min(self.sample_num, self.batch_size)manifold_h = int(np.floor(np.sqrt(tot_num_samples)))manifold_w = int(np.floor(np.sqrt(tot_num_samples)))save_images(samples[:manifold_h * manifold_w, :, :, :],[manifold_h, manifold_w],'./' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format(epoch, idx+1))if np.mod(idx+1, self.save_freq) == 0:self.save(self.checkpoint_dir, counter)# After an epoch, start_batch_id is set to zero# non-zero value is only for the first epoch after loading pre-trained modelstart_batch_id = 0# save modelself.save(self.checkpoint_dir, counter)# show temporal results# self.visualize_results(epoch)# save model for final stepself.save(self.checkpoint_dir, counter)@propertydef model_dir(self):return "{}_{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, self.sn)def save(self, checkpoint_dir, step):checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)def load(self, checkpoint_dir):import reprint(" [*] Reading checkpoints...")checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:ckpt_name = os.path.basename(ckpt.model_checkpoint_path)self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))print(" [*] Success to read {}".format(ckpt_name))return True, counterelse:print(" [*] Failed to find a checkpoint")return False, 0def visualize_results(self, epoch):tot_num_samples = min(self.sample_num, self.batch_size)image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))""" random condition, random noise """z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png')def test(self):tf.global_variables_initializer().run()self.saver = tf.train.Saver()could_load, checkpoint_counter = self.load(self.checkpoint_dir)result_dir = os.path.join(self.result_dir, self.model_dir)check_folder(result_dir)if could_load:print(" [*] Load SUCCESS")else:print(" [!] Load failed...")tot_num_samples = min(self.sample_num, self.batch_size)image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))""" random condition, random noise """for i in range(self.test_num) :z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],[image_frame_dim, image_frame_dim],result_dir + '/' + self.model_name + '_test_{}.png'.format(i))

6. 主文件main.py

main.py文件的主要内容为:

from SAGAN import SAGAN
import argparse
from utils import *"""parsing and configuration"""
def parse_args():desc = "Tensorflow implementation of Self-Attention GAN"parser = argparse.ArgumentParser(description=desc)parser.add_argument('--phase', type=str, default='train', help='train or test ?')parser.add_argument('--dataset', type=str, default='celebA', help='dataset name')parser.add_argument('--epoch', type=int, default=10, help='The number of epochs to run')parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')parser.add_argument('--batch_size', type=int, default=32, help='The size of batch per gpu')parser.add_argument('--print_freq', type=int, default=500, help='The number of image_print_freqy')parser.add_argument('--save_freq', type=int, default=500, help='The number of ckpt_save_freq')parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for generator')parser.add_argument('--d_lr', type=float, default=0.0004, help='learning rate for discriminator')parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for Adam optimizer')parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for Adam optimizer')parser.add_argument('--z_dim', type=int, default=128, help='Dimension of noise vector')parser.add_argument('--up_sample', type=str2bool, default=True, help='using upsample-conv')parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')parser.add_argument('--gan_type', type=str, default='hinge', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')parser.add_argument('--img_size', type=int, default=128, help='The size of image')parser.add_argument('--sample_num', type=int, default=64, help='The number of sample images')parser.add_argument('--test_num', type=int, default=10, help='The number of images generated by the test')parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',help='Directory name to save the checkpoints')parser.add_argument('--result_dir', type=str, default='results',help='Directory name to save the generated images')parser.add_argument('--log_dir', type=str, default='logs',help='Directory name to save training logs')parser.add_argument('--sample_dir', type=str, default='samples',help='Directory name to save the samples on training')return check_args(parser.parse_args())"""checking arguments"""
def check_args(args):# --checkpoint_dircheck_folder(args.checkpoint_dir)# --result_dircheck_folder(args.result_dir)# --result_dircheck_folder(args.log_dir)# --sample_dircheck_folder(args.sample_dir)# --epochtry:assert args.epoch >= 1except:print('number of epochs must be larger than or equal to one')# --batch_sizetry:assert args.batch_size >= 1except:print('batch size must be larger than or equal to one')return args"""main"""
def main():# parse argumentsargs = parse_args()if args is None:exit()# open sessionwith tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:gan = SAGAN(sess, args)# build graphgan.build_model()# show network architectureshow_all_variables()if args.phase == 'train' :# launch the graph in a sessiongan.train()# visualize learned generatorgan.visualize_results(args.epoch - 1)print(" [*] Training finished!")if args.phase == 'test' :gan.test()print(" [*] Test finished!")if __name__ == '__main__':main()

四、实现结果

写完上述所有文件之后,运行main.py即可。实验过程比较慢,设置epoch为10,每个epoch内迭代10000次,我的GPU是GTX1060 3G,每个epoch大概需要近4000秒,即1个小时5分钟,所以还是非常耗时的。目前只进行了训练过程,所以只有训练过程的生成样本。

当epoch=0, iter=500时,也就是迭代500次,生成样本为:

当epoch=1,iter=0时,也就是运算了10000次,效果为:

当epoch=2,iter=0时,即运算20000次,效果为:

当epoch=4,iter=0,即运算40000次时,效果为:

当epoch=6,iter=0,即运算60000次时,效果为:

由于loss已经很小了,而且后面的训练效果其实差距也不大,所以也就没有再继续训练了。

五、分析

1. 训练过程中似乎出现了模式崩塌的现象,因为生成的样本都非常类似,这个还需要进一步检查。

2. 之前用DCGAN、WGAN、BEGAN也做过人脸生成,下面来比较一下他们的效果:

DCGAN生成图像的大小为64*64的, 但是可以明显的看到,DCGAN的生成结果中,很多人脸的姿态都非常相似,DCGAN很容易出现模式崩塌现象,而且DCGAN生成的人脸肤色偏黑,且图像中的噪点很多,边缘非常不平滑,生成的效果比较差。

WGAN理论上解决了模式崩塌现象,生成人脸的尺寸为128*128,肤色明显自然了很多,但是生成的效果很差,边界几乎看不清,而且有的图像几乎什么也看不出来,噪点也非常多。

BEGAN生成的人脸尺寸为64*64,它是再DCGAN的基础上扩展的,效果明显好了很多,生成图像偶尔有噪点,人物的五官和轮廓都非常清晰,整体来讲效果不错。

SAGAN生成的人脸尺寸是128*128,生成的结果中,出现了模式崩塌的现象,可能是我对人脸数据集没有做预处理,生成人脸的五官是很清晰,但是五官之外的轮廓就非常糟糕了。

这几种模型,目前来看,我的实验效果是BEGAN最好,可能是因为我对其他数据集没怎么做数据预处理吧。

3. 关于celebA数据集的下载,作者也给了代码,这里也同时给出:

import os
import zipfile
import argparse
import requestsfrom tqdm import tqdmparser = argparse.ArgumentParser(description='Download dataset for SAGAN')
parser.add_argument('dataset', metavar='N', type=str, nargs='+', choices=['celebA'],help='name of dataset to download [celebA]')def download_file_from_google_drive(id, destination):URL = "https://docs.google.com/uc?export=download"session = requests.Session()response = session.get(URL, params={'id': id}, stream=True)token = get_confirm_token(response)if token:params = {'id': id, 'confirm': token}response = session.get(URL, params=params, stream=True)save_response_content(response, destination)def get_confirm_token(response):for key, value in response.cookies.items():if key.startswith('download_warning'):return valuereturn Nonedef save_response_content(response, destination, chunk_size=32 * 1024):total_size = int(response.headers.get('content-length', 0))with open(destination, "wb") as f:for chunk in tqdm(response.iter_content(chunk_size), total=total_size,unit='B', unit_scale=True, desc=destination):if chunk:  # filter out keep-alive new chunksf.write(chunk)def download_celeb_a(dirpath):data_dir = 'celebA'if os.path.exists(os.path.join(dirpath, data_dir)):print('Found Celeb-A - skip')returnfilename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"save_path = os.path.join(dirpath, filename)if os.path.exists(save_path):print('[*] {} already exists'.format(save_path))else:download_file_from_google_drive(drive_id, save_path)zip_dir = ''with zipfile.ZipFile(save_path) as zf:zip_dir = zf.namelist()[0]zf.extractall(dirpath)os.remove(save_path)os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))def prepare_data_dir(path='./dataset'):if not os.path.exists(path):os.mkdir(path)if __name__ == '__main__':args = parser.parse_args()prepare_data_dir()if any(name in args.dataset for name in ['CelebA', 'celebA', 'celebA']):download_celeb_a('./dataset')

对抗生成网络学习(十一)——SAGAN生成更为精细的人脸图像(tensorflow实现)相关推荐

  1. SAGAN生成更为精细的人脸图像(tensorflow实现)

    一.背景 SAGAN全称为Self-Attention Generative Adversarial Networks,是由Han Zhang等人[1]于18年5月提出的一种模型.文章中作者解释到,传 ...

  2. 对抗生成网络学习(十五)——starGAN实现人脸属性修改(tensorflow实现)

    一.背景 最近事情比较多,一个多月没写CSDN了,最近打算做一做satrGAN. starGAN是Yunjey Choi等人于17年11月提出的一个模型[1].该模型可以实现人脸的属性修改,原理上来说 ...

  3. 对抗生成网络学习(六)——BEGAN实现不同人脸的生成(tensorflow实现)

    一.背景 BEGAN,即边界平衡GAN(Boundary Equilibrium GAN),是DavidBerthelot等人[1]于2017年03月提出的一种方法.传统的GAN是利用判别器去评估生成 ...

  4. 对抗生成网络学习(四)——WGAN+爬虫生成皮卡丘图像(tensorflow实现)

    一.背景 WGAN的全称为Wasserstein GAN, 是Martin Arjovsky等人于17年1月份提出的一个模型,该文章可以参考[1].WGAN针对GAN存在的问题进行了有针对性的改进,但 ...

  5. 对抗生成网络学习(十四)——DRAGAN对模型倒塌问题的处理和生成图像质量评价(tensorflow实现)

    一.背景 之前在做GAN主要是关注GAN的应用,找了一些比较好的例子实现了下,后面还会持续做这方面的工作.今天来看看DRAGAN对于GAN中一些问题的处理方法,也为今后这方面的研究做一部分基础工作吧, ...

  6. 对抗生成网络_深度卷积生成对抗网络

    本教程演示了如何使用深度卷积生成对抗网络(DCGAN)生成手写数字图片.该代码是使用 Keras Sequential API 与 tf.GradientTape 训练循环编写的. 什么是生成对抗网络 ...

  7. 对抗生成网络学习(七)——SRGAN生成超分辨率影像(tensorflow实现)

    一.背景 SRGAN(Super-Resolution Generative Adversarial Network)即超分辨率GAN,是Christian Ledig等人于16年9月提出的一种对抗神 ...

  8. 对抗生成网络学习(九)——CartoonGAN+爬虫生成《言叶之庭》风格的影像(tensorflow实现)

    一.背景 cartoonGAN是Yang Chen等人于2018年2月提出的一种模型.该模型针对漫画风格图像生成做了进一步研究,提出了新的GAN网络结构和两种损失函数,相较于之前的漫画风格生成的GAN ...

  9. GAN对抗生成网络学习笔记(三)DCGAN原理

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一.DCGAN简介 1.1 DCGAN的特点 二.几个重要概念 2.1 下采样(SubSampled) 2.2 上采样 ...

最新文章

  1. 各种的jsp数据库连接方法代码!(以前收集的)
  2. 开发工具之pycharm 快捷键说明
  3. Customization larbin
  4. 中国中老年化妆品行业消费需求现状与产销规模前景展望报告2022年
  5. 1、leetcode437 路和总径3
  6. s4-8 虚拟局域网
  7. 一文读懂 AVL 树
  8. mocha 测试 mysql_node项目mocha自动化测试的疑问
  9. node.js打包环境部署CentOS7.4
  10. 视+AR正式发布EasyAR引擎2.0版,并宣布开放AR相机平台
  11. HTML、 CSS、 JavaScript三者的关系
  12. Deep Learning资源搜集
  13. 集成学习-幸福感预测案例分析
  14. android xml 注释快捷键,xml注释(xml注释掉一段代码)
  15. IDEA------自动导包快捷键
  16. SQL中模式的定义和删除
  17. 用python turtle画画草地天空星星花朵小草
  18. 谷歌股价跌的越多,我们买的越多
  19. 希尔顿与锦江集团续签合作协议,将在华开逾600家希尔顿欢朋酒店
  20. USBTO232的几个问题,乱码,回车无效,驱动安装

热门文章

  1. future java get_关于 Future get方法的疑问
  2. 关于在GET请求中使用body
  3. Ae效果控件快速参考:抠像
  4. Halcon椭圆测量
  5. WIn10 1909 Windows Hello 指纹:出现错误,请稍后再试一次
  6. asponse.word按模板导出word文档
  7. APP是怎么做出来的呢?
  8. 辉芒微FT61F023,FT61F011A
  9. 罗技M590优联无法使用的问题解决
  10. GraphSAGE NIPS 2017 代码分析(Tensorflow版)