作者丨苏剑林

单位丨广州火焰信息科技有限公司

研究方向丨NLP,神经网络

个人主页丨kexue.fm

我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了。但是 GAN 不一样,一般来说它涉及有两个不同的 loss,这两个 loss 需要交替优化。

现在主流的方案是判别器和生成器都按照 1:1 的次数交替训练(各训练一次,必要时可以给两者设置不同的学习率,即 TTUR),交替优化就意味我们需要传入两次数据(从内存传到显存)、执行两次前向传播和反向传播。

如果我们能把这两步合并起来,作为一步去优化,那么肯定能节省时间的,这也就是 GAN 的同步训练。

注:本文不是介绍新的 GAN,而是介绍 GAN 的新写法,这只是一道编程题,不是一道算法题。

如果在TF中

如果是在 TensorFlow 中,实现同步训练并不困难,因为我们定义好了判别器和生成器的训练算子了(假设为 D_solver 和 G_solver ),那么直接执行:

sess.run([D_solver, G_solver], feed_dict={x_in: x_train, z_in: z_train})

就行了。这建立在我们能分别获取判别器和生成器的参数、能直接操作 sess.run 的基础上。

更通用的方法

但是如果是 Keras 呢?Keras 中已经把流程封装好了,一般来说我们没法去操作得如此精细。

所以,下面我们介绍一个通用的技巧,只需要定义单一一个 loss,然后扔给优化器,就能够实现 GAN 的训练。同时,从这个技巧中,我们还可以学习到如何更加灵活地操作 loss 来控制梯度。

判别器的优化

我们以 GAN 的 hinge loss 为例子,它的形式是:

注意意味着要固定 G,因为 G 本身也是有优化参数的,不固定的话就应该是

为了固定G,除了“把 G 的参数从优化器中去掉”这个方法之外,我们也可以利用 stop_gradient 去手动固定:

这里:

这样一来,在式 (2) 中,我们虽然同时放开了 D,G 的权重,但是不断地优化式 (2),会变的只有 D,而 G 是不会变的,因为我们用的是基于梯度下降的优化器,而 G 的梯度已经被停止了,换句话说,我们可以理解为 G 的梯度被强行设置为 0,所以它的更新量一直都是 0。

生成器的优化

现在解决了 D 的优化,那么 G 呢? stop_gradient 可以很方便地放我们固定里边部分的梯度(比如 D(G(z)) 的 G(z)),但 G 的优化是要我们去固定外边的 D,没有函数实现它。但不要灰心,我们可以用一个数学技巧进行转化。

首先,我们要清楚,我们想要 D(G(z)) 里边的 G 的梯度,不想要 D 的梯度,如果直接对 D(G(z)) 求梯度,那么同时会得到 D,G 的梯度。如果直接求的梯度呢?只能得到 D 的梯度,因为 G 已经被停止了。那么,重点来了,将这两个相减,不就得到单纯的 G 的梯度了吗!

现在优化式 (4) ,那么 D 是不会变的,改变的是 G。

值得一提的是,直接输出这个式子,结果是恒等于 0,因为两部分都是一样的,直接相减自然是 0,但它的梯度不是 0。也就是说,这是一个恒等于 0 的 loss,但是梯度却不恒等于 0。

合成单一loss 

好了,现在式 (2) 和式 (4) 都同时放开了 D,G,大家都是 arg min,所以可以将两步合成一个 loss:

写出这个 loss,就可以同时完成判别器和生成器的优化了,而不需要交替训练,但是效果基本上等效于 1:1 的交替训练。引入 λ 的作用,相当于让判别器和生成器的学习率之比为 1:λ。

参考代码:

https://github.com/bojone/gan/blob/master/gan_one_step_with_hinge_loss.py

文章小结

文章主要介绍了实现 GAN 的一个小技巧,允许我们只写单个模型、用单个 loss 就实现 GAN 的训练。它本质上就是用 stop_gradient 来手动控制梯度的技巧,在其他任务上也可能用得到它。

所以,以后我写 GAN 都用这种写法了,省力省时。当然,理论上这种写法需要多耗些显存,这也算是牺牲空间换时间吧。

点击以下标题查看作者其他文章:

  • 变分自编码器VAE:原来是这么一回事 | 附开源代码

  • 再谈变分自编码器VAE:从贝叶斯观点出发

  • 变分自编码器VAE:这样做为什么能成?

  • 从变分编码、信息瓶颈到正态分布:论遗忘的重要性

  • 深度学习中的互信息:无监督提取特征

  • 全新视角:用变分推断统一理解生成模型

  • 细水长flow之NICE:流模型的基本概念与实现

  • 细水长flow之f-VAEs:Glow与VAEs的联姻

  • 深度学习中的Lipschitz约束:泛化与生成模型

#投 稿 通 道#

 让你的论文被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢? 答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。

来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志

? 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site

• 所有文章配图,请单独在附件中发送

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通

?

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

▽ 点击 | 阅读原文 | 查看作者博客

巧断梯度:单个loss实现GAN模型(附开源代码)相关推荐

  1. 哈佛NLP组论文解读:基于隐变量的注意力模型 | 附开源代码

    作者丨邓云天 学校丨哈佛大学NLP组博士生 研究方向丨自然语言处理 摘要 Attention 注意力模型在神经网络中被广泛应用.在已有的工作中,Attention 机制一般是决定性的而非随机变量.我们 ...

  2. 优化切尔诺贝利灾难模型——附matlab代码

    优化切尔诺贝利灾难模型--附matlab代码 切尔诺贝利核电站事故是人类历史上最严重的核事故之一,对环境和人类健康造成了极大的影响.针对这样的事故,科学家们开发了许多模型用于预测和优化应对措施.本文将 ...

  3. 在TensorFlow中对比两大生成模型:VAE与GAN(附测试代码)

    来源:机器之心 本文长度为3071字,建议阅读6分钟 本文在 MNIST 上对VAE和GAN这两类生成模型的性能进行了对比测试. 项目链接:https://github.com/kvmanohar22 ...

  4. CVPR2019 Oral!伯克利、麻省理工GAN图像合成最新成果(附开源代码)!

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 摘要 本文提出了一种空间自适应归一化的简单但有效的层,在给定输入语义布局的情况下合成真实照片图像. 以前的方法直接将 ...

  5. 深度长文 | 从FM推演各深度CTR预估模型(附开源代码)

    作者丨龙心尘 & 寒小阳 研究方向丨机器学习,数据挖掘 题记:多年以后,当资深算法专家们看着无缝对接用户需求的广告收入节节攀升时,他们可能会想起自己之前痛苦推导 FM 与深度学习公式的某个夜晚 ...

  6. 10分钟搭建你的第一个图像识别模型 | 附完整代码

    (图片由AI科技大本营付费下载自视觉中国) 作者 | Pulkit Sharma 译者 | 王威力 来源 | 数据派THU(ID:DatapiTHU) [导读]本文介绍了图像识别的深度学习模型的建立过 ...

  7. 【无人机】基于蒙特卡洛算法实现无人机任务分配模型附matlab代码

    1 简介 注意:所谓的实时分配指的是实时分配用户位置-根据实时分配的位置更新无人机的位置进而优化最优的分配任务 ​ 2 部分代码 clcclose allclear alldisp('无人机优化模型' ...

  8. 【微电网优化】基于粒子群求解CHP机组、气网、电网、储热罐和电锅炉微电网优化模型附matlab代码

    1 简介 近年来随着全球性的环境污染问题与能源危机日益突出,人们的环保意识与节能意识不断提高,使得微电网成为了电力系统领域的研究热点之一.相对于传统的大电网,微电网具有自身的特点和优势,发电过程产生的 ...

  9. 【储能优化】基于粒子群求解考虑分时电价-需求响应后的风光柴油储能优化配置模型附matlab代码

    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信.

最新文章

  1. idea编译的文件怎么用cmd打开_JAVA学习册|基础语法|cmd输出HelloWorld
  2. 《DSP using MATLAB》示例Example7.20
  3. 破站www.2637.cn居然敢截持偶的IE!
  4. 对flex-grow和flex-shrink的深入理解
  5. 奈飞文化手册_奈飞文化手册学习笔记
  6. c语言中c4700在哪个位置,C语言单链表问题。。高手来啊warning C4700
  7. strace 简单用法
  8. java this()函数_java中this关键字的三种用法
  9. postgre 表被加锁无法解锁问题
  10. ubuntu 安装J2EE环境
  11. Kruskal/Prim/Dijkstra模板
  12. java中的triple_无法在使用Java加密的.NET中使用TripleDES进行解密
  13. Sketch中的快捷键总结
  14. php文件目录教程,详谈PHP文件目录基础操作_PHP教程
  15. Android开发我音乐App
  16. WPS无法加载EndNote加载项
  17. 优秀的计算机简历,计算机优秀简历范文
  18. arcgis显示后台错误_ArcGIS后台服务器抛出异常的解决方法
  19. poj 2456 Aggressive cows 【二分+最大化最小值】
  20. 2018年高教社杯全国大学生数学建模竞赛D题解题思路

热门文章

  1. 以两台Linux主机在docker中实现mysql主主备份以用nginx实现mysql高可用
  2. NOIP2013Day1T3 表示只能过一个点
  3. 算法代码[置顶] 机器学习实战之KNN算法详解
  4. 控制好你的 Wordpress 侧边栏
  5. java外部类调用内部类_java中的外部类和内部类 | 学步园
  6. html手机pc不同页面,PC端和手机端如何同时生成静态页
  7. css 图文 上下 居中,CSS垂直居中的6种方法
  8. 查linux有哪些task_Java面试手册:Linux高频考点
  9. 壳体有矩理论与实用计算机方法,《薄壳计算和理论》.pdf
  10. unity 手机 模糊效果_GUI背景模糊效果优化