巧断梯度:单个loss实现GAN模型(附开源代码)
作者丨苏剑林
单位丨广州火焰信息科技有限公司
研究方向丨NLP,神经网络
个人主页丨kexue.fm
我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了。但是 GAN 不一样,一般来说它涉及有两个不同的 loss,这两个 loss 需要交替优化。
现在主流的方案是判别器和生成器都按照 1:1 的次数交替训练(各训练一次,必要时可以给两者设置不同的学习率,即 TTUR),交替优化就意味我们需要传入两次数据(从内存传到显存)、执行两次前向传播和反向传播。
如果我们能把这两步合并起来,作为一步去优化,那么肯定能节省时间的,这也就是 GAN 的同步训练。
注:本文不是介绍新的 GAN,而是介绍 GAN 的新写法,这只是一道编程题,不是一道算法题。
如果在TF中
如果是在 TensorFlow 中,实现同步训练并不困难,因为我们定义好了判别器和生成器的训练算子了(假设为 D_solver 和 G_solver ),那么直接执行:
sess.run([D_solver, G_solver], feed_dict={x_in: x_train, z_in: z_train})
就行了。这建立在我们能分别获取判别器和生成器的参数、能直接操作 sess.run 的基础上。
更通用的方法
但是如果是 Keras 呢?Keras 中已经把流程封装好了,一般来说我们没法去操作得如此精细。
所以,下面我们介绍一个通用的技巧,只需要定义单一一个 loss,然后扔给优化器,就能够实现 GAN 的训练。同时,从这个技巧中,我们还可以学习到如何更加灵活地操作 loss 来控制梯度。
判别器的优化
我们以 GAN 的 hinge loss 为例子,它的形式是:
注意意味着要固定 G,因为 G 本身也是有优化参数的,不固定的话就应该是。
为了固定G,除了“把 G 的参数从优化器中去掉”这个方法之外,我们也可以利用 stop_gradient 去手动固定:
这里:
这样一来,在式 (2) 中,我们虽然同时放开了 D,G 的权重,但是不断地优化式 (2),会变的只有 D,而 G 是不会变的,因为我们用的是基于梯度下降的优化器,而 G 的梯度已经被停止了,换句话说,我们可以理解为 G 的梯度被强行设置为 0,所以它的更新量一直都是 0。
生成器的优化
现在解决了 D 的优化,那么 G 呢? stop_gradient 可以很方便地放我们固定里边部分的梯度(比如 D(G(z)) 的 G(z)),但 G 的优化是要我们去固定外边的 D,没有函数实现它。但不要灰心,我们可以用一个数学技巧进行转化。
首先,我们要清楚,我们想要 D(G(z)) 里边的 G 的梯度,不想要 D 的梯度,如果直接对 D(G(z)) 求梯度,那么同时会得到 D,G 的梯度。如果直接求的梯度呢?只能得到 D 的梯度,因为 G 已经被停止了。那么,重点来了,将这两个相减,不就得到单纯的 G 的梯度了吗!
现在优化式 (4) ,那么 D 是不会变的,改变的是 G。
值得一提的是,直接输出这个式子,结果是恒等于 0,因为两部分都是一样的,直接相减自然是 0,但它的梯度不是 0。也就是说,这是一个恒等于 0 的 loss,但是梯度却不恒等于 0。
合成单一loss
好了,现在式 (2) 和式 (4) 都同时放开了 D,G,大家都是 arg min,所以可以将两步合成一个 loss:
写出这个 loss,就可以同时完成判别器和生成器的优化了,而不需要交替训练,但是效果基本上等效于 1:1 的交替训练。引入 λ 的作用,相当于让判别器和生成器的学习率之比为 1:λ。
参考代码:
https://github.com/bojone/gan/blob/master/gan_one_step_with_hinge_loss.py
文章小结
文章主要介绍了实现 GAN 的一个小技巧,允许我们只写单个模型、用单个 loss 就实现 GAN 的训练。它本质上就是用 stop_gradient 来手动控制梯度的技巧,在其他任务上也可能用得到它。
所以,以后我写 GAN 都用这种写法了,省力省时。当然,理论上这种写法需要多耗些显存,这也算是牺牲空间换时间吧。
点击以下标题查看作者其他文章:
变分自编码器VAE:原来是这么一回事 | 附开源代码
变分自编码器VAE:这样做为什么能成?
从变分编码、信息瓶颈到正态分布:论遗忘的重要性
深度学习中的互信息:无监督提取特征
全新视角:用变分推断统一理解生成模型
细水长flow之NICE:流模型的基本概念与实现
细水长flow之f-VAEs:Glow与VAEs的联姻
深度学习中的Lipschitz约束:泛化与生成模型
#投 稿 通 道#
让你的论文被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢? 答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得或技术干货。我们的目的只有一个,让知识真正流动起来。
? 来稿标准:
• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)
• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接
• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志
? 投稿邮箱:
• 投稿邮箱:hr@paperweekly.site
• 所有文章配图,请单独在附件中发送
• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通
?
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。
▽ 点击 | 阅读原文 | 查看作者博客
巧断梯度:单个loss实现GAN模型(附开源代码)相关推荐
- 哈佛NLP组论文解读:基于隐变量的注意力模型 | 附开源代码
作者丨邓云天 学校丨哈佛大学NLP组博士生 研究方向丨自然语言处理 摘要 Attention 注意力模型在神经网络中被广泛应用.在已有的工作中,Attention 机制一般是决定性的而非随机变量.我们 ...
- 优化切尔诺贝利灾难模型——附matlab代码
优化切尔诺贝利灾难模型--附matlab代码 切尔诺贝利核电站事故是人类历史上最严重的核事故之一,对环境和人类健康造成了极大的影响.针对这样的事故,科学家们开发了许多模型用于预测和优化应对措施.本文将 ...
- 在TensorFlow中对比两大生成模型:VAE与GAN(附测试代码)
来源:机器之心 本文长度为3071字,建议阅读6分钟 本文在 MNIST 上对VAE和GAN这两类生成模型的性能进行了对比测试. 项目链接:https://github.com/kvmanohar22 ...
- CVPR2019 Oral!伯克利、麻省理工GAN图像合成最新成果(附开源代码)!
点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 摘要 本文提出了一种空间自适应归一化的简单但有效的层,在给定输入语义布局的情况下合成真实照片图像. 以前的方法直接将 ...
- 深度长文 | 从FM推演各深度CTR预估模型(附开源代码)
作者丨龙心尘 & 寒小阳 研究方向丨机器学习,数据挖掘 题记:多年以后,当资深算法专家们看着无缝对接用户需求的广告收入节节攀升时,他们可能会想起自己之前痛苦推导 FM 与深度学习公式的某个夜晚 ...
- 10分钟搭建你的第一个图像识别模型 | 附完整代码
(图片由AI科技大本营付费下载自视觉中国) 作者 | Pulkit Sharma 译者 | 王威力 来源 | 数据派THU(ID:DatapiTHU) [导读]本文介绍了图像识别的深度学习模型的建立过 ...
- 【无人机】基于蒙特卡洛算法实现无人机任务分配模型附matlab代码
1 简介 注意:所谓的实时分配指的是实时分配用户位置-根据实时分配的位置更新无人机的位置进而优化最优的分配任务 2 部分代码 clcclose allclear alldisp('无人机优化模型' ...
- 【微电网优化】基于粒子群求解CHP机组、气网、电网、储热罐和电锅炉微电网优化模型附matlab代码
1 简介 近年来随着全球性的环境污染问题与能源危机日益突出,人们的环保意识与节能意识不断提高,使得微电网成为了电力系统领域的研究热点之一.相对于传统的大电网,微电网具有自身的特点和优势,发电过程产生的 ...
- 【储能优化】基于粒子群求解考虑分时电价-需求响应后的风光柴油储能优化配置模型附matlab代码
✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信.
最新文章
- idea编译的文件怎么用cmd打开_JAVA学习册|基础语法|cmd输出HelloWorld
- 《DSP using MATLAB》示例Example7.20
- 破站www.2637.cn居然敢截持偶的IE!
- 对flex-grow和flex-shrink的深入理解
- 奈飞文化手册_奈飞文化手册学习笔记
- c语言中c4700在哪个位置,C语言单链表问题。。高手来啊warning C4700
- strace 简单用法
- java this()函数_java中this关键字的三种用法
- postgre 表被加锁无法解锁问题
- ubuntu 安装J2EE环境
- Kruskal/Prim/Dijkstra模板
- java中的triple_无法在使用Java加密的.NET中使用TripleDES进行解密
- Sketch中的快捷键总结
- php文件目录教程,详谈PHP文件目录基础操作_PHP教程
- Android开发我音乐App
- WPS无法加载EndNote加载项
- 优秀的计算机简历,计算机优秀简历范文
- arcgis显示后台错误_ArcGIS后台服务器抛出异常的解决方法
- poj 2456 Aggressive cows 【二分+最大化最小值】
- 2018年高教社杯全国大学生数学建模竞赛D题解题思路
热门文章
- 以两台Linux主机在docker中实现mysql主主备份以用nginx实现mysql高可用
- NOIP2013Day1T3 表示只能过一个点
- 算法代码[置顶] 机器学习实战之KNN算法详解
- 控制好你的 Wordpress 侧边栏
- java外部类调用内部类_java中的外部类和内部类 | 学步园
- html手机pc不同页面,PC端和手机端如何同时生成静态页
- css 图文 上下 居中,CSS垂直居中的6种方法
- 查linux有哪些task_Java面试手册:Linux高频考点
- 壳体有矩理论与实用计算机方法,《薄壳计算和理论》.pdf
- unity 手机 模糊效果_GUI背景模糊效果优化