参考链接:https://github.com/yenchenlin/pix2pix-tensorflow

https://blog.csdn.net/stdcoutzyx/article/details/78820728

utils.py

from __future__ import division
import math
import json
import random
import pprint
import scipy.misc
import numpy as np
from time import gmtime,strftime
pp = pprint.PrettyPrinter()get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])#########################################################################
# 载入图片
# 读取图片
def imread(path,is_grayscale=False):if(is_grayscale):return scipy.misc.imread(path,flatten=True).astype(np.float)else:return scipy.misc.imread(path).astype(np.float)
# 载入图片
def load_image(image_path):input_img = imread(image_path)# 图片宽度w = int(input_img.shape[1])# 将成对数据分开w2 = int(w/2)img_A = input_img[:,0:w2]img_B = input_img[:,w2:w]# 分离label和targetreturn img_A,img_B
# 处理分离后的图片
def preprocess_A_and_B(img_A,img_B,load_size=286,fine_size=256,flip=True,is_test=False):if is_test:img_A = scipy.misc.imresize(img_A,[fine_size,fine_size])img_B = scipy.misc.imresize(img_B,[fine_size,fine_size])else: # 对图片做一处理,统一维度fine_sizeimg_A = scipy.misc.imresize(img_A,[load_size,load_size])img_B = scipy.misc.imresize(img_B,[load_size,load_size])h1 = int(np.ceil(np.random.uniform(1e-2,load_size-fine_size)))w1 = int(np.ceil(np.random.uniform(1e-2,load_size-fine_size)))img_A = img_A[h1:h1+fine_size,w1:w1+fine_size]img_B = img_B[h1:h1+fine_size,w1:w1+fine_size]if flip and np.random.random() > 0.5:# 反转矩阵的左右img_A = np.fliplr(img_A)img_B = np.fliplr(img_B)return img_A, img_B
# 加载数据
def load_data(image_path, flip=True,is_test=False):# 加载图片img_A, img_B = load_image(image_path)# 统一维度固定大小256x256img_A, img_B = preprocess_A_and_B(img_A, img_B, flip=flip, is_test=is_test)# 归一化处理img_A = img_A/127.5 - 1.img_B = img_B/127.5 - 1.# 按通道将A,B Concatenate起来  [fine_size,fine_size,input_c_dim + output_c_dim]->[256,256,6]img_AB = np.concatenate((img_A,img_B),axis=2) return img_AB
#####################################################################
# 测试
# a,b = load_image("cityscapes/train/1.jpg")
# c,d = preprocess_A_and_B(a,b)
# print(c.shape)
# a = load_data("cityscapes/train/1.jpg")
# print(a.shape)
######################################################################
# -1-1---->0-1
def inverse_transform(images):return (images+1.)/2
# 合并图片
def merge(images,size):h,w = images.shape[1], images.shape[2]img = np.zeros((h*size[0], w*size[1],3))for idx, image in enumerate(images):i = idx % size[1]j = idx // size[1]img[j*h:j*h+h,i*w:i*w+w,:] = imagereturn img
# 保存图片
def imsave(images,size,path):return scipy.misc.imsave(path,merge(images,size))
def save_images(images,size,image_path):return imsave(inverse_transform(images),size,image_path)

ops.py

import math
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
from utils import *# 批归一化
class batch_norm(object):def __init__(self, epsilon=1e-5, momentum=0.9, name="batch_norm"):with tf.variable_scope(name):self.epsilon = epsilonself.momentum = momentumself.name = namedef __call__(self,x,train=True):return tf.contrib.layers.batch_norm(x,decay=self.momentum,updates_collections=None,\epsilon=self.epsilon,scale=True,scope=self.name)def binary_cross_entropy(preds, targets, name=None):"""Computes binary cross entropy given `preds`.For brevity, let `x = `, `z = targets`.  The logistic loss isloss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i]))Args:preds: A `Tensor` of type `float32` or `float64`.targets: A `Tensor` of the same type and shape as `preds`."""eps = 1e-12with ops.op_scope([preds, targets], name, "bce_loss") as name:preds = ops.convert_to_tensor(preds, name="preds")targets = ops.convert_to_tensor(targets, name="targets")return tf.reduce_mean(-(targets * tf.log(preds + eps) +(1. - targets) * tf.log(1. - preds + eps)))
# concat
def conv_cond_concat(x,y):x_shapes = x.get_shape()y_shapes = y.get_shape()return tf.concat([x,y*tf.ones([x_shapes[0],x_shapes[1],x_shapes[2],y_shapes[3]])],3)
# 卷积
def conv2d(input_,output_dim,k_h=5,k_w=5,d_h=2,d_w=2,stddev=0.02,name="conv2d"):with tf.variable_scope(name):w = tf.get_variable('w',[k_h,k_w,input_.get_shape()[-1],output_dim],initializer=tf.truncated_normal_initializer(stddev=stddev))conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')biases = tf.get_variable('biases',[output_dim],initializer=tf.constant_initializer(0.0))conv = tf.reshape(tf.nn.bias_add(conv,biases), conv.get_shape())return conv# 反卷积
def deconv2d(input_, output_shape,k_h=5,k_w=5, d_h=2, d_w=2, stddev=0.02, name="deconv2d",with_w=False):with tf.variable_scope(name):# 卷积核:[height, width, output_channels, in_channels]w = tf.get_variable('w',[k_h, k_w, output_shape[-1], input_.get_shape()[-1]],initializer=tf.random_normal_initializer(stddev=stddev))deconv = tf.nn.conv2d_transpose(input_,w,output_shape=output_shape,strides=[1, d_h, d_w, 1])biases = tf.get_variable('biases',[output_shape[-1]],initializer=tf.constant_initializer(0.0))deconv = tf.reshape(tf.nn.bias_add(deconv,biases),deconv.get_shape())if with_w:return deconv, w, biaseselse:return deconv
# lrelu激活函数
def lrelu(x,leak=0.2,name='lrelu'):return tf.maximum(x,leak*x)
def linear(input_,output_size,scope=None,stddev=0.02,bias_start=0.0,with_w=False):shape = input_.get_shape().as_list()with tf.variable_scope(scope or "Linear"):matrix = tf.get_variable("Matrix",[shape[1],output_size],tf.float32,tf.random_normal_initializer(stddev=stddev))bias = tf.get_variable("bias",[output_size],initializer=tf.constant_initializer(bias_start))if with_w:return tf.matmul(input_,matrix) + bias, matrix, biaselse:return tf.matmul(input_,matrix) + bias

model.py

from __future__ import division
import os
import time
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrangefrom ops import *
from utils import *class pix2pix(object):def __init__(self, sess, image_size=256,batch_size=1, sample_size=1, output_size=256,gf_dim=64, df_dim=64, L1_lambda=100,input_c_dim=3, output_c_dim=3, dataset_name='facades',checkpoint_dir=None, sample_dir=None):"""Args:sess: TensorFlow sessionbatch_size: The size of batch. Should be specified before training.output_size: (optional) The resolution in pixels of the images. [256]gf_dim: (optional) Dimension of gen filters in first conv layer. [64]df_dim: (optional) Dimension of discrim filters in first conv layer. [64]input_c_dim: (optional) Dimension of input image color. For grayscale input, set to 1. [3]output_c_dim: (optional) Dimension of output image color. For grayscale input, set to 1. [3]"""self.sess = sessself.is_grayscale = (input_c_dim == 1)self.batch_size = batch_sizeself.image_size = image_sizeself.sample_size = sample_sizeself.output_size = output_sizeself.gf_dim = gf_dimself.df_dim = df_dimself.input_c_dim = input_c_dimself.output_c_dim = output_c_dimself.L1_lambda = L1_lambda# batch normalization : deals with poor initialization helps gradient flowself.d_bn1 = batch_norm(name='d_bn1')self.d_bn2 = batch_norm(name='d_bn2')self.d_bn3 = batch_norm(name='d_bn3')self.g_bn_e2 = batch_norm(name='g_bn_e2')self.g_bn_e3 = batch_norm(name='g_bn_e3')self.g_bn_e4 = batch_norm(name='g_bn_e4')self.g_bn_e5 = batch_norm(name='g_bn_e5')self.g_bn_e6 = batch_norm(name='g_bn_e6')self.g_bn_e7 = batch_norm(name='g_bn_e7')self.g_bn_e8 = batch_norm(name='g_bn_e8')self.g_bn_d1 = batch_norm(name='g_bn_d1')self.g_bn_d2 = batch_norm(name='g_bn_d2')self.g_bn_d3 = batch_norm(name='g_bn_d3')self.g_bn_d4 = batch_norm(name='g_bn_d4')self.g_bn_d5 = batch_norm(name='g_bn_d5')self.g_bn_d6 = batch_norm(name='g_bn_d6')self.g_bn_d7 = batch_norm(name='g_bn_d7')self.dataset_name = dataset_nameself.checkpoint_dir = checkpoint_dirself.build_model()def build_model(self):# img_A和img_Bconcat后的六通道输入self.real_data = tf.placeholder(tf.float32,[self.batch_size, self.image_size, self.image_size,self.input_c_dim + self.output_c_dim],name='real_A_and_B_images')# 分开后的img_A和img_Bself.real_B = self.real_data[:, :, :, :self.input_c_dim]self.real_A = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]# 输入标签图片生成目标图片self.fake_B = self.generator(self.real_A)# 把真的标签和目标图片concat起来self.real_AB = tf.concat([self.real_A, self.real_B], 3)# 把标签和生成假的目标图再concat起来self.fake_AB = tf.concat([self.real_A, self.fake_B], 3)# 判别器判别真假self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False)self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True)# 生成器 u-net结构 生成假图self.fake_B_sample = self.sampler(self.real_A)# 可视化参数self.d_sum = tf.summary.histogram("d", self.D)self.d__sum = tf.summary.histogram("d_", self.D_)self.fake_B_sum = tf.summary.image("fake_B", self.fake_B)# 判别器lossself.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))# 生成器lossself.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_))) \+ self.L1_lambda * tf.reduce_mean(tf.abs(self.real_B - self.fake_B))# 可视化lossself.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real)self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)self.d_loss = self.d_loss_real + self.d_loss_fakeself.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)t_vars = tf.trainable_variables()self.d_vars = [var for var in t_vars if 'd_' in var.name]self.g_vars = [var for var in t_vars if 'g_' in var.name]self.saver = tf.train.Saver()def load_random_samples(self):# 等概率随机抽取batch_size个图片data = np.random.choice(glob('{}/val/*.jpg'.format(self.dataset_name)), self.batch_size)# 加载数据sample = [load_data(sample_file) for sample_file in data]if (self.is_grayscale):sample_images = np.array(sample).astype(np.float32)[:, :, :, None]else: # 变为矩阵形式,A,B已经concat后的数据[256,256,6]sample_images = np.array(sample).astype(np.float32)return sample_imagesdef sample_model(self, sample_dir, epoch, idx):sample_images = self.load_random_samples()# samples生成的假的图片,喂入concat后真的图片samples, d_loss, g_loss = self.sess.run([self.fake_B_sample, self.d_loss, self.g_loss],feed_dict={self.real_data: sample_images})# 保存图片save_images(samples, [self.batch_size, 1],'./{}/train_{:02d}_{:04d}.png'.format(sample_dir, epoch, idx))print("[Sample] d_loss: {:.8f}, g_loss: {:.8f}".format(d_loss, g_loss))def train(self, args):"训练pix2pix"d_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \.minimize(self.d_loss, var_list=self.d_vars)g_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \.minimize(self.g_loss, var_list=self.g_vars)init_op = tf.global_variables_initializer()self.sess.run(init_op)self.g_sum = tf.summary.merge([self.d__sum,self.fake_B_sum, self.d_loss_fake_sum, self.g_loss_sum])self.d_sum = tf.summary.merge([self.d_sum, self.d_loss_real_sum, self.d_loss_sum])if not os.path.exists('logs'):os.makedirs('logs')self.writer = tf.summary.FileWriter("./logs", self.sess.graph)counter = 1start_time = time.time()if self.load(self.checkpoint_dir):print(" [*] Load SUCCESS")else:print(" [!] Load failed...")for epoch in xrange(args.epoch):data = glob('{}/train/*.jpg'.format(self.dataset_name))print(len(data))#np.random.shuffle(data)batch_idxs = min(len(data), args.train_size) // self.batch_sizefor idx in xrange(0, batch_idxs):# 文件名batch_files = data[idx*self.batch_size:(idx+1)*self.batch_size]# 矩阵形式数据 [256,256,6]batch = [load_data(batch_file) for batch_file in batch_files]if (self.is_grayscale):batch_images = np.array(batch).astype(np.float32)[:, :, :, None]else:batch_images = np.array(batch).astype(np.float32)# 更新判别器_, summary_str = self.sess.run([d_optim, self.d_sum],feed_dict={ self.real_data: batch_images })self.writer.add_summary(summary_str, counter)# 更新生成器,运行生成器两次,确保d_loss不接近0(不同于paper)for _ in range(2):_,summary_str = self.sess.run([g_optim,self.g_sum],feed_dict={self.real_data:batch_images})self.writer.add_summary(summary_str,counter)errD_fake = self.d_loss_fake.eval({self.real_data: batch_images})errD_real = self.d_loss_real.eval({self.real_data: batch_images})errG = self.g_loss.eval({self.real_data: batch_images})counter += 1print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \% (epoch, idx, batch_idxs,time.time() - start_time, errD_fake+errD_real, errG))# 每100次保存一次图片if np.mod(counter, 100) == 1:self.sample_model(args.sample_dir, epoch, idx)if np.mod(counter, 500) == 2:self.save(args.checkpoint_dir, counter)def discriminator(self,image,y=None,reuse=False):with tf.variable_scope("discriminator") as scope:# 图片大小为256x256x6if reuse:tf.get_variable_scope().reuse_variables()else:assert tf.get_variable_scope().reuse == Falseh0 = lrelu(conv2d(image,self.df_dim, 5, 5, 2, 2, name='d_h0_conv'))# h0 is (128 x 128 x self.df_dim)h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, 5, 5, 2, 2, name='d_h1_conv')))# h1 is (64 x 64 x self.df_dim*2)h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, 5, 5, 2, 2, name='d_h2_conv')))# h2 is (32 x 32 x self.df_dim*4)h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, 5, 5, 1, 1, name='d_h3_conv')))# h3 is (16 x 16 x self.df_dim*8)h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin')return tf.nn.sigmoid(h4), h4def generator(self, image, y=None):with tf.variable_scope("generator") as scope:s = self.output_sizes2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128)# image is (256 x 256 x input_c_dim)e1 = conv2d(image, self.gf_dim, name='g_e1_conv')# e1 is (128 x 128 x self.gf_dim)e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv'))# e2 is (64 x 64 x self.gf_dim*2)e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv'))# e3 is (32 x 32 x self.gf_dim*4)e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv'))# e4 is (16 x 16 x self.gf_dim*8)e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv'))# e5 is (8 x 8 x self.gf_dim*8)e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv'))# e6 is (4 x 4 x self.gf_dim*8)e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv'))# e7 is (2 x 2 x self.gf_dim*8)e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv'))# e8 is (1 x 1 x self.gf_dim*8)self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8),[self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True)d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5)d1 = tf.concat([d1, e7], 3)# d1 is (2 x 2 x self.gf_dim*8*2)self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1),[self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True)d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5)d2 = tf.concat([d2, e6], 3)# d2 is (4 x 4 x self.gf_dim*8*2)self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2),[self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True)d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5)d3 = tf.concat([d3, e5], 3)# d3 is (8 x 8 x self.gf_dim*8*2)self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3),[self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True)d4 = self.g_bn_d4(self.d4)d4 = tf.concat([d4, e4], 3)# d4 is (16 x 16 x self.gf_dim*8*2)self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4),[self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True)d5 = self.g_bn_d5(self.d5)d5 = tf.concat([d5, e3], 3)# d5 is (32 x 32 x self.gf_dim*4*2)self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5),[self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True)d6 = self.g_bn_d6(self.d6)d6 = tf.concat([d6, e2], 3)# d6 is (64 x 64 x self.gf_dim*2*2)self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6),[self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True)d7 = self.g_bn_d7(self.d7)d7 = tf.concat([d7, e1], 3)# d7 is (128 x 128 x self.gf_dim*1*2)self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),[self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True)# d8 is (256 x 256 x output_c_dim)return tf.nn.tanh(self.d8)def sampler(self, image, y=None):with tf.variable_scope("generator") as scope:scope.reuse_variables()s = self.output_sizes2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128)# image is (256 x 256 x input_c_dim)e1 = conv2d(image, self.gf_dim, name='g_e1_conv')# e1 is (128 x 128 x self.gf_dim)e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv'))# e2 is (64 x 64 x self.gf_dim*2)e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv'))# e3 is (32 x 32 x self.gf_dim*4)e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv'))# e4 is (16 x 16 x self.gf_dim*8)e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv'))# e5 is (8 x 8 x self.gf_dim*8)e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv'))# e6 is (4 x 4 x self.gf_dim*8)e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv'))# e7 is (2 x 2 x self.gf_dim*8)e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv'))# e8 is (1 x 1 x self.gf_dim*8)self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8),[self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True)d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5)d1 = tf.concat([d1, e7], 3)# d1 is (2 x 2 x self.gf_dim*8*2)self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1),[self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True)d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5)d2 = tf.concat([d2, e6], 3)# d2 is (4 x 4 x self.gf_dim*8*2)self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2),[self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True)d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5)d3 = tf.concat([d3, e5], 3)# d3 is (8 x 8 x self.gf_dim*8*2)self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3),[self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True)d4 = self.g_bn_d4(self.d4)d4 = tf.concat([d4, e4], 3)# d4 is (16 x 16 x self.gf_dim*8*2)self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4),[self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True)d5 = self.g_bn_d5(self.d5)d5 = tf.concat([d5, e3], 3)# d5 is (32 x 32 x self.gf_dim*4*2)self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5),[self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True)d6 = self.g_bn_d6(self.d6)d6 = tf.concat([d6, e2], 3)# d6 is (64 x 64 x self.gf_dim*2*2)self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6),[self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True)d7 = self.g_bn_d7(self.d7)d7 = tf.concat([d7, e1], 3)# d7 is (128 x 128 x self.gf_dim*1*2)self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),[self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True)# d8 is (256 x 256 x output_c_dim)return tf.nn.tanh(self.d8)# 保存模型   def save(self, checkpoint_dir, step):model_name = "pix2pix.model"model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size)checkpoint_dir = os.path.join(checkpoint_dir, model_dir)if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)self.saver.save(self.sess,os.path.join(checkpoint_dir, model_name),global_step=step)# 加载模型def load(self, checkpoint_dir):print(" [*] Reading checkpoint...")model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size)checkpoint_dir = os.path.join(checkpoint_dir, model_dir)ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:ckpt_name = os.path.basename(ckpt.model_checkpoint_path)self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))return Trueelse:return Falsedef test(self, args):"""Test pix2pix"""init_op = tf.global_variables_initializer()self.sess.run(init_op)sample_files = glob('{}\\val\\*.jpg'.format(self.dataset_name))print(sample_files)# sort testing inputn = [int(i) for i in map(lambda x: x.split('\\')[-1].split('.jpg')[0], sample_files)]sample_files = [x for (y, x) in sorted(zip(n, sample_files))]# load testing inputprint("Loading testing images ...")sample = [load_data(sample_file, is_test=True) for sample_file in sample_files]if (self.is_grayscale):sample_images = np.array(sample).astype(np.float32)[:, :, :, None]else:sample_images = np.array(sample).astype(np.float32)sample_images = [sample_images[i:i+self.batch_size]for i in xrange(0, len(sample_images), self.batch_size)]sample_images = np.array(sample_images)print(sample_images.shape)start_time = time.time()if self.load(self.checkpoint_dir):print(" [*] Load SUCCESS")else:print(" [!] Load failed...")for i, sample_image in enumerate(sample_images):idx = i+1print("sampling image ", idx)samples = self.sess.run(self.fake_B_sample,feed_dict={self.real_data: sample_image})save_images(samples, [self.batch_size, 1],'./{}/test_{:04d}.png'.format(args.test_dir, idx))

main.py

import argparse
import os
import scipy.misc
import numpy as npfrom model import pix2pix
import tensorflow as tfparser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset_name', dest='dataset_name', default='cityscapes', help='name of the dataset')
parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch')
parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train')
parser.add_argument('--load_size', dest='load_size', type=int, default=256, help='scale images to this size')
parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer')
parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer')
parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels')
parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels')
parser.add_argument('--niter', dest='niter', type=int, default=200, help='# of iter at starting learning rate')
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--flip', dest='flip', type=bool, default=True, help='if flip the images for data argumentation')
parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA')
parser.add_argument('--phase', dest='phase', default='train', help='train, test')
parser.add_argument('--save_epoch_freq', dest='save_epoch_freq', type=int, default=1000, help='save a model every save_epoch_freq epochs (does not overwrite previously saved models)')
parser.add_argument('--save_latest_freq', dest='save_latest_freq', type=int, default=5000, help='save the latest model every latest_freq sgd iterations (overwrites the previous latest model)')
parser.add_argument('--print_freq', dest='print_freq', type=int, default=10, help='print the debug information every print_freq iterations')
parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false')
parser.add_argument('--serial_batches', dest='serial_batches', type=bool, default=False, help='f 1, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--serial_batch_iter', dest='serial_batch_iter', type=bool, default=True, help='iter into serial image list')
parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here')
parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here')
parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here')
parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=100.0, help='weight on L1 term in objective')args = parser.parse_args()def main(_):if not os.path.exists(args.checkpoint_dir):os.makedirs(args.checkpoint_dir)if not os.path.exists(args.sample_dir):os.makedirs(args.sample_dir)if not os.path.exists(args.test_dir):os.makedirs(args.test_dir)with tf.Session() as sess:model = pix2pix(sess, image_size=args.fine_size, batch_size=args.batch_size,output_size=args.fine_size, dataset_name=args.dataset_name,checkpoint_dir=args.checkpoint_dir, sample_dir=args.sample_dir)if args.phase == 'train':model.train(args)else:model.test(args)if __name__ == '__main__':tf.app.run()

自己训练结果

Pix2Pix代码解析相关推荐

  1. 【对比学习】CUT模型论文解读与NCE loss代码解析

    标题:Contrastive Learning for Unpaired Image-to-Image Translation(基于对比学习的非配对图像转换) 作者:Taesung Park, Ale ...

  2. pix2pixHD代码解析

    前言 环境配置:puthon3.6.9 + pytorch1.1.0 + CUDA10.1 + RTX 2080TI(12G) 代码链接: NVIDIA /pix2pixHD 原文地址:High-Re ...

  3. matrix_multiply代码解析

    matrix_multiply代码解析 关于matrix_multiply 程序执行代码里两个矩阵的乘法,并将相乘结果打印在屏幕上. 示例的主要目的是展现怎么实现一个自定义CPU计算任务. 参考:ht ...

  4. CornerNet代码解析——损失函数

    CornerNet代码解析--损失函数 文章目录 CornerNet代码解析--损失函数 前言 总体损失 1.Heatmap的损失 2.Embedding的损失 3.Offset的损失 前言 今天要解 ...

  5. 视觉SLAM开源算法ORB-SLAM3 原理与代码解析

    来源:深蓝学院,文稿整理者:何常鑫,审核&修改:刘国庆 本文总结于上交感知与导航研究所科研助理--刘国庆关于[视觉SLAM开源算法ORB-SLAM3 原理与代码解析]的公开课. ORB-SLA ...

  6. java获取object属性值_java反射获取一个object属性值代码解析

    有些时候你明明知道这个object里面是什么,但是因为种种原因,你不能将它转化成一个对象,只是想单纯地提取出这个object里的一些东西,这个时候就需要用反射了. 假如你这个类是这样的: privat ...

  7. python中的doc_基于Python获取docx/doc文件内容代码解析

    这篇文章主要介绍了基于Python获取docx/doc文件内容代码解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 整体思路: 下载文件并修改后缀 ...

  8. mongoose框架示例代码解析(一)

    mongoose框架示例代码解析(一) 参考: Mongoose Networking Library Documentation(Server) Mongoose Networking Librar ...

  9. ViBe算法原理和代码解析

    ViBe - a powerful technique for background detection and subtraction in video sequences 算法官网:http:// ...

  10. 【Android 逆向】使用 Python 代码解析 ELF 文件 ( PyCharm 中进行断点调试 | ELFFile 实例对象分析 )

    文章目录 一.PyCharm 中进行断点调试 二.ELFFile 实例对象分析 一.PyCharm 中进行断点调试 在上一篇博客 [Android 逆向]使用 Python 代码解析 ELF 文件 ( ...

最新文章

  1. controlfile
  2. 144. Binary Tree Preorder Traversal(非递归实现二叉树的前序遍历)
  3. easyui js解析字符串_easyui的解析器Parser
  4. python gui插件_Python进阶量化交易专栏场外篇17- GUI控件在回测工具上的添加
  5. java应用架构设计_java应用架构设计
  6. restful 接口 安全性设计
  7. 实例19:python
  8. 大数据世界要熟悉的5门语言
  9. 使用U盘驱动器安装Linux,使用U盘安装Ubuntu的详细图文教程
  10. [Android]Handler的消息机制
  11. python中args是什么意思_理解Python中的*,*args
  12. 分布式团队中沟通引发的问题, itest 解决之道
  13. HenCoder UI 部分 2-1 布局基础
  14. JAVA 实现《中国象棋》游戏
  15. Python数据分析与机器学习42-Python库分析科比生涯数据
  16. 清华大学计算机系2016名单,清华大学2016年自主招生北京考生入选名单汇总
  17. Vscode 快速打开setting.json
  18. 学习uc/os-ii
  19. 磁盘与文件系统管理--鸟哥私房菜读书笔记
  20. 【科技橙就新商业】淘系技术走进四川大学,讲述淘宝天猫的前端故事

热门文章

  1. 十年磨一剑:梳理淘宝网技术架构的发展
  2. 不用找,你想要的人物Flash动画素材都在这里
  3. Fedora 9与Windows共享文件
  4. KAIOS软件下载-自己做的
  5. c#控制台应用程序读取 config
  6. 转载--32个鲜为人知的自学网站
  7. 手机无线电驾驶与马歇尔·麦克卢汉的哲学
  8. VARCHAR2 与 NVARCHAR2 区别
  9. Origin复制图形格式
  10. 关于dotnetbar控件