Pytorch实现扩散模型【DDPM代码解读篇2】

发布于:2024-05-06 ⋅ 阅读:(29) ⋅ 点赞:(0)

扩散的代码实现

本文承接  Pytorch实现扩散模型【DDPM代码解读篇1】http://t.csdnimg.cn/aDK0A

主要介绍“扩散是如何实现的”。代码逻辑清晰,可快速上手学习。

# 扩散的代码实现
# 扩散过程是训练部分的模型。它打开了一个采样接口,允许我们使用已经训练好的模型生成样本。
class DiffusionModel(nn.Module):
	# 类变量,用于将字符串调度器名称映射到相应的调度函数
    SCHEDULER_MAPPING = {
        "linear": linear_beta_schedule,
        "cosine": cosine_beta_schedule,
        "sigmoid": sigmoid_beta_schedule,
    }
 
    def __init__(
        self,
        model: nn.Module,
        image_size: int,
        *,
        beta_scheduler: str = "linear",  # 调度器类型,默认为线性
        timesteps: int = 1000,
        schedule_fn_kwargs: dict | None = None,  # 调度函数的关键字参数,默认为 None
        auto_normalize: bool = True,
    ) -> None:
        super().__init__()
        self.model = model
 
        self.channels = self.model.channels
        self.image_size = image_size
 
 		# 从 SCHEDULER_MAPPING 字典中获取与 beta_scheduler 字符串相对应的调度函数
 		# 如果 beta_scheduler 字符串不存在于字典中,则返回 None
        self.beta_scheduler_fn = self.SCHEDULER_MAPPING.get(beta_scheduler)
        # 检查获取到的调度函数是否为 None,即检查是否成功选择了β调度函数
        # 如果调度函数为 None,则说明指定的 beta_scheduler 字符串不在预定义的调度函数列表中,于是抛出 ValueError 异常
        if self.beta_scheduler_fn is None:
            raise ValueError(f"unknown beta schedule {beta_scheduler}")
 		# 检查是否提供了调度函数的关键字参数。若未提供,将schedule_fn_kwargs 设置为空字典。
        if schedule_fn_kwargs is None:
            schedule_fn_kwargs = {}
 		
 		# 用于计算扩散模型中的β调度函数,以及与β相关的其他参数,如α和后验方差:
        betas = self.beta_scheduler_fn(timesteps, **schedule_fn_kwargs)  # 生成一个包含β值的张量 betas
        alphas = 1.0 - betas

        # 对α值进行累积乘积,得到一个新的张量 alphas_cumprod,其形状与 betas 相同,包含了从0到 timesteps-1 时间步的所有α值的乘积。
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        '''
        	对 alphas_cumprod 进行填充操作,将其第一个元素用 1.0 填充,以确保在计算后验方差时不会出现除以零的情况。
			F.pad 函数用于在张量的指定维度上进行填充,这里在维度 0 上进行填充,向左填充一个元素。
        '''

        # 计算后验方差
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
 
 		# 注册缓冲区(buffer),并将每个相关的张量转换为 torch.float32 类型
        register_buffer = lambda name, val: self.register_buffer(
            name, val.to(torch.float32)
        )
 
        register_buffer("betas", betas)  # 包含 β 值的张量
        register_buffer("alphas_cumprod", alphas_cumprod)  # 包含 α 累积乘积的张量
        register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)  # 包含 α 累积乘积的前一个时间步的张量
        register_buffer("sqrt_recip_alphas", torch.sqrt(1.0 / alphas))  # α 的倒数的平方根的张量
        register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))  # α 累积乘积的平方根的张量
        register_buffer(
            "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
        )
        register_buffer("posterior_variance", posterior_variance)  # 后验方差的张量
 
        timesteps, *_ = betas.shape
        '''
        	这里使用了“*”操作符,它的作用是在变量解构(destructuring)中丢弃不需要的部分。因为 betas 张量是一维的,所以这里的“*”操作符实际上没有起到什么作用,只是为了让代码更具通用性。
        '''
        self.num_timesteps = int(timesteps)  # 将时间步数转换为整数
 
        self.sampling_timesteps = timesteps
 
 		# 归一化
 		# auto_normalize 为 True,则选择 normalize_to_neg_one_to_one 函数进行归一化;否则选择 identity 函数,即不进行归一化操作。
        self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity

        #  auto_normalize 为 True,则选择 unnormalize_to_zero_to_one 函数进行反归一化;否则选择 identity 函数,即不进行反归一化操作。
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
 
    @torch.inference_mode()
    '''
    	可以将下面的函数或代码块置于推断模式中。这意味着,在装饰器声明的范围内,PyTorch 将禁用梯度计算,不会跟踪梯度,也不会进行任何与梯度相关的操作。这有助于提高推断速度,并且可以确保模型在进行推断时不会意外地进行训练相关的计算。
    '''

    def p_sample(self, x: torch.Tensor, timestamp: int) -> torch.Tensor:
    	# 使用了解构语法 *,将张量形状中的除最后一维之外的所有维度忽略掉,并将结果赋值给一个名为 _ 的临时变量,最后一个维度被赋值给 device
        b, *_, device = *x.shape, x.device

        batched_timestamps = torch.full(
            (b,), timestamp, device=device, dtype=torch.long
        )
        # 创建了一个形状为 (b,) 的张量 batched_timestamps,用于存储批次中每个样本的时间戳。
        # timestamp,其数据类型为 torch.long,并且张量存储在与输入张量相同的设备上
 		
 		# 将输入张量 x 和时间戳张量 batched_timestamps 传递给模型 self.model,以获取预测值 preds
        preds = self.model(x, batched_timestamps)
 		
 		# 使用函数 extract 从预先计算的参数 self.betas 中提取与批次时间戳对应的β值
        betas_t = extract(self.betas, batched_timestamps, x.shape)
        sqrt_recip_alphas_t = extract(
            self.sqrt_recip_alphas, batched_timestamps, x.shape
        )  # 提取与批次时间戳对应的α倒数的平方根
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, batched_timestamps, x.shape
        )  # 提取与批次时间戳对应的1减去α累积乘积的平方根
 		
 		# 计算预测的样本均值
        predicted_mean = sqrt_recip_alphas_t * (
            x - betas_t * preds / sqrt_one_minus_alphas_cumprod_t
        )
 		
 		#如果时间戳为零,直接返回预测的样本均值;否则,计算样本的后验方差并添加噪声,然后返回结果。
        if timestamp == 0:
            return predicted_mean
        else:
            posterior_variance = extract(
                self.posterior_variance, batched_timestamps, x.shape
            )
            noise = torch.randn_like(x)
            return predicted_mean + torch.sqrt(posterior_variance) * noise
 
    @torch.inference_mode()
    def p_sample_loop(
        self, shape: tuple, return_all_timesteps: bool = False
    ) -> torch.Tensor:
        batch, device = shape[0], "mps"  # 从形状元组中获取批量大小 batch,并设置设备为 "mps"(多处理器尺寸)
 
        img = torch.randn(shape, device=device) # 函数生成一个具有指定形状的随机张量 img,其值服从标准正态分布
        # This cause me a RunTimeError on MPS device due to MPS back out of memory
        # No ideas how to resolve it at this point
 
        # imgs = [img]
 		
 		'''
 			使用 tqdm 函数创建一个迭代进度条,迭代范围是从 0 到 self.num_timesteps 的逆序。
 			每个时间步长 t 都会调用 p_sample 方法进行样本采样,并更新 img 的值。
 		'''
        for t in tqdm(reversed(range(0, self.num_timesteps)), total=self.num_timesteps):
            img = self.p_sample(img, t)
            # imgs.append(img)
            '''
            	将每个时间步长的采样结果添加到一个列表 imgs 中。在循环中,每次迭代会生成一个新的采样结果,并将其添加到列表中
            	允许在函数结束后返回所有时间步长的采样结果,以便进一步分析或处理
            '''
 		
 		# 最终的采样结果
        ret = img # if not return_all_timesteps else torch.stack(imgs, dim=1)
 		
 		# 调用 unnormalize 方法将最终的采样结果反归一化,使其返回到原始数据范围内。
        ret = self.unnormalize(ret)
        return ret
 
 	# return_all_timesteps指定是否返回所有时间步长的样本,默认为 False,表示只返回最终时间步长的样本
    def sample(
        self, batch_size: int = 16, return_all_timesteps: bool = False  
    ) -> torch.Tensor:
        shape = (batch_size, self.channels, self.image_size, self.image_size)
        return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)
 
 	# 用于在给定时间步长 t 上生成样本
    def q_sample(
        self, x_start: torch.Tensor, t: int, noise: torch.Tensor = None
    ) -> torch.Tensor:
    	# 首先检查是否提供了噪声
        if noise is None:
            noise = torch.randn_like(x_start)

 		# 接着根据 t 从预先计算的参数中提取相应的系数
        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )
 		
 		# 最后根据扩散过程的定义,计算并返回生成的样本
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
 
    def p_loss(
        self,
        x_start: torch.Tensor,
        t: int,
        noise: torch.Tensor = None,
        loss_type: str = "l2",
    ) -> torch.Tensor:
        if noise is None:
            noise = torch.randn_like(x_start)
        x_noised = self.q_sample(x_start, t, noise=noise)  # 在给定时间步长 t 上生成经过噪声处理的样本 x_noised

        # 使用生成的 x_noised 作为输入,调用模型 self.model,并传入时间步长 t,以获取预测的噪声 predicted_noise。
        predicted_noise = self.model(x_noised, t)
 
        if loss_type == "l2":  # 均方误差损失函数
            loss = F.mse_loss(noise, predicted_noise)
        elif loss_type == "l1":  # 绝对值误差损失函数
            loss = F.l1_loss(noise, predicted_noise)
        else:
            raise ValueError(f"unknown loss type {loss_type}")
        return loss
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w, device, img_size = *x.shape, x.device, self.image_size
        assert h == w == img_size, f"image size must be {img_size}"
 		# 解析输入 x 的形状,并确保输入的图像是正方形且大小与 image_size 相同。

 		# 生成一个随机的时间步长 timestamp,范围在 [0, num_timesteps) 内
        timestamp = torch.randint(0, self.num_timesteps, (1,)).long().to(device)
        x = self.normalize(x)
        return self.p_loss(x, timestamp)

Life is a journey. We pursue love and light with purity.

你的 “三连” 是小曦持续更新的动力!
下期将推出
扩散的代码实现,零距离解读扩散是如何实现的。