1.引言

在现实生活当中,除了语言之间的翻译之外,我们也经常会遇到各种图像的“翻译”任务,即给定一张图像,生成目标图像,常见的场景有:图像风格迁移、图像超级分辨率、图像上色、图像去马赛克等。而在现实生活当中,图像翻译任务更常见的场景可能是图像的修图与美化,因此,本文将准备介绍另一个新的图像翻译任务——AI修图,即给定一张图像,让机器自动对该图像进行修图,从而达到一个更加美化的效果。

本文将利用GAN网络中一个比较经典的模型,即pix2pix模型,该网络采用一种完全监督的方法,即利用完全配对的输入和输出图像训练模型,通过训练好的模型将输入的图像生成指定任务的目标图像。目前该方法是图像翻译任务中完全监督方法里面效果和通用性最好的一个模型,在介绍这个模型的结构之前,可以先来看下作者利用这个网络所做的一些有趣的实验:

  • 图像语义标签——真实图像
  • 白天——夜景
  • 简笔画上色
  • 黑白图像——彩色图像

具体效果如下图所示 :

2.pix2pix网络介绍

pix2pix网络是GAN网络中的一种,主要是采用cGAN网络的结构,它依然包括了一个生成器和一个判别器。生成器采用的是一个U-net的结构,其结构有点类似Encoder-decoder,总共包含15层,分别有8层卷积层作为encoder,7层反卷积层(关于反卷积层的概念可以参考这篇博客:反卷积原理不可多得的好文)作为decoder,与传统的encoder-decoder不同的是引入了一个叫做“skip-connect”的技巧,即每一层反卷积层的输入都是:前一层的输出+与该层对称的卷积层的输出,从而保证encoder的信息在decoder时可以不断地被重新记忆,使得生成的图像尽可能保留原图像的一些信息。具体如下图所示:

对于判别器,pix2pix采用的是一个6层的卷积网络,其思想与传统的判别器类似,只是有以下两点比较特别的地方:

  • 将输入图像与目标图像进行堆叠:pix2pix的判别器的输入不仅仅只是真实图像与生成图像,还将输入图像也一起作为输入的一部分,即将输入图像与真实图像、生成图像分别在第3通道进行拼接,然后一起作为输入传入判别器模型。
  • 引入PatchGAN的思想:传统的判别器是对一张图像输出一个softmax概率值,而pix2pix的判别器则引入了PatchGAN的思想,将一张图像通过多层卷积层后最终输出了一个比较小的矩阵,比如30*30,然后对每个像素点输出一个softmax概率值,这就相当于对一张输入图像切分为很多小块,对每一小块分别计算一个输出。作者表示引入PatchGAN其实可以起到一种类似计算风格或纹理损失的效果。

其具体的结构如下图所示:

3.模型的损失函数

pix2pix的损失函数除了标准的GAN网络的损失函数之外,还引入了的损失函数。记为输入的图像,为真实图像(输出图像),为生成器,为判别器,则标准的GAN网络的损失函数为:

对G施加惩罚,即:

因此,最终GAN网络的损失函数为:

这样一来,标准的GAN损失负责捕捉图像高频特征,而损失则负责捕捉低频特征,使得生成结果既真实且清晰。

4.pix2pix的tensorflow实现

本文利用pix2pix进行AI修图,采用的框架是tensorflow实现。首先是将输入图像和真实图像(输出图像)分别压缩至256*256的规格,并将两者拼接在一起,形式如下:

其中,左侧为修图前的原图,右侧为人工修图的结果,总共采集了1700对这样的图像作为模型的训练集,模型的主要代码模块如下:

import tensorflow as tf
import numpy as np
from PIL import Image
from data_loader import get_batch_data
import os
import reclass pix2pix(object):def __init__(self, sess, batch_size, L1_lambda):""":param sess: tf.Session:param batch_size: batch_size. [int]:param L1_lambda: L1_loss lambda. [int]"""self.sess = sessself.k_initializer = tf.random_normal_initializer(0, 0.02)self.g_initializer = tf.random_normal_initializer(1, 0.02)self.L1_lambda = L1_lambdaself.bulid_model()def bulid_model(self):"""初始化模型:return:"""# init variableself.x_ = tf.placeholder(dtype=tf.float32, shape=[None, 256, 256, 3], name='x')self.y_ = tf.placeholder(dtype=tf.float32, shape=[None, 256, 256, 3], name='y')# generatorself.g = self.generator(self.x_)# discriminatorself.d_real = self.discriminator(self.x_, self.y_)self.d_fake = self.discriminator(self.x_, self.g, reuse=True)# lossself.loss_g, self.loss_d = self.loss(self.d_real, self.d_fake, self.y_, self.g)# summarytf.summary.scalar("loss_g", self.loss_g)tf.summary.scalar("loss_d", self.loss_d)self.merged = tf.summary.merge_all()# varsself.vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]self.vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]# saverself.saver = tf.train.Saver()def discriminator(self, x, y, reuse=None):"""判别器:param x: 输入图像. [tensor]:param y: 目标图像. [tensor]:param reuse: reuse or not. [boolean]:return:"""with tf.variable_scope('discriminator', reuse=reuse):x = tf.concat([x, y], axis=3)h0 = self.lrelu(self.d_conv(x, 64, 2))  # 128 128 64h0 = self.d_conv(h0, 128, 2)h0 = self.lrelu(self.batch_norm(h0))  # 64 64 128h0 = self.d_conv(h0, 256, 2)h0 = self.lrelu(self.batch_norm(h0))  # 32 32 256h0 = self.d_conv(h0, 512, 1)h0 = self.lrelu(self.batch_norm(h0))  # 31 31 512h0 = self.d_conv(h0, 1, 1)  # 30 30 1h0 = tf.nn.sigmoid(h0)return h0def generator(self, x):"""生成器:param x: 输入图像. [tensor]:return: h0,生成的图像. [tensor]"""with tf.variable_scope('generator', reuse=None):layers = []h0 = self.g_conv(x, 64)layers.append(h0)for filters in [128, 256, 512, 512, 512, 512, 512]:  # [128, 256, 512, 512, 512, 512, 512]h0 = self.lrelu(layers[-1])h0 = self.g_conv(h0, filters)h0 = self.batch_norm(h0)layers.append(h0)encode_layers_num = len(layers)  # 8for i, filters in enumerate([512, 512, 512, 512, 256, 128, 64]):  # [512, 512, 512, 512, 256, 128, 64]skip_layer = encode_layers_num - i - 1if i == 0:inputs = layers[-1]else:inputs = tf.concat([layers[-1], layers[skip_layer]], axis=3)h0 = tf.nn.relu(inputs)h0 = self.g_deconv(h0, filters)h0 = self.batch_norm(h0)if i < 3:h0 = tf.nn.dropout(h0, keep_prob=0.5)layers.append(h0)inputs = tf.concat([layers[-1], layers[0]], axis=3)h0 = tf.nn.relu(inputs)h0 = self.g_deconv(h0, 3)h0 = tf.nn.tanh(h0, name='g')return h0def loss(self, d_real, d_fake, y, g):"""定义损失函数:param d_real: 真实图像判别器的输出. [tensor]:param d_fake: 生成图像判别器的输出. [tensor]:param y: 目标图像. [tensor]:param g: 生成图像. [tensor]:return: loss_g, loss_d, 分别对应生成器的损失函数和判别器的损失函数"""loss_d_real = tf.reduce_mean(self.sigmoid_cross_entropy_with_logits(d_real, tf.ones_like(d_real)))loss_d_fake = tf.reduce_mean(self.sigmoid_cross_entropy_with_logits(d_fake, tf.zeros_like(d_fake)))loss_d = loss_d_real + loss_d_fakeloss_g_gan = tf.reduce_mean(self.sigmoid_cross_entropy_with_logits(d_fake, tf.ones_like(d_fake)))loss_g_l1 = tf.reduce_mean(tf.abs(y - g))loss_g = loss_g_gan + loss_g_l1 * self.L1_lambdareturn loss_g, loss_ddef lrelu(self, x, leak=0.2):"""lrelu函数:param x::param leak::return:"""return tf.maximum(x, leak * x)def d_conv(self, inputs, filters, strides):"""判别器卷积层:param inputs: 输入. [tensor]:param filters: 输出通道数. [int]:param strides: 卷积核步伐. [int]:return:"""padded = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT')return tf.layers.conv2d(padded,kernel_size=4,filters=filters,strides=strides,padding='valid',kernel_initializer=self.k_initializer)def g_conv(self, inputs, filters):"""生成器卷积层:param inputs: 输入. [tensor]:param filters: 输出通道数. [int]:return:"""return tf.layers.conv2d(inputs,kernel_size=4,filters=filters,strides=2,padding='same',kernel_initializer=self.k_initializer)def g_deconv(self, inputs, filters):"""生成器反卷积层:param inputs: 输入. [tensor]:param filters: 输出通道数. [int]:return:"""return tf.layers.conv2d_transpose(inputs,kernel_size=4,filters=filters,strides=2,padding='same',kernel_initializer=self.k_initializer)def batch_norm(self, inputs):"""批标准化函数:param inputs: 输入. [tensor]:return:"""return tf.layers.batch_normalization(inputs,axis=3,epsilon=1e-5,momentum=0.1,training=True,gamma_initializer=self.g_initializer)def sigmoid_cross_entropy_with_logits(self, x, y):"""交叉熵函数:param x::param y::return:"""return tf.nn.sigmoid_cross_entropy_with_logits(logits=x,labels=y)def train(self, images, epoch, batch_size):"""训练函数:param images: 图像路径列表. [list]:param epoch: 迭代次数. [int]:param batch_size: batch_size. [int]:return:"""# optimizerupdate_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies(update_ops):optim_d = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5).minimize(self.loss_d, var_list=self.vars_d)optim_g = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5).minimize(self.loss_g, var_list=self.vars_g)# init variablesinit_op = tf.global_variables_initializer()self.sess.run(init_op)self.writer = tf.summary.FileWriter("./log", self.sess.graph)# trainingfor i in range(epoch):# 获取图像列表print("Epoch:%d/%d:" % ((i + 1), epoch))batch_num = int(np.ceil(len(images) / batch_size))# batch_list = np.array_split(random.sample(images, len(images)), batch_num)batch_list = np.array_split(images, batch_num)# 训练生成器和判别器for j in range(len(batch_list)):batch_x, batch_y = get_batch_data(batch_list[j])_, loss_d = self.sess.run([optim_d, self.loss_d],feed_dict={self.x_: batch_x, self.y_: batch_y})_, loss_g = self.sess.run([optim_g, self.loss_g],feed_dict={self.x_: batch_x, self.y_: batch_y})print("%d/%d -loss_d:%.4f -loss_g:%.4f" % ((j + 1), len(batch_list), loss_d, loss_g))# 保存损失值summary = self.sess.run(self.merged,feed_dict={self.x_: batch_x, self.y_: batch_y})self.writer.add_summary(summary, global_step=i)# 保存模型,每10次保存一次if (i + 1) % 10 == 0:self.saver.save(self.sess, './checkpoint/epoch_%d.ckpt' % (i + 1))# 测试,每循环一次测试一次if (i + 1) % 1 == 0:# 对训练集最后一张图像进行测试train_save_path = os.path.join('./result/train',re.sub('.jpg','',os.path.basename(images[-1])) + '_' + str(i + 1) + '.jpg')train_g = self.sess.run(self.g,feed_dict={self.x_: batch_x})train_g = 255 * (np.array(train_g[0] + 1) / 2)im = Image.fromarray(np.uint8(train_g))im.save(train_save_path)# 对验证集进行测试img = np.zeros((256, 256 * 3, 3))val_img_path = np.array(['./data/val/color/10901.jpg'])batch_x, batch_y = get_batch_data(val_img_path)val_g = self.sess.run(self.g, feed_dict={self.x_: batch_x})img[:, :256, :] = 255 * (np.array(batch_x + 1) / 2)img[:, 256:256 * 2, :] = 255 * (np.array(batch_y + 1) / 2)img[:, 256 * 2:, :] = 255 * (np.array(val_g[0] + 1) / 2)img = Image.fromarray(np.uint8(img))img.save('./result/val/10901_%d.jpg' % (i + 1))def save_img(self, g, data, save_path):"""保存图像:param g: 生成的图像. [array]:param data: 测试数据. [list]:param save_path: 保存路径. [str]:return:"""if len(data) == 1:img = np.zeros((256, 256 * 2, 3))img[:, :256, :] = 255* (np.array(data[0] + 1) / 2)img[:, 256:, :] = 255 * (np.array(g[0] + 1) / 2)else:img = np.zeros((256, 256 * 3, 3))img[:, :256, :] = 255 * (np.array(data[0] + 1) / 2)img[:, 256:256 * 2, :] = 255 * (np.array(data[1] + 1) / 2)img[:, 256 * 2:, :] = 255 * (np.array(g[0] + 1) / 2)im = Image.fromarray(np.uint8(img))im.save(os.path.join('./result/test', os.path.basename(save_path)))def test(self, images, batch_size=1, save_path=None, mode=None):"""测试函数:param images: 测试图像列表. [list]:param batch_size: batch_size. [int]:param save_path: 保存路径:return:"""# init variablesinit_op = tf.global_variables_initializer()self.sess.run(init_op)# load modelself.saver.restore(self.sess,tf.train.latest_checkpoint('./checkpoint'))# testif mode != 'orig':for j in range(len(images)):batch_x, batch_y = get_batch_data(np.array([images[j]]))g = self.sess.run(self.g, feed_dict={self.x_: batch_x})if save_path == None:self.save_img(g,data=[batch_x[0], batch_y[0]],save_path=images[j])else:self.save_img(g,data=[batch_x[0], batch_y[0]],save_path=save_path)else:for j in range(len(images)):batch_x = get_batch_data(np.array([images[j]]), mode=mode)g = self.sess.run(self.g, feed_dict={self.x_: batch_x})batch_x = 255 * (np.array(batch_x[0] + 1) / 2)g = 255 * (np.array(g[0] + 1) / 2)img = np.hstack((batch_x, g))im = Image.fromarray(np.uint8(img))im.save(os.path.join('./result/test', os.path.basename(images[j])))

最终经过训练40个epoch后,判别器和生成器的损失函数均达到了平衡状态,因此,对训练过程进行了终止,如下图所示:

5.模型的效果

利用训练40个epoch后的模型对测试集进行测试,得到模型最终的效果如下:

其中,从左到右分别对应原图、人工修图、AI修图,可以发现,AI修图的结果会使得色彩更加艳丽,并且修图的效果比人工修图更加真实一点,本文也利用训练好的模型对任意规格的高清图像进行了测试,得到效果如下:

左边是从百度上直接下载下来的两张风景图,右边是本文训练出来的模型修图后的结果,可以发现,虽然这两张原图的已经是经过p图之后的结果,但是用AI修图后在亮度、色彩对比度等方面还是有进一步的提升,模型的泛化效果还是蛮不错滴!

最后,大概讲一下模型的缺点吧,pix2pix虽然通用性很强,但是模型能否收敛对数据的质量要求很高,如果数据质量比较差的话,则训练出来的模型效果就比较差,笔者最开始没有对数据进行清洗,因此训练出来的效果比较模糊,另外,pix2pix要求必须是严格的配对数据,因此,对数据的要求更加苛刻,如果对这方面比较感兴趣的朋友,也可以考虑一下非监督学习方面的模型,比如WESPE模型等。以下是原论文的地址和作者的pytorch实现:

  • 论文地址:http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
  • pytorch实现:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

招聘信息:

熊猫书院算法工程师:

https://www.lagou.com/jobs/4842081.html

希望对深度学习算法感兴趣的小伙伴们可以加入我们,一起改变教育!

AI修图!pix2pix网络介绍与tensorflow实现相关推荐

  1. AI修图!pix2pix网络介绍

    语言翻译是大家都知道的应用.但图像作为一种交流媒介,也有很多种表达方式,比如灰度图.彩色图.梯度图甚至人的各种标记等.在这些图像之间的转换称之为图像翻译,是一个图像生成任务. 多年来,这些任务都需要用 ...

  2. 华为AI开发平台ModelArts介绍和应用

    目录 一.ModelArts介绍 1.注册账号或登录账号 2.ModelArts功能 二.AI Gallery介绍 三.PyCharm ToolKit介绍 四.垃圾分类应用 一.ModelArts介绍 ...

  3. ConvNeXt网络介绍,搭建以及训练

    ConvNeXt网络介绍 今年(2022)一月份,Facebook AI Research和UC Berkeley一起发表了一篇文章A ConvNet for the 2020s,在文章中提出了Con ...

  4. 百度陆奇:AI是5G网络下最好的加速器,技术商业化还要更快

    问耕 假装发自 LV 量子位 出品 | 公众号 QbitAI 下一代通信网络5G,意味着什么? 百度集团总裁兼COO陆奇美国时间10日在出席CES一个对话活动时说,AI是5G网络下最好的加速器,陆奇表 ...

  5. MOOC网深度学习应用开发5——生成式对抗网络原理及Tensorflow实现

    生成式对抗网络原理及Tensorflow实现 生成式对抗网络GAN的简介 利用GAN生成Fashion-MNIST图像 鸢尾花品种识别:TensorFlow.js应用开发 TensorFlow.js介 ...

  6. 【大模型】—Open AI GPT大模型介绍

    大模型-- Open AI GPT大模型介绍 人工智能技术的快速发展引发了对智能系统和应用的巨大需求.多模态大模型已经成为了人工智能领域的重要研究方向之一.OpenAI作为一家全球领先的人工智能公司, ...

  7. OpenStack网络介绍

    OpenStack网络介绍     OpenStack里面的网络相对复杂.经常有人对几个网络概念搞混淆.因此,本文对OpenStack里面的Provider network 和 Tenant netw ...

  8. FNN 网络介绍与源码浅析

    FNN 网络介绍与源码浅析 前言 周五晚上分享 paper !!! 感动自己一把~

  9. AI修图市场潜力大,分析全方位

    随着人工智能技术的发展,AI修图已经成为了一个热门话题.AI修图是指使用人工智能技术对图片进行自动化处理,包括自动美化.自动去除瑕疵.自动调整光影等功能.目前,AI修图已经被广泛应用于各个领域,市场潜 ...

最新文章

  1. React源码分析与实现(一):组件的初始化与渲染
  2. Linux KVM与Xen的性能比较
  3. word python 域 操作_python实现在windows下操作word的方法
  4. Pinyin4j中文字符和拼音之间的转换
  5. SpringSecurity认证用户状态的判断
  6. android 启动第三方程序的代码(利用ComponentName)
  7. Beta冲刺(9/7)——2019.5.31
  8. oracle数据库支持2颗cpu,2.3 Oracle数据库中常见的性能问题
  9. (92)多人投票器(七人投票器)
  10. python面试题之Python 的特点和优点是什么
  11. 5.1作业5 四则运算 测试与封装
  12. thinkphp5之配置tp5重写伪静态
  13. 32bit64bit Win7系统下的IE8离线升级到IE11方法
  14. 软件开发全过程必备文档下载(@附所有文档)
  15. python 微信开发库_WeRoBot 是一个微信公众号开发框架
  16. 利用SEQ2SEQ模型实现车牌识别
  17. 计算机思维培训心得,2020参加计算机培训心得体会精选
  18. 【EMMC】MSM8953里时钟是如何分频的
  19. Matplotlib 多个子图使用一个图例
  20. 「大数据的关键思考系列」15:阿里巴巴的大数据实践(1)

热门文章

  1. 值得看三次的高干文_值得看三次的高干文,熬夜都要看
  2. python环境准备(一)
  3. 最简明扼要的 Systemd 教程,只需十分钟
  4. Nebula Graph|信息图谱在携程酒店的应用
  5. 如何做云班课上的计算机作业,云班课不分组怎么提交作业
  6. 深入理解Spring两大特性:IoC和AOP
  7. h20r1203功率管参数_电磁炉功率管H20R1353可以用H20R1203代换吗
  8. 【经验】新手选择插画学习书籍的方法有哪些?热门插画书籍推荐!
  9. Windows上python读取grib2文件(不用Linux)
  10. 计算机顶会论文多少钱,计算机视觉顶会文章的解读汇总(CVPR/ECCV/ICCV/NIPS)