Score SDE 三种随机微分方程代码解读
定义SDE类
定义了7个子函数
T: End time of the SDE.
sde:
marginal_prob: Parameters to determine the marginal distribution of the SDE, pt(x)p_t(x)pt(x).
prior_sampling: Generate one sample from the prior distribution, pT(x)p_T(x)pT(x).
prior_logp: Compute log-density of the prior distribution.
discretize: Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
reverse: Create the reverse-time SDE/ODE.
reverse
class SDE(abc.ABC):"""SDE abstract class. Functions are designed for a mini-batch of inputs."""def __init__(self, N):"""Construct an SDE.Args:N: number of discretization time steps."""super().__init__()self.N = N@property@abc.abstractmethoddef T(self):"""End time of the SDE."""pass@abc.abstractmethoddef sde(self, x, t):pass@abc.abstractmethoddef marginal_prob(self, x, t):"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""pass@abc.abstractmethoddef prior_sampling(self, shape):"""Generate one sample from the prior distribution, $p_T(x)$."""pass@abc.abstractmethoddef prior_logp(self, z):"""Compute log-density of the prior distribution.Useful for computing the log-likelihood via probability flow ODE.Args:z: latent codeReturns:log probability density"""passdef discretize(self, x, t):"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.Useful for reverse diffusion sampling and probabiliy flow sampling.Defaults to Euler-Maruyama discretization.Args:x: a torch tensort: a torch float representing the time step (from 0 to `self.T`)Returns:f, G"""dt = 1 / self.Ndrift, diffusion = self.sde(x, t)f = drift * dtG = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))return f, Gdef reverse(self, score_fn, probability_flow=False):"""Create the reverse-time SDE/ODE.Args:score_fn: A time-dependent score-based model that takes x and t and returns the score.probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling."""N = self.NT = self.Tsde_fn = self.sdediscretize_fn = self.discretize
定义reverse-time SDE类
# Build the class for reverse-time SDE.class RSDE(self.__class__):def __init__(self):self.N = Nself.probability_flow = probability_flow@propertydef T(self):return Tdef sde(self, x, t):"""Create the drift and diffusion functions for the reverse SDE/ODE."""drift, diffusion = sde_fn(x, t)score = score_fn(x, t)drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)# Set the diffusion function to zero for ODEs.diffusion = 0. if self.probability_flow else diffusionreturn drift, diffusiondef discretize(self, x, t):"""Create discretized iteration rules for the reverse diffusion sampler."""f, G = discretize_fn(x, t)rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)rev_G = torch.zeros_like(G) if self.probability_flow else Greturn rev_f, rev_Greturn RSDE()
定义VPSDE类
class VPSDE(SDE):def __init__(self, beta_min=0.1, beta_max=20, N=1000):"""Construct a Variance Preserving SDE.Args:beta_min: value of beta(0)beta_max: value of beta(1)N: number of discretization steps"""super().__init__(N)self.beta_0 = beta_minself.beta_1 = beta_maxself.N = Nself.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)self.alphas = 1. - self.discrete_betasself.alphas_cumprod = torch.cumprod(self.alphas, dim=0)self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)@propertydef T(self):return 1def sde(self, x, t):beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)drift = -0.5 * beta_t[:, None, None, None] * xdiffusion = torch.sqrt(beta_t)return drift, diffusiondef marginal_prob(self, x, t):log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0mean = torch.exp(log_mean_coeff[:, None, None, None]) * xstd = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))return mean, stddef prior_sampling(self, shape):return torch.randn(*shape)def prior_logp(self, z):shape = z.shapeN = np.prod(shape[1:])logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.return logpsdef discretize(self, x, t):"""DDPM discretization."""timestep = (t * (self.N - 1) / self.T).long()beta = self.discrete_betas.to(x.device)[timestep]alpha = self.alphas.to(x.device)[timestep]sqrt_beta = torch.sqrt(beta)f = torch.sqrt(alpha)[:, None, None, None] * x - xG = sqrt_betareturn f, G
定义subVPSDE类
class subVPSDE(SDE):def __init__(self, beta_min=0.1, beta_max=20, N=1000):"""Construct the sub-VP SDE that excels at likelihoods.Args:beta_min: value of beta(0)beta_max: value of beta(1)N: number of discretization steps"""super().__init__(N)self.beta_0 = beta_minself.beta_1 = beta_maxself.N = N@propertydef T(self):return 1def sde(self, x, t):beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)drift = -0.5 * beta_t[:, None, None, None] * xdiscount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2)diffusion = torch.sqrt(beta_t * discount)return drift, diffusion#边际概率函数,返回值是边际概率的mean和stddef marginal_prob(self, x, t):log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0mean = torch.exp(log_mean_coeff)[:, None, None, None] * xstd = 1 - torch.exp(2. * log_mean_coeff)return mean, stddef prior_sampling(self, shape):return torch.randn(*shape)def prior_logp(self, z):shape = z.shapeN = np.prod(shape[1:])return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.
定义VESDE类
class VESDE(SDE):def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):"""Construct a Variance Exploding SDE.Args:sigma_min: smallest sigma.sigma_max: largest sigma.N: number of discretization steps"""super().__init__(N)self.sigma_min = sigma_minself.sigma_max = sigma_maxself.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))self.N = N@propertydef T(self):return 1def sde(self, x, t):sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** tdrift = torch.zeros_like(x)diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),device=t.device))return drift, diffusiondef marginal_prob(self, x, t):std = self.sigma_min * (self.sigma_max / self.sigma_min) ** tmean = xreturn mean, stddef prior_sampling(self, shape):return torch.randn(*shape) * self.sigma_maxdef prior_logp(self, z):shape = z.shapeN = np.prod(shape[1:])return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)def discretize(self, x, t):"""SMLD(NCSN) discretization."""timestep = (t * (self.N - 1) / self.T).long()sigma = self.discrete_sigmas.to(t.device)[timestep]adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),self.discrete_sigmas[timestep - 1].to(t.device))f = torch.zeros_like(x)G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)return f, G
Score SDE 三种随机微分方程代码解读相关推荐
- html语言闪烁特效代码,css3 实现文字闪烁效果的三种方式示例代码
1.通过改变透明度来实现文字的渐变闪烁,效果图: 文字闪烁 星星之火可以燎原 .myclass{ letter-spacing:5px;/*字间距*/ color: red; font-weight: ...
- python 随机请求头_为了爬虫换个头,我用python实现三种随机请求头方式!
相信大家在爬虫中都设置过请求头 user-agent 这个参数吧? 在请求的时候,加入这个参数,就可以一定程度的伪装成浏览器,就不会被服务器直接识别为spider.demo.code ,据我了解的,我 ...
- java源代码实例倒计时_Java倒计时三种实现方式代码实例
写完js倒计时,突然想用java实现倒计时,写了三种实现方式 一:设置时长的倒计时: 二:设置时间戳的倒计时: 三:使用java.util.Timer类实现的时间戳倒计时 代码如下: package ...
- 插入排序的三种算法(Java代码实现)
目录 插入排序: 基本思想: 1:直接插入排序: 基本思想: 过程: 2:折半插入排序: 基本思想: 过程: 3:希尔排序: 基本思想: 过程: 插入排序: 基本思想: 每一趟将一个待排序的数,按照它 ...
- Java中遍历Set集合的三种方法(实例代码)
哈喽,欢迎来到小朱课堂,下面开始你的学习吧! Java中遍历Set集合的三种方法 废话不多说,直接上代码 1.迭代遍历: Set set = new HashSet(); Iterator it = ...
- java倒计时_Java倒计时三种实现方式代码实例
写完js倒计时,突然想用java实现倒计时,写了三种实现方式 一:设置时长的倒计时: 二:设置时间戳的倒计时: 三:使用java.util.Timer类实现的时间戳倒计时 代码如下: package ...
- php判断质数,php如何判断是否为素数?判断素数的三种方法(代码示例)
本篇文章给大家带来的内容是介绍php如何判断是否为素数?判断素数的三种方法(代码示例).有一定的参考价值,有需要的朋友可以参考一下,希望对你们有所帮助. 什么是素数? 质数又称素数.一个大于1的自然数 ...
- storyboard搭建项目_简单谈谈ios程序界面实现的三种方式(代码创建,xib和storyboard)...
一丶前言 实现ios界面总的来说,有三种方式,传统的是纯代码创建与xib创建,近年来,苹果官网一直推荐用storyboard管理项目界面,最新的xcode 创建的project也是默认为storybo ...
- Stanford cs224n 第三课: GloVe 代码解读
Makefile Makefile是linux中特有的一种文件, 方便自动化编译. GloVe的源码是用C语言编写的, 在linux的环境当中需要编写一个Makefile文件来编译.关于Makefil ...
最新文章
- go build不从本地gopath获取_Go包管理GOPATH、vendor、go mod机制
- 全球及中国USB分路器行业发展布局与应用现状调研报告2022年
- SAP Spartacus的发布方式以及语义化版本管理机制
- python 示例_Python date isoweekday()方法与示例
- 优质淘宝产品描述页模板框架PSD分层模板,美工实用素材
- Python技术、爬虫、数据分析问题汇总【自用】
- AI 芯片为何遭遇滑铁卢?
- pytest与unittest区别
- bzoj 2527: [Poi2011]Meteors
- 【现代版】为人处世三十六计详解,真的很受益!
- 量子计算机 模拟,量子计算机首次模拟实现“时光倒流”
- arcgis拓扑几何,因缝隙太小而不能自动创建要素修复的处理办法
- android 模拟器加速,android开发怎么设置加速模拟器如真机运行
- 摩尔条纹拯救我的3D检测
- 淘宝客(springboot版本)从头开始搭建(二)
- 105套抖音快闪模板
- Boost中的协程—Boost.Asio中的coroutine类
- Java 在线编程编译工具上线,直接运行Java代码
- 微软培训和认证的建议
- 大白话说期权——除了买涨买跌,我们还能怎么交易?二元期权又是什么鬼?
热门文章
- 无广告软件下载网站有哪些?资源分享老司机分享经验【亲身经历】
- MNIST数据集的导入与预处理
- python win32api.sendmessage_最新版本:python win32api模拟了背景鼠标单击问题。
- 腾讯、阿里、字节跳动的简单比较
- 基于easyui 1.3.6设计的后台管理系统模板界面
- 《图解TCPIP》<6.3>tcp协议
- 后渗透篇:劫持技术(lpk.dll劫持游戏注入【Win7 实例】)
- Win10升级Win11必备的5款免费软件
- Python自动化办公:读取pdf文档
- C语言判断素数的两种方法