扩散的代码实现
本文承接 Pytorch实现扩散模型【DDPM代码解读篇1】http://t.csdnimg.cn/aDK0A
主要介绍“扩散是如何实现的”。代码逻辑清晰,可快速上手学习。
# 扩散的代码实现
# 扩散过程是训练部分的模型。它打开了一个采样接口,允许我们使用已经训练好的模型生成样本。
class DiffusionModel(nn.Module):
# 类变量,用于将字符串调度器名称映射到相应的调度函数
SCHEDULER_MAPPING = {
"linear": linear_beta_schedule,
"cosine": cosine_beta_schedule,
"sigmoid": sigmoid_beta_schedule,
}
def __init__(
self,
model: nn.Module,
image_size: int,
*,
beta_scheduler: str = "linear", # 调度器类型,默认为线性
timesteps: int = 1000,
schedule_fn_kwargs: dict | None = None, # 调度函数的关键字参数,默认为 None
auto_normalize: bool = True,
) -> None:
super().__init__()
self.model = model
self.channels = self.model.channels
self.image_size = image_size
# 从 SCHEDULER_MAPPING 字典中获取与 beta_scheduler 字符串相对应的调度函数
# 如果 beta_scheduler 字符串不存在于字典中,则返回 None
self.beta_scheduler_fn = self.SCHEDULER_MAPPING.get(beta_scheduler)
# 检查获取到的调度函数是否为 None,即检查是否成功选择了β调度函数
# 如果调度函数为 None,则说明指定的 beta_scheduler 字符串不在预定义的调度函数列表中,于是抛出 ValueError 异常
if self.beta_scheduler_fn is None:
raise ValueError(f"unknown beta schedule {beta_scheduler}")
# 检查是否提供了调度函数的关键字参数。若未提供,将schedule_fn_kwargs 设置为空字典。
if schedule_fn_kwargs is None:
schedule_fn_kwargs = {}
# 用于计算扩散模型中的β调度函数,以及与β相关的其他参数,如α和后验方差:
betas = self.beta_scheduler_fn(timesteps, **schedule_fn_kwargs) # 生成一个包含β值的张量 betas
alphas = 1.0 - betas
# 对α值进行累积乘积,得到一个新的张量 alphas_cumprod,其形状与 betas 相同,包含了从0到 timesteps-1 时间步的所有α值的乘积。
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
'''
对 alphas_cumprod 进行填充操作,将其第一个元素用 1.0 填充,以确保在计算后验方差时不会出现除以零的情况。
F.pad 函数用于在张量的指定维度上进行填充,这里在维度 0 上进行填充,向左填充一个元素。
'''
# 计算后验方差
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# 注册缓冲区(buffer),并将每个相关的张量转换为 torch.float32 类型
register_buffer = lambda name, val: self.register_buffer(
name, val.to(torch.float32)
)
register_buffer("betas", betas) # 包含 β 值的张量
register_buffer("alphas_cumprod", alphas_cumprod) # 包含 α 累积乘积的张量
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) # 包含 α 累积乘积的前一个时间步的张量
register_buffer("sqrt_recip_alphas", torch.sqrt(1.0 / alphas)) # α 的倒数的平方根的张量
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) # α 累积乘积的平方根的张量
register_buffer(
"sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
)
register_buffer("posterior_variance", posterior_variance) # 后验方差的张量
timesteps, *_ = betas.shape
'''
这里使用了“*”操作符,它的作用是在变量解构(destructuring)中丢弃不需要的部分。因为 betas 张量是一维的,所以这里的“*”操作符实际上没有起到什么作用,只是为了让代码更具通用性。
'''
self.num_timesteps = int(timesteps) # 将时间步数转换为整数
self.sampling_timesteps = timesteps
# 归一化
# auto_normalize 为 True,则选择 normalize_to_neg_one_to_one 函数进行归一化;否则选择 identity 函数,即不进行归一化操作。
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
# auto_normalize 为 True,则选择 unnormalize_to_zero_to_one 函数进行反归一化;否则选择 identity 函数,即不进行反归一化操作。
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
@torch.inference_mode()
'''
可以将下面的函数或代码块置于推断模式中。这意味着,在装饰器声明的范围内,PyTorch 将禁用梯度计算,不会跟踪梯度,也不会进行任何与梯度相关的操作。这有助于提高推断速度,并且可以确保模型在进行推断时不会意外地进行训练相关的计算。
'''
def p_sample(self, x: torch.Tensor, timestamp: int) -> torch.Tensor:
# 使用了解构语法 *,将张量形状中的除最后一维之外的所有维度忽略掉,并将结果赋值给一个名为 _ 的临时变量,最后一个维度被赋值给 device
b, *_, device = *x.shape, x.device
batched_timestamps = torch.full(
(b,), timestamp, device=device, dtype=torch.long
)
# 创建了一个形状为 (b,) 的张量 batched_timestamps,用于存储批次中每个样本的时间戳。
# timestamp,其数据类型为 torch.long,并且张量存储在与输入张量相同的设备上
# 将输入张量 x 和时间戳张量 batched_timestamps 传递给模型 self.model,以获取预测值 preds
preds = self.model(x, batched_timestamps)
# 使用函数 extract 从预先计算的参数 self.betas 中提取与批次时间戳对应的β值
betas_t = extract(self.betas, batched_timestamps, x.shape)
sqrt_recip_alphas_t = extract(
self.sqrt_recip_alphas, batched_timestamps, x.shape
) # 提取与批次时间戳对应的α倒数的平方根
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, batched_timestamps, x.shape
) # 提取与批次时间戳对应的1减去α累积乘积的平方根
# 计算预测的样本均值
predicted_mean = sqrt_recip_alphas_t * (
x - betas_t * preds / sqrt_one_minus_alphas_cumprod_t
)
#如果时间戳为零,直接返回预测的样本均值;否则,计算样本的后验方差并添加噪声,然后返回结果。
if timestamp == 0:
return predicted_mean
else:
posterior_variance = extract(
self.posterior_variance, batched_timestamps, x.shape
)
noise = torch.randn_like(x)
return predicted_mean + torch.sqrt(posterior_variance) * noise
@torch.inference_mode()
def p_sample_loop(
self, shape: tuple, return_all_timesteps: bool = False
) -> torch.Tensor:
batch, device = shape[0], "mps" # 从形状元组中获取批量大小 batch,并设置设备为 "mps"(多处理器尺寸)
img = torch.randn(shape, device=device) # 函数生成一个具有指定形状的随机张量 img,其值服从标准正态分布
# This cause me a RunTimeError on MPS device due to MPS back out of memory
# No ideas how to resolve it at this point
# imgs = [img]
'''
使用 tqdm 函数创建一个迭代进度条,迭代范围是从 0 到 self.num_timesteps 的逆序。
每个时间步长 t 都会调用 p_sample 方法进行样本采样,并更新 img 的值。
'''
for t in tqdm(reversed(range(0, self.num_timesteps)), total=self.num_timesteps):
img = self.p_sample(img, t)
# imgs.append(img)
'''
将每个时间步长的采样结果添加到一个列表 imgs 中。在循环中,每次迭代会生成一个新的采样结果,并将其添加到列表中
允许在函数结束后返回所有时间步长的采样结果,以便进一步分析或处理
'''
# 最终的采样结果
ret = img # if not return_all_timesteps else torch.stack(imgs, dim=1)
# 调用 unnormalize 方法将最终的采样结果反归一化,使其返回到原始数据范围内。
ret = self.unnormalize(ret)
return ret
# return_all_timesteps指定是否返回所有时间步长的样本,默认为 False,表示只返回最终时间步长的样本
def sample(
self, batch_size: int = 16, return_all_timesteps: bool = False
) -> torch.Tensor:
shape = (batch_size, self.channels, self.image_size, self.image_size)
return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)
# 用于在给定时间步长 t 上生成样本
def q_sample(
self, x_start: torch.Tensor, t: int, noise: torch.Tensor = None
) -> torch.Tensor:
# 首先检查是否提供了噪声
if noise is None:
noise = torch.randn_like(x_start)
# 接着根据 t 从预先计算的参数中提取相应的系数
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
# 最后根据扩散过程的定义,计算并返回生成的样本
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def p_loss(
self,
x_start: torch.Tensor,
t: int,
noise: torch.Tensor = None,
loss_type: str = "l2",
) -> torch.Tensor:
if noise is None:
noise = torch.randn_like(x_start)
x_noised = self.q_sample(x_start, t, noise=noise) # 在给定时间步长 t 上生成经过噪声处理的样本 x_noised
# 使用生成的 x_noised 作为输入,调用模型 self.model,并传入时间步长 t,以获取预测的噪声 predicted_noise。
predicted_noise = self.model(x_noised, t)
if loss_type == "l2": # 均方误差损失函数
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "l1": # 绝对值误差损失函数
loss = F.l1_loss(noise, predicted_noise)
else:
raise ValueError(f"unknown loss type {loss_type}")
return loss
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w, device, img_size = *x.shape, x.device, self.image_size
assert h == w == img_size, f"image size must be {img_size}"
# 解析输入 x 的形状,并确保输入的图像是正方形且大小与 image_size 相同。
# 生成一个随机的时间步长 timestamp,范围在 [0, num_timesteps) 内
timestamp = torch.randint(0, self.num_timesteps, (1,)).long().to(device)
x = self.normalize(x)
return self.p_loss(x, timestamp)
Life is a journey. We pursue love and light with purity.