audioLDM模型代码阅读(一)——指数移动平均EMA

发布于:2025-09-03 ⋅ 阅读:(12) ⋅ 点赞:(0)

指数移动平均EMA

指数移动平均(Exponential Moving Average,简称EMA)是一种用于平滑数据序列的统计方法,核心思想是对近期数据赋予更高权重,对远期数据赋予指数衰减的较低权重,从而在减少噪声干扰的同时,更灵敏地反映数据的近期变化趋势。

1. 数学定义

假设我们有一个数据序列x1,x2,…,xtx_1, x_2, \dots, x_tx1,x2,,xt(例如模型训练中的参数更新值),EMA 在第ttt 步的计算如下:
EMAt=α⋅xt+(1−α)⋅EMAt−1 \text{EMA}_t = \alpha \cdot x_t + (1 - \alpha) \cdot \text{EMA}_{t-1} EMAt=αxt+(1α)EMAt1
其中:
-EMAt\text{EMA}_tEMAt 是第ttt 步的指数移动平均值;
-xtx_txt 是第ttt 步的当前值;
-EMAt−1\text{EMA}_{t-1}EMAt1 是第t−1t-1t1 步的指数移动平均值(初始值可设为x1x_1x1);
-α\alphaα 是平滑系数(0<α<10 < \alpha < 10<α<1),控制近期数据的权重:α\alphaα 越接近 1,近期数据影响越大,平滑效果越弱;α\alphaα 越接近 0,历史数据影响越大,平滑效果越强。

2. 与简单移动平均(SMA)的区别

简单移动平均(SMA)对一定窗口内的所有数据赋予相同权重(例如近 10 步的平均值),而 EMA 对数据的权重随时间呈指数衰减(近期数据权重更高)。因此:

  • EMA 对近期变化更敏感,能更快追踪数据趋势;
  • SMA 对噪声的平滑效果更稳定,但对趋势变化的响应更慢。

3. 在机器学习中的应用

在深度学习中,EMA 常被用于维护模型参数的“影子副本”(shadow parameters):

  • 训练过程中,模型参数会随梯度下降不断更新(可能因噪声或波动导致参数不稳定);
  • 同时,EMA 会根据当前参数和历史 EMA 参数动态更新“影子参数”,相当于对训练过程中的参数波动做了平滑;
  • 最终,用 EMA 得到的“影子参数”往往泛化能力更强(减少过拟合风险),在评估或推理时使用这些参数可能获得更优性能。

4. 关键特性

  • 动态适应性:通过调整α\alphaα 可平衡“对近期变化的敏感度”和“对历史趋势的依赖”;
  • 计算高效:仅需保存上一步的 EMA 值,无需存储完整历史数据,适合大规模模型训练;
  • 平滑波动:有效抑制训练中的参数震荡(如 batch 噪声导致的参数跳变),提升模型稳定性。

简单来说,EMA 就像给数据“加了一层平滑滤镜”,既保留近期变化的关键信息,又弱化了噪声干扰,在模型训练中是提升泛化能力的常用技巧。

代码解读

先给出完整的代码:

import torch
from torch import nn


class LitEma(nn.Module):
    def __init__(self, model, decay=0.9999, use_num_upates=True):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError("Decay must be between 0 and 1")

        self.m_name2s_name = {}
        self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
        self.register_buffer(
            "num_updates",
            torch.tensor(0, dtype=torch.int)
            if use_num_upates
            else torch.tensor(-1, dtype=torch.int),
        )

        for name, p in model.named_parameters():
            if p.requires_grad:
                # remove as '.'-character is not allowed in buffers
                s_name = name.replace(".", "")
                self.m_name2s_name.update({name: s_name})
                self.register_buffer(s_name, p.clone().detach().data)

        self.collected_params = []

    def forward(self, model):
        decay = self.decay

        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

        one_minus_decay = 1.0 - decay

        with torch.no_grad():
            m_param = dict(model.named_parameters())
            shadow_params = dict(self.named_buffers())

            for key in m_param:
                if m_param[key].requires_grad:
                    sname = self.m_name2s_name[key]
                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                    shadow_params[sname].sub_(
                        one_minus_decay * (shadow_params[sname] - m_param[key])
                    )
                else:
                    assert not key in self.m_name2s_name

    def copy_to(self, model):
        m_param = dict(model.named_parameters())
        shadow_params = dict(self.named_buffers())
        for key in m_param:
            if m_param[key].requires_grad:
                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
            else:
                assert not key in self.m_name2s_name

    def store(self, parameters):
        """
        Save the current parameters for restoring later.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            temporarily stored.
        """
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)

这段代码实现了一个基于PyTorch的指数移动平均(Exponential Moving Average, EMA) 模块,名为LitEma。EMA是训练神经网络时常用的技术,通过维护模型参数的滑动平均值来提高模型的泛化能力和稳定性,尤其在生成模型(如GAN、扩散模型)和大规模深度学习任务中广泛应用。

核心功能

LitEma的主要作用是:

  1. 为目标模型的可训练参数维护一组"影子参数"(shadow parameters);
  2. 通过指数移动平均规则不断更新这些影子参数;
  3. 提供方法将影子参数复制到原模型(用于评估或保存);
  4. 支持临时保存和恢复原模型参数(避免EMA参数影响训练过程)。

代码逐部分解析

1. 类定义与初始化(__init__方法)
class LitEma(nn.Module):
    def __init__(self, model, decay=0.9999, use_num_upates=True):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError("Decay must be between 0 and 1")

        self.m_name2s_name = {}  # 原参数名 -> 影子参数名的映射(处理名称中的'.')
        # 注册衰减率为buffer(不参与梯度计算,但会被保存到state_dict)
        self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
        # 注册更新次数计数器(用于动态调整衰减率)
        self.register_buffer(
            "num_updates",
            torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
        )

        # 为模型中所有可训练参数创建影子参数(初始值为参数的副本)
        for name, p in model.named_parameters():
            if p.requires_grad:  # 只处理需要梯度的参数
                # 替换参数名中的'.'(因为buffer名称不允许包含'.')
                s_name = name.replace(".", "")
                self.m_name2s_name.update({name: s_name})
                # 注册影子参数(detach()确保不跟踪梯度)
                self.register_buffer(s_name, p.clone().detach().data)

        self.collected_params = []  # 用于临时保存原模型参数
  • 参数说明

    • model:需要应用EMA的目标模型;
    • decay:EMA衰减率(通常接近1,如0.9999),控制历史参数的权重;
    • use_num_upates:是否根据训练步数动态调整衰减率(训练初期使用较小衰减率,加速影子参数收敛)。
  • 核心操作

    • 为模型中所有可训练参数(requires_grad=True)创建对应的"影子参数",并以buffer形式注册(buffer是PyTorch中不参与梯度计算但会被保存的参数);
    • m_name2s_name映射原参数名和影子参数名(处理名称中的.,因为buffer名称不允许包含.)。
2. 更新EMA影子参数(forward方法)
def forward(self, model):
    decay = self.decay

    # 动态调整衰减率(如果启用)
    if self.num_updates >= 0:
        self.num_updates += 1
        # 公式:min(decay, (1 + 步数)/(10 + 步数)),训练初期衰减率较小
        decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

    one_minus_decay = 1.0 - decay  # 新参数的权重

    with torch.no_grad():  # 确保更新不影响梯度计算
        m_param = dict(model.named_parameters())  # 原模型参数
        shadow_params = dict(self.named_buffers())  # 影子参数

        # 遍历所有可训练参数,更新影子参数
        for key in m_param:
            if m_param[key].requires_grad:
                sname = self.m_name2s_name[key]  # 影子参数名
                # 确保影子参数与原参数类型一致
                shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                # EMA核心公式:影子参数 = 衰减率×旧影子参数 + (1-衰减率)×新参数
                shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
            else:
                # 非可训练参数不应出现在影子参数映射中
                assert not key in self.m_name2s_name
  • 核心逻辑
    每次调用forward时,根据EMA公式更新影子参数:
    shadow_param=decay×shadow_param+(1−decay)×model_param\text{shadow\_param} = \text{decay} \times \text{shadow\_param} + (1 - \text{decay}) \times \text{model\_param}shadow_param=decay×shadow_param+(1decay)×model_param
    其中,decay控制历史影子参数的权重(越接近1,历史参数影响越大)。

  • 动态衰减率
    use_num_upates=True,训练初期(步数少)会使用较小的decay(如第一步为(1+1)/(10+1)≈0.18),让影子参数快速跟上模型参数变化;随着步数增加,decay逐渐接近初始设定值(如0.9999)。

3. 将影子参数复制到原模型(copy_to方法)
def copy_to(self, model):
    m_param = dict(model.named_parameters())
    shadow_params = dict(self.named_buffers())
    for key in m_param:
        if m_param[key].requires_grad:
            # 将影子参数复制到原模型参数
            m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
        else:
            assert not key in self.m_name2s_name
  • 作用:将EMA维护的影子参数复制到原模型中,使模型在评估或推理时使用更稳定的EMA参数。
4. 临时保存和恢复原模型参数(storerestore方法)
def store(self, parameters):
    """保存当前参数,用于后续恢复"""
    self.collected_params = [param.clone() for param in parameters]

def restore(self, parameters):
    """恢复之前保存的参数"""
    for c_param, param in zip(self.collected_params, parameters):
        param.data.copy_(c_param.data)
  • 使用场景
    在验证或保存模型时,通常需要先用store保存原模型参数,再用copy_to应用EMA参数;完成后用restore恢复原参数,避免EMA参数影响后续训练(保证优化器仍作用于原参数)。

典型使用流程

  1. 初始化模型和LitEma

    model = MyModel()
    ema = LitEma(model, decay=0.999)
    
  2. 训练过程中,每次参数更新后调用ema.forward(model)更新影子参数:

    for epoch in range(num_epochs):
        for batch in dataloader:
            # 正常训练步骤(前向、计算损失、反向传播、优化器更新)
            loss = model(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # 更新EMA影子参数
            ema.forward(model)
    
  3. 验证/保存模型时,使用EMA参数:

    # 保存原模型参数
    ema.store(model.parameters())
    # 应用EMA参数到模型
    ema.copy_to(model)
    
    # 此时模型使用EMA参数,可进行验证或保存
    validate(model)
    torch.save(model.state_dict(), "ema_model.pth")
    
    # 恢复原模型参数,继续训练
    ema.restore(model.parameters())
    

总结

LitEma实现了一个灵活的EMA模块,通过维护参数的指数移动平均值,帮助模型在训练中保持稳定性,提升最终性能。其核心是影子参数的动态更新机制,以及方便的参数切换功能(copy_to/store/restore),使其能无缝集成到PyTorch训练流程中。