直白的讲,DDMP类的扩散模型方法,就是训练一个深度神经网络(如Unet)取学习、拟合逆扩散过程
q
(
x
t
−
1
∣
x
t
)
q(x_{t-1}|x_{t})
q(xt−1∣xt)的分布;推理阶段就能直接随机采样一个噪声,然后使用Unet一步一步采样图片。主要的模型架构涉及到Unet模型和一个扩散过程类,本笔记就是针对官方代码中实现此部分的代码进行注释解析,会涉及到IDDPM在DDPM基础上提出的四个改善点中的两个,分别是扩散过程中余弦加噪方法和设置可学习方差涉及到的
L
v
l
b
L_{vlb}
Lvlb损失的计算。IDDPM代码需要与论文对照学习才能有效理解,不然很多代码都会搞不清楚为什么那样写,可通过此链接浏览论文笔记IDDPM论文阅读辅助理解。
self.posterior_mean_coef1 ==>
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}
1−αˉtαˉt−1βt;论文公式11中第一个系数
self.posterior_mean_coef2 ==>
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}
1−αˉtαt(1−αˉt−1);论文公式11中第二个系数
具体代码及注释如此:
import enum
import math
import numpy as np
import torch as th
from.nn import mean_flat
from.losses import normal_kl, discretized_gaussian_log_likelihood
# 生成加噪方案;IDDPM提出的余弦加噪方案效果更好defget_named_beta_schedule(schedule_name, num_diffusion_timesteps):"""
Get a pre-defined beta schedule for the given name.获取给定名称的预定义beta方案
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""if schedule_name =="linear":# 线性的加噪方案,DDPM的加噪方案# Linear schedule from Ho et al, extended to work for any number of# diffusion steps.
scale =1000/ num_diffusion_timesteps # num_diffusion_timesteps不一定时1000,scale是训练时时间步序列的缩放量
beta_start = scale *0.0001
beta_end = scale *0.02# 将区间[beta_start, beta_end]等分num_diffusion_timesteps段返回return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)elif schedule_name =="cosine":# 余弦加噪方案return betas_for_alpha_bar(
num_diffusion_timesteps,lambda t: math.cos((t +0.008)/1.008* math.pi /2)**2,# 此处的0.008就是论文公式17中的s)else:raise NotImplementedError(f"unknown beta schedule: {schedule_name}")# 具体的余弦加噪方案,对应论文中公式17defbetas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas =[]# 记录为num_diffusion_timesteps步数的每个时间步上通过传入的lambda函数计算得到的方差# 结合论文公式17,即lambda函数为f(t),而beta_t = 1 - alpha_t=1-alpha_bat_t/alpha_bat_{t-1}# 公式17中alpha_bat_t=f(t)/f(0),则alpha_bat_{t-1}=f(t-1)/f(0),即有beta_t=1-f(t)/f(t-1)for i inrange(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps # f(t)
t2 =(i +1)/ num_diffusion_timesteps # f(t+1)
betas.append(min(1- alpha_bar(t2)/ alpha_bar(t1), max_beta))# 设置一个上限return np.array(betas)# 定义Unet模型预测均值的类型;使用enum.auto会只用为类中的三种类型分配1,2,3的数值classModelMeanType(enum.Enum):"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto()# the model predicts x_{t-1},预测上一时间步图像x_{t-1}的均值
START_X = enum.auto()# the model predicts x_0,直接预测x_0
EPSILON = enum.auto()# the model predicts epsilon,预测扩散步加载的损失ε# 定义Unet模型预测方差的类型classModelVarType(enum.Enum):"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()# 可学习的方差
FIXED_SMALL = enum.auto()# 固定的方差中的%\bar{\beta}_t%
FIXED_LARGE = enum.auto()# 固定的方差中的%\beta_t%
LEARNED_RANGE = enum.auto()# 学习两个方差之间的插值# 因为预测数据不同,损失类型也不同classLossType(enum.Enum):
MSE = enum.auto()# use raw MSE loss (and KL when learning variances);原始DDPM的MSE损失(如果是可学习方差,还会计算L_{vlb})
RESCALED_MSE =(enum.auto())# use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto()# use the variational lower-bound;只计算L_{vlb}
RESCALED_KL = enum.auto()# like KL, but rescale to estimate the full VLBdefis_vb(self):# 损失类型是否是变分下界return self == LossType.KL or self == LossType.RESCALED_KL
# 原始扩散过程类;训练和采样扩散模型的对象classGaussianDiffusion:"""
Utilities for training and sampling diffusion models.
Ported directly from here, and then adapted over time to further experimentation.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
:param model_mean_type: a ModelMeanType determining what the model outputs.
:param model_var_type: a ModelVarType determining how variance is output.
:param loss_type: a LossType determining the loss function to use.
:param rescale_timesteps: if True, pass floating point timesteps into the
model so that they are always scaled like in the
original paper (0 to 1000).
"""def__init__(
self,*,
betas,# 训练时间步对象的β
model_mean_type,# 模型预测均值类型
model_var_type,# 模型预测方差类型
loss_type,# 损失类型
rescale_timesteps=False,# 时间步序列是否进行调整;训练时设为False,预测采样时可设为True,减少扩散步数):
self.model_mean_type = model_mean_type # 预测的均值类型,即是预测噪声、还是直接预测x_{t-1}的均值或者是x_0
self.model_var_type = model_var_type # 预测的方差类型,是可学习的还是固定的
self.loss_type = loss_type # 计算的loss类型
self.rescale_timesteps = rescale_timesteps # 时间步序列是否进行rescale# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)# 原始的betas
self.betas = betas
# beta需要是一维的向量,只要在(0, 1]区间内assertlen(betas.shape)==1,"betas must be 1-D"assert(betas >0).all()and(betas <=1).all()
self.num_timesteps =int(betas.shape[0])# 训练用的原始扩散时间步数# 本函数中以下代码都是在计算高斯分布扩散过程中涉及到的固定量
alphas =1.0- betas # α
self.alphas_cumprod = np.cumprod(alphas, axis=0)# $\bar{\alpha}_t$
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])# $\bar{\alpha}_{t-1}$
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:],0.0)# $\bar{\alpha}_{t+1}$assert self.alphas_cumprod_prev.shape ==(self.num_timesteps,)# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0- self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0- self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0/ self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0/ self.alphas_cumprod -1)# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance =(
betas *(1.0- self.alphas_cumprod_prev)/(1.0- self.alphas_cumprod))# 对应于论文中公式10# log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain.# 后验分布方差在扩散模型开始处为0,计算对视时需要进行截断,就是用t=1时的值替代t=0时刻的值
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:]))# 后验分布计算均值公式的两个系数,对应于论文中公式11
self.posterior_mean_coef1 =(
betas * np.sqrt(self.alphas_cumprod_prev)/(1.0- self.alphas_cumprod))# 第一个系数
self.posterior_mean_coef2 =((1.0- self.alphas_cumprod_prev)* np.sqrt(alphas)/(1.0- self.alphas_cumprod))# 第二个系数# 对应论文公式8、9,基于x_0和t,计算x_t的分布defq_mean_variance(self, x_start, t):"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.x_0,没有噪声的输入
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.扩散步骤的数量(减去1),0意味着第一步
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""# _extract_into_tensor函数是把sqrt_alphas_cumprod中的第t个元素取出,与x_0相乘得到均值
mean =(_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)* x_start)# 均值
variance = _extract_into_tensor(1.0- self.alphas_cumprod, t, x_start.shape)# 方差
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)# 方差的对数return mean, variance, log_variance
# 从q(x_t | x_0)中采样图像defq_sample(self, x_start, t, noise=None):"""
Diffuse the data for a given number of diffusion steps
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.即x_t
"""if noise isNone:# 如果没有传入噪声
noise = th.randn_like(x_start)# 从标准分布中随机采样一个与x_0大小一致的噪音assert noise.shape == x_start.shape
return(
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)* x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)* noise)# 直接用公式9进行重参数采样得到x_t# 完整对应论文中的公式10和11,计算后验分布的均值和方差defq_posterior_mean_variance(self, x_start, x_t, t):"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""assert x_start.shape == x_t.shape
posterior_mean =(
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape)* x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape)* x_t)# 后验分布均值
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)# 后验分布方差
posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)assert(
posterior_mean.shape[0]== posterior_variance.shape[0]== posterior_log_variance_clipped.shape[0]== x_start.shape[0])return posterior_mean, posterior_variance, posterior_log_variance_clipped
# 通过模型(Unet),基于x_t预测x_{t-1}的均值与方差;即逆扩散过程的均值和方差,也会预测x_0defp_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].为True,将将去噪信号截断至[-1,1]
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.如果不是None,则是一个函数,该函数在x_start用于采样之前对x_start预测;在clip_denised之前应用
clip_denoised、denoised_fn两个参数为ddim方法所需
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.存储Unet所需的参数
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""if model_kwargs isNone:
model_kwargs ={}
B, C = x.shape[:2]# batch_size, channel_numsassert t.shape ==(B,)# 一个batch中每个图片输入都对应一个时间步t,故t的size为(batch_size,)# 虽然Unet输出的尺寸一样,但模型训练预测的目标不同,输出数据表示的含义不同
model_output = model(x, self._scale_timesteps(t),**model_kwargs)# 得到方差和对视方差if self.model_var_type in[ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:# 可学习的方差assert model_output.shape ==(B, C *2,*x.shape[2:])# 因为是可学习的方差,故此处的通道数乘2;Unet模型初始化时也如此设置
model_output, model_var_values = th.split(model_output, C, dim=1)# 分割后的model_output就是Uner预测的均值if self.model_var_type == ModelVarType.LEARNED:# 直接预测方差
model_log_variance = model_var_values # Unet预测的直接是方差的对数
model_variance = th.exp(model_log_variance)# 方差的对数取exp得到真实方差数值else:# Unet预测是方差插值的系数,在[-1, 1]之间,# 见公式15
min_log = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x.shape)# $\log\bar{\beta}_t$
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)# $\log\beta_t$# The model_var_values is [-1, 1] for [min_var, max_var].
frac =(model_var_values +1)/2# frac即论文公式14中的v,将值转换为[0, 1]区间
model_log_variance = frac * max_log +(1- frac)* min_log
model_variance = th.exp(model_log_variance)else:# 不可学习的方差
model_variance, model_log_variance ={# for fixedlarge, we set the initial (log-)variance like so# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE:(
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),),
ModelVarType.FIXED_SMALL:(
self.posterior_variance,
self.posterior_log_variance_clipped,),}[self.model_var_type]# 先在字典中为两种固定方差设置对应的数值,然后用模型的方差类型获取对应的方差数值# 基于时间步t,获取对应的固定方差和方差对数
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)defprocess_xstart(x):# 对x进行处理if denoised_fn isnotNone:
x = denoised_fn(x)if clip_denoised:return x.clamp-(1,1)return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X:# 预测x_{t-1}的期望值,或者说是均值
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output))# 论文公式11计算x_0
model_mean = model_output # 预测的均值elif self.model_mean_type in[ModelMeanType.START_X, ModelMeanType.EPSILON]:if self.model_mean_type == ModelMeanType.START_X:# 直接预测x_0
pred_xstart = process_xstart(model_output)else:# 预测eps的期望值
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))# 论文公式9计算x_0# 如果预测的不是均值,就只能通过公式11计算
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t)# 基于预测的x_0和x_t、t计算出t-1时刻的均值else:raise NotImplementedError(self.model_mean_type)assert(
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)return{"mean": model_mean,# 均值"variance": model_variance,# 方差"log_variance": model_log_variance,# 方差的对数"pred_xstart": pred_xstart,# x_0}# 对应论文公式9,调整后可通过x_t和噪声ε计算x_0def_predict_xstart_from_eps(self, x_t, t, eps):assert x_t.shape == eps.shape
return(
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape)* x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)* eps
)# 基于论文中的公式11,将公式转换以下就能基于均值μ和x_t求x_0;参数中的xprev就是Unet模型预测的均值def_predict_xstart_from_xprev(self, x_t, t, xprev):assert x_t.shape == xprev.shape
return(# (xprev - coef2*x_t) / coef1
_extract_into_tensor(1.0/ self.posterior_mean_coef1, t, x_t.shape)* xprev
- _extract_into_tensor(
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
)* x_t)# 基于公式9基于x_0和x_t计算噪声εdef_predict_eps_from_xstart(self, x_t, t, pred_xstart):return(
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape)* x_t
- pred_xstart
)/ _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)# 返回原始训练时间步的tdef_scale_timesteps(self, t):if self.rescale_timesteps:return t.float()*(1000.0/ self.num_timesteps)return t
# 从x_t预测x_{t-1},推理过程defp_sample(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""# 计算出t-1时刻的均值和方差
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,)
noise = th.randn_like(x)
nonzero_mask =((t !=0).float().view(-1,*([1]*(len(x.shape)-1))))# no noise when t == 0,t=0时刻时没有噪声
sample = out["mean"]+ nonzero_mask * th.exp(0.5* out["log_variance"])* noise # 重参数采样x=σ*ε+μreturn{"sample": sample,"pred_xstart": out["pred_xstart"]}# x_{t-1}, x_0; 每一个逆扩散步都会预测一次x_0# 模型训练完成后,使用Unet模型进行采样,会将整个逆扩散时间步每一步采样的图片输出defp_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,):"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final =Nonefor sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,):
final = sample
return final["sample"]defp_sample_loop_progressive(
self,
model,
shape,
noise=None,# 逆扩散第一步,T时刻的标准噪音
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,):"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.从模型中生成样本,并从扩散的每个时间步产生中间样本。
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""if device isNone:
device =next(model.parameters()).device
assertisinstance(shape,(tuple,list))if noise isnotNone:
img = noise
else:
img = th.randn(*shape, device=device)# 对时间步序列进行倒序排序,因为逆扩散过程与扩散过程是反向的
indices =list(range(self.num_timesteps))[::-1]if progress:# Lazy import so that we don't depend on tqdm.from tqdm.auto import tqdm
indices = tqdm(indices)for i in indices:
t = th.tensor([i]* shape[0], device=device)# 取出的时间twith th.no_grad():# 逆扩散过程进行图像采样时,Unet不需要计算梯度
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,)# 预测x_{t-1}和x_0yield out # 以迭代的方式输出每个时间步采样的图像
img = out["sample"]# 当前的x_{t-1}为下一个时刻输入的img# DDIM论文提出的采样方式defddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
eta=0.0,):"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,)# Usually our model outputs epsilon, but we re-derive it# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma =(
eta
* th.sqrt((1- alpha_bar_prev)/(1- alpha_bar))* th.sqrt(1- alpha_bar / alpha_bar_prev))# Equation 12.
noise = th.randn_like(x)
mean_pred =(
out["pred_xstart"]* th.sqrt(alpha_bar_prev)+ th.sqrt(1- alpha_bar_prev - sigma **2)* eps
)
nonzero_mask =((t !=0).float().view(-1,*([1]*(len(x.shape)-1))))# no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return{"sample": sample,"pred_xstart": out["pred_xstart"]}defddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
eta=0.0,):"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""assert eta ==0.0,"Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,)# Usually our model outputs epsilon, but we re-derive it# in case we used x_start or x_prev prediction.
eps =(
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape)* x
- out["pred_xstart"])/ _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)# Equation 12. reversed
mean_pred =(
out["pred_xstart"]* th.sqrt(alpha_bar_next)+ th.sqrt(1- alpha_bar_next)* eps
)return{"sample": mean_pred,"pred_xstart": out["pred_xstart"]}defddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,):"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final =Nonefor sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,):
final = sample
return final["sample"]defddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,):"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""if device isNone:
device =next(model.parameters()).device
assertisinstance(shape,(tuple,list))if noise isnotNone:
img = noise
else:
img = th.randn(*shape, device=device)
indices =list(range(self.num_timesteps))[::-1]if progress:# Lazy import so that we don't depend on tqdm.from tqdm.auto import tqdm
indices = tqdm(indices)for i in indices:
t = th.tensor([i]* shape[0], device=device)with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
eta=eta,)yield out
img = out["sample"]# 计算损失L_{vlb},即需要优化的KL散度def_vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""# 真实的x_0、x_t和t计算出x_{t-1}的均值与方差;即论文中q(x_{t-1} | x_t, x_0)的分布
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t)# x_t、t和预测的x_0去计算出x_{t-1}的均值与方差,使用神经网络预测的;即论文中p_θ(x_{t-1} | x_0)的分布
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)# p_{theta}和q分布之间的KL散度,对应L_{t-1}损失函数
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
kl = mean_flat(kl)/ np.log(2.0)# 对应L_0损失函数, TODO 搞清楚
decoder_nll =-discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5* out["log_variance"])assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll)/ np.log(2.0)# At the first timestep return the decoder NLL,# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))# t=0时刻,用离散的高斯分布去计算似然;t>0时刻,直接用KL散度
output = th.where((t ==0), decoder_nll, kl)return{"output": output,"pred_xstart": out["pred_xstart"]}# 计算训练损失;三种方法:只学习vb、只学习MSE loss(就是DDPM中提出的只学习噪声的损失,即L_{simple})、同时学习vb和MSE loss,即L_{hybrid}deftraining_losses(self, model, x_start, t, model_kwargs=None, noise=None):"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""if model_kwargs isNone:
model_kwargs ={}if noise isNone:
noise = th.randn_like(x_start)# 用于扩散过程中和x_0一起计算x_t# 基于x_0和任意时刻t以及噪音采样出x_t
x_t = self.q_sample(x_start, t, noise=noise)
terms ={}if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:# 如果是计算L_{vlb}
terms["loss"]= self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,)["output"]if self.loss_type == LossType.RESCALED_KL:
terms["loss"]*= self.num_timesteps
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:# 如果需要计算MSE损失
model_output = model(x_t, self._scale_timesteps(t),**model_kwargs)# Unet模型的预测输出# 如果模型会预测可学习方差,还是需要计算L_{vlb}if self.model_var_type in[ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE,]:
B, C = x_t.shape[:2]assert model_output.shape ==(B, C *2,*x_t.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)# Learn the variance using the variational bound, but don't let# it affect our mean prediction.# 使用变分界学习方差,但不要让其影响均值预测;故将方差对应的维度数detach()
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"]= self._vb_terms_bpd(# 此处不使用model,而是定义了一个lambda函数直接把frozen_out返回,因为frozen_out就是在前面通过model计算出来的# _vb_terms_bpd内部使用model也是一样的计算需求,故可直接把frozen_out返回
model=lambda*args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,)["output"]if self.loss_type == LossType.RESCALED_MSE:# Divide by 1000 for equivalence with initial implementation.# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"]*= self.num_timesteps /1000.0# 根据设置的均值类型,获取对应的实际均值数据,作为MSE损失中的target
target ={
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,# 此处的noise就是前面从x_0计算x_t时叠加的损失}[self.model_mean_type]assert model_output.shape == target.shape == x_start.shape
terms["mse"]= mean_flat((target - model_output)**2)# 计算mse损失if"vb"in terms:# 如果计算了L_{vlb}
terms["loss"]= terms["mse"]+ terms["vb"]# 相当于L_{hybrid}else:
terms["loss"]= terms["mse"]# 相当于L_{simple}else:raise NotImplementedError(self.loss_type)return terms
def_prior_bpd(self, x_start):"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps -1]* batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)return mean_flat(kl_prior)/ np.log(2.0)defcalc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
- vb: an [N x T] tensor of terms in the lower-bound.
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
"""
device = x_start.device
batch_size = x_start.shape[0]
vb =[]
xstart_mse =[]
mse =[]for t inlist(range(self.num_timesteps))[::-1]:
t_batch = th.tensor([t]* batch_size, device=device)
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)# Calculate VLB term at the current timestepwith th.no_grad():
out = self._vb_terms_bpd(
model,
x_start=x_start,
x_t=x_t,
t=t_batch,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,)
vb.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"]- x_start)**2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
mse.append(mean_flat((eps - noise)**2))
vb = th.stack(vb, dim=1)
xstart_mse = th.stack(xstart_mse, dim=1)
mse = th.stack(mse, dim=1)
prior_bpd = self._prior_bpd(x_start)
total_bpd = vb.sum(dim=1)+ prior_bpd
return{"total_bpd": total_bpd,"prior_bpd": prior_bpd,"vb": vb,"xstart_mse": xstart_mse,"mse": mse,}# 从传入的一维序列中抽取时间步timesteps上的数值返回def_extract_into_tensor(arr, timesteps, broadcast_shape):"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()whilelen(res.shape)<len(broadcast_shape):
res = res[...,None]return res.expand(broadcast_shape)
from abc import abstractmethod
import math
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from.fp16_util import convert_module_to_f16, convert_module_to_f32
from.nn import(
SiLU,
conv_nd,
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
checkpoint,)# 继承该类的模块在进行forward计算时需要时间步嵌入作为参数参与计算classTimestepBlock(nn.Module):"""
Any module where forward() takes timestep embeddings as a second argument.
"""@abstractmethoddefforward(self, x, emb):"""
Apply the module to `x` given `emb` timestep embeddings.
"""classTimestepEmbedSequential(nn.Sequential, TimestepBlock):"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.一个顺序模块,将时间步嵌入传递给支持它作为额外输入的子模块
"""defforward(self, x, emb):for layer in self:# self是有多个模块顺序连接而成# 如果当前遍历的layer是TimestepBlock类,就要使用时间步emb进行计算# 其实Unet架构中只有ResBlock是继承TimestepBlock的,故只有在ResBlock中会传入embifisinstance(layer, TimestepBlock):
x = layer(x, emb)else:
x = layer(x)return x
# 上采样模块classUpsample(nn.Module):"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""def__init__(self, channels, use_conv, dims=2):super().__init__()
self.channels = channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, channels, channels,3, padding=1)defforward(self, x):assert x.shape[1]== self.channels
# 使用插值的方式进行上采样if self.dims ==3:
x = F.interpolate(
x,(x.shape[2], x.shape[3]*2, x.shape[4]*2), mode="nearest")else:
x = F.interpolate(x, scale_factor=2, mode="nearest")if self.use_conv:
x = self.conv(x)# 上采样后再使用卷积层进行一次映射;空间尺寸和通道数都不变return x
# 下采样模块classDownsample(nn.Module):"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""def__init__(self, channels, use_conv, dims=2):super().__init__()
self.channels = channels
self.use_conv = use_conv
self.dims = dims
stride =2if dims !=3else(1,2,2)if use_conv:# 使用带stride的卷积下采样
self.op = conv_nd(dims, channels, channels,3, stride=stride, padding=1)else:# 使用平均池化下采样
self.op = avg_pool_nd(stride)defforward(self, x):assert x.shape[1]== self.channels
return self.op(x)# 继承TimestepBlock类的残差块,可选择性的调整通道数classResBlock(TimestepBlock):"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
"""def__init__(
self,
channels,# 输入通道数
emb_channels,# 时间步嵌入的通道数
dropout,
out_channels=None,# 输出通道数
use_conv=False,# 如果为true,且out_channels存在,就使用空间卷积代替1X1卷积在skip连接中改变通道数
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,):super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels # 如果没有指定输出通道数,输出通道数与输入通道数一样,即不改变通道数
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),# 在通道上分group归一化
SiLU(),
conv_nd(dims, channels, self.out_channels,3, padding=1),)
self.emb_layers = nn.Sequential(
SiLU(),
linear(
emb_channels,2* self.out_channels if use_scale_shift_norm else self.out_channels,),)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
SiLU(),
nn.Dropout(p=dropout),
zero_module(# 将传入的模块参数全部设置为0后返回模块
conv_nd(dims, self.out_channels, self.out_channels,3, padding=1)),)if self.out_channels == channels:# 如果输入输出通道数不变
self.skip_connection = nn.Identity()# skip连接是一个恒等输出elif use_conv:# 如果使用卷积改变通道数
self.skip_connection = conv_nd(
dims, channels, self.out_channels,3, padding=1)# 卷积计算,空间尺寸不变,通道数改变else:
self.skip_connection = conv_nd(dims, channels, self.out_channels,1)# 1X1卷积改变通道defforward(self, x, emb):"""
Apply the block to a Tensor, conditioned on a timestep embedding.此模块应用于条件是时间步嵌入的张量
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""return checkpoint(
self._forward,(x, emb), self.parameters(), self.use_checkpoint)def_forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)whilelen(emb_out.shape)<len(h.shape):# 如果时间步嵌入的尺寸小于特征图尺寸,就用None填充
emb_out = emb_out[...,None]if self.use_scale_shift_norm:# 使用缩放偏移正则融合时间步嵌入
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out,2, dim=1)# 将emb_out按维度1拆分为两部分,作为缩放量和偏移量
h = out_norm(h)*(1+ scale)+ shift
h = out_rest(h)else:
h = h + emb_out # 直接相加
h = self.out_layers(h)return self.skip_connection(x)+ h # residual连接# 自注意力模块classAttentionBlock(nn.Module):"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""def__init__(self, channels, num_heads=1, use_checkpoint=False):super().__init__()
self.channels = channels
self.num_heads = num_heads # 注意力头数
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels *3,1)
self.attention = QKVAttention()
self.proj_out = zero_module(conv_nd(1, channels, channels,1))defforward(self, x):return checkpoint(self._forward,(x,), self.parameters(), self.use_checkpoint)def_forward(self, x):
b, c,*spatial = x.shape
x = x.reshape(b, c,-1)# 将输入的空间维度拉成一个维度,[b, c, w*h]
qkv = self.qkv(self.norm(x))# 先在通道数上分group正则后将维度扩充至3倍,[b, 3c, w*h],就是将向量在维度上复制三份作为q、k、v
qkv = qkv.reshape(b * self.num_heads,-1, qkv.shape[2])# 将qkv进行头数分割,[b*num_heads, 3c/num_heads, w*h]
h = self.attention(qkv)# 多头注意力计算,输出[b*num_heads, c/num_heads, w*h]
h = h.reshape(b,-1, h.shape[-1])# reshape为[b, c, w*h]
h = self.proj_out(h)# [b, c, w*h]return(x + h).reshape(b, c,*spatial)# 与输入x残差连接后还原为[b, c, w, h],尺寸不变# 注意力计算classQKVAttention(nn.Module):"""
A module which performs QKV attention.
"""defforward(self, qkv):"""
Apply QKV attention.
:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
:return: an [N x C x T] tensor after attention.
"""
ch = qkv.shape[1]//3# 计算出单个q、k、v的通道,c/num_heads
q, k, v = th.split(qkv, ch, dim=1)# 将qkv在维度1,即通道维度分为三分,得到q、k、v,尺寸都是[b*num_heads, c/num_heads, w*h]
scale =1/ math.sqrt(math.sqrt(ch))# 缩放系数# th.einsum是爱因斯坦求和约定,用于简洁表示乘积、点积、转置等运算:"bct,bcs->bts"表示b维度不变,c、t和c、s矩阵相乘得到t、s
weight = th.einsum("bct,bcs->bts", q * scale, k * scale # 因为scale是ch进行两次开放后的倒数,故此处两个sacle相乘后正好是ch的开方,不用进行除法操作)# More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)# 对同一行的数据及逆行一个softmax运算return th.einsum("bts,bcs->bct", weight, v)# weight与v进行矩阵相乘,得到最终的输出,尺寸不变,仍为[b*num_heads, c/num_heads, w*h]@staticmethoddefcount_flops(model, _x, y):"""
A counter for the `thop` package to count the operations in an
attention operation.“thop”包的计数器,用于计算注意操作中的运算量
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c,*spatial = y[0].shape
num_spatial =int(np.prod(spatial))# We perform two matmuls with the same number of ops.# The first computes the weight matrix, the second computes# the combination of the value vectors.
matmul_ops =2* b *(num_spatial **2)* c
model.total_ops += th.DoubleTensor([matmul_ops])# 带多头自注意力的UnetclassUNetModel(nn.Module):"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
"""def__init__(
self,
in_channels,# 输入通道
model_channels,# 模型基础通道数
out_channels,# 输出通道
num_res_blocks,# 每次上、下采样时的残差块数量
attention_resolutions,# 采样过程中进行自注意力计算的下采样率,用于判断在模型何处添加自注意力层
dropout=0,
channel_mult=(1,2,4,8),# Unet每层的通道数乘子
conv_resample=True,# True表示使用可学习的卷积进行上、下采样
dims=2,
num_classes=None,# 类别数,如果存在,用于表示以图形类别为条件的条件嵌入
use_checkpoint=False,# 是否使用梯度checkpoint减少内存使用
num_heads=1,# 自注意力头数
num_heads_upsample=-1,
use_scale_shift_norm=False,):super().__init__()if num_heads_upsample ==-1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.num_heads = num_heads
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels *4# 时间步嵌入层,因为扩散过程与时间绑定,并且Unet模型中的ResBlock模块计算时需要传入时间步嵌入
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
SiLU(),
linear(time_embed_dim, time_embed_dim),)# 如果是条件生成,还会有一个label_emb,条件嵌入if self.num_classes isnotNone:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)# u-net输入及下采样部分网络架构,即为Unet左侧部分模块
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels,3, padding=1))])
input_block_chans =[model_channels]# 记录左侧部分各层的输出通道数,便于后续与右侧部分模块各层连接
ch = model_channels
ds =1# 表示下采样率for level, mult inenumerate(channel_mult):# 一次循环就是一次下采样for _ inrange(num_res_blocks):# 每次下采样设置num_res_blocks个ResBlock模块
layers =[
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,# 通道数在扩大
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,)]
ch = mult * model_channels # 上一个ResBlock的输出通道数是下一个ResBlock的输入通道数,更新chif ds in attention_resolutions:# 如果当前的下采样率在attention_resolutions中,就添加AttentionBlock
layers.append(
AttentionBlock(
ch, use_checkpoint=use_checkpoint, num_heads=num_heads
))# 将layers中存储的所有层添加在TimestepEmbedSequential类的容器中,然后再添加到记录左侧部分模块的input_blocks中
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_chans.append(ch)# 记录该层下采样前的输出通道if level !=len(channel_mult)-1:# 如果当前不是最后一层下采样层
self.input_blocks.append(# 添加一个下采样模块;上面layers中存储的ResBlock和AttentionBlock只是对通道数进行了调整,还没有改变空间尺寸
TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)))
input_block_chans.append(ch)# 记录该层下采样后的输出通道
ds *=2# 下采样率乘2# Unet的中间部分,特征图的空间尺寸和通道数没有变化
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,),
AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,),)# Unet右侧上采样部分模块
self.output_blocks = nn.ModuleList([])# 上采样过程,通道数乘子应该是倒叙遍历,取值分别为[(3, 8), (2, 4), (1, 2), (0, 1)]for level, mult inlist(enumerate(channel_mult))[::-1]:for i inrange(num_res_blocks +1):
layers =[
ResBlock(
ch + input_block_chans.pop(),# 右侧上采样时会与左侧下采样中的各层的输出直接相加,故通道维度需要增加ch
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,)]
ch = model_channels * mult
if ds in attention_resolutions:# 添加自注意力层
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,))# level是从3->0,意思就是在最后一个上采样层(从下到上)之前的层,在最后一个ResBlock后面添加一个上采样模块if level and i == num_res_blocks:
layers.append(Upsample(ch, conv_resample, dims=dims))
ds //=2# 下采样率随着上采样的进行减小
self.output_blocks.append(TimestepEmbedSequential(*layers))# 将layers中存储的所有模块解包赋给out_blocks# 输出部分模块
self.out = nn.Sequential(
normalization(ch),
SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels,3, padding=1)),)# 将模块中的参数转换为16位的半精度defconvert_to_fp16(self):"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)# 将模块中的参数转换为32位的全精度defconvert_to_fp32(self):"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)@propertydefinner_dtype(self):# 返回模型使用的数据类型"""
Get the dtype used by the torso of the model.
"""returnnext(self.input_blocks.parameters()).dtype
defforward(self, x, timesteps, y=None):"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.输入数据,即图像
:param timesteps: a 1-D batch of timesteps.batch中每个图像对应在扩散过程中的时间步t
:param y: an [N] Tensor of labels, if class-conditional.图像的类别标签
:return: an [N x C x ...] Tensor of outputs.
"""# y必须与num_classes同时存在assert(y isnotNone)==(
self.num_classes isnotNone),"must specify y if and only if the model is class-conditional"
hs =[]# 存储下采样每层的输出特征图
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))# 时间嵌入if self.num_classes isnotNone:assert y.shape ==(x.shape[0],)
emb = emb + self.label_emb(y)# 将时间嵌入和条件嵌入相加
h = x.type(self.inner_dtype)# 数据类型转换for module in self.input_blocks:# 下采样过程
h = module(h, emb)
hs.append(h)# 记录每层的输出
h = self.middle_block(h, emb)# 中间部分for module in self.output_blocks:# 上采样部分
cat_in = th.cat([h, hs.pop()], dim=1)# 上采样之前每次加上左侧对应层的输出特征
h = module(cat_in, emb)
h = h.type(x.dtype)return self.out(h)# 返回Unet所有的中间层特征图张量defget_feature_vectors(self, x, timesteps, y=None):"""
Apply the model and return all of the intermediate tensors.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: a dict with the following keys:
- 'down': a list of hidden state tensors from downsampling.
- 'middle': the tensor of the output of the lowest-resolution
block in the model.
- 'up': a list of hidden state tensors from upsampling.
"""
hs =[]
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))if self.num_classes isnotNone:assert y.shape ==(x.shape[0],)
emb = emb + self.label_emb(y)
result =dict(down=[], up=[])# 记录上、下采样过程中产生的中间张量
h = x.type(self.inner_dtype)for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
result["down"].append(h.type(x.dtype))# 记录下采样每层的中间张量
h = self.middle_block(h, emb)
result["middle"]= h.type(x.dtype)# 记录中间层的输出for module in self.output_blocks:
cat_in = th.cat([h, hs.pop()], dim=1)
h = module(cat_in, emb)
result["up"].append(h.type(x.dtype))# 记录上采样每层的中间张量return result
classSuperResModel(UNetModel):"""
A UNetModel that performs super-resolution.执行超分辨率的Unet模型
Expects an extra kwarg `low_res` to condition on a low-resolution image.
在低分辨率图像上计算需要一个额外的low_res参数作为条件
"""def__init__(self, in_channels,*args,**kwargs):super().__init__(in_channels *2,*args,**kwargs)defforward(self, x, timesteps, low_res=None,**kwargs):
_, _, new_height, new_width = x.shape
# 通过双线性插值将低分辨率上采样到高分辨率
upsampled = F.interpolate(low_res,(new_height, new_width), mode="bilinear")
x = th.cat([x, upsampled], dim=1)# 将插值得到的数据和输入x在通道数上拼接returnsuper().forward(x, timesteps,**kwargs)# 基于新的x数据直接调用父类UNetModel的forward函数defget_feature_vectors(self, x, timesteps, low_res=None,**kwargs):
_, new_height, new_width, _ = x.shape
upsampled = F.interpolate(low_res,(new_height, new_width), mode="bilinear")
x = th.cat([x, upsampled], dim=1)returnsuper().get_feature_vectors(x, timesteps,**kwargs)