1 GAN 介绍

GAN,叫做生成对抗网络 (Generative Adversarial Network) 。其基本原理是生成器网络 G(Generator) 和判别器网络 D(Discriminator) 相互博弈。生成器网络 G 的主要作用是生成图片,在输入一个随机编码 (random code) z后,自动的生成假样本 G(z) 。判别器网络 D 的主要作用是判断输入是否为真实样本并提供反馈机制,真样本则输出 1 ,反之为 0 。在两个网络相互博弈的过程中,两个网络的能力都越来越高:G 生成的图片越来越像真样本,D 也越来越会判断图片的真假,然后我们在最大化 D 的前提下,最小化 D 对 G 的判断能力,这实际上就是最小最大值问题,或者说二人零和博弈,其目标函数表达式:minGmaxDE[logD(G(z)])+log(1−D(x))]\rm \underset{G}{min} \; \underset{D}{max} \; E[log D(G(z)])+log(1-D(x))]Gmin​Dmax​E[logD(G(z)])+log(1−D(x))]其中表达式中的第一项 D(G(z)) 处理的是假图像 G(z) ,我们尽量降低评分 D(G(z)) ;第二项处理的是真图像 x ,此时评分要高。但是 GAN 并不是完美的,也有自己的局限性。比如说没有用户控制的能力和低分辨率与低质量的问题。

为了提高 GAN 的用户控制能力,人类进行了一些列的探索研究。比如 Pix2Pix 模型采用有条件的使用用户输入,使用**成对的数据 (paired data) 进行训练; CycleGAN 模型使用不成对的数据 (unpaired data) **的就能训练 。但无论是 Pix2Pix 还是 CycleGAN ,都是解决了从一个领域到另一个领域的图像转换问题。当有很多领域需要转换时,对于每一个领域转换,都需要重新训练一个模型去解决。目前,存在的模型处理多领域图像生成任务时,学习 k 个领域之间所有映射就必须训练 k * (k-1) 个生成器。如果训练一对一的图像多领域生成任务时,主要会导致两个问题:

  • 训练低效,每次训练耗时很大。
  • 训练效果有限,因为一个领域转换单独训练的话就不能利用其它领域的数据来增大泛化能力。

上图中 (a) 模型说明如何训练 12 个不同生成器网络以达到 4 个不同领域图像之间转换任务。很明显每个生成器不能够充分利用整个训练数据,只能从 4 个领域中 2 个领域相互学习,这样就会生成图片质量不好。而上图(b)中的模型就可以解决这些问题,该模型接受多个领域训练数据,并仅使用一个生成器来学习多领域图像之间映射关系。根据模型的长相将该模型称为星形网络,外文名就是 StarGAN 。


上图是根据 StarGAN 模型训练出的效果。在同一种模型下,可以做多领域图像之间的转换,比如更换头发颜色、更换表情、更换年龄等。

2 StarGAN模型及其优点

2.1 starGAN介绍

上图是对 StarGAN 的简单介绍,主要包含判别器 D 和生成器 G 。
(a)D 对真假图片进行判别,真图片判真,假图片判假,真图片被分类到相应域。
(b)G 接受真图片和目标域标签并生成假图片;
(c)G 在给定原始域标签的情况下将假图片重建为原始图片(重构损失);
(d)G 尽可能生成与真实图像无法区分的图像,并且通过 D 分类到目标域。

2.2 StarGAN 优点

  • 提出 StarGAN 网络模型,仅使用一个 G 和 D 就可以实现多个领域之间图像生成和训练。
  • 采用 mask vector 方法控制所有可用域图像标签以实现训练集之间的多领域图像转换。
  • StarGAN 相对于基准模型, 在面部属性转移和面部表情合成的任务中有更好的效果 (具体数据请参看原论文中的实验部分)

3 StarGAN

首先描述 StarGAN 网络,在一个数据集中进行多领域的图像转换任务;然后我们讨论了如何使 StarGAN 能合并包含不同标签的数据集以及对其中任意的标签属性灵活进行图像转换。

3.1 多领域图像转换

训练一个生成器 G ,能够多领域映射。将带有领域标签 c 的输入图像 x 转换为输出图像 y,即G(x,c)]→y\rm G(x,c)] \rightarrow yG(x,c)]→y。随机生成目标领域标签 c 使得 G 能够灵活的转换输入图像,同时使用 D 控制多领域。这样 D 就在图像源和域标签上产生概率分布,即D:x→Dsrc(x)],Dcls(x)\rm D : x → {D src (x)],D cls (x)}D:x→Dsrc(x)],Dcls(x)。

3.1.1 对抗损失函数 (Adversarial Loss)

使用对抗损失函数提高生成图像质量,达到 D 无法区分出来输出图像和生成图像之间的差别:Ladv=Ex[logDsrc(x)]]+Ex,c[log(1−Dsrc(G(x,c))]L_{adv}=E_x[logD_{src}(x)]] + E_{x,c}[log(1 - D_{src}(G(x,c))]Ladv​=Ex​[logDsrc​(x)]]+Ex,c​[log(1−Dsrc​(G(x,c))]根据输入图像 x 和目标领域标签 c ,由 G 生成输出图像G(x,c)]\rm G(x,c)]G(x,c)],同时 D 区分出真实图像和生成图像。将Dsrc(x)]\rm D_{src}(x)]Dsrc​(x)]作为输入图像 x 经过 D 之后得到的可能性分布。生成器 G 使这个式子尽可能的小,而 D 则尽可能使其最大化。

3.1.2 目标域分类损失函数(Domain Classification Loss)

对于一个输入图像 x 和目标分布标签 c ,我们的目标是将 x 转换为输出图像 y后能够被正确分类为目标分布 c 。为了实现这一目标,我们在 D 之上添加一个辅助分类器,并在优化 G 和 D 时采用目标域分类损失函数。简单来说,我们将这个式子分解为两部分:一个真实图像的分布分类损失用于约束 D ,一个假图像的分布分类损失用于约束 G 。其表达式如下所示:Lclsr=Ex,c’[−logDcls(c’∣x)]]L^{r}_{cls} = E_{x,c’} [-logD_{cls}(c’|x)]]Lclsr​=Ex,c’​[−logDcls​(c’∣x)]]其中,Dcls(c’∣x)]D_{cls}(c’|x)]Dcls​(c’∣x)]代表 D 计算出来的领域标签的可能性分布。一方面,通过将这个式子最小化, D 将真实图像 x 正确分类到与其相关分布 c’ 。另一方面,假图像的分类分布的损失函数定义如下:Lclsf=Ex,c[−logDcls(c∣G(x,c)])]L_{cls}^f = E_{x,c}[-log D_{cls}(c|G(x,c)])]Lclsf​=Ex,c​[−logDcls​(c∣G(x,c)])]即 G 使这个式子最小化,使得生成的图像能够被 D 判别为目标领域 c。

3.1.3 重构误差(Reconstruction Loss)

通过最小化对抗损失和分类损失, G 训练生成的图像尽可能与真实图像一样,并且能够被分类到正确的目标领域。然而,最小化这两个损失函数不能保证 , 转换后的图像中,只改变领域差异的部分, 而保留输入图像中的其他内容 。故对 G 使用循环一致性损失函数 (cycle consistency loss) ,如下:Lrec=Ex,c,c’[∣∣x−G(G(x,c)],c’)∣∣1]L_{rec} = E_{x,c,c’} [||x - G(G(x,c)],c’)||_{1}]Lrec​=Ex,c,c’​[∣∣x−G(G(x,c)],c’)∣∣1​]其中: G 以生成图像 G(x,c) 以及原始输入图像领域标签 c’ 为输入,努力重构出原始图像 x 。我们选择L范数作为重构损失函数。注意到我们两次使用了同一个生成器,第一次将原始图像转换到目标领域的图像,然后将生成的图像重构回原始图像。

3.1.4 总体损失函数表示(Full Objective)

最终 G 和 D 的损失函数表示如下:LD=−Ladv+λclsLclsrL_D = -L_{adv} + \lambda_{cls}L^{r}_{cls}LD​=−Ladv​+λcls​Lclsr​LG=Ladv+λclsLclsf+λrecLrecL_G = L_{adv} + \lambda_{cls}L^{f}_{cls}+ \lambda_{rec}L_{rec}LG​=Ladv​+λcls​Lclsf​+λrec​Lrec​其中λcls\lambda_{cls}λcls​ 和 λrec\lambda_{rec}λrec​是控制分类误差和重构误差相对于对抗误差的相对权重的超参数。在所有实验中,我们设置λcls=1,λrec=10\lambda_{cls} = 1,\lambda_{rec} = 10λcls​=1,λrec​=10。

3.1.5 改进损失函数

为了 GAN 训练过程稳定,生成高质量的图像,论文中采用自定义梯度惩罚来代替对抗误差损失:Ladv=Ex[Dsrc(x)]]−Ex,c[Dsrc(G(x,c))]−λgpEx^[(∣∣∇x^Dsrc(x^)∣∣2−1)2]L_{adv}=E_x[D_{src}(x)]] - E_{x,c}[D_{src}(G(x,c))] - \lambda_{gp}E_{\hat{x}} [(||\nabla{\hat{x}}D_{src}(\hat{x})||_{2}-1)^2]Ladv​=Ex​[Dsrc​(x)]]−Ex,c​[Dsrc​(G(x,c))]−λgp​Ex^​[(∣∣∇x^Dsrc​(x^)∣∣2​−1)2]其中:x^\hat{x}x^表示真实和生成图像之间均匀采样的直线,试验时λgp=10\lambda_{gp}=10λgp​=10。

3.2 多数据集训练

starGAN 的一个重要优势在于它能够同时合并包含不同标签的不同数据集,使得其在测试阶段能够控制所有的标签。从多个数据集学习的问题在于标签信息对每一个数据集而言只是部分已知。在 CelebA 和 RaFD 的例子中,前一个数据集包含诸如发色,性别等信息,但它不包含任何后一个数据集中包含的诸如开心生气等表情标签。这会引起问题,因为在将 G(x,c) 重构回输入图像 x 时需要完整的标签信息 c’ 。

3.2.1 向量掩码(Mask Vector)

为了缓解这一问题,我们引入了向量掩码 m,使 StarGAN 模型能够忽略不确定的标签,专注于特定数据集提供的明确的已知标签。在 StarGAN 中我们使用 n 维的 one-hot 向量来代表 m ,n 表示数据集的数量。除此之外,我们将标签的同一版本定义为一个数组:c‾=[c1,…,cn,m]\rm \overline{c} = [c_1,…,c_n,m]c=[c1​,…,cn​,m]其中:[·]表示串联,其中 c表示第 i 个数据集的标签,已知标签 c 的向量能用二值标签表示二值属性或者用 one-hot 的形式表示多类属性。对于剩下的 n-1 个未 i 知标签我们简单的置为 0 。

3.2.2 训练策略

利用多数据集训练 StarGAN 时,我们使用上面定义的c‾\overline{c}c作为生成器的输入。如此,生成器学会忽略非特定的标签,而专注于指定的标签。除了输入标签c‾\overline{c}c,此处的生成器与单数据集训练的生成器网络结构一样。另一方面我们也扩展判别器的辅助分类器的分类类别到到所属聚集的所有标签。最后,我们将我们的模型按照多任务学习的方式进行训练,其中,判别器只将已知标签相关的分类误差最小化即可。

3.3 训练数据处理

以 celebA 数据为例,下载后的数据包括 label 文件和图像。

  • 文件的第一行为图像的总数:202599。
  • 第二行为数据处理的类别,共 40 种,如下:

(1, ‘5_o_Clock_Shadow’), (2, ‘Arched_Eyebrows’), (3, ‘Attractive’), (4, ‘Bags_Under_Eyes’), (5, ‘Bald’), (6, ‘Bangs’), (7, ‘Big_Lips’), (8, ‘Big_Nose’), (9, ‘Black_Hair’), (10, ‘Blond_Hair’), (11, ‘Blurry’), (12, ‘Brown_Hair’), (13, ‘Bushy_Eyebrows’), (14, ‘Chubby’), (15, ‘Double_Chin’), (16, ‘Eyeglasses’), (17, ‘Goatee’), (18, ‘Gray_Hair’), (19, ‘Heavy_Makeup’), (20, ‘High_Cheekbones’), (21, ‘Male’), (22, ‘Mouth_Slightly_Open’), (23, ‘Mustache’), (24, ‘Narrow_Eyes’), (25, ‘No_Beard’), (26, ‘Oval_Face’), (27, ‘Pale_Skin’), (28, ‘Pointy_Nose’), (29, ‘Receding_Hairline’), (30, ‘Rosy_Cheeks’), (31, ‘Sideburns’), (32, ‘Smiling’), (33, ‘Straight_Hair’), (34, ‘Wavy_Hair’), (35, ‘Wearing_Earrings’), (36, ‘Wearing_Hat’), (37, ‘Wearing_Lipstick’), (38, ‘Wearing_Necklace’), (39, ‘Wearing_Necktie’), (40, ‘Young’)

  • 第三行及之后的每行为,图像名,已经对应的 40 种类别的 label , label 值为 1 或 -1。

000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1

4 总结与展望

通过本文学习,您应该初步了解 StarGAN 模型的网络结构和实现原理,以及关键部分代码的初步实现。如果您对深度学习 Tensorflow 比较了解,可以参考 Tensorflow版实现starGAN;如果您对pytorch框架比较熟悉,可以参考 pytorch实现starGAN;如果您想更深入的学习了解starGAN原理,可以参考 论文。

如果想体验项目效果,您可以登陆 Mo 平台,在 应用中心 中找到 StarGAN,可以体验以下五种特征[‘Black_Hair’, ‘Blond_Hair’, ‘Brown_Hair’, ‘Male’, ‘Young’] 的风格变换。考虑到代码较长,我们在StarGAN 项目源码中对相关代码做了详细解释。您在学习的过程中,遇到困难或者发现我们的错误,可以随时联系我们。

5 参考资料

1.论文:https://arxiv.org/pdf/1711.09020.pdf
2.博客:https://blog.csdn.net/stdcoutzyx/article/details/78829232
3.博客:https://www.cnblogs.com/Thinker-pcw/p/9785379.html
4.pytorch原版github地址:https://github.com/yunjey/StarGAN
5.tensorflow版github地址:https://github.com/taki0112/StarGAN-Tensorflow
6.Celeba数据集:https://www.dropbox.com/s/d1kjpkqklf0uw77/celeba.zip?dl=0


Mo(网址:momodel.cn)是一个支持 Python 的人工智能在线建模平台,能帮助你快速开发、训练并部署模型。


Mo 人工智能俱乐部 是由网站的研发与产品设计团队发起、致力于降低人工智能开发与使用门槛的俱乐部。团队具备大数据处理分析、可视化与数据建模经验,已承担多领域智能项目,具备从底层到前端的全线设计开发能力。主要研究方向为大数据管理分析与人工智能技术,并以此来促进数据驱动的科学研究。

目前俱乐部每周六在杭州举办以机器学习为主题的线下技术沙龙活动,不定期进行论文分享与学术交流。希望能汇聚来自各行各业对人工智能感兴趣的朋友,不断交流共同成长,推动人工智能民主化、应用普及化。

【Mo 人工智能技术博客】StarGAN——生成你的明星脸相关推荐

  1. 【Mo 人工智能技术博客】多标准中文分词 Multi-Criteria-CWS

    多标准中文分词 Multi-Criteria-CWS 作者:宋彤彤 自然语言处理(NLP)是人工智能中很重要且具有挑战性的方向,而自然语言处理的第一步就是分词,分词的效果直接决定和影响后续工作的效率. ...

  2. 【Mo 人工智能技术博客】胶囊网络——Capsule Network

    胶囊网络--Capsule Network 作者:林泽龙 1. 背景介绍 CNN 在处理图像分类问题上表现非常出色,已经完成了很多不可思议的任务,并且在一些项目上超过了人类,对整个机器学习的领域产生了 ...

  3. 【Mo 人工智能技术博客】现在最流行的图神经网络库 pytorch geometric 上手教学

    简介 Graph Neural Networks 简称 GNN,称为图神经网络.近年来 GNN 在学术界受到的关注越来越多,与之相关的论文数量呈上升趋势,GNN 通过对信息的传递,转换和聚合实现特征的 ...

  4. 【Mo 人工智能技术博客】采用 Python 机器学习预测足球比赛结果

    采用 Python 机器学习预测足球比赛结果 足球是世界上最火爆的运动之一,世界杯期间也往往是球迷们最亢奋的时刻.比赛狂欢季除了炸出了熬夜看球的铁杆粉丝,也让足球竞猜也成了大家茶余饭后最热衷的话题.甚 ...

  5. 【Mo 人工智能技术博客】基于耦合网络的推荐系统

    基于耦合网络的推荐系统 作者:陈东瑞 1.复杂网络基础知识 当我们拿起手机给家人.朋友或者同事拨打电话时,就不知不觉中参与到了社交网络形成的过程中:当我们登上高铁或者飞机时,就可以享受交通网络给我们带 ...

  6. 【Mo 人工智能技术博客】利用Logistic函数和LSTM分析疫情数据

    利用Logistic函数和LSTM分析疫情数据 作者:林泽龙 Mo 1. 背景 2019 新型冠状病毒 (SARS-CoV-2),曾用名 2019-nCoV,通用简称新冠病毒,是一种具有包膜的正链单股 ...

  7. 【Mo 人工智能技术博客】使用 Seq2Seq 实现中英文翻译

    1. 介绍 1.1 Deep NLP 自然语言处理(Natural Language Processing,NLP)是计算机科学.人工智能和语言学领域交叉的分支学科,主要让计算机处理或理解自然语言,如 ...

  8. 【Mo 人工智能技术博客】基于 Python 和 NLTK 的推特情感分析

    基于 Python 和 NLTK 的推特情感分析 作者:宋彤彤 1. 导读 NLTK 是 Python 的一个自然语言处理模块,其中实现了朴素贝叶斯分类算法.这次 Mo 来教大家如何通过 python ...

  9. 【Mo 人工智能技术博客】python玩转信号处理与机器学习入门

    python玩转信号处理与机器学习入门 作者:王镇 面对毫无规律的随机信号,看着杂乱无章的振动波形,你是否也像曾经的我一样一头雾水,不知从何处下手.莫慌,接下来小编就带你入门怎样用python处理这些 ...

最新文章

  1. 做完小程序项目、老板给我加了6k薪资~
  2. MySQL中interactive_timeout和wait_timeout的区别
  3. python 发送邮件不显示附件_python无法通过电子邮件发送附件文件
  4. zabbix学习小结
  5. 自我学习的技巧和建议
  6. easyui datagrid 浏览器像素及改变表、列宽问题
  7. Linux下git的使用——将已有项目放到github上
  8. winform通过ListView绑定数据库数据源
  9. [leetcode 70]Climbing Stairs
  10. 标签编辑新工具:如何使用控制台标签编辑器(Tag editor)
  11. Linux状态监控在root下可用,监控linux状态
  12. 排序算法--冒泡排序
  13. 锐捷网关交换机开启dhcp服务
  14. 3dmax联机分布式渲染方法技巧详解
  15. java实现支付宝网页扫码支付
  16. 未来互联网+大数据时代,DT革命互联网大数据应用简析
  17. 纪中训练5月23日提高组T1
  18. React 接入 Ueditor + xiumi
  19. Ingest Node Pipeline Processor
  20. 微信小程序之数据交互

热门文章

  1. 【Python技能树共建】selenium入手篇
  2. Linux环境搭建:CentOS7安装Oracle
  3. 网络驱动器问题:指定的网络文件夹目前是以其他用户名和密码进行映射的
  4. 给定一个十进制数,将其转化为N进制数-----17年滴滴笔试题
  5. log4cplus日志格式输出配置
  6. RAW 264.7 小鼠单核巨噬细胞白血病细胞培养解决方案
  7. 机器学习_深度学习毕设题目汇总——车辆车牌
  8. python 检查身份证号的正确性
  9. 【​观察】美国公有云“排位赛”结束 中国市场正“步其后尘”?
  10. 告别慢SQL,如何去写一手好SQL ?