【AIGC】DDPM scheduler解析:扩散模型里的“调度器”到底在调什么?

发布于:2025-08-18 ⋅ 阅读:(13) ⋅ 点赞:(0)

扩散模型里的“调度器”到底在调什么?

—— 以 DDPM 仓库为例,一行行拆解 scheduler 预计算代码

如果你第一次接触扩散模型(Diffusion Model),最绕不开的一个词就是 scheduler。它到底在“调度”什么?为什么论文里一大堆 α、β、ᾱ,代码里又跑出来一堆 alphas_cumprodsqrt_recip_alphas

本文就带你把DDPM里的 scheduler 预计算代码彻底拆开,告诉你每一个张量背后对应的公式与直觉。读完你不仅能秒懂 forward_noising.py 在干什么,还能自己手写一个 scheduler。


1. 先给结论:scheduler 在“提前算好每一步的权重”

扩散模型的前向过程(加噪)和反向过程(去噪)都依赖大量随时间步 t 变化的系数
如果每一步都在现场算,训练/采样就会被拖垮。于是 DDPM 的做法是:

一次性把 0~T 步要用到的所有系数都算出来,放进张量。后面直接按 t 索引即可。

这些系数就是 scheduler 的全部工作。


2. 代码全景

以下代码位于 forward_noising.py,删掉了注释后不到 20 行,却把整个前向过程需要的张量全部算完:

import torch
import torch.nn.functional as F

# 1. 线性 β 调度器
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

T = 300
betas = linear_beta_schedule(T)

# 2. 预计算所有中间量
alphas             = 1.0 - betas
alphas_cumprod     = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

sqrt_recip_alphas  = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

3. 逐行拆解:从符号到直觉

3.1 T = 300 —— 扩散“总步数”

把一张清晰图片变成纯高斯噪声,需要 300 步小步快跑。步数越多,单步扰动越小,数值稳定性越好。


3.2 betas —— 每一步加多少噪声

代码 公式 直觉
betas = linear_beta_schedule(T) βt\beta_tβt 线性从 0.0001 → 0.02 控制第 t 步注入噪声的方差。β 越大,破坏越狠。

3.3 alphas ——“我还剩多少原图”

代码 公式 直觉
alphas = 1.0 - betas αt=1−βt\alpha_t = 1 - \beta_tαt=1βt 原图保留比例。因为 β 很小,α 非常接近 1。

3.4 alphas_cumprod —— 一口气算到第 t 步的“累积保留率”

代码 公式 直觉
alphas_cumprod = torch.cumprod(alphas, 0) αˉt=∏i=1tαi\bar\alpha_t = \prod_{i=1}^{t}\alpha_iαˉt=i=1tαi 如果你想直接x0x_0x0 跳到 xtx_txt,就靠它。DDPM 的“闭式采样”核心。

3.5 sqrt_alphas_cumprod & sqrt_one_minus_alphas_cumprod —— 前向公式里的“两根魔法棒”

代码 公式 直觉
sqrt_alphas_cumprod αˉt\sqrt{\bar\alpha_t}αˉt 原图 x0x_0x0 的缩放系数
sqrt_one_minus_alphas_cumprod 1−αˉt\sqrt{1 - \bar\alpha_t}1αˉt 噪声 ε\varepsilonε 的缩放系数

把两者组合起来就是 DDPM 论文最经典的前向公式:

xt=αˉt x0+1−αˉt ε,ε∼N(0,I) x_t = \sqrt{\bar\alpha_t}\, x_0 + \sqrt{1-\bar\alpha_t}\, \varepsilon,\quad \varepsilon\sim\mathcal N(0,I) xt=αˉt x0+1αˉt ε,εN(0,I)


3.6 alphas_cumprod_prev —— 上一步的 ᾱ

代码 直觉
F.pad(..., value=1.0) αˉt−1\bar\alpha_{t-1}αˉt1 对齐到 t 的索引,方便后续向量运算。第 0 步之前补 1.0(αˉ0=1\bar\alpha_0=1αˉ0=1)。

3.7 sqrt_recip_alphas —— 反向去噪“放大器”

代码 直觉
torch.sqrt(1.0 / alphas) 在反向公式里把模型预测的 εθ(xt,t)\varepsilon_\theta(x_t,t)εθ(xt,t) 再乘回去,恢复信号。

3.8 posterior_variance —— 反向“再抖一点点”的方差

代码 公式 直觉
betas * (1 - ᾱ_prev) / (1 - ᾱ) β~t=1−αˉt−11−αˉtβt\tilde\beta_t = \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_tβ~t=1αˉt1αˉt1βt xtx_txt 预测 xt−1x_{t-1}xt1 时,需要再采样一点点噪声,方差就是 β~t\tilde\beta_tβ~t。DDPM 论文里把它固定住,不交给网络学,简化训练。

4. 一张图总结

把上面所有张量按时间轴画出来(T=300):

时间步 t β_t α_t ᾱ_t √ᾱ_t √(1-ᾱ_t) β̃_t
0 1.000 1.000 0.000
1 0.0001 0.9999 0.9999 0.99995 0.0100 0.0001
300 0.0200 0.9800 ~0.002 0.045 0.999 0.0199

可以看到:

  • ᾱ_t 从 1 一路降到接近 0,解释“原图逐渐消失”。
  • √(1-ᾱ_t) 从 0 升到接近 1,解释“噪声逐渐占满”。
  • β̃_t 始终跟 β_t 差不多,但略小,保证反向过程的方差不会爆炸。

5. 小结 & 下一步

  1. scheduler 不是玄学,只是一次性把“每一步要乘的数”算好。
  2. 核心就 8 个张量,对应论文里 4 个公式,背下来就能手写扩散。
  3. 后面不管是 DDIM、PLMS 还是 DPM-Solver,都只是在换 β 的调度策略采样公式;预计算的思路一模一样。

网站公告

今日签到

点亮在社区的每一天
去签到