指数移动平均EMA
指数移动平均(Exponential Moving Average,简称EMA)是一种用于平滑数据序列的统计方法,核心思想是对近期数据赋予更高权重,对远期数据赋予指数衰减的较低权重,从而在减少噪声干扰的同时,更灵敏地反映数据的近期变化趋势。
1. 数学定义
假设我们有一个数据序列x1,x2,…,xtx_1, x_2, \dots, x_tx1,x2,…,xt(例如模型训练中的参数更新值),EMA 在第ttt 步的计算如下:
EMAt=α⋅xt+(1−α)⋅EMAt−1 \text{EMA}_t = \alpha \cdot x_t + (1 - \alpha) \cdot \text{EMA}_{t-1} EMAt=α⋅xt+(1−α)⋅EMAt−1
其中:
-EMAt\text{EMA}_tEMAt 是第ttt 步的指数移动平均值;
-xtx_txt 是第ttt 步的当前值;
-EMAt−1\text{EMA}_{t-1}EMAt−1 是第t−1t-1t−1 步的指数移动平均值(初始值可设为x1x_1x1);
-α\alphaα 是平滑系数(0<α<10 < \alpha < 10<α<1),控制近期数据的权重:α\alphaα 越接近 1,近期数据影响越大,平滑效果越弱;α\alphaα 越接近 0,历史数据影响越大,平滑效果越强。
2. 与简单移动平均(SMA)的区别
简单移动平均(SMA)对一定窗口内的所有数据赋予相同权重(例如近 10 步的平均值),而 EMA 对数据的权重随时间呈指数衰减(近期数据权重更高)。因此:
- EMA 对近期变化更敏感,能更快追踪数据趋势;
- SMA 对噪声的平滑效果更稳定,但对趋势变化的响应更慢。
3. 在机器学习中的应用
在深度学习中,EMA 常被用于维护模型参数的“影子副本”(shadow parameters):
- 训练过程中,模型参数会随梯度下降不断更新(可能因噪声或波动导致参数不稳定);
- 同时,EMA 会根据当前参数和历史 EMA 参数动态更新“影子参数”,相当于对训练过程中的参数波动做了平滑;
- 最终,用 EMA 得到的“影子参数”往往泛化能力更强(减少过拟合风险),在评估或推理时使用这些参数可能获得更优性能。
4. 关键特性
- 动态适应性:通过调整α\alphaα 可平衡“对近期变化的敏感度”和“对历史趋势的依赖”;
- 计算高效:仅需保存上一步的 EMA 值,无需存储完整历史数据,适合大规模模型训练;
- 平滑波动:有效抑制训练中的参数震荡(如 batch 噪声导致的参数跳变),提升模型稳定性。
简单来说,EMA 就像给数据“加了一层平滑滤镜”,既保留近期变化的关键信息,又弱化了噪声干扰,在模型训练中是提升泛化能力的常用技巧。
代码解读
先给出完整的代码:
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
这段代码实现了一个基于PyTorch的指数移动平均(Exponential Moving Average, EMA) 模块,名为LitEma
。EMA是训练神经网络时常用的技术,通过维护模型参数的滑动平均值来提高模型的泛化能力和稳定性,尤其在生成模型(如GAN、扩散模型)和大规模深度学习任务中广泛应用。
核心功能
LitEma
的主要作用是:
- 为目标模型的可训练参数维护一组"影子参数"(shadow parameters);
- 通过指数移动平均规则不断更新这些影子参数;
- 提供方法将影子参数复制到原模型(用于评估或保存);
- 支持临时保存和恢复原模型参数(避免EMA参数影响训练过程)。
代码逐部分解析
1. 类定义与初始化(__init__
方法)
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {} # 原参数名 -> 影子参数名的映射(处理名称中的'.')
# 注册衰减率为buffer(不参与梯度计算,但会被保存到state_dict)
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
# 注册更新次数计数器(用于动态调整衰减率)
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
)
# 为模型中所有可训练参数创建影子参数(初始值为参数的副本)
for name, p in model.named_parameters():
if p.requires_grad: # 只处理需要梯度的参数
# 替换参数名中的'.'(因为buffer名称不允许包含'.')
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
# 注册影子参数(detach()确保不跟踪梯度)
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = [] # 用于临时保存原模型参数
参数说明:
model
:需要应用EMA的目标模型;decay
:EMA衰减率(通常接近1,如0.9999),控制历史参数的权重;use_num_upates
:是否根据训练步数动态调整衰减率(训练初期使用较小衰减率,加速影子参数收敛)。
核心操作:
- 为模型中所有可训练参数(
requires_grad=True
)创建对应的"影子参数",并以buffer形式注册(buffer是PyTorch中不参与梯度计算但会被保存的参数); - 用
m_name2s_name
映射原参数名和影子参数名(处理名称中的.
,因为buffer名称不允许包含.
)。
- 为模型中所有可训练参数(
2. 更新EMA影子参数(forward
方法)
def forward(self, model):
decay = self.decay
# 动态调整衰减率(如果启用)
if self.num_updates >= 0:
self.num_updates += 1
# 公式:min(decay, (1 + 步数)/(10 + 步数)),训练初期衰减率较小
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay # 新参数的权重
with torch.no_grad(): # 确保更新不影响梯度计算
m_param = dict(model.named_parameters()) # 原模型参数
shadow_params = dict(self.named_buffers()) # 影子参数
# 遍历所有可训练参数,更新影子参数
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key] # 影子参数名
# 确保影子参数与原参数类型一致
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
# EMA核心公式:影子参数 = 衰减率×旧影子参数 + (1-衰减率)×新参数
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
# 非可训练参数不应出现在影子参数映射中
assert not key in self.m_name2s_name
核心逻辑:
每次调用forward
时,根据EMA公式更新影子参数:
shadow_param=decay×shadow_param+(1−decay)×model_param\text{shadow\_param} = \text{decay} \times \text{shadow\_param} + (1 - \text{decay}) \times \text{model\_param}shadow_param=decay×shadow_param+(1−decay)×model_param
其中,decay
控制历史影子参数的权重(越接近1,历史参数影响越大)。动态衰减率:
若use_num_upates=True
,训练初期(步数少)会使用较小的decay
(如第一步为(1+1)/(10+1)≈0.18
),让影子参数快速跟上模型参数变化;随着步数增加,decay
逐渐接近初始设定值(如0.9999)。
3. 将影子参数复制到原模型(copy_to
方法)
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
# 将影子参数复制到原模型参数
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
- 作用:将EMA维护的影子参数复制到原模型中,使模型在评估或推理时使用更稳定的EMA参数。
4. 临时保存和恢复原模型参数(store
和restore
方法)
def store(self, parameters):
"""保存当前参数,用于后续恢复"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""恢复之前保存的参数"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
- 使用场景:
在验证或保存模型时,通常需要先用store
保存原模型参数,再用copy_to
应用EMA参数;完成后用restore
恢复原参数,避免EMA参数影响后续训练(保证优化器仍作用于原参数)。
典型使用流程
初始化模型和
LitEma
:model = MyModel() ema = LitEma(model, decay=0.999)
训练过程中,每次参数更新后调用
ema.forward(model)
更新影子参数:for epoch in range(num_epochs): for batch in dataloader: # 正常训练步骤(前向、计算损失、反向传播、优化器更新) loss = model(batch) loss.backward() optimizer.step() optimizer.zero_grad() # 更新EMA影子参数 ema.forward(model)
验证/保存模型时,使用EMA参数:
# 保存原模型参数 ema.store(model.parameters()) # 应用EMA参数到模型 ema.copy_to(model) # 此时模型使用EMA参数,可进行验证或保存 validate(model) torch.save(model.state_dict(), "ema_model.pth") # 恢复原模型参数,继续训练 ema.restore(model.parameters())
总结
LitEma
实现了一个灵活的EMA模块,通过维护参数的指数移动平均值,帮助模型在训练中保持稳定性,提升最终性能。其核心是影子参数的动态更新机制,以及方便的参数切换功能(copy_to
/store
/restore
),使其能无缝集成到PyTorch训练流程中。