©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络

上周笔者写了《生成扩散模型漫谈:构建ODE的一般步骤(上)》,本以为已经窥见了构建 ODE 扩散模型的一般规律,结果不久后评论区大神 @gaohuazuo 就给出了一个构建格林函数更高效、更直观的方案,让笔者自愧不如。再联想起之前大神之前在《生成扩散模型漫谈:“硬刚”扩散ODE》同样也给出了一个关于扩散 ODE 的精彩描述(间接启发了上一篇文章的结果),大神的洞察力不得不让人叹服。

经过讨论和思考,笔者发现大神的思路本质上就是一阶偏微分方程的特征线法,通过构造特定的向量场保证初值条件,然后通过求解微分方程保证终值条件,同时保证了初值和终值条件,真的非常巧妙!最后,笔者将自己的收获总结成此文,作为上一篇的后续。

前情回顾

简单回顾一下上一篇文章的结果。假设随机变量  连续地变换成 ,其变化规律服从 ODE

那么对应的  时刻的分布  服从“连续性方程”:

记 ,那么连续性方程可以简写成

为了求解这个方程,可以用格林函数的思想,即先求解

那么

就是满足约束条件的解之一。

几何直观

所谓格林函数,其实思想很简单,它就是说我们先不要着急解决复杂数据生成,我们先假设要生成的数据只有一个点 ,先解决这单个数据点的生成问题。有的读者想这不是很简单吗?直接  就完事了?当然不是这么简单,我们需要的是连续的、渐变的生成,如下图所示,就是  上的任意一点 ,都沿着一条光滑轨迹运行到  的  上:

▲ 格林函数示意图。图中,在处的每个点,都沿着特定的轨迹运行到处的一个点,除了公共点外,轨迹之间无重叠,这些轨迹就是格林函数的场线

而我们的目的,只是构造一个生成模型出来,所以我们原则上并不在乎轨迹的形状如何,只要它们都穿过 ,那么,我们可以人为地选择我们喜欢的、经过  的一个轨迹簇,记为

再次强调,这代表着以  为起点、以  为终点的一个轨迹簇,轨迹自变量、因变量分别为 ,起点  是固定不变的,终点  是可以任意变化的,轨迹的形状是无所谓的,我们可以选择直线、抛物线等等。

现在我们对式(6)两边求导,由于  是可以随意变化的,它相当于微分方程的积分常数,对它求导就等于 ,于是我们有

对比式(1),我们就得到

这里将原本的记号  替换为了 ,以标记轨线具有公共点 。也就是说,这样构造出来的力场  所对应的 ODE 轨迹,必然是经过  的,这就保证了格林函数的初值条件。

特征线法

既然初值条件有保证了,那么我们不妨要求更多一点:再保证一下终值条件。终值条件也就是希望  时  的分布是跟  无关的简单分布。上一篇文章的求解框架的主要缺点,就是无法直接保证终值分布的简单性,只能通过事后分析来研究。这篇文章的思路则是直接通过设计特定的  来保证初值条件,然后就有剩余空间来保证终值条件了。而且,同时保证了初、终值后,在满足连续性方程(2)的前提下,积分条件是自然满足的。

用数学的方式说,我们就是要在给定  和  的前提下,去求解方程(2),这是一个一阶偏微分方程,可以通过“特征线法”求解,其理论介绍可以参考笔者之前写的《一阶偏微分方程的特征线法》[1]。首先,我们将方程(2)等价地改写成

同前面类似,由于接下来是在给定起点  进行求解,所以上式将  替换为 ,以标记这是起点为  的解。

特征线法的思路,是先在某条特定的轨迹上考虑偏微分方程的解,这可以将偏微分转化为常微分,降低求解难度。具体来说,我们假设  是  的函数,在方程(1)的轨线上求解。此时由于成立方程(1),将上式左端的  替换为  后,左端正好是  的全微分,所以此时有

注意,此时所有的 应当被替换为对应的 的函数,这理论上可以从轨迹方程(6)解出。替换后,上式的 、 都是纯粹 的函数,所以上式只是关于 的一个线性常微分方程,可以解得

代入终值条件 ,得到 ,即

把轨迹方程(6)的 代入,就得到一个只含有 的函数,便是最终要求解的格林函数 了,相应地有 。

训练目标

有了格林函数,我们就可以得到

于是

根据《生成扩散模型漫谈:一般框架之SDE篇》中构建得分匹配目标的方法,可以构建训练目标

它跟《Flow Matching for Generative Modeling》[2]所给出的“Conditional Flow Matching”形式上是一致的,后面我们还会看到,该论文的结果都可以从本文的方法推出。训练完成后,就可以通过求解方程  来生成样本了。从这个训练目标也可以看出,我们对  的要求是易于采样就行了。

一些例子

可能前面的抽象结果对大家来说还是不大好理解,接下来我们来给出一些具体例子,以便加深大家对这个框架的直观理解。至于特征线法本身,笔者在《一阶偏微分方程的特征线法》[1] 也说过,一开始笔者也觉得特征线法像是“变魔术”一样难以捉摸,按照步骤操作似乎不困难,但总把握不住关键之处,理解它需要一个反复斟酌的思考过程,无法进一步代劳了。

直线轨迹

作为最简单的例子,我们假设 是沿着直线轨迹变为 ,简单起见我们还可以将 T 设为 1,这不会损失一般性,那么 的方程可以写为

根据式(8),有

此时 ,根据式(12)就有

代入式(16)中的 ,得到

特别地,如果  是标准正态分布,那么上式实则意味着 ,这正好是常见的高斯扩散模型之一。这个框架的新结果,是允许我们选择更一般的先验分布 ,比如均匀分布。另外在介绍得分匹配(15)时也已经说了,对  我们只需要知道它的采样方式就行了,而上式告诉我们只需要先验分布易于采样就行,因为:

效果演示

注意,我们假设从 到 的轨迹是一条直线,这仅仅是对于单点生成的,也就是格林函数解。当通过格林函数叠加出一般分布对应的的力场 时,其生成轨迹就不再是直线了。

下图演示了先验分布为均匀分布时多点生成的轨线图:

▲ 单点生成

▲ 两点生成

▲ 三点生成

参考作图代码:

1import numpy as np2from scipy.integrate import odeint3import matplotlib4import matplotlib.pyplot as plt5matplotlib.rc('text', usetex=True)6matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]78prior = lambda x: 0.5 if 2 >= x >= 0 else 09p = lambda xt, x0, t: prior((xt - x0) / t + x0) / t
10f = lambda xt, x0, t: (xt - x0) / t
11
12def f_full(xt, t):
13    x0s = [0.5, 0.5, 1.2, 1.7]  # 0.5出现两次,代表其频率是其余的两倍
14    fs = np.array([f(xt, x0, t) for x0 in x0s]).reshape(-1)
15    ps = np.array([p(xt, x0, t) for x0 in x0s]).reshape(-1)
16    return (fs * ps).sum() / (ps.sum() + 1e-8)
17
18for x1 in np.arange(0.01, 1.99, 0.10999/2):
19    ts = np.arange(1, 0, -0.001)
20    xs = odeint(f_full, x1, ts).reshape(-1)[::-1]
21    ts = ts[::-1]
22    if abs(xs[0] - 0.5) < 0.1:
23        _ = plt.plot(ts, xs, color='skyblue')
24    elif abs(xs[0] - 1.2) < 0.1:
25        _ = plt.plot(ts, xs, color='orange')
26    else:
27        _ = plt.plot(ts, xs, color='limegreen')
28
29plt.xlabel('$t$')
30plt.ylabel(r'$\boldsymbol{x}$')
31plt.show()

一般推广

其实上面的结果还可以一般地推广到

这里的  是任意满足  的  函数, 是任意满足  的单调递增函数。根据式(8),有

这也等价于《Flow Matching for Generative Modeling》[2] 中的式(15),此时 ,根据式(12)就有

代入 ,最终结果是

这是关于线性 ODE 扩散的一般结果,包含高斯扩散,也允许使用非高斯的先验分布。

再复杂些? 

前面的例子,都是通过 (的某个变换)与 的简单线性插值(插值权重纯粹是 的函数)来构建 的变化轨迹。那么一个很自然的问题就是:可不可以考虑更复杂的轨迹呢?

理论上可以,但是更高的复杂度意味着隐含了更多的假设,而我们通常很难检验目标数据是否支持这些假设,因此通常都不考虑更复杂的轨迹了。此外,对于更复杂的轨迹,解析求解的难度通常也更高,不管是理论还是实验,都难以操作下去。

更重要的一点的,我们目前所假设的轨迹,仅仅是单点生成的轨迹而已,前面已经演示了,即便假设为直线,多点生成依然会导致复杂的曲线。所以,如果单点生成的轨迹都假设得不必要的复杂,那么可以想像多点生成的轨迹复杂度将会奇高,模型可能会极度不稳定。

文章小结

接着上一篇文章的内容,本文再次讨论了 ODE 式扩散模型的构建思路。这一次我们从几何直观出发,通过构造特定的向量场保证结果满足初值分布条件,然后通过求解微分方程保证终值分布条件,得到一个同时满足初值和终值条件的格林函数。特别地,该方法允许我们使用任意简单分布作为先验分布,摆脱以往对高斯分布的依赖来构建扩散模型。

参考文献

[1] https://kexue.fm/archives/4718

[2] https://arxiv.org/abs/2210.02747

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

​生成扩散模型漫谈:构建ODE的一般步骤(下)相关推荐

  1. 生成扩散模型漫谈:构建ODE的一般步骤(上)

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 书接上文,在<生成扩散模型漫谈:从万有引力到扩散模型>中,我们介绍了一个由万有引力 ...

  2. 生成扩散模型漫谈:一般框架之ODE篇

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 上一篇文章<生成扩散模型漫谈:一般框架之SDE篇>中,我们对宋飏博士的论文< ...

  3. 生成扩散模型漫谈:从万有引力到扩散模型

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 对于很多读者来说,生成扩散模型可能是他们遇到的第一个能够将如此多的数学工具用到深度学习上的模型 ...

  4. 生成扩散模型漫谈:统一扩散模型(应用篇)

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 在<生成扩散模型漫谈:统一扩散模型(理论篇)>中,笔者自称构建了一个统一的模型框架 ...

  5. 生成扩散模型漫谈:最优扩散方差估计(上)

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 对于生成扩散模型来说,一个很关键的问题是生成过程的方差应该怎么选择,因为不同的方差会明显影响生 ...

  6. 通俗理解DDPM:生成扩散模型

    说到生成模型,VAE.GAN可谓是"如雷贯耳",此外,还有一些比较小众的选择,如flow模型.VQ-VAE等,也颇有人气,尤其是VQ-VAE及其变体VQ-GAN,近期已经逐渐发展到 ...

  7. 【翻译】A Survey on Generative Diffusion Model(生成扩散模型的综述研究)

    写在开头: 1.本文作者:Hanqun Cao, Cheng Tan, Zhangyang Gao, Guangyong Chen, Pheng-Ann Heng, Senior Member, IE ...

  8. 多尺度生成扩散模型预测蛋白-配体复合物结构的动态骨架

    今天给大家介绍的是来自加州理工大学Zhuoran Qiao和NVIDIA团队发表在arxiv上的预印本<DYNAMIC-BACKBONE PROTEIN-LIGAND STRUCTURE PRE ...

  9. 扩散模型探索:DDIM 笔记与思考

    DIFFUSION系列笔记|DDIM 数学.思考与 ppdiffuser 代码探索 论文:DENOISING DIFFUSION IMPLICIT MODELS 该 notebook 主要对 DDIM ...

最新文章

  1. php的参数的乘除,关于PHP在企业中处理数字加减乘除和对比运算方案
  2. jQuery JavaScript库达到新的里程碑
  3. 知识回顾——构造函数
  4. python执行命令并返回结果集_Python接口测试结果集实现封装比较
  5. Android网络编程之使用HTTP訪问网络资源
  6. DotNet_Performance_Tuning_ANTS_Performance_Profiler
  7. dataset for person re-id
  8. pcm 转化为wav 文件
  9. 我们都笑了freeeim
  10. C#LeetCode刷题之#367-有效的完全平方数(Valid Perfect Square)
  11. SAP License:SAP会计凭证抬头的字段状态控制
  12. 网络渗透技术如何自学,自学黑客要多久
  13. js调用html文件上传,JavaScript里的文件上传API
  14. Android10修改电池图标,导航栏、信号及电池图标修改方法(新增视频教程)
  15. 零基础想学习大数据?(同样适合有一定基础想进阶的)跟着这几个步骤走
  16. 计算机培训教学准备,计算机教学计划锦集五篇
  17. hive 数据的导入导出
  18. 招才猫显示服务器开小差,梦幻西游:百区平转开启却抢不到服务器?教你几招助你顺利转区...
  19. 两电平变流器matlab仿真,基于H桥级联型五电平逆变器Matlab仿真分析.doc
  20. Linux svn使用

热门文章

  1. 关于tcp连接中timewait的作用
  2. 前端技术演进(九):参考文章
  3. 笑着笑着就哭了,睡着睡着就痛了:QQ伤感日志
  4. 考研英语核心词汇趣讲-导学(1)
  5. 《我不是药神》 观影感
  6. 十分钟搞定Java多线程-如何使用sleep()方法和TimeUnit暂停线程
  7. 数据有哪些重要的作用?
  8. 对象池——Smiple Pool For Unity
  9. linux系统是如何获取网卡的通信数据的
  10. 软件测试人员如何快速提升自我?