improved-diffusion代码逐行理解之schedule_sampler

发布于:2024-07-09 ⋅ 阅读:(190) ⋅ 点赞:(0)

1、create_named_schedule_sampler

create_named_schedule_sampler来根据名称创建不同类型的 ScheduleSampler 实例,并给出了两个具体的子类示例,UniformSampler和 LossSecondMomentResampler

#image_train.py
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
#resample.py

def create_named_schedule_sampler(name, diffusion):

    if name == "uniform":
        return UniformSampler(diffusion)
    elif name == "loss-second-moment":
        return LossSecondMomentResampler(diffusion)
    else:
        raise NotImplementedError(f"unknown schedule sampler: {name}")

2、ScheduleSampler

这是一个用于扩散过程(diffusion process)中时间步(timesteps)分布的类,其目的是减少目标函数(objective function)的方差。在机器学习、特别是扩散模型(如扩散概率模型Diffusion Probabilistic Models)中,时间步的采样和重新加权是优化和估计模型参数的重要步骤。

这个类提供了一个基本的框架,其中包含了用于重要性采样(importance sampling)的sample()方法。重要性采样是一种技术,用于根据某种非均匀分布(即与原始数据分布不同的分布)来抽取样本,并通过调整样本的权重来估计原始数据分布下的统计量(如期望值)。

class ScheduleSampler(ABC):
    @abstractmethod
    def weights(self):

    def sample(self, batch_size, device):

        w = self.weights()
        p = w / np.sum(w)
        indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
        indices = th.from_numpy(indices_np).long().to(device)
        weights_np = 1 / (len(p) * p[indices_np])
        weights = th.from_numpy(weights_np).float().to(device)
        return indices, weights

3、UniformSampler

一致性采样类,权重相等。

class UniformSampler(ScheduleSampler):
    def __init__(self, diffusion):
        self.diffusion = diffusion
        self._weights = np.ones([diffusion.num_timesteps])

    def weights(self):
        return self._weights

4、LossAwareSampler

使用来自模型的损失来更新重新加权的参数。这个方法通常会在分布式训练的场景下被调用,其中不同的计算节点(或称为“rank”)会处理不同的数据批次,并计算每个时间步对应的损失。然后,这个方法会执行同步操作,以确保所有计算节点在重新加权参数上保持一致。


class LossAwareSampler(ScheduleSampler):
    def update_with_local_losses(self, local_ts, local_losses):

        batch_sizes = [
            th.tensor([0], dtype=th.int32, device=local_ts.device)
            for _ in range(dist.get_world_size())
        ]
        dist.all_gather(
            batch_sizes,
            th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
        )

        # Pad all_gather batches to be the maximum batch size.
        batch_sizes = [x.item() for x in batch_sizes]
        max_bs = max(batch_sizes)

        timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
        loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
        dist.all_gather(timestep_batches, local_ts)
        dist.all_gather(loss_batches, local_losses)
        timesteps = [
            x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
        ]
        losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
        self.update_with_all_losses(timesteps, losses)

    @abstractmethod
    def update_with_all_losses(self, ts, losses):


5、LossSecondMomentResampler

LossSecondMomentResampler 类是一个继承自 LossAwareSampler,用于根据损失的第二矩(即损失的平方的平均值)来重新采样时间步。这种重新采样方法在训练过程中给予那些损失较高的时间步更多的关注,以尝试减少整体损失。
以下是该类的主要组件和功能的详细说明:

init 方法: 初始化类实例时,需要传入一个 diffusion 对象(可能是一个扩散模型或包含时间步信息的对象),history_per_term(每个时间步保留的损失历史记录的长度),以及 uniform_prob(在最终权重中保留给均匀采样的概率,以防止完全忽略任何时间步)。此外,还初始化了两个 NumPy 数组,_loss_history 用于存储每个时间步的损失历史记录,_loss_counts 用于跟踪每个时间步已经记录了多少次损失。
weights 方法: 根据损失历史记录计算每个时间步的权重。如果尚未完成预热(即所有时间步的损失历史记录都未达到指定的长度),则返回一个全为 1 的权重数组,表示使用均匀采样。否则,计算每个时间步损失平方的平均值的平方根,并将这些值归一化为权重。然后,根据 uniform_prob 调整权重,以确保至少有一部分权重是均匀分配给所有时间步的。
update_with_all_losses 方法: 使用新的时间步和对应的损失值来更新损失历史记录。如果某个时间步的损失历史记录已满,则移除最旧的损失值并添加新的损失值;否则,直接在历史记录中添加新的损失值,并更新该时间步的损失计数。
_warmed_up 方法: 一个私有方法,用于检查是否所有时间步的损失历史记录都已达到指定的长度。如果是,则返回 True,表示已经完成了预热过程;否则,返回 False。
LossSecondMomentResampler 类通过跟踪每个时间步的损失历史记录,并根据这些记录的统计信息(特别是损失的第二矩)来重新计算采样权重,从而在训练过程中更加智能地选择时间步。这种方法有助于在保持一定程度均匀性的同时,更加关注那些对模型性能影响较大的时间步。

class LossSecondMomentResampler(LossAwareSampler):
    def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
        self.diffusion = diffusion
        self.history_per_term = history_per_term
        self.uniform_prob = uniform_prob
        self._loss_history = np.zeros(
            [diffusion.num_timesteps, history_per_term], dtype=np.float64
        )
        self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)

    def weights(self):
        if not self._warmed_up():
            return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
        weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
        weights /= np.sum(weights)
        weights *= 1 - self.uniform_prob
        weights += self.uniform_prob / len(weights)
        return weights

    def update_with_all_losses(self, ts, losses):
        for t, loss in zip(ts, losses):
            if self._loss_counts[t] == self.history_per_term:
                # Shift out the oldest loss term.
                self._loss_history[t, :-1] = self._loss_history[t, 1:]
                self._loss_history[t, -1] = loss
            else:
                self._loss_history[t, self._loss_counts[t]] = loss
                self._loss_counts[t] += 1

    def _warmed_up(self):
        return (self._loss_counts == self.history_per_term).all()

网站公告

今日签到

点亮在社区的每一天
去签到