作者 | 李秋键

责编 | Carol

封图 | CSDN 下载自视觉中国

近几年来GAN图像生成应用越来越广泛,其中主要得益于GAN 在博弈下不断提高建模能力,最终实现以假乱真的图像生成。GAN 由两个神经网络组成,一个生成器和一个判别器组成,其中生成器试图产生欺骗判别器的真实样本,而判别器试图区分真实样本和生成样本。这种对抗博弈下使得生成器和判别器不断提高性能,在达到纳什平衡后生成器可以实现以假乱真的输出。

其中GAN 在图像生成应用最为突出,当然在计算机视觉中还有许多其他应用,如图像绘画,图像标注,物体检测和语义分割。在自然语言处理中应用 GAN 的研究也是一种增长趋势,如文本建模,对话生成,问答和机器翻译。然而,在 NLP 任务中训练 GAN 更加困难并且需要更多技术,这也使其成为具有挑战性但有趣的研究领域。

而今天我们就将利用CC-GAN训练将侧脸生成正脸的模型,其中迭代20次结果如下:

实验前的准备

首先我们使用的python版本是3.6.5所用到的模块如下:tensorflow用来模型训练和网络层建立;numpy模块用来处理矩阵运算;OpenCV用来读取图片和图像处理;os模块用来读取数据集等本地文件操作。

素材准备

其中准备训练的不同角度人脸图片放入以下文件夹作为训练集,如下图可见:

测试集图片如下可见:

模型搭建

原始GAN(GAN 简介与代码实战)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约束,能生成我们想要的图片,这个时候,CGAN就横空出世了。其中CCGAN整体模型结构如下:

1、网络结构参数的搭建:

首先是定义标准化、激活函数和池化层等函数:Batch_Norm是对其进行规整,是为了防止同一个batch间的梯度相互抵消。其将不同batch规整到同一个均值0和方差1。InstanceNorm是将输入在深度方向上减去均值除以标准差,可以加快网络的训练速度。

def instance_norm(x, scope='instance_norm'):return tf_contrib.layers.instance_norm(x, epsilon=1e-05, center=True, scale=True, scope=scope)
def batch_norm(x, scope='batch_norm'):return tf_contrib.layers.batch_norm(x, decay=0.9, epsilon=1e-05, center=True, scale=True, scope=scope)
def flatten(x) :return tf.layers.flatten(x)
def lrelu(x, alpha=0.2):return tf.nn.leaky_relu(x, alpha)
def relu(x):return tf.nn.relu(x)
def global_avg_pooling(x):gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True)return gap
def resblock(x_init, c, scope='resblock'):with tf.variable_scope(scope):with tf.variable_scope('res1'):x = slim.conv2d(x_init, c, kernel_size=[3,3], stride=1, activation_fn = None)x = batch_norm(x)x = relu(x)with tf.variable_scope('res2'):x = slim.conv2d(x, c, kernel_size=[3,3], stride=1, activation_fn = None)x = batch_norm(x)return x + x_init

然后是卷积层的定义:

def conv(x, c):x1 = slim.conv2d(x, c, kernel_size=[5,5], stride=2, padding = 'SAME', activation_fn=relu)
#    print(x1.shape)x2 = slim.conv2d(x, c, kernel_size=[3,3], stride=2, padding = 'SAME', activation_fn=relu)
#    print(x2.shape)x3 = slim.conv2d(x, c, kernel_size=[1,1], stride=2, padding = 'SAME', activation_fn=relu)
#    print(x3.shape)out = tf.concat([x1, x2, x3],axis = 3)out = slim.conv2d(out, c, kernel_size=[1,1], stride=1, padding = 'SAME', activation_fn=None)
#    print(out.shape)
return out

生成器函数定义:

def mixgenerator(x_init, c, org_pose, trg_pose):    reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0with tf.variable_scope('generator', reuse = reuse):org_pose = tf.cast(tf.reshape(org_pose, shape=[-1, 1, 1, org_pose.shape[-1]]), tf.float32)print(org_pose.shape)org_pose = tf.tile(org_pose, [1, x_init.shape[1], x_init.shape[2], 1])print(org_pose.shape)x = tf.concat([x_init, org_pose], axis=-1)print(x.shape)               x = conv(x, c)x = batch_norm(x, scope='bat_norm_1')x = relu(x)#64print('----------------')print(x.shape)x = conv(x, c*2)x = batch_norm(x, scope='bat_norm_2')x = relu(x)#32print(x.shape)x = conv(x, c*4)x = batch_norm(x, scope='bat_norm_3')x = relu(x)#16print(x.shape)f_org = xx = conv(x, c*8)x = batch_norm(x, scope='bat_norm_4')x = relu(x)#8print(x.shape)x = conv(x, c*8)x = batch_norm(x, scope='bat_norm_5')x = relu(x)#4print(x.shape)for i in range(6):x = resblock(x, c*8, scope = str(i)+"_resblock")trg_pose = tf.cast(tf.reshape(trg_pose, shape=[-1, 1, 1, trg_pose.shape[-1]]), tf.float32)print(trg_pose.shape)trg_pose = tf.tile(trg_pose, [1, x.shape[1], x.shape[2], 1])print(trg_pose.shape)x = tf.concat([x, trg_pose], axis=-1)print(x.shape)x = slim.conv2d_transpose(x, c*8, kernel_size=[3, 3], stride=2, activation_fn=None)x = batch_norm(x, scope='bat_norm_8')x = relu(x)#8print(x.shape)x = slim.conv2d_transpose(x, c*4, kernel_size=[3, 3], stride=2, activation_fn=None)x = batch_norm(x, scope='bat_norm_9')x = relu(x)#16print(x.shape)f_trg =xx = slim.conv2d_transpose(x, c*2, kernel_size=[3, 3], stride=2, activation_fn=None)x = batch_norm(x, scope='bat_norm_10')x = relu(x)#32print(x.shape)x = slim.conv2d_transpose(x, c, kernel_size=[3, 3], stride=2, activation_fn=None)x = batch_norm(x, scope='bat_norm_11')x = relu(x)#64print(x.shape)z = slim.conv2d_transpose(x, 3 , kernel_size=[3,3], stride=2, activation_fn = tf.nn.tanh)f = tf.concat([f_org, f_trg], axis=-1)print(f.shape)return z, f  

下面还有判别器等函数定义,不加赘述。

2、VGG程序设立:

VGG模型网络层的搭建:

def build(self, rgb, include_fc=False):"""load variable from npy to build the VGGinput format: bgr image with shape [batch_size, h, w, 3]scale: (-1, 1)"""start_time = time.time()rgb_scaled = (rgb + 1) / 2 # [-1, 1] ~ [0, 1]
#        blue, green, red = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
#        bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0],
#                                        green - VGG_MEAN[1],
#                                        red - VGG_MEAN[2]])self.conv1_1 = self.conv_layer(rgb_scaled, "conv1_1")self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")self.pool1 = self.max_pool(self.conv1_2, 'pool1')self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")self.pool2 = self.max_pool(self.conv2_2, 'pool2')self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")self.conv3_2_no_activation = self.no_activation_conv_layer(self.conv3_1, "conv3_2")self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4")self.pool3 = self.max_pool(self.conv3_4, 'pool3')self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")self.conv4_4_no_activation = self.no_activation_conv_layer(self.conv4_3, "conv4_4")self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4")self.pool4 = self.max_pool(self.conv4_4, 'pool4')self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")self.conv5_4_no_activation = self.no_activation_conv_layer(self.conv5_3, "conv5_4")self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4")self.pool5 = self.max_pool(self.conv5_4, 'pool5')if include_fc:self.fc6 = self.fc_layer(self.pool5, "fc6")assert self.fc6.get_shape().as_list()[1:] == [4096]self.relu6 = tf.nn.relu(self.fc6)self.fc7 = self.fc_layer(self.relu6, "fc7")self.relu7 = tf.nn.relu(self.fc7)self.fc8 = self.fc_layer(self.relu7, "fc8")self.prob = tf.nn.softmax(self.fc8, name="prob")self.data_dict = Noneprint(("Finished building vgg19: %ds" % (time.time() - start_time)))

池化层、卷积层函数的定义:

def avg_pool(self, bottom, name):return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)def max_pool(self, bottom, name):return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)def conv_layer(self, bottom, name):with tf.variable_scope(name):filt = self.get_conv_filter(name)conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')conv_biases = self.get_bias(name)bias = tf.nn.bias_add(conv, conv_biases)relu = tf.nn.relu(bias)return reludef no_activation_conv_layer(self, bottom, name):with tf.variable_scope(name):filt = self.get_conv_filter(name)conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')conv_biases = self.get_bias(name)x = tf.nn.bias_add(conv, conv_biases)return xdef fc_layer(self, bottom, name):with tf.variable_scope(name):shape = bottom.get_shape().as_list()dim = 1for d in shape[1:]:dim *= dx = tf.reshape(bottom, [-1, dim])weights = self.get_fc_weight(name)biases = self.get_bias(name)# Fully connected layer. Note that the '+' operation automatically# broadcasts the biases.fc = tf.nn.bias_add(tf.matmul(x, weights), biases)return fcdef get_conv_filter(self, name):return tf.constant(self.data_dict[name][0], name="filter")def get_bias(self, name):return tf.constant(self.data_dict[name][1], name="biases")def get_fc_weight(self, name):return tf.constant(self.data_dict[name][0], name="weights")

模型的训练

设置GPU加速训练,需要配置好CUDA环境,并按照tensorflow-gpu版本。

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.reset_default_graph()
model = Sequential() #创建一个神经网络对象
#添加一个卷积层,传入固定宽高三通道的
数据集读取和训练批次的划分:
imagedir = './data/'
img_label_org, label_trg, img = reader.images_list(imagedir)
epoch = 800
batch_size = 10
total_sample_num = len(img_label_org)
if total_sample_num % batch_size == 0:    n_batch = int(total_sample_num / batch_size)
else:n_batch = int(total_sample_num / batch_size) + 1

输入输出神经元和判别器等初始化:

org_image = tf.placeholder(tf.float32,[None,128,128,3], name='org_image')
trg_image = tf.placeholder(tf.float32,[None,128,128,3], name='trg_image')
org_pose = tf.placeholder(tf.float32,[None,9], name='org_pose')
trg_pose = tf.placeholder(tf.float32,[None,9], name='trg_pose')
gen_trg, feat = model.mixgenerator(org_image, 32, org_pose, trg_pose)
out_trg = model.generator(feat, 32, trg_pose)#D_ab
D_r, real_logit, real_pose = model.snpixdiscriminator(trg_image)
D_f, fake_logit, fake_pose = model.snpixdiscriminator(gen_trg)
D_f_, fake_logit_, fake_pose_ = model.snpixdiscriminator(out_trg)
# fake or real D_LOSS
loss_pred_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logit, labels=tf.ones_like(D_r)))
loss_pred_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit_, labels=tf.zeros_like(D_f_)))
loss_d_pred = loss_pred_r + loss_pred_f
#pose loss
loss_d_pose = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=real_pose, labels=trg_pose))
loss_g_pose_ = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=fake_pose_, labels=trg_pose))
loss_g_pose = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=fake_pose, labels=trg_pose))
#G_LOSS
loss_g_pred = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit_, labels=tf.ones_like(D_f_)))
out_pix_loss = ops.L2_loss(out_trg, trg_image)
out_pre_loss, out_feat_texture = ops.vgg_loss(out_trg, trg_image)
out_loss_texture =  ops.texture_loss(out_feat_texture)
out_loss_tv = 0.0002 * tf.reduce_mean(ops.tv_loss(out_trg))
gen_pix_loss = ops.L2_loss(gen_trg, trg_image)
out_g_loss = 100*gen_pix_loss + 100*out_pix_loss + loss_g_pred + out_pre_loss + out_loss_texture + out_loss_tv + loss_g_pose_
gen_g_loss = 100 * gen_pix_loss + loss_g_pose
# d_loss
disc_loss = loss_d_pred + loss_d_pose
out_global_step = tf.Variable(0, trainable=False)
gen_global_step = tf.Variable(0, trainable=False)
disc_global_step = tf.Variable(0, trainable=False)
start_decay_step = 500000
start_learning_rate = 0.0001
decay_steps = 500000
end_learning_rate = 0.0
out_lr = (tf.where(tf.greater_equal(out_global_step, start_decay_step), tf.train.polynomial_decay(start_learning_rate, out_global_step-start_decay_step, decay_steps, end_learning_rate, power=1.0),start_learning_rate))
gen_lr = (tf.where(tf.greater_equal(gen_global_step, start_decay_step), tf.train.polynomial_decay(start_learning_rate, gen_global_step-start_decay_step, decay_steps, end_learning_rate, power=1.0),start_learning_rate))
disc_lr = (tf.where(tf.greater_equal(disc_global_step, start_decay_step), tf.train.polynomial_decay(start_learning_rate, disc_global_step-start_decay_step, decay_steps, end_learning_rate, power=1.0),start_learning_rate))
t_vars = tf.trainable_variables()
g_gen_vars = [var for var in t_vars if 'generator' in var.name]
g_out_vars = [var for var in t_vars if 'generator_1' in var.name]
d_vars = [var for var in t_vars if 'discriminator' in var.name]
train_gen = tf.train.AdamOptimizer(gen_lr, beta1=0.5, beta2=0.999).minimize(gen_g_loss, var_list = g_gen_vars, global_step = gen_global_step)
train_out = tf.train.AdamOptimizer(out_lr, beta1=0.5, beta2=0.999).minimize(out_g_loss, var_list = g_out_vars, global_step = out_global_step)
train_disc = tf.train.AdamOptimizer(disc_lr, beta1=0.5, beta2=0.999).minimize(disc_loss, var_list = d_vars, global_step = disc_global_step)
saver = tf.train.Saver(tf.global_variables())

模型训练、图片生成和模型的保存:

with tf.Session(config=config) as sess:for d in ['/gpu:0']:with tf.device(d):ckpt = tf.train.get_checkpoint_state('./models/')if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):saver.restore(sess, ckpt.model_checkpoint_path)print('Import models successful!')else:sess.run(tf.global_variables_initializer())print('Initialize successful!')for i in range(epoch):random.shuffle(img_label_org)random.shuffle(label_trg)for j in range(n_batch):if j == n_batch - 1:n = total_sample_numelse:n = j * batch_size + batch_sizeimg_org_output, img_trg_output, label_org_output, label_trg_output, image_name_output = reader.images_read(img_label_org[j*batch_size:n], label_trg[j*batch_size:n], img, imagedir)feeds = {org_image:img_org_output, trg_image:img_trg_output, org_pose:label_org_output,trg_pose:label_trg_output}if i < 400:sess.run(train_disc, feed_dict=feeds)sess.run(train_gen, feed_dict=feeds)sess.run(train_out, feed_dict=feeds)else:sess.run(train_gen, feed_dict=feeds)sess.run(train_out, feed_dict=feeds)if j%10==0:sess.run(train_disc, feed_dict=feeds)if j%2==0:gen_g_loss_,out_g_loss_, disc_loss_, org_image_, gen_trg_, out_trg_, trg_image_ = sess.run([gen_g_loss, out_g_loss, disc_loss, org_image, gen_trg, out_trg, trg_image],feeds)print("epoch:", i, "iter:", j, "gen_g_loss_:", gen_g_loss_, "out_g_loss_:", out_g_loss_, "loss_disc:", disc_loss_)for n in range(batch_size):org_image_output = (org_image_[n] + 1)*127.5gen_trg_output = (gen_trg_[n] + 1)*127.5out_trg_output = (out_trg_[n] + 1)*127.5trg_image_output = (trg_image_[n] + 1)*127.5temp = np.concatenate([org_image_output, gen_trg_output, out_trg_output, trg_image_output], 1)cv.imwrite("./record/%d_%d_%d_image.jpg" %(i, j, n), temp)if i%10==0 or i==epoch-1:saver.save(sess, './models/wssGAN.ckpt', global_step=gen_global_step)print("Finish!")

最终运行程序结果如下:

初始训练一次结果:

训练20次结果:

经过对比,可以发现有明显的提升!

源码地址:

https://pan.baidu.com/s/1cpRJlk7yUwhYJSIkRpbNpg

提取码:kdxe

作者介绍:

李秋键,CSDN 博客专家,CSDN达人课作者。硕士在读于中国矿业大学,开发有taptap安卓武侠游戏一部,vip视频解析,文意转换工具,写作机器人等项目,发表论文若干,多次高数竞赛获奖等等。

推荐阅读

  • 认知智能,AI的下一个十年 | AI Procon 202

  • 台积电9月14日断供华为:中国“芯”的坎坷之路

  • 2020 美国大选在即,又到了 AI 花式打击假新闻的季节

  • 业界首发|云原生领域首本架构白皮书重磅发布

  • 观点 | 以太坊客户端多样性问题从何而来?

用 Python 可以实现侧脸转正脸?我也要试一下!相关推荐

  1. python standard lib_跟Python Standard Library混个脸熟(一)

    跟Python Standard Library混个脸熟(一) [TOC] Python 本身提供了很丰富的功能,怎么学习这些东西一直是我思考的,通读整个 Python3 Standard Libra ...

  2. 【爬虫实践之爬虫进阶】python爬取网页猫脸,使用opencv检测是否有猫脸,有则爬取

    使用python的opencv编写能够检测猫脸的模型可见,我前面的文章有较为详细的见解,链接如下: python使用opencv对猫脸进行检测,并且框出猫脸_小琼带你轻松学编程的博客-CSDN博客 p ...

  3. 教大家用python画皮卡丘的脸

    仅以此程序送给我的姐姐,嘻嘻~ 效果图 用到的库是python的标准库turtle. 话不多说,上程序! 1.导入turtle库 2.写皮卡丘各部位的函数 鼻子: 眼睛: 脸: 嘴巴: 配置画笔画布参 ...

  4. 用Python+OpenCV+PIL构建猫脸识别器

    在这篇文章中,我将向你展示如何编写一个检测猫脸的简单程序.在我的人脸检测帖子中,我演示了如何使用Python检测人脸. 当我用图像测试代码时,我发现其中有些图像中有动物,但是我们创建的人脸检测模型并不 ...

  5. python使用opencv对猫脸进行检测,并且框出猫脸

    首先导入需要的cv2库,如果没有的可以在Terminal中使用pip install opencv-python导入opencv的主要库包. import cv2 使用filepath赋值照片的路径并 ...

  6. Python - 深度学习系列2-人脸比对 Siamese

    说明 使用Siamese网络进行目标的相似度比较,其好处在于避免了许多复杂的数学处理(仿射变换).本文参考了PyTorch练手项目四:孪生网络(Siamese Network),并结合github上的 ...

  7. 三周写出高性能的Python代码,这些小技巧你值得一试。

    1一个不上进的 Python 使用者 我是一个有 C 语言背景的开发者.最近转做了 Python,平时用 Python 还算 6,这周在给新员工分享工作之后,有个小孩跑来问我:"哥,你是学 ...

  8. python第七章动手试一试_《Python从入门到实践》第八章动手试一试

    8-1 消息 :编写一个名为display_message() 的函数,它打印一个句子,指出你在本章学的是什么.调用这个函数,确认显示的消息正确无误. def display_message(): p ...

  9. Python编程:从入门到实践_动手试一试答案

    动手试一试答案 第二章 第三章 第四章 第二章 #2-1 简单消息 message = "Hello Python world" print(message) #2-2 多条简单消 ...

最新文章

  1. thinkphp中出现unserialize(): Error at offset 533 of 1857 bytes如何解决
  2. Latex:表格制作全攻略
  3. 全球及中国低温纳米定位器行业发展趋势分析与风险评估报告2021-2027年版
  4. Python中使用高德API实现经纬度转地名
  5. 河北计算机科学与技术研究生,2021年河北工业大学计算机科学与技术(081200)硕士研究生招生信息_考研招生计划和招生人数 - 学途吧...
  6. 痛惜!年仅38岁,中科院研究员、博导周传不幸病逝!
  7. linux升级python
  8. SpringMVC自学日志05(结果跳转方式,数据处理 ,乱码问题)
  9. 梯度投影算法 matlab,梯度投影法及其Matlab实现
  10. flash as3笔记1
  11. [POJ 1006] 生理周期
  12. JdbcUtils针对事务问题作出的第三次修改
  13. diagrams软件 可替换visio ProcessOn 亿图图示
  14. 手机端访问本地编写的html页面【亲测有效】
  15. 计算机系统的cpu数量,设置cpu核心数量方法,电脑降低cpu处理器数量和内存大小图文教程...
  16. 127.0.0.1、192.168.0.111、本机地址、URL
  17. 一、安装Centos
  18. 一款超好用的开源密码管理器?
  19. CLOSE关闭连接的各种情况
  20. VSCode快捷键冲突?关掉微软拼音的简繁体切换热键即可

热门文章

  1. osi七层协议和tcp/ip四层协议
  2. 百度搜索查询命令——组合型
  3. 本地连接受限制或无法连接怎么办?
  4. 用VS2005打开方案出现“此安装不支持该项目类型”
  5. Java垃圾回收调优
  6. Linux安装无法运行install,Linux新手安装Debian-8.2.0可能遇到的问题
  7. 麦肯锡问题与解决方法技巧
  8. ajax post数组对象,Django:ajax POST发送对象数组无法正常工作的数据
  9. 2018-3-25论文(Grey Wolf Optimizer)自然界狼群的生活等级
  10. 循环for语句 if语句