训练数据集:填充轮廓->建筑照片

下载链接:https://pan.baidu.com/s/1xUg8AC7NEXyKebSUNtRvdg 密码:2kw1

CGAN是Conditional Generative Adversarial Nets的缩写,也称为条件生成对抗网络。条件生成对抗网络指的是在生成对抗网络中加入条件(condition),条件的作用是监督生成对抗网络。本篇博客通过简单代码搭建,向大家解析了条件生成对抗网络CGAN。

在开始解析CGAN代码之前,笔者想说的是,要理解CGAN,还请大家先明了CGAN的原理,笔者在下面提供一些笔者认为比较好的了解CGAN原理的链接:

(1) 直接进行论文阅读:https://arxiv.org/abs/1411.1784

(2) 可以翻阅站内的一篇博客,笔者认为写得很不错:Conditional Generative Adversarial Nets论文笔记

(3) 笔者也简单解析一下CGAN的原理,原理图如下(截图来自CGAN论文)

如上图所示,和原始的生成对抗网络相比,条件生成对抗网络CGAN在生成器的输入和判别器的输入中都加入了条件y。这个y可以是任何类型的数据(可以是类别标签,或者其他类型的数据等)。目的是有条件地监督生成器生成的数据,使得生成器生成结果的方式不是完全自由无监督的。

CGAN训练的目标函数如下图所示:

从上面的目标函数中可以看到,条件y不仅被送入了判别器的输入中,也被融入了生成器的输入中。下面,笔者就来解析CGAN的代码,首先还是列举一下笔者主要使用的工具和库。

(1) Python 3.5.2

(2) numpy

(3) Tensorflow 1.2

(4) argparse 用来解析命令行参数

(5) random 用来打乱输入顺序

(6) os 用来读取图片路径和文件名

(7) glob 用来读取图片路径和文件名

(8) cv2 用来读取图片

笔者搭建的CGAN代码分成4大部分,分别是:

(1) train.py 训练的主控程序

(2) image_reader.py 数据读取接口

(3) net.py 定义网络结构

(4) evaluate.py 测试的主控程序

其中,训练时使用到的文件是(1),(2),(3)项,测试时使用到的文件时(2),(3),(4)。

下面,笔者放出代码与注释:

首先是train.py文件中的代码:

from __future__ import print_functionimport argparse
from random import shuffle
import random
import os
import sys
import math
import tensorflow as tf
import glob
import cv2from image_reader import *
from net import *parser = argparse.ArgumentParser(description='')parser.add_argument("--snapshot_dir", default='./snapshots', help="path of snapshots") #保存模型的路径
parser.add_argument("--out_dir", default='./train_out', help="path of train outputs") #训练时保存可视化输出的路径
parser.add_argument("--image_size", type=int, default=256, help="load image size") #网络输入的尺度
parser.add_argument("--random_seed", type=int, default=1234, help="random seed") #随机数种子
parser.add_argument('--base_lr', type=float, default=0.0002, help='initial learning rate for adam') #学习率
parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')  #训练的epoch数量
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') #adam优化器的beta1参数
parser.add_argument("--summary_pred_every", type=int, default=200, help="times to summary.") #训练中每过多少step保存训练日志(记录一下loss值)
parser.add_argument("--write_pred_every", type=int, default=100, help="times to write.") #训练中每过多少step保存可视化结果
parser.add_argument("--save_pred_every", type=int, default=5000, help="times to save.") #训练中每过多少step保存模型(可训练参数)
parser.add_argument("--lamda_l1_weight", type=float, default=0.0, help="L1 lamda") #训练中L1_Loss前的乘数
parser.add_argument("--lamda_gan_weight", type=float, default=1.0, help="GAN lamda") #训练中GAN_Loss前的乘数
parser.add_argument("--train_picture_format", default='.png', help="format of training datas.") #网络训练输入的图片的格式(图片在CGAN中被当做条件)
parser.add_argument("--train_label_format", default='.jpg', help="format of training labels.") #网络训练输入的标签的格式(标签在CGAN中被当做真样本)
parser.add_argument("--train_picture_path", default='./dataset/train_picture/', help="path of training datas.") #网络训练输入的图片路径
parser.add_argument("--train_label_path", default='./dataset/train_label/', help="path of training labels.") #网络训练输入的标签路径args = parser.parse_args() #用来解析命令行参数
EPS = 1e-12 #EPS用于保证log函数里面的参数大于零def save(saver, sess, logdir, step): #保存模型的save函数model_name = 'model' #保存的模型名前缀checkpoint_path = os.path.join(logdir, model_name) #模型的保存路径与名称if not os.path.exists(logdir): #如果路径不存在即创建os.makedirs(logdir)saver.save(sess, checkpoint_path, global_step=step) #保存模型print('The checkpoint has been created.')def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图img_rgb = (img + 1.) * 127.5return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到训练过程中的可视化结果picture_image = cv_inv_proc(picture) #还原输入的图像gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的样本label_image = cv_inv_proc(label) #还原真实的样本(标签)inv_picture_image = cv2.resize(picture_image, (width, height)) #还原图像的尺寸inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #还原生成的样本的尺寸inv_label_image = cv2.resize(label_image, (width, height)) #还原真实的样本的尺寸output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #把他们拼起来return outputdef l1_loss(src, dst): #定义l1_lossreturn tf.reduce_mean(tf.abs(src - dst))def main(): #训练程序的主函数if not os.path.exists(args.snapshot_dir): #如果保存模型参数的文件夹不存在则创建os.makedirs(args.snapshot_dir)if not os.path.exists(args.out_dir): #如果保存训练中可视化输出的文件夹不存在则创建os.makedirs(args.out_dir)train_picture_list = glob.glob(os.path.join(args.train_picture_path, "*")) #得到训练输入图像路径名称列表tf.set_random_seed(args.random_seed) #初始一下随机数train_picture = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_picture') #输入的训练图像train_label = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_label') #输入的与训练图像匹配的标签gen_label = generator(image=train_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的输出dis_real = discriminator(image=train_picture, targets=train_label, df_dim=64, reuse=False, name="discriminator") #判别器返回的对真实标签的判别结果dis_fake = discriminator(image=train_picture, targets=gen_label, df_dim=64, reuse=True, name="discriminator") #判别器返回的对生成(虚假的)标签判别结果gen_loss_GAN = tf.reduce_mean(-tf.log(dis_fake + EPS)) #计算生成器损失中的GAN_loss部分gen_loss_L1 = tf.reduce_mean(l1_loss(gen_label, train_label)) #计算生成器损失中的L1_loss部分gen_loss = gen_loss_GAN * args.lamda_gan_weight + gen_loss_L1 * args.lamda_l1_weight #计算生成器的lossdis_loss = tf.reduce_mean(-(tf.log(dis_real + EPS) + tf.log(1 - dis_fake + EPS))) #计算判别器的lossgen_loss_sum = tf.summary.scalar("gen_loss", gen_loss) #记录生成器loss的日志dis_loss_sum = tf.summary.scalar("dis_loss", dis_loss) #记录判别器loss的日志summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) #日志记录器g_vars = [v for v in tf.trainable_variables() if 'generator' in v.name] #所有生成器的可训练参数d_vars = [v for v in tf.trainable_variables() if 'discriminator' in v.name] #所有判别器的可训练参数d_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #判别器训练器g_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #生成器训练器d_grads_and_vars = d_optim.compute_gradients(dis_loss, var_list=d_vars) #计算判别器参数梯度d_train = d_optim.apply_gradients(d_grads_and_vars) #更新判别器参数g_grads_and_vars = g_optim.compute_gradients(gen_loss, var_list=g_vars) #计算生成器参数梯度g_train = g_optim.apply_gradients(g_grads_and_vars) #更新生成器参数train_op = tf.group(d_train, g_train) #train_op表示了参数更新操作config = tf.ConfigProto()config.gpu_options.allow_growth = True #设定显存不超量使用sess = tf.Session(config=config) #新建会话层init = tf.global_variables_initializer() #参数初始化器sess.run(init) #初始化所有可训练参数saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=50) #模型保存器counter = 0 #counter记录训练步数for epoch in range(args.epoch): #训练epoch数shuffle(train_picture_list) #每训练一个epoch,就打乱一下输入的顺序for step in range(len(train_picture_list)): #每个训练epoch中的训练step数counter += 1picture_name, _ = os.path.splitext(os.path.basename(train_picture_list[step])) #获取不包含路径和格式的输入图片名称#读取一张训练图片,一张训练标签,以及相应的高和宽picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name, picture_path=args.train_picture_path, label_path=args.train_label_path, picture_format = args.train_picture_format, label_format = args.train_label_format, size = args.image_size)batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis = 0) #填充维度batch_label = np.expand_dims(np.array(label_resize).astype(np.float32), axis = 0) #填充维度feed_dict = { train_picture : batch_picture, train_label : batch_label } #构造feed_dictgen_loss_value, dis_loss_value, _ = sess.run([gen_loss, dis_loss, train_op], feed_dict=feed_dict) #得到每个step中的生成器和判别器lossif counter % args.save_pred_every == 0: #每过save_pred_every次保存模型save(saver, sess, args.snapshot_dir, counter)if counter % args.summary_pred_every == 0: #每过summary_pred_every次保存训练日志gen_loss_sum_value, discriminator_sum_value = sess.run([gen_loss_sum, dis_loss_sum], feed_dict=feed_dict)summary_writer.add_summary(gen_loss_sum_value, counter)summary_writer.add_summary(discriminator_sum_value, counter)if counter % args.write_pred_every == 0: #每过write_pred_every次写一下训练的可视化结果gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #run出生成器的输出write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到训练的可视化结果write_image_name = args.out_dir + "/out"+ str(counter) + ".png" #待保存的训练可视化结果路径与名称cv2.imwrite(write_image_name, write_image) #保存训练的可视化结果print('epoch {:d} step {:d} \t gen_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, gen_loss_value, dis_loss_value))if __name__ == '__main__':main()

然后是image_reader.py文件:

import os
import numpy as np
import tensorflow as tf
import cv2#读取图片的函数,接收六个参数
#输入参数分别是图片名,图片路径,标签路径,图片格式,标签格式,需要调整的尺寸大小
def ImageReader(file_name, picture_path, label_path, picture_format = ".png", label_format = ".jpg", size = 256):picture_name = picture_path + file_name + picture_format #得到图片名称和路径label_name = label_path + file_name + label_format #得到标签名称和路径picture = cv2.imread(picture_name, 1) #读取图片label = cv2.imread(label_name, 1) #读取标签height = picture.shape[0] #得到图片的高width = picture.shape[1] #得到图片的宽picture_resize_t = cv2.resize(picture, (size, size)) #调整图片的尺寸,改变成网络输入的大小picture_resize = picture_resize_t / 127.5 - 1. #归一化图片label_resize_t = cv2.resize(label, (size, size)) #调整标签的尺寸,改变成网络输入的大小label_resize = label_resize_t / 127.5 - 1. #归一化标签return picture_resize, label_resize, height, width #返回网络输入的图片,标签,还有原图片和标签的长宽

接着是net.py文件:

import numpy as np
import tensorflow as tf
import math#构造可训练参数
def make_var(name, shape, trainable = True):return tf.get_variable(name, shape, trainable = trainable)#定义卷积层
def conv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "conv2d", biased = False):input_dim = input_.get_shape()[-1]with tf.variable_scope(name):kernel = make_var(name = 'weights', shape=[kernel_size, kernel_size, input_dim, output_dim])output = tf.nn.conv2d(input_, kernel, [1, stride, stride, 1], padding = padding)if biased:biases = make_var(name = 'biases', shape = [output_dim])output = tf.nn.bias_add(output, biases)return output#定义空洞卷积层
def atrous_conv2d(input_, output_dim, kernel_size, dilation, padding = "SAME", name = "atrous_conv2d", biased = False):input_dim = input_.get_shape()[-1]with tf.variable_scope(name):kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, input_dim, output_dim])output = tf.nn.atrous_conv2d(input_, kernel, dilation, padding = padding)if biased:biases = make_var(name = 'biases', shape = [output_dim])output = tf.nn.bias_add(output, biases)return output#定义反卷积层
def deconv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "deconv2d"):input_dim = input_.get_shape()[-1]input_height = int(input_.get_shape()[1])input_width = int(input_.get_shape()[2])with tf.variable_scope(name):kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, output_dim, input_dim])output = tf.nn.conv2d_transpose(input_, kernel, [1, input_height * 2, input_width * 2, output_dim], [1, 2, 2, 1], padding = "SAME")return output#定义batchnorm(批次归一化)层
def batch_norm(input_, name="batch_norm"):with tf.variable_scope(name):input_dim = input_.get_shape()[-1]scale = tf.get_variable("scale", [input_dim], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))offset = tf.get_variable("offset", [input_dim], initializer=tf.constant_initializer(0.0))mean, variance = tf.nn.moments(input_, axes=[1,2], keep_dims=True)epsilon = 1e-5inv = tf.rsqrt(variance + epsilon)normalized = (input_-mean)*invoutput = scale*normalized + offsetreturn output#定义lrelu激活层
def lrelu(x, leak=0.2, name = "lrelu"):return tf.maximum(x, leak*x)#定义生成器,采用UNet架构,主要由8个卷积层和8个反卷积层组成
def generator(image, gf_dim=64, reuse=False, name="generator"):input_dim = int(image.get_shape()[-1]) #获取输入通道dropout_rate = 0.5 #定义dropout的比例with tf.variable_scope(name):if reuse:tf.get_variable_scope().reuse_variables()else:assert tf.get_variable_scope().reuse is False#第一个卷积层,输出尺度[1, 128, 128, 64]e1 = batch_norm(conv2d(input_=image, output_dim=gf_dim, kernel_size=4, stride=2, name='g_e1_conv'), name='g_bn_e1')#第二个卷积层,输出尺度[1, 64, 64, 128]e2 = batch_norm(conv2d(input_=lrelu(e1), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_e2_conv'), name='g_bn_e2')#第三个卷积层,输出尺度[1, 32, 32, 256]e3 = batch_norm(conv2d(input_=lrelu(e2), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_e3_conv'), name='g_bn_e3')#第四个卷积层,输出尺度[1, 16, 16, 512]e4 = batch_norm(conv2d(input_=lrelu(e3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e4_conv'), name='g_bn_e4')#第五个卷积层,输出尺度[1, 8, 8, 512]e5 = batch_norm(conv2d(input_=lrelu(e4), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e5_conv'), name='g_bn_e5')#第六个卷积层,输出尺度[1, 4, 4, 512]e6 = batch_norm(conv2d(input_=lrelu(e5), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e6_conv'), name='g_bn_e6')#第七个卷积层,输出尺度[1, 2, 2, 512]e7 = batch_norm(conv2d(input_=lrelu(e6), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e7_conv'), name='g_bn_e7')#第八个卷积层,输出尺度[1, 1, 1, 512]e8 = batch_norm(conv2d(input_=lrelu(e7), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e8_conv'), name='g_bn_e8')#第一个反卷积层,输出尺度[1, 2, 2, 512]d1 = deconv2d(input_=tf.nn.relu(e8), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d1')d1 = tf.nn.dropout(d1, dropout_rate) #随机扔掉一般的输出d1 = tf.concat([batch_norm(d1, name='g_bn_d1'), e7], 3)#第二个反卷积层,输出尺度[1, 4, 4, 512]d2 = deconv2d(input_=tf.nn.relu(d1), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d2')d2 = tf.nn.dropout(d2, dropout_rate) #随机扔掉一般的输出d2 = tf.concat([batch_norm(d2, name='g_bn_d2'), e6], 3)#第三个反卷积层,输出尺度[1, 8, 8, 512]d3 = deconv2d(input_=tf.nn.relu(d2), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d3')d3 = tf.nn.dropout(d3, dropout_rate) #随机扔掉一般的输出d3 = tf.concat([batch_norm(d3, name='g_bn_d3'), e5], 3)#第四个反卷积层,输出尺度[1, 16, 16, 512]d4 = deconv2d(input_=tf.nn.relu(d3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d4')d4 = tf.concat([batch_norm(d4, name='g_bn_d4'), e4], 3)#第五个反卷积层,输出尺度[1, 32, 32, 256]d5 = deconv2d(input_=tf.nn.relu(d4), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_d5')d5 = tf.concat([batch_norm(d5, name='g_bn_d5'), e3], 3)#第六个反卷积层,输出尺度[1, 64, 64, 128]d6 = deconv2d(input_=tf.nn.relu(d5), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_d6')d6 = tf.concat([batch_norm(d6, name='g_bn_d6'), e2], 3)#第七个反卷积层,输出尺度[1, 128, 128, 64]d7 = deconv2d(input_=tf.nn.relu(d6), output_dim=gf_dim, kernel_size=4, stride=2, name='g_d7')d7 = tf.concat([batch_norm(d7, name='g_bn_d7'), e1], 3)#第八个反卷积层,输出尺度[1, 256, 256, 3]d8 = deconv2d(input_=tf.nn.relu(d7), output_dim=input_dim, kernel_size=4, stride=2, name='g_d8')return tf.nn.tanh(d8)#定义判别器
def discriminator(image, targets, df_dim=64, reuse=False, name="discriminator"):with tf.variable_scope(name):if reuse:tf.get_variable_scope().reuse_variables()else:assert tf.get_variable_scope().reuse is Falsedis_input = tf.concat([image, targets], 3)#第1个卷积模块,输出尺度: 1*128*128*64h0 = lrelu(conv2d(input_ = dis_input, output_dim = df_dim, kernel_size = 4, stride = 2, name='d_h0_conv'))#第2个卷积模块,输出尺度: 1*64*64*128h1 = lrelu(batch_norm(conv2d(input_ = h0, output_dim = df_dim*2, kernel_size = 4, stride = 2, name='d_h1_conv'), name='d_bn1'))#第3个卷积模块,输出尺度: 1*32*32*256h2 = lrelu(batch_norm(conv2d(input_ = h1, output_dim = df_dim*4, kernel_size = 4, stride = 2, name='d_h2_conv'), name='d_bn2'))#第4个卷积模块,输出尺度: 1*32*32*512h3 = lrelu(batch_norm(conv2d(input_ = h2, output_dim = df_dim*8, kernel_size = 4, stride = 1, name='d_h3_conv'), name='d_bn3'))#最后一个卷积模块,输出尺度: 1*32*32*1output = conv2d(input_ = h3, output_dim = 1, kernel_size = 4, stride = 1, name='d_h4_conv')dis_out = tf.sigmoid(output) #在输出之前经过sigmoid层,因为需要进行log运算return dis_out

上面就是训练所需的全部代码,大家可以看到,在net.py文件中。生成器使用UNet结构,在生成器和判别器中,image参数就是指的条件,并且在生成器的输入中,随机噪声被去掉了(仅仅输入了条件);在判别器的输入中,条件和待判别的图像被拼接(concat)了起来。

如果需要开启训练,可以调整train.py中的最后四个参数,根据自己的需求调整训练输入的图片和标签文件路径和相应的格式。另外,由于CGAN训练中需要匹配条件与判别图片,因此,训练读取的图片和标签名称应该是匹配的,在image_reader.py中也能看到,程序是按照同一个名称,去检索训练一个批次输入的图像和对应的标签。

下面是evaluate.py文件:

import argparse
import sys
import math
import tensorflow as tf
import numpy as np
import glob
import cv2from image_reader import *
from net import *parser = argparse.ArgumentParser(description='')parser.add_argument("--test_picture_path", default='./dataset/test_picture/', help="path of test datas.")#网络测试输入的图片路径
parser.add_argument("--test_label_path", default='./dataset/test_label/', help="path of test datas.") #网络测试输入的标签路径
parser.add_argument("--image_size", type=int, default=256, help="load image size") #网络输入的尺度
parser.add_argument("--test_picture_format", default='.png', help="format of test pictures.") #网络测试输入的图片的格式
parser.add_argument("--test_label_format", default='.jpg', help="format of test labels.") #网络测试时读取的标签的格式
parser.add_argument("--snapshots", default='./snapshots/',help="Path of Snapshots") #读取训练好的模型参数的路径
parser.add_argument("--out_dir", default='./test_output/',help="Output Folder") #保存网络测试输出图片的路径args = parser.parse_args() #用来解析命令行参数def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图img_rgb = (img + 1.) * 127.5return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到网络测试的结果picture_image = cv_inv_proc(picture) #还原输入的图像gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的结果label_image = cv_inv_proc(label) #还原读取的标签inv_picture_image = cv2.resize(picture_image, (width, height)) #将输入图像还原到原大小inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #将生成的结果还原到原大小inv_label_image = cv2.resize(label_image, (width, height)) #将标签还原到原大小output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #拼接得到输出结果return outputdef main():if not os.path.exists(args.out_dir): #如果保存测试结果的文件夹不存在则创建os.makedirs(args.out_dir)test_picture_list = glob.glob(os.path.join(args.test_picture_path, "*")) #得到测试输入图像路径名称列表test_picture = tf.placeholder(tf.float32, shape=[1, 256, 256, 3], name='test_picture') #测试输入的图像gen_label = generator(image=test_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的生成结果restore_var = [v for v in tf.global_variables() if 'generator' in v.name] #需要载入的已训练的模型参数config = tf.ConfigProto()config.gpu_options.allow_growth = True #设定显存不超量使用sess = tf.Session(config=config) #建立会话层saver = tf.train.Saver(var_list=restore_var, max_to_keep=1) #导入模型参数时使用checkpoint = tf.train.latest_checkpoint(args.snapshots) #读取模型参数saver.restore(sess, checkpoint) #导入模型参数for step in range(len(test_picture_list)):picture_name, _ = os.path.splitext(os.path.basename(test_picture_list[step])) #得到一张网络测试的输入图像名字#读取一张测试图片,一张标签,以及相应的高和宽picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name,picture_path=args.test_picture_path,label_path=args.test_label_path,picture_format=args.test_picture_format,label_format=args.test_label_format,size=args.image_size)batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis=0) #填充维度feed_dict = {test_picture: batch_picture} #构造feed_dictgen_label_value = sess.run(gen_label, feed_dict=feed_dict) #得到生成结果write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到一张需要存的图像write_image_name = args.out_dir + picture_name + ".png" #为上述的图像构造保存路径与文件名cv2.imwrite(write_image_name, write_image) #保存测试结果print('step {:d}'.format(step))if __name__ == '__main__':main()

如果需要测试训练完毕的模型,相应地更改测试图片和标签输入路径和格式的四个参数即可,并设置读取模型权重的路径即可。

下面,笔者就以训练的填充轮廓生成建筑图片的例子为大家展示一下CGAN的效果:

首先是训练时的可视化输出图像,从左到右,第一张是网络的输入图片(条件),第二张是生成器生成的建筑图像,第三张是真实的建筑图像(标签)。

首先是训练200次的输出:

然后是训练5600次的输出:

然后是训练19000次的输出:

然后是训练36500次的输出:

然后是训练46700次的输出:

然后是训练65700次的输出:

然后是训练72400次的输出:

最后是训练96300次的输出:

下面展示一下训练的loss曲线:

生成器的loss曲线:

判别器的loss曲线:

最后展示一下在测试集上面的效果:

左边是输入的图像(条件),中间是生成的图像,右边是标签(真实的样本)。

上面就是在测试集上面的效果,读者朋友们可以从文章开头笔者放出的链接中下载数据集进行实验。

在train.py中,如果将lamda_l1_weight参数改成100,就是pix2pix的做法,笔者放了一些测试集的效果(训练有一些过拟合):

到这里,CGAN的模型搭建及解析就接近尾声了,很感谢Mehdi Mirza和Simon Osindero,为大家带来条件监督的生成对抗网络算法。CGAN还可以做很多有趣的事情,比如说这个有趣的工作:AI可能真的要代替插画师了……,项目链接https://make.girls.moe/#/,通过CGAN有条件地生成二次元萌妹。

笔者也衷心希望,此篇博客对大家的学习研究有帮助。

欢迎阅读笔者后续博客,各位读者朋友的支持与鼓励是我最大的动力!

written by jiong

天下熙熙皆为利来,天下攘攘皆为利往

详解GAN代码之搭建并详解CGAN代码相关推荐

  1. [Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:[Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleG ...

  2. [转]大数据环境搭建步骤详解(Hadoop,Hive,Zookeeper,Kafka,Flume,Hbase,Spark等安装与配置)

    大数据环境安装和配置(Hadoop2.7.7,Hive2.3.4,Zookeeper3.4.10,Kafka2.1.0,Flume1.8.0,Hbase2.1.1,Spark2.4.0等) 系统说明 ...

  3. php强类型 vscode,VSCode + WSL 2 + Ruby环境搭建图文详解

    vscode配置ruby开发环境 vscode近年来发展迅速,几乎在3年之间就抢占了原来vim.sublime text的很多份额,犹记得在2015-2016年的时候,ruby推荐的开发环境基本上都是 ...

  4. Linux系统下SVN服务器的搭建过程详解 UpJ}s7+

    Linux系统下SVN服务器的搭建过程详解 UpJ}s7+   1 环境:  服务器放在redhatAS4.0上,客户端在windows 2000. k_lb"5z   Z]jSq@%1H* ...

  5. 服务器和网页接口,WebApi架构详解,WebApi接口搭建与部署WebApi服务器

      WebApi架构详解,WebApi接口搭建与部署WebApi服务器 本文关键词:WebApi架构, WebApi接口搭建, WebApi部署 1. Api是什么? API(Application ...

  6. 明晚8点公开课 | 用AI给旧时光上色!详解GAN在黑白照片上色中的应用

    在改革开放40周年之际,百度联合新华社推出了一个刷屏级的H5应用--用AI技术为黑白老照片上色,浓浓的怀旧风勾起了心底快被遗忘的时光. 想了解如何给老照片上色?本次公开课中,我们邀请到了百度高级研发工 ...

  7. es springboot 不设置id_es(elasticsearch)整合SpringCloud(SpringBoot)搭建教程详解

    注意:适用于springboot或者springcloud框架 1.首先下载相关文件 2.然后需要去启动相关的启动文件 3.导入相关jar包(如果有相关的依赖包不需要导入)以及配置配置文件,并且写一个 ...

  8. 《Java和Android开发实战详解》——2.5节良好的Java程序代码编写风格

    本节书摘来自异步社区<Java和Android开发实战详解>一书中的第2章,第2.5节良好的Java程序代码编写风格,作者 陈会安,更多章节内容可以访问云栖社区"异步社区&quo ...

  9. DL之YoloV3:Yolo V3算法的简介(论文介绍)、各种DL框架代码复现、架构详解、案例应用等配图集合之详细攻略

    DL之YoloV3:Yolo V3算法的简介(论文介绍).各种DL框架代码复现.架构详解.案例应用等配图集合之详细攻略 目录 Yolo V3算法的简介(论文介绍) 0.YoloV3实验结果 1.Yol ...

最新文章

  1. 【PHP】微信官方代码Log调试输出类,面向对象设计模式!来看看,你会有收益!...
  2. 2.2tensorflow2官方demo
  3. LeetCode——排序
  4. Linux同步目录 保留文件修改时间和权限 rsync
  5. 【Linux】VMware连接CRT
  6. numpy ndarray 多维数组的内存管理
  7. LeetCode 73. Set Matrix Zeroes
  8. linux下的系统服务管理及日志管理
  9. 爬虫---如何抓取app的思路和方案
  10. 不能错过的linux驱动开发的经典书籍推荐
  11. java 时间格式化 注解_Java关于时间格式化的方法
  12. kettle软件的使用
  13. 【剑桥摄影协会】伽马校正(Gamma)
  14. LABjs分析 http://labjs.com/documentation.php#queuescript
  15. html中自动换行标记[转]
  16. 光计算机pdf,神威bull;太湖之光计算机系统.PDF
  17. 网上看到了一个关于黑客的练习方式
  18. ROS中的imu_transformer包是什么,在哪里可以下载啊
  19. 今晚直播,你该了解的MySQL 8.0 SQL优化新特性
  20. 超简单vue-devtools工具安装

热门文章

  1. QML Map中测距——QtLocation轻量级地图应用学习
  2. 你有“隐私泄露担忧”吗?适合普通用户的6个方法来了
  3. WormHole是一个简单、易用的api管理平台,支持dubbo服务调用
  4. MFC不同窗口之间传递数据
  5. 攻击重放技术以及什么是重放攻击?
  6. 深度学习(五):FastFCN代码运行、测试与预测
  7. Revit快速标注 | 有求必应的【万能标注】操作步骤
  8. irc php,IRC / 实时聊天系统
  9. 多线程相关实例(多线程经典应用场景)
  10. 如何简单设计接口测试用例