一、背景

前一阶段比较忙,很久没有继续做GAN的实验了。近期终于抽空做完了infoGAN,个人认为infoGAN是对GAN的更进一步改进,由于GAN是输入的随机生成噪声,所以生成的图像也是随机的,而infoGAN想要生成的是指定特征的图像,因此infoGAN对GAN的随机输入加了约束,这是其最大的改进之处。infoGAN是16年6月份由Xi Chen等人提出的一种模型。本实验主要利用infoGAN生成宽窄不一,高低各异的服装影像。

本实验以fashion-mnist数据集为例,用尽可能少的代码实现infoGAN。

[1]文章链接:https://arxiv.org/abs/1606.03657

二、infoGAN原理

infoGAN的原理网上介绍的也比较多,这里不再过多叙述。推荐一篇对原理讲解比较清楚的文章:

[2]InfoGAN介绍

从文章中作者的介绍来看,infoGAN最主要的贡献是引入了互信息(mutual information),通过最大化(maximizing)GAN噪声变量子集和观测值之间的互信息,以实现对学习过程的可解译性。作者将实验应用于MNIST, CelebA,SVHN数据集,结果表明引入互信息的模型都能够取得更好的效果。

In this paper, we present a simple modification to the generative adversarial network objective that encourages it to learn interpretable and meaningful representations. We do so by maximizing the mutual information between a fixed small subset of the GAN’s noise variables and the observations, which turns out to be relatively straightforward. Despite its simplicity, we found our method to be surprisingly effective: it was able to discover highly semantic and meaningful hidden representations on a number of image datasets: digits (MNIST), faces (CelebA), and house numbers (SVHN). The quality of our unsupervised disentangled representation matches previous works that made use of supervised label information [5–9]. These results suggest that generative modelling augmented with a mutual information cost could be a fruitful approach for learning disentangled representations.

通俗一点来说,“GAN模型在生成器使用噪声z的时候没有加任何的限制,所以在以一种高度混合的方式使用z,z的任何一个维度都没有明显的表示一个特征,所以在数据生成过程中,我们无法得知什么样的噪声z可以用来生成数字1,什么样的噪声z可以用来生成数字3,我们对这些一无所知,这从一点程度上限制了我们对GAN的使用[2]”。而作者对infoGAN的改进便针对这个问题,“Info代表互信息,它表示生成数据x与隐藏编码c之间关联程度的大小,为了使的x与c之间关联密切,所以我们需要最大化互信息的值,据此对原始GAN模型的值函数做了一点修改,相当于加了一个互信息的正则化项。是一个超参,通过之后的实验选择了一个最优值1。[2]”

后文中作者也提到了,infoGAN主要针对GAN的问题进行了改进,采用的模型基础仍是DCGAN,因此infoGAN的具体实现过程可以参照DCGAN。文章中作者进行了很多公式的推导,得出的最终结论为:

Hence, InfoGAN is defined as the following minimax game with a variational regularization of mutual information and a hyperparameter λ:

不用理解上述公式也没关系,我们只要知道了infoGAN主要做了哪方面的改进就行。关于infoGAN的实现代码,网上也比较多,下面给出几个比较好的代码:

[3]https://github.com/openai/InfoGAN

[4]https://github.com/AndyHsiao26/InfoGAN-tensorflow

[5]https://github.com/hwalsuklee/tensorflow-generative-model-collections

本实验的目的就在于用最少,最简单的代码实现infoGAN,主要参考了[5]的实现过程,并在原代码的基础上进行了少量改进,只不过这次保留了原代码中的类,而我之前的GAN实现都是尽量写成函数的形式。

三、infoGAN实现

1.数据准备

这次的实验数据采用的是fashion-mnist数据集,顾名思义,该数据集与mnist数据集的格式相同,只不过该数据集是10类服饰,但图像仍是28*28的灰度图:

该数据集的下载地址为:https://github.com/zalandoresearch/fashion-mnist

打开上述地址,找到下面的数据集,点击download即可开始下载:

下载好的数据集,我们放在'./data/fashion-mnist/'文件夹下,这样就准备好了数据,不用解压,下面即可开始实验部分:

2.数据操作函数准备(utils.py)

这一部分主要准备一些数据操作函数,包括数据的加载,图像的存储,文件夹的建立等函数。这部分函数DCGAN中也能用到,因此直接将其拷贝过来,进行后续的使用。这里将该文件命名为utils.py,直接给出该文件的代码:

"""
Most codes from https://github.com/carpedm20/DCGAN-tensorflow
"""
from __future__ import division
import scipy.misc
import numpy as np
import os, gzipimport tensorflow as tf
import tensorflow.contrib.slim as slimdef load_mnist(dataset_name):data_dir = 'data/' + dataset_namedef extract_data(filename, num_data, head_size, data_size):with gzip.open(filename) as bytestream:bytestream.read(head_size)buf = bytestream.read(data_size * num_data)data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)return datadata = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)trX = data.reshape((60000, 28, 28, 1))data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)trY = data.reshape((60000))data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)teX = data.reshape((10000, 28, 28, 1))data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)teY = data.reshape((10000))trY = np.asarray(trY)teY = np.asarray(teY)X = np.concatenate((trX, teX), axis=0)y = np.concatenate((trY, teY), axis=0).astype(np.int)seed = 547np.random.seed(seed)np.random.shuffle(X)np.random.seed(seed)np.random.shuffle(y)y_vec = np.zeros((len(y), 10), dtype=np.float)for i, label in enumerate(y):y_vec[i, y[i]] = 1.0return X / 255., y_vecdef check_folder(log_dir):if not os.path.exists(log_dir):os.makedirs(log_dir)return log_dirdef show_all_variables():model_vars = tf.trainable_variables()slim.model_analyzer.analyze_vars(model_vars, print_info=True)def save_images(images, size, image_path):return imsave(inverse_transform(images), size, image_path)def merge(images, size):h, w = images.shape[1], images.shape[2]if (images.shape[3] in (3,4)):c = images.shape[3]img = np.zeros((h * size[0], w * size[1], c))for idx, image in enumerate(images):i = idx % size[1]j = idx // size[1]img[j * h:j * h + h, i * w:i * w + w, :] = imagereturn imgelif images.shape[3]==1:img = np.zeros((h * size[0], w * size[1]))for idx, image in enumerate(images):i = idx % size[1]j = idx // size[1]img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]return imgelse:raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')def imsave(images, size, path):image = np.squeeze(merge(images, size))return scipy.misc.imsave(path, image)def inverse_transform(images):return (images+1.)/2.

3.图层函数(layers.py)

这一部分主要是编写图层函数,并将这部分函数保存到layers.py文件当中,这里直接给出layers.py的代码:

"""
Most codes from https://github.com/carpedm20/DCGAN-tensorflow
"""
from utils import *if "concat_v2" in dir(tf):def concat(tensors, axis, *args, **kwargs):return tf.concat_v2(tensors, axis, *args, **kwargs)
else:def concat(tensors, axis, *args, **kwargs):return tf.concat(tensors, axis, *args, **kwargs)def bn(x, is_training, scope):return tf.contrib.layers.batch_norm(x,decay=0.9,updates_collections=None,epsilon=1e-5,scale=True,is_training=is_training,scope=scope)def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"):with tf.variable_scope(name):w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],initializer=tf.truncated_normal_initializer(stddev=stddev))conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())return convdef deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, name="deconv2d", stddev=0.02, with_w=False):with tf.variable_scope(name):# filter : [height, width, output_channels, in_channels]w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],initializer=tf.random_normal_initializer(stddev=stddev))try:deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1])# Support for verisons of TensorFlow before 0.7.0except AttributeError:deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1])biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())if with_w:return deconv, w, biaseselse:return deconvdef lrelu(x, leak=0.2, name="lrelu"):return tf.maximum(x, leak*x)def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):shape = input_.get_shape().as_list()with tf.variable_scope(scope or "Linear"):matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,tf.random_normal_initializer(stddev=stddev))bias = tf.get_variable("bias", [output_size],initializer=tf.constant_initializer(bias_start))if with_w:return tf.matmul(input_, matrix) + bias, matrix, biaselse:return tf.matmul(input_, matrix) + bias

4.infoGAN的模型实现(infoGAN.py)

这一部分主要编写infoGAN的实现,包括参数,生成器和判别器,以及loss函数,这部分的代码比较长,回头我在对其进行详细的解释。最终infoGAN的代码为:

from __future__ import division
import timefrom layers import *
from utils import *class infoGAN(object):model_name = "infoGAN"     # name for checkpointdef __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir, SUPERVISED=True):self.sess = sessself.dataset_name = dataset_nameself.checkpoint_dir = checkpoint_dirself.result_dir = result_dirself.log_dir = log_dirself.epoch = epochself.batch_size = batch_sizeif dataset_name == 'mnist' or dataset_name == 'fashion-mnist':# parametersself.input_height = 28self.input_width = 28self.output_height = 28self.output_width = 28self.z_dim = z_dim         # dimension of noise-vectorself.y_dim = 12         # dimension of code-vector (label+two features)self.c_dim = 1self.SUPERVISED = SUPERVISED # if it is true, label info is directly used for code# trainself.learning_rate = 0.0002self.beta1 = 0.5# testself.sample_num = 64  # number of generated images to be saved# codeself.len_discrete_code = 10  # categorical distribution (i.e. label)self.len_continuous_code = 2  # gaussian distribution (e.g. rotation, thickness)# load mnistself.data_X, self.data_y = load_mnist(self.dataset_name)# get number of batches for a single epochself.num_batches = len(self.data_X) // self.batch_sizeelse:raise NotImplementedErrordef classifier(self, x, is_training=True, reuse=False):# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)# Architecture : (64)5c2s-(128)5c2s_BL-FC1024_BL-FC128_BL-FC12S’# All layers except the last two layers are shared by discriminator# Number of nodes in the last layer is reduced by half. It gives better results.with tf.variable_scope("classifier", reuse=reuse):net = lrelu(bn(linear(x, 64, scope='c_fc1'), is_training=is_training, scope='c_bn1'))out_logit = linear(net, self.y_dim, scope='c_fc2')out = tf.nn.softmax(out_logit)return out, out_logitdef discriminator(self, x, is_training=True, reuse=False):# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_Swith tf.variable_scope("discriminator", reuse=reuse):net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))net = tf.reshape(net, [self.batch_size, -1])net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))out_logit = linear(net, 1, scope='d_fc4')out = tf.nn.sigmoid(out_logit)return out, out_logit, netdef generator(self, z, y, is_training=True, reuse=False):# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_Swith tf.variable_scope("generator", reuse=reuse):# merge noise and codez = concat([z, y], 1)net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))net = tf.reshape(net, [self.batch_size, 7, 7, 128])net = tf.nn.relu(bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,scope='g_bn3'))out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))return outdef build_model(self):# some parametersimage_dims = [self.input_height, self.input_width, self.c_dim]bs = self.batch_size""" Graph Input """# imagesself.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')# labelsself.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y')# noisesself.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')""" Loss Function """## 1. GAN Loss# output of D for real imagesD_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False)# output of D for fake imagesG = self.generator(self.z, self.y, is_training=True, reuse=False)D_fake, D_fake_logits, input4classifier_fake = self.discriminator(G, is_training=True, reuse=True)# get loss for discriminatord_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real)))d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake)))self.d_loss = d_loss_real + d_loss_fake# get loss for generatorself.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake)))## 2. Information Losscode_fake, code_logit_fake = self.classifier(input4classifier_fake, is_training=True, reuse=False)# discrete code : categoricaldisc_code_est = code_logit_fake[:, :self.len_discrete_code]disc_code_tg = self.y[:, :self.len_discrete_code]q_disc_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=disc_code_est, labels=disc_code_tg))# continuous code : gaussiancont_code_est = code_logit_fake[:, self.len_discrete_code:]cont_code_tg = self.y[:, self.len_discrete_code:]q_cont_loss = tf.reduce_mean(tf.reduce_sum(tf.square(cont_code_tg - cont_code_est), axis=1))# get information lossself.q_loss = q_disc_loss + q_cont_loss""" Training """# divide trainable variables into a group for D and a group for Gt_vars = tf.trainable_variables()d_vars = [var for var in t_vars if 'd_' in var.name]g_vars = [var for var in t_vars if 'g_' in var.name]q_vars = [var for var in t_vars if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)]# optimizerswith tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \.minimize(self.d_loss, var_list=d_vars)self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \.minimize(self.g_loss, var_list=g_vars)self.q_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \.minimize(self.q_loss, var_list=q_vars)"""" Testing """# for testself.fake_images = self.generator(self.z, self.y, is_training=False, reuse=True)""" Summary """d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)q_loss_sum = tf.summary.scalar("g_loss", self.q_loss)q_disc_sum = tf.summary.scalar("q_disc_loss", q_disc_loss)q_cont_sum = tf.summary.scalar("q_cont_loss", q_cont_loss)# final summary operationsself.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])self.q_sum = tf.summary.merge([q_loss_sum, q_disc_sum, q_cont_sum])def train(self):# initialize all variablestf.global_variables_initializer().run()# graph inputs for visualize training resultsself.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))self.test_labels = self.data_y[0:self.batch_size]self.test_codes = np.concatenate((self.test_labels, np.zeros([self.batch_size, self.len_continuous_code])),axis=1)# saver to save modelself.saver = tf.train.Saver()# summary writerself.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)# restore check-point if it exitscould_load, checkpoint_counter = self.load(self.checkpoint_dir)if could_load:start_epoch = (int)(checkpoint_counter / self.num_batches)start_batch_id = checkpoint_counter - start_epoch * self.num_batchescounter = checkpoint_counterprint(" [*] Load SUCCESS")else:start_epoch = 0start_batch_id = 0counter = 1print(" [!] Load failed...")# loop for epochstart_time = time.time()for epoch in range(start_epoch, self.epoch):# get batch datafor idx in range(start_batch_id, self.num_batches):batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]# generate codeif self.SUPERVISED == True:batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size]else:batch_labels = np.random.multinomial(1,self.len_discrete_code * [float(1.0 / self.len_discrete_code)],size=[self.batch_size])batch_codes = np.concatenate((batch_labels, np.random.uniform(-1, 1, size=(self.batch_size, 2))),axis=1)batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)# update D network_, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss],feed_dict={self.inputs: batch_images, self.y: batch_codes,self.z: batch_z})self.writer.add_summary(summary_str, counter)# update G and Q network_, summary_str_g, g_loss, _, summary_str_q, q_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss, self.q_optim, self.q_sum, self.q_loss],feed_dict={self.inputs: batch_images, self.z: batch_z, self.y: batch_codes})self.writer.add_summary(summary_str_g, counter)self.writer.add_summary(summary_str_q, counter)# display training statuscounter += 1print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))# save training results for every 300 stepsif np.mod(counter, 300) == 0:samples = self.sess.run(self.fake_images,feed_dict={self.z: self.sample_z, self.y: self.test_codes})tot_num_samples = min(self.sample_num, self.batch_size)manifold_h = int(np.floor(np.sqrt(tot_num_samples)))manifold_w = int(np.floor(np.sqrt(tot_num_samples)))save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, idx))# After an epoch, start_batch_id is set to zero# non-zero value is only for the first epoch after loading pre-trained modelstart_batch_id = 0# save modelself.save(self.checkpoint_dir, counter)# show temporal resultsself.visualize_results(epoch)# save model for final stepself.save(self.checkpoint_dir, counter)def visualize_results(self, epoch):tot_num_samples = min(self.sample_num, self.batch_size)image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))""" random noise, random discrete code, fixed continuous code """y = np.random.choice(self.len_discrete_code, self.batch_size)y_one_hot = np.zeros((self.batch_size, self.y_dim))y_one_hot[np.arange(self.batch_size), y] = 1z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')""" specified condition, random noise """n_styles = 10  # must be less than or equal to self.batch_sizenp.random.seed()si = np.random.choice(self.batch_size, n_styles)for l in range(self.len_discrete_code):y = np.zeros(self.batch_size, dtype=np.int64) + ly_one_hot = np.zeros((self.batch_size, self.y_dim))y_one_hot[np.arange(self.batch_size), y] = 1samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})# save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],#             check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)samples = samples[si, :, :, :]if l == 0:all_samples = sampleselse:all_samples = np.concatenate((all_samples, samples), axis=0)""" save merged images to check style-consistency """canvas = np.zeros_like(all_samples)for s in range(n_styles):for c in range(self.len_discrete_code):canvas[s * self.len_discrete_code + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]save_images(canvas, [n_styles, self.len_discrete_code],check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')""" fixed noise """assert self.len_continuous_code == 2c1 = np.linspace(-1, 1, image_frame_dim)c2 = np.linspace(-1, 1, image_frame_dim)xv, yv = np.meshgrid(c1, c2)xv = xv[:image_frame_dim,:image_frame_dim]yv = yv[:image_frame_dim, :image_frame_dim]c1 = xv.flatten()c2 = yv.flatten()z_fixed = np.zeros([self.batch_size, self.z_dim])for l in range(self.len_discrete_code):y = np.zeros(self.batch_size, dtype=np.int64) + ly_one_hot = np.zeros((self.batch_size, self.y_dim))y_one_hot[np.arange(self.batch_size), y] = 1y_one_hot[np.arange(image_frame_dim*image_frame_dim), self.len_discrete_code] = c1y_one_hot[np.arange(image_frame_dim*image_frame_dim), self.len_discrete_code+1] = c2samples = self.sess.run(self.fake_images,feed_dict={ self.z: z_fixed, self.y: y_one_hot})save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_c1c2_%d.png' % l)@propertydef model_dir(self):return "{}_{}_{}_{}".format(self.model_name, self.dataset_name,self.batch_size, self.z_dim)def save(self, checkpoint_dir, step):checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)def load(self, checkpoint_dir):import reprint(" [*] Reading checkpoints...")checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:ckpt_name = os.path.basename(ckpt.model_checkpoint_path)self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))print(" [*] Success to read {}".format(ckpt_name))return True, counterelse:print(" [*] Failed to find a checkpoint")return False, 0

5.main函数实现模型训练(main.py)

最终就是训练模型过程了,这了main.py文件的主要工作就是设置参数,创建infoGAN模型,进行模型的训练,main.py文件的主要代码为:

from infoGAN import infoGANimport tensorflow as tf"""main"""
def main():with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:# declare instance for GANinfogan = infoGAN(sess,epoch=20,batch_size=64,z_dim=62,dataset_name='fashion-mnist',checkpoint_dir='checkpoint',result_dir='results',log_dir='logs')# build graphinfogan.build_model()# show network architecture# show_all_variables()# launch the graph in a sessioninfogan.train()print(" [*] Training finished!")# visualize learned generatorinfogan.visualize_results(20-1)print(" [*] Testing finished!")if __name__ == '__main__':main()

6.模型的执行

模型的执行只需运行main.py文件即可,然后等模型训练完毕即可查看模型的结果。

四、实验结果

实验一共设置了20个epoch,训练效果比较好。这里直接展示关于衣服的训练结果。

当epoch=1时候,实验的结果为:

当epoch=5的时候,实验的结果为:

当epoch=10的时候,实验的结果为:

当epoch=20的时候,实验的结果为:

五、分析

1.实验的结果还是比较好的,即使epoch=1,也能够比较清晰的看出最后的生成服饰图像,而且也能够明显的看到生成的服饰宽窄各异的衣服。

2.所有文件的结构为:

-- data            (原始数据集的文件夹)|------ fashion-mnist|------ t10k-images-idx3-ubyte.gz|------ t10k-labels-idx1-ubyte.gz|------ train-images-idx3-ubyte.gz|------ train-labels-idx1-ubyte.gz
-- utils.py{import...def load_mnist(dataset_name):...def check_folder(log_dir):...def show_all_variables():...def save_images(images, size, image_path):...def merge(images, size):...def imsave(images, size, path):...def inverse_transform(images):...}
-- layers.py{import...def bn(x, is_training, scope):...def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"):...def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, name="deconv2d", stddev=0.02, with_w=False):...def lrelu(x, leak=0.2, name="lrelu"):...def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):...}
-- infoGAN.py{import...class infoGAN(object):def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir, SUPERVISED=True):...def classifier(self, x, is_training=True, reuse=False):...def discriminator(self, x, is_training=True, reuse=False):...def generator(self, z, y, is_training=True, reuse=False):...def build_model(self):...def train(self):...def visualize_results(self, epoch):...def model_dir(self):...def save(self, checkpoint_dir, step):...def load(self, checkpoint_dir):...}
-- main.py{import ...def main():...if __name__ == '__main__':...}   

对抗生成网络学习(五)——infoGAN生成宽窄不一,高低各异的服装影像(tensorflow实现)相关推荐

  1. 对抗生成网络学习(十三)——conditionalGAN生成自己想要的手写数字(tensorflow实现)

    一.背景 其实我原本是不打算做这个模型,因为conditionalGAN能做的,infoGAN也能做,infoGAN我在之前的文章中写到了:对抗神经网络学习(五)--infoGAN生成宽窄不一,高低各 ...

  2. 对抗生成网络学习(十五)——starGAN实现人脸属性修改(tensorflow实现)

    一.背景 最近事情比较多,一个多月没写CSDN了,最近打算做一做satrGAN. starGAN是Yunjey Choi等人于17年11月提出的一个模型[1].该模型可以实现人脸的属性修改,原理上来说 ...

  3. 对抗生成网络学习(十)——attentiveGAN实现影像去雨滴的过程(tensorflow实现)

    一.背景 attentiveGAN是Rui Qian等人于17年11月份提出的一种模型.<Attentive Generative Adversarial Network for Raindro ...

  4. 对抗生成网络_深度卷积生成对抗网络

    本教程演示了如何使用深度卷积生成对抗网络(DCGAN)生成手写数字图片.该代码是使用 Keras Sequential API 与 tf.GradientTape 训练循环编写的. 什么是生成对抗网络 ...

  5. 对抗生成网络学习(十一)——SAGAN生成更为精细的人脸图像(tensorflow实现)

    一.背景 SAGAN全称为Self-Attention Generative Adversarial Networks,是由Han Zhang等人[1]于18年5月提出的一种模型.文章中作者解释到,传 ...

  6. 对抗生成网络学习(十四)——DRAGAN对模型倒塌问题的处理和生成图像质量评价(tensorflow实现)

    一.背景 之前在做GAN主要是关注GAN的应用,找了一些比较好的例子实现了下,后面还会持续做这方面的工作.今天来看看DRAGAN对于GAN中一些问题的处理方法,也为今后这方面的研究做一部分基础工作吧, ...

  7. 对抗生成网络学习(七)——SRGAN生成超分辨率影像(tensorflow实现)

    一.背景 SRGAN(Super-Resolution Generative Adversarial Network)即超分辨率GAN,是Christian Ledig等人于16年9月提出的一种对抗神 ...

  8. 对抗生成网络学习(九)——CartoonGAN+爬虫生成《言叶之庭》风格的影像(tensorflow实现)

    一.背景 cartoonGAN是Yang Chen等人于2018年2月提出的一种模型.该模型针对漫画风格图像生成做了进一步研究,提出了新的GAN网络结构和两种损失函数,相较于之前的漫画风格生成的GAN ...

  9. 对抗生成网络学习(四)——WGAN+爬虫生成皮卡丘图像(tensorflow实现)

    一.背景 WGAN的全称为Wasserstein GAN, 是Martin Arjovsky等人于17年1月份提出的一个模型,该文章可以参考[1].WGAN针对GAN存在的问题进行了有针对性的改进,但 ...

  10. GAN对抗生成网络学习笔记(三)DCGAN原理

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一.DCGAN简介 1.1 DCGAN的特点 二.几个重要概念 2.1 下采样(SubSampled) 2.2 上采样 ...

最新文章

  1. Spring MVC-表单(Form)标签-单选按钮集合(RadioButtons)示例(转载实践)
  2. Blazor将.NET带回到浏览器
  3. SourceForge 停止在被遗弃项目捆绑第三方软件
  4. mysql外键教程_MySQL外键使用详解
  5. Hi3516A开发--根文件系统
  6. Func 与Action
  7. npm上传自己的项目
  8. Spring学习总结(23)——Spring Framework 5.0 新特性
  9. row_number() over使用方法
  10. php 自定义条件,php如何自定义一个方法
  11. tensorflow sigmoid 如何计算训练数据的正确率_初探 TensorFlow.js
  12. 人工智能的发展方向与机遇
  13. 初识微服务之Eureka
  14. php杂谈【基础篇】之_7.PHP涉及的所有英文单词
  15. 超好用的pdf编辑+pdf转word工具 – Adobe Acrobat Pro DC下载
  16. 苹果xr如何截屏_iphone敲两下截屏如何操作 苹果手机触控截屏方法【教程步骤】...
  17. Serdes系列总结——Xilinx serdes IP使用(二)——10G serdes
  18. 一个简单的软件测试流程
  19. 【微信抢红包】红包助手-修改版
  20. 从网红店到家居设计,“Ins风”正在无孔不入

热门文章

  1. python导入鸢尾花数据集_python 鸢尾花数据集报表展示
  2. 网络中的常见的各种协议--报文格式总结学习
  3. List集合的各种排序
  4. java实习鉴定书个人鉴定_大学生实习鉴定表自我鉴定范文
  5. 利用tensorflow加载VGG19
  6. CxImage功能强大的图形处理程序
  7. 华为星环大数据_大数据平台-华为和星环
  8. python 分割不等长字符串表格_python如何将字符串等长分割
  9. 网闸端口限制时,用HaneWin NFS Server来部署单一接口来交互,实现挂载便于访问
  10. EXCEL中实现经纬度距离计算、高斯坐标转换、GIS数据导出等功能