定义SDE类

定义了7个子函数

T: End time of the SDE.
sde:
marginal_prob: Parameters to determine the marginal distribution of the SDE, pt(x)p_t(x)pt​(x).
prior_sampling: Generate one sample from the prior distribution, pT(x)p_T(x)pT​(x).
prior_logp: Compute log-density of the prior distribution.
discretize: Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
reverse: Create the reverse-time SDE/ODE.

reverse

class SDE(abc.ABC):"""SDE abstract class. Functions are designed for a mini-batch of inputs."""def __init__(self, N):"""Construct an SDE.Args:N: number of discretization time steps."""super().__init__()self.N = N@property@abc.abstractmethoddef T(self):"""End time of the SDE."""pass@abc.abstractmethoddef sde(self, x, t):pass@abc.abstractmethoddef marginal_prob(self, x, t):"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""pass@abc.abstractmethoddef prior_sampling(self, shape):"""Generate one sample from the prior distribution, $p_T(x)$."""pass@abc.abstractmethoddef prior_logp(self, z):"""Compute log-density of the prior distribution.Useful for computing the log-likelihood via probability flow ODE.Args:z: latent codeReturns:log probability density"""passdef discretize(self, x, t):"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.Useful for reverse diffusion sampling and probabiliy flow sampling.Defaults to Euler-Maruyama discretization.Args:x: a torch tensort: a torch float representing the time step (from 0 to `self.T`)Returns:f, G"""dt = 1 / self.Ndrift, diffusion = self.sde(x, t)f = drift * dtG = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))return f, Gdef reverse(self, score_fn, probability_flow=False):"""Create the reverse-time SDE/ODE.Args:score_fn: A time-dependent score-based model that takes x and t and returns the score.probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling."""N = self.NT = self.Tsde_fn = self.sdediscretize_fn = self.discretize

定义reverse-time SDE类

    # Build the class for reverse-time SDE.class RSDE(self.__class__):def __init__(self):self.N = Nself.probability_flow = probability_flow@propertydef T(self):return Tdef sde(self, x, t):"""Create the drift and diffusion functions for the reverse SDE/ODE."""drift, diffusion = sde_fn(x, t)score = score_fn(x, t)drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)# Set the diffusion function to zero for ODEs.diffusion = 0. if self.probability_flow else diffusionreturn drift, diffusiondef discretize(self, x, t):"""Create discretized iteration rules for the reverse diffusion sampler."""f, G = discretize_fn(x, t)rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)rev_G = torch.zeros_like(G) if self.probability_flow else Greturn rev_f, rev_Greturn RSDE()

定义VPSDE类

class VPSDE(SDE):def __init__(self, beta_min=0.1, beta_max=20, N=1000):"""Construct a Variance Preserving SDE.Args:beta_min: value of beta(0)beta_max: value of beta(1)N: number of discretization steps"""super().__init__(N)self.beta_0 = beta_minself.beta_1 = beta_maxself.N = Nself.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)self.alphas = 1. - self.discrete_betasself.alphas_cumprod = torch.cumprod(self.alphas, dim=0)self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)@propertydef T(self):return 1def sde(self, x, t):beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)drift = -0.5 * beta_t[:, None, None, None] * xdiffusion = torch.sqrt(beta_t)return drift, diffusiondef marginal_prob(self, x, t):log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0mean = torch.exp(log_mean_coeff[:, None, None, None]) * xstd = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))return mean, stddef prior_sampling(self, shape):return torch.randn(*shape)def prior_logp(self, z):shape = z.shapeN = np.prod(shape[1:])logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.return logpsdef discretize(self, x, t):"""DDPM discretization."""timestep = (t * (self.N - 1) / self.T).long()beta = self.discrete_betas.to(x.device)[timestep]alpha = self.alphas.to(x.device)[timestep]sqrt_beta = torch.sqrt(beta)f = torch.sqrt(alpha)[:, None, None, None] * x - xG = sqrt_betareturn f, G

定义subVPSDE类

class subVPSDE(SDE):def __init__(self, beta_min=0.1, beta_max=20, N=1000):"""Construct the sub-VP SDE that excels at likelihoods.Args:beta_min: value of beta(0)beta_max: value of beta(1)N: number of discretization steps"""super().__init__(N)self.beta_0 = beta_minself.beta_1 = beta_maxself.N = N@propertydef T(self):return 1def sde(self, x, t):beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)drift = -0.5 * beta_t[:, None, None, None] * xdiscount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2)diffusion = torch.sqrt(beta_t * discount)return drift, diffusion#边际概率函数,返回值是边际概率的mean和stddef marginal_prob(self, x, t):log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0mean = torch.exp(log_mean_coeff)[:, None, None, None] * xstd = 1 - torch.exp(2. * log_mean_coeff)return mean, stddef prior_sampling(self, shape):return torch.randn(*shape)def prior_logp(self, z):shape = z.shapeN = np.prod(shape[1:])return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.

定义VESDE类

class VESDE(SDE):def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):"""Construct a Variance Exploding SDE.Args:sigma_min: smallest sigma.sigma_max: largest sigma.N: number of discretization steps"""super().__init__(N)self.sigma_min = sigma_minself.sigma_max = sigma_maxself.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))self.N = N@propertydef T(self):return 1def sde(self, x, t):sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** tdrift = torch.zeros_like(x)diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),device=t.device))return drift, diffusiondef marginal_prob(self, x, t):std = self.sigma_min * (self.sigma_max / self.sigma_min) ** tmean = xreturn mean, stddef prior_sampling(self, shape):return torch.randn(*shape) * self.sigma_maxdef prior_logp(self, z):shape = z.shapeN = np.prod(shape[1:])return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)def discretize(self, x, t):"""SMLD(NCSN) discretization."""timestep = (t * (self.N - 1) / self.T).long()sigma = self.discrete_sigmas.to(t.device)[timestep]adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),self.discrete_sigmas[timestep - 1].to(t.device))f = torch.zeros_like(x)G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)return f, G

Score SDE 三种随机微分方程代码解读相关推荐

  1. html语言闪烁特效代码,css3 实现文字闪烁效果的三种方式示例代码

    1.通过改变透明度来实现文字的渐变闪烁,效果图: 文字闪烁 星星之火可以燎原 .myclass{ letter-spacing:5px;/*字间距*/ color: red; font-weight: ...

  2. python 随机请求头_为了爬虫换个头,我用python实现三种随机请求头方式!

    相信大家在爬虫中都设置过请求头 user-agent 这个参数吧? 在请求的时候,加入这个参数,就可以一定程度的伪装成浏览器,就不会被服务器直接识别为spider.demo.code ,据我了解的,我 ...

  3. java源代码实例倒计时_Java倒计时三种实现方式代码实例

    写完js倒计时,突然想用java实现倒计时,写了三种实现方式 一:设置时长的倒计时: 二:设置时间戳的倒计时: 三:使用java.util.Timer类实现的时间戳倒计时 代码如下: package ...

  4. 插入排序的三种算法(Java代码实现)

    目录 插入排序: 基本思想: 1:直接插入排序: 基本思想: 过程: 2:折半插入排序: 基本思想: 过程: 3:希尔排序: 基本思想: 过程: 插入排序: 基本思想: 每一趟将一个待排序的数,按照它 ...

  5. Java中遍历Set集合的三种方法(实例代码)

    哈喽,欢迎来到小朱课堂,下面开始你的学习吧! Java中遍历Set集合的三种方法 废话不多说,直接上代码 1.迭代遍历: Set set = new HashSet(); Iterator it = ...

  6. java倒计时_Java倒计时三种实现方式代码实例

    写完js倒计时,突然想用java实现倒计时,写了三种实现方式 一:设置时长的倒计时: 二:设置时间戳的倒计时: 三:使用java.util.Timer类实现的时间戳倒计时 代码如下: package ...

  7. php判断质数,php如何判断是否为素数?判断素数的三种方法(代码示例)

    本篇文章给大家带来的内容是介绍php如何判断是否为素数?判断素数的三种方法(代码示例).有一定的参考价值,有需要的朋友可以参考一下,希望对你们有所帮助. 什么是素数? 质数又称素数.一个大于1的自然数 ...

  8. storyboard搭建项目_简单谈谈ios程序界面实现的三种方式(代码创建,xib和storyboard)...

    一丶前言 实现ios界面总的来说,有三种方式,传统的是纯代码创建与xib创建,近年来,苹果官网一直推荐用storyboard管理项目界面,最新的xcode 创建的project也是默认为storybo ...

  9. Stanford cs224n 第三课: GloVe 代码解读

    Makefile Makefile是linux中特有的一种文件, 方便自动化编译. GloVe的源码是用C语言编写的, 在linux的环境当中需要编写一个Makefile文件来编译.关于Makefil ...

最新文章

  1. go build不从本地gopath获取_Go包管理GOPATH、vendor、go mod机制
  2. 全球及中国USB分路器行业发展布局与应用现状调研报告2022年
  3. SAP Spartacus的发布方式以及语义化版本管理机制
  4. python 示例_Python date isoweekday()方法与示例
  5. 优质淘宝产品描述页模板框架PSD分层模板,美工实用素材
  6. Python技术、爬虫、数据分析问题汇总【自用】
  7. AI 芯片为何遭遇滑铁卢?
  8. pytest与unittest区别
  9. bzoj 2527: [Poi2011]Meteors
  10. 【现代版】为人处世三十六计详解,真的很受益!
  11. 量子计算机 模拟,量子计算机首次模拟实现“时光倒流”
  12. arcgis拓扑几何,因缝隙太小而不能自动创建要素修复的处理办法
  13. android 模拟器加速,android开发怎么设置加速模拟器如真机运行
  14. 摩尔条纹拯救我的3D检测
  15. 淘宝客(springboot版本)从头开始搭建(二)
  16. 105套抖音快闪模板
  17. Boost中的协程—Boost.Asio中的coroutine类
  18. Java 在线编程编译工具上线,直接运行Java代码
  19. 微软培训和认证的建议
  20. 大白话说期权——除了买涨买跌,我们还能怎么交易?二元期权又是什么鬼?

热门文章

  1. 无广告软件下载网站有哪些?资源分享老司机分享经验【亲身经历】
  2. MNIST数据集的导入与预处理
  3. python win32api.sendmessage_最新版本:python win32api模拟了背景鼠标单击问题。
  4. 腾讯、阿里、字节跳动的简单比较
  5. 基于easyui 1.3.6设计的后台管理系统模板界面
  6. 《图解TCPIP》<6.3>tcp协议
  7. 后渗透篇:劫持技术(lpk.dll劫持游戏注入【Win7 实例】)
  8. Win10升级Win11必备的5款免费软件
  9. Python自动化办公:读取pdf文档
  10. C语言判断素数的两种方法