Diffusion扩散模型学习1——Pytorch搭建DDPM利用深度卷积神经网络实现图片生成

  • 学习前言
  • 源码下载地址
  • 网络构建
    • 一、什么是Diffusion
      • 1、加噪过程
      • 2、去噪过程
    • 二、DDPM网络的构建(Unet网络的构建)
    • 三、Diffusion的训练思路
  • 利用DDPM生成图片
    • 一、数据集的准备
    • 二、数据集的处理
    • 三、模型训练

学习前言

我又死了我又死了我又死了!

源码下载地址

https://github.com/bubbliiiing/ddpm-pytorch

喜欢的可以点个star噢。

网络构建

一、什么是Diffusion


如上图所示。DDPM模型主要分为两个过程:
1、Forward加噪过程(从右往左),数据集的真实图片中逐步加入高斯噪声,最终变成一个杂乱无章的高斯噪声,这个过程一般发生在训练的时候。加噪过程满足一定的数学规律。
2、Reverse去噪过程(从左往右),指对加了噪声的图片逐步去噪,从而还原出真实图片,这个过程一般发生在预测生成的时候。尽管在这里说的是加了噪声的图片,但实际去预测生成的时候,是随机生成一个高斯噪声来去噪。去噪的时候不断根据XtX_tXt的图片生成Xt−1X_{t-1}Xt1的噪声,从而实现图片的还原。

1、加噪过程


Forward加噪过程主要符合如下的公式:
xt=αtxt−1+1−αtz1x_t=\sqrt{\alpha_t} x_{t-1}+\sqrt{1-\alpha_t} z_{1} xt=αt

xt1+1αt

z1

其中αt\sqrt{\alpha_t}αt

是预先设定好的超参数,被称为Noise schedule,通常是小于1的值,在论文中αt\alpha_tαt的值从0.9999到0.998。ϵt−1∼N(0,1)\epsilon_{t-1} \sim N(0, 1)ϵt1N(0,1)是高斯噪声。由公式(1)迭代推导。

xt=at(at−1xt−2+1−αt−1z2)+1−αtz1=atat−1xt−2+(at(1−αt−1)z2+1−αtz1)x_t=\sqrt{a_t}\left(\sqrt{a_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} z_2\right)+\sqrt{1-\alpha_t} z_1=\sqrt{a_t a_{t-1}} x_{t-2}+\left(\sqrt{a_t\left(1-\alpha_{t-1}\right)} z_2+\sqrt{1-\alpha_t} z_1\right)xt=at

(at1

xt2+1αt1

z2)
+1αt

z1=
atat1

xt2+
(at(1αt1)

z2+1αt

z1)

其中每次加入的噪声都服从高斯分布 z1,z2,…∼N(0,1)z_1, z_2, \ldots \sim \mathcal{N}(0, 1)z1,z2,N(0,1),两个高斯分布的相加高斯分布满足公式:N(0,σ12)+N(0,σ22)∼N(0,(σ12+σ22))\mathcal{N}\left(0, \sigma_1^2 \right)+\mathcal{N}\left(0, \sigma_2^2 \right) \sim \mathcal{N}\left(0,\left(\sigma_1^2+\sigma_2^2\right) \right)N(0,σ12)+N(0,σ22)N(0,(σ12+σ22)),因此,得到xtx_txt的公式为:
xt=atat−1xt−2+1−αtαt−1z2x_t = \sqrt{a_t a_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} z_2 xt=atat1

xt2+1αtαt1

z2

因此不断往里面套,就能发现规律了,其实就是累乘
可以直接得出x0x_0x0xtx_txt的公式:
xt=αt‾x0+1−αt‾ztx_t=\sqrt{\overline{\alpha_t}} x_0+\sqrt{1-\overline{\alpha_t}} z_t xt=αt

x0+
1αt

zt

其中αt‾=∏itαi\overline{\alpha_t}=\prod_i^t \alpha_iαt=itαi,这是随Noise schedule设定好的超参数,zt−1∼N(0,1)z_{t-1} \sim N(0, 1)zt1N(0,1)也是一个高斯噪声。通过上述两个公式,我们可以不断的将图片进行破坏加噪。

2、去噪过程


反向过程就是通过估测噪声,多次迭代逐渐将被破坏的xtx_txt恢复成x0x_0x0,在恢复时刻,我们已经知道的是xtx_txt,这是图片在ttt时刻的噪声图。一下子从xtx_txt恢复成x0x_0x0是不可能的,我们只能一步一步的往前推,首先从xtx_txt恢复成xt−1x_{t-1}xt1。根据贝叶斯公式,已知xtx_txt反推xt−1x_{t-1}xt1
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt−1∣x0)q(xt∣x0)q\left(x_{t-1} \mid x_t, x_0\right)=q\left(x_t \mid x_{t-1}, x_0\right) \frac{q\left(x_{t-1} \mid x_0\right)}{q\left(x_t \mid x_0\right)} q(xt1xt,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0)
右边的三个东西都可以从x_0开始推得到:
q(xt−1∣x0)=aˉt−1x0+1−aˉt−1z∼N(aˉt−1x0,1−aˉt−1)q\left(x_{t-1} \mid x_0\right)=\sqrt{\bar{a}_{t-1}} x_0+\sqrt{1-\bar{a}_{t-1}} z \sim \mathcal{N}\left(\sqrt{\bar{a}_{t-1}} x_0, 1-\bar{a}_{t-1}\right) q(xt1x0)=aˉt1

x0+1aˉt1

z
N(aˉt1

x0,1aˉt1)

q(xt∣x0)=aˉtx0+1−αˉtz∼N(aˉtx0,1−αˉt)q\left(x_t \mid x_0\right) = \sqrt{\bar{a}_t} x_0+\sqrt{1-\bar{\alpha}_t} z \sim \mathcal{N}\left(\sqrt{\bar{a}_t} x_0 , 1-\bar{\alpha}_t\right) q(xtx0)=aˉt

x0+
1αˉt

z
N(aˉt

x0,1αˉt)

q(xt∣xt−1,x0)=atxt−1+1−αtz∼N(atxt−1,1−αt)q\left(x_t \mid x_{t-1}, x_0\right)=\sqrt{a_t} x_{t-1}+\sqrt{1-\alpha_t} z \sim \mathcal{N}\left(\sqrt{a_t} x_{t-1}, 1-\alpha_t\right) \\ q(xtxt1,x0)=at

xt1+
1αt

z
N(at

xt1,1αt)

因此,由于右边三个东西均满足正态分布,q(xt−1∣xt,x0)q\left(x_{t-1} \mid x_t, x_0\right)q(xt1xt,x0)满足分布如下:
∝exp⁡(−12((xt−αtxt−1)2βt+(xt−1−αˉt−1x0)21−αˉt−1−(xt−αˉtx0)21−αˉt))\propto \exp \left(-\frac{1}{2}\left(\frac{\left(x_t-\sqrt{\alpha_t} x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}} x_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) exp(21(βt(xtαt

xt1)
2
+1αˉt1(xt1αˉt1

x0)
2
1αˉt(xtαˉt

x0)
2
)
)

把标准正态分布展开后,乘法就相当于加,除法就相当于减,把他们汇总
接下来继续化简,咱们现在要求的是上一时刻的分布
∝exp⁡(−12((xt−αtxt−1)2βt+(xt−1−αˉt−1x0)21−αˉt−1−(xt−αˉtx0)21−αˉt))=exp⁡(−12(xt2−2αtxtxt−1+αtxt−12βt+xt−12−2αˉt−1x0xt−1+αˉt−1x021−αˉt−1−(xt−αˉtx0)21−αˉt))=exp⁡(−12((αtβt+11−αˉt−1)xt−12−(2αtβtxt+2αˉt−11−αˉt−1x0)xt−1+C(xt,x0)))\begin{aligned} & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(x_t-\sqrt{\alpha_t} x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}} x_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\frac{x_t^2-2 \sqrt{\alpha_t} x_t x_{t-1}+\alpha_t x_{t-1}^2}{\beta_t}+\frac{x_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} x_0 x_{t-1}+\bar{\alpha}_{t-1} x_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) x_{t-1}^2-\left(\frac{2 \sqrt{\alpha_t}}{\beta_t} x_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} x_0\right) x_{t-1}+C\left(x_t, x_0\right)\right)\right) \end{aligned} exp(21(βt(xtαt

xt1)
2
+1αˉt1(xt1αˉt1

x0)
2
1αˉt(xtαˉt

x0)
2
)
)
=exp(21(βtxt22αt

xtxt1+αtxt12
+1αˉt1xt122αˉt1

x0xt1+αˉt1x02
1αˉt(xtαˉt

x0)
2
)
)
=exp(21((βtαt+1αˉt11)xt12(βt2αt

xt+1αˉt12αˉt1

x0)
xt1+C(xt,x0))
)

正态分布满足公式,exp⁡(−(x−μ)22σ2)=exp⁡(−12(1σ2x2−2μσ2x+μ2σ2))\exp \left(-\frac{(x-\mu)^2}{2 \sigma^2}\right)=\exp \left(-\frac{1}{2}\left(\frac{1}{\sigma^2} x^2-\frac{2 \mu}{\sigma^2} x+\frac{\mu^2}{\sigma^2}\right)\right)exp(2σ2(xμ)2)=exp(21(σ21x2σ22μx+σ2μ2)),其中σ\sigmaσ就是方差,μ\muμ就是均值,配方后我们就可以获得均值和方差。

此时的均值为:μ~t(xt,x0)=αt(1−αˉt−1)1−αˉtxt+αˉt−1βt1−αˉtx0\tilde{\mu}_t\left(x_t, x_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} x_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} x_0μ~t(xt,x0)=1αˉtαt

(1αˉt1)xt+1αˉtαˉt1

βt
x0
。根据之前的公式,xt=αt‾x0+1−αt‾ztx_t=\sqrt{\overline{\alpha_t}} x_0+\sqrt{1-\overline{\alpha_t}} z_txt=αt

x0+
1αt

zt
,我们可以使用xtx_txt反向估计x0x_0x0得到x0x_0x0满足分布x0=1αˉt(xt−1−αˉtzt)x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathrm{x}_t-\sqrt{1-\bar{\alpha}_t} z_t\right)x0=αˉt

1
(xt1αˉt

zt)
。最终得到均值为μ~t=1at(xt−βt1−aˉtzt)\tilde{\mu}_t=\frac{1}{\sqrt{a_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{a}_t}} z_t\right)μ~t=at

1
(xt1aˉt

βt
zt)
ztz_tzt代表t时刻的噪音是什么。由ztz_tzt无法直接获得,网络便通过当前时刻的xtx_txt经过神经网络计算ztz_tztϵθ(xt,t)\epsilon_\theta\left(x_t, t\right)ϵθ(xt,t)也就是上面提到的ztz_tztϵθ\epsilon_\thetaϵθ代表神经网络。
xt−1=1αt(xt−1−αt1−αˉtϵθ(xt,t))+σtzx_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta\left(x_t, t\right)\right)+\sigma_t z xt1=αt

1
(xt1αˉt

1αt
ϵθ(xt,t))
+
σtz

由于加噪过程中的真实噪声ϵ\epsilonϵ在复原过程中是无法获得的,因此DDPM的关键就是训练一个由xtx_txtttt估测橾声的模型 ϵθ(xt,t)\epsilon_\theta\left(x_t, t\right)ϵθ(xt,t),其中θ\thetaθ就是模型的训练参数,σt\sigma_tσt 也是一个高斯噪声 σt∼N(0,1)\sigma_t \sim N(0,1)σtN(0,1),用于表示估测与实际的差距。在DDPM中,使用U-Net作为估测噪声的模型。

本质上,我们就是训练这个Unet模型,该模型输入为xtx_txtttt,输出为xtx_txt时刻的高斯噪声。即利用xtx_txtttt预测这一时刻的高斯噪声。这样就可以一步一步的再从噪声回到真实图像。

二、DDPM网络的构建(Unet网络的构建)


上图是典型的Unet模型结构,仅仅作为示意图,里面具体的数字同学们无需在意,和本文的学习无关。在本文中,Unet的输入和输出shape相同,通道均为3(一般为RGB三通道),宽高相同。

本质上,DDPM最重要的工作就是训练Unet模型,该模型输入为xtx_txtttt,输出为xt−1x_{t-1}xt1时刻的高斯噪声。即利用xtx_txtttt预测上一时刻的高斯噪声。这样就可以一步一步的再从噪声回到真实图像。

假设我们需要生成一个[64, 64, 3]的图像,在ttt时刻,我们有一个xtx_txt噪声图,该噪声图的的shape也为[64, 64, 3],我们将它和ttt一起输入到Unet中。Unet的输出为xt−1x_{t-1}xt1时刻的[64, 64, 3]的噪声。

实现代码如下,代码中的特征提取模块为残差结构,方便优化:

import mathimport torch
import torch.nn as nn
import torch.nn.functional as Fdef get_norm(norm, num_channels, num_groups):if norm == "in":return nn.InstanceNorm2d(num_channels, affine=True)elif norm == "bn":return nn.BatchNorm2d(num_channels)elif norm == "gn":return nn.GroupNorm(num_groups, num_channels)elif norm is None:return nn.Identity()else:raise ValueError("unknown normalization type")#------------------------------------------#
#   计算时间步长的位置嵌入。
#   一半为sin,一半为cos。
#------------------------------------------#
class PositionalEmbedding(nn.Module):def __init__(self, dim, scale=1.0):super().__init__()assert dim % 2 == 0self.dim = dimself.scale = scaledef forward(self, x):device      = x.devicehalf_dim    = self.dim // 2emb = math.log(10000) / half_dimemb = torch.exp(torch.arange(half_dim, device=device) * -emb)# x * self.scale和emb外积emb = torch.outer(x * self.scale, emb)emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb#------------------------------------------#
#   下采样层,一个步长为2x2的卷积
#------------------------------------------#
class Downsample(nn.Module):def __init__(self, in_channels):super().__init__()self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)def forward(self, x, time_emb, y):if x.shape[2] % 2 == 1:raise ValueError("downsampling tensor height should be even")if x.shape[3] % 2 == 1:raise ValueError("downsampling tensor width should be even")return self.downsample(x)#------------------------------------------#
#   上采样层,Upsample+卷积
#------------------------------------------#
class Upsample(nn.Module):def __init__(self, in_channels):super().__init__()self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"),nn.Conv2d(in_channels, in_channels, 3, padding=1),)def forward(self, x, time_emb, y):return self.upsample(x)#------------------------------------------#
#   使用Self-Attention注意力机制
#   做一个全局的Self-Attention
#------------------------------------------#
class AttentionBlock(nn.Module):def __init__(self, in_channels, norm="gn", num_groups=32):super().__init__()self.in_channels = in_channelsself.norm = get_norm(norm, in_channels, num_groups)self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)self.to_out = nn.Conv2d(in_channels, in_channels, 1)def forward(self, x):b, c, h, w  = x.shapeq, k, v     = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)q = q.permute(0, 2, 3, 1).view(b, h * w, c)k = k.view(b, c, h * w)v = v.permute(0, 2, 3, 1).view(b, h * w, c)dot_products = torch.bmm(q, k) * (c ** (-0.5))assert dot_products.shape == (b, h * w, h * w)attention   = torch.softmax(dot_products, dim=-1)out         = torch.bmm(attention, v)assert out.shape == (b, h * w, c)out         = out.view(b, h, w, c).permute(0, 3, 1, 2)return self.to_out(out) + x#------------------------------------------#
#   用于特征提取的残差结构
#------------------------------------------#
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,norm="gn", num_groups=32, use_attention=False,):super().__init__()self.activation = activationself.norm_1 = get_norm(norm, in_channels, num_groups)self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)self.norm_2 = get_norm(norm, out_channels, num_groups)self.conv_2 = nn.Sequential(nn.Dropout(p=dropout), nn.Conv2d(out_channels, out_channels, 3, padding=1),)self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else Noneself.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else Noneself.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)def forward(self, x, time_emb=None, y=None):out = self.activation(self.norm_1(x))# 第一个卷积out = self.conv_1(out)# 对时间time_emb做一个全连接,施加在通道上if self.time_bias is not None:if time_emb is None:raise ValueError("time conditioning was specified but time_emb is not passed")out += self.time_bias(self.activation(time_emb))[:, :, None, None]# 对种类y_emb做一个全连接,施加在通道上if self.class_bias is not None:if y is None:raise ValueError("class conditioning was specified but y is not passed")out += self.class_bias(y)[:, :, None, None]out = self.activation(self.norm_2(out))# 第二个卷积+残差边out = self.conv_2(out) + self.residual_connection(x)# 最后做个Attentionout = self.attention(out)return out#------------------------------------------#
#   Unet模型
#------------------------------------------#
class UNet(nn.Module):def __init__(self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,):super().__init__()# 使用到的激活函数,一般为SILUself.activation = activation# 是否对输入进行paddingself.initial_pad = initial_pad# 需要去区分的类别数self.num_classes = num_classes# 对时间轴输入的全连接层self.time_mlp = nn.Sequential(PositionalEmbedding(base_channels, time_emb_scale),nn.Linear(base_channels, time_emb_dim),nn.SiLU(),nn.Linear(time_emb_dim, time_emb_dim),) if time_emb_dim is not None else None# 对输入图片的第一个卷积self.init_conv  = nn.Conv2d(img_channels, base_channels, 3, padding=1)# self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征# 然后利用Downsample降低特征图的高宽self.downs      = nn.ModuleList()self.ups        = nn.ModuleList()# channels指的是每一个模块处理后的通道数# now_channels是一个中间变量,代表中间的通道数channels        = [base_channels]now_channels    = base_channelsfor i, mult in enumerate(channel_mults):out_channels = base_channels * multfor _ in range(num_res_blocks):self.downs.append(ResidualBlock(now_channels, out_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))now_channels = out_channelschannels.append(now_channels)if i != len(channel_mults) - 1:self.downs.append(Downsample(now_channels))channels.append(now_channels)# 可以看作是特征整合,中间的一个特征提取模块self.mid = nn.ModuleList([ResidualBlock(now_channels, now_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,norm=norm, num_groups=num_groups, use_attention=True,),ResidualBlock(now_channels, now_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=False,),])# 进行上采样,进行特征融合for i, mult in reversed(list(enumerate(channel_mults))):out_channels = base_channels * multfor _ in range(num_res_blocks + 1):self.ups.append(ResidualBlock(channels.pop() + now_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))now_channels = out_channelsif i != 0:self.ups.append(Upsample(now_channels))assert len(channels) == 0self.out_norm = get_norm(norm, base_channels, num_groups)self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)def forward(self, x, time=None, y=None):# 是否对输入进行paddingip = self.initial_padif ip != 0:x = F.pad(x, (ip,) * 4)# 对时间轴输入的全连接层if self.time_mlp is not None:if time is None:raise ValueError("time conditioning was specified but tim is not passed")time_emb = self.time_mlp(time)else:time_emb = Noneif self.num_classes is not None and y is None:raise ValueError("class conditioning was specified but y is not passed")# 对输入图片的第一个卷积x = self.init_conv(x)# skips用于存放下采样的中间层skips = [x]for layer in self.downs:x = layer(x, time_emb, y)skips.append(x)# 特征整合与提取for layer in self.mid:x = layer(x, time_emb, y)# 上采样并进行特征融合for layer in self.ups:if isinstance(layer, ResidualBlock):x = torch.cat([x, skips.pop()], dim=1)x = layer(x, time_emb, y)# 上采样并进行特征融合x = self.activation(self.out_norm(x))x = self.out_conv(x)if self.initial_pad != 0:return x[:, :, ip:-ip, ip:-ip]else:return x

三、Diffusion的训练思路

Diffusion的训练思路比较简单,首先随机给每个batch里每张图片都生成一个t,代表我选择这个batch里面第t个时刻的噪声进行拟合。代码如下:

t = torch.randint(0, self.num_timesteps, (b,), device=device)

生成batch_size个噪声,计算施加这个噪声后模型在t个时刻的噪声图片是怎么样的,如下所示:

def perturb_x(self, x, t, noise):return (extract(self.sqrt_alphas_cumprod, t,  x.shape) * x +extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise)   def get_losses(self, x, t, y):# x, noise [batch_size, 3, 64, 64]noise           = torch.randn_like(x)perturbed_x     = self.perturb_x(x, t, noise)

之后利用这个噪声图片、t和网络模型计算预测噪声,利用预测噪声和实际噪声进行拟合。

def get_losses(self, x, t, y):# x, noise [batch_size, 3, 64, 64]noise           = torch.randn_like(x)perturbed_x     = self.perturb_x(x, t, noise)estimated_noise = self.model(perturbed_x, t, y)if self.loss_type == "l1":loss = F.l1_loss(estimated_noise, noise)elif self.loss_type == "l2":loss = F.mse_loss(estimated_noise, noise)return loss

利用DDPM生成图片

DDPM的库整体结构如下:

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。

训练过程中,可在results文件夹内查看训练效果:

Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成相关推荐

  1. Diffusion扩散模型学习2——Stable Diffusion结构解析-以文本生成图像(文生图,txt2img)为例

    Diffusion扩散模型学习2--Stable Diffusion结构解析-以文本生成图像(文生图,txt2img)为例 学习前言 源码下载地址 网络构建 一.什么是Stable Diffusion ...

  2. Diffusion 扩散模型(DDPM)详解及torch复现

    文章目录 torch复现 第1步:正向过程=噪声调度器 Step 2: 反向传播 = U-Net Step 3: 损失函数 采样 Training 我公众号文章目录综述: https://wanggu ...

  3. Stable diffusion扩散模型相关

    时隔两年半(2年4个月),我又回来研究生成技术了.以前学习研究GAN没结果,不管是技术上,还是应用产品上,结果就放弃了,现在基于diffusion的技术又把生成技术带上了一个新的高度.现在自己又来研究 ...

  4. Stable Diffusion 2.0 上线,却痛失黄暴图片生成能力

    文章授权  新智元  编辑:编辑部 [新智元导读]大火的文本到图像模型Stable Diffusion2.0版本来了,然而因为这个原因,广大网友们闹起来了. 昨天,Stability AI宣布,Sta ...

  5. 【深度学习】——pytorch搭建模型及相关模型

    目录 1.搭建模型的流程 1)步骤 2)完整代码--手写minist数据集为例(这里使用的数据集是自带的) 2.搭建模型的四种方法 1)方法一--利用nn.Sequential() 2)方法二--利用 ...

  6. Seq2Seq模型学习(pytorch)

    在看pytorch的官方英文例子,做些笔记,如有纰漏请指正,原文:https://pytorch.org/tutorials/beginner/chatbot_tutorial.html 数据准备 首 ...

  7. 扩散模型(Diffusion Model,DDPM,GLIDE,DALLE2,Stable Diffusion)

    随着最近DALLE2和stable diffusion的大火,扩散模型的出色表现丝毫不逊色VAE和GAN,已经形成生成领域的三大方向:VAE.GAN和Diffusion,如上图可以简要看出几类主线模型 ...

  8. 深度学习:Diffusion Models in Vision: A Survey视觉中的扩散模型:综述

    Diffusion Models in Vision: A Survey视觉中的扩散模型:综述 0.摘要 1.概述 2.通用模型架构 2.1.Denoising Diffusion Probabili ...

  9. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

最新文章

  1. 大白话讲解Promise(二)理解Promise规范
  2. 您能否提供一些示例,说明为什么用正则表达式很难解析XML和HTML? [关闭]
  3. 敲诈勒索比特币不断,企业用户如何防“山寨”钓鱼邮件
  4. Cpp 对象模型探索 / delete 运算符内部调用过程分析
  5. boost::iostreams::file_descriptor_source用法的测试程序
  6. SAP CRM WebClient UI Date time format in BSP UI
  7. 【HDU 1150】Machine Schedule(二分图匹配)
  8. Android5.0水波纹效果ripple实现
  9. iOS 图形处理 翻译
  10. IIS访问要求输入用户名密码
  11. linux 安装 xpdf csdn,Centos安装xpdf 解析pdf文件
  12. mac软件意外退出怎么解决_Mac 软件常见问题解决方法汇总
  13. c语言根据日期求星期蔡勒公式,利用蔡勒公式获得给定日期的星期数
  14. 编写一个程序求解字谜游戏问题
  15. Java 8:那些Java8的常见写法
  16. 免费ICP域名备案查接口
  17. 三极管构成的电流负反馈放大器
  18. m1芯片Mac如何玩ios手游
  19. 百万调音师—Audition初识
  20. 使用wangEditor富文本编辑器遇到的问题总结

热门文章

  1. 【CSRF02】跨站请求伪造——DVWA靶场实操(含CSRF+XSS协同攻击实验)
  2. Perl 语言学习笔记 (一)
  3. git导出代码的方法~archive
  4. 阿里云服务器中挖矿木马处理过程
  5. VC 出错 msdev.exe错误
  6. Delphi 2010 调用WebService接口
  7. VMware Workstation 14 虚拟机配置xp系统
  8. 对于大量工控软件,IFIX 组态王等的深层解密分析,曲线
  9. python遗传算法实例:求一元二次方程实例
  10. c语言成绩管理系统教程,C语言学生成绩管理系统教程.doc