什么是 Warmup?
Warmup 是一种学习率调度策略,在训练初期逐步增加学习率(LR),而不是直接使用目标学习率。它解决了两个关键问题:
- 避免早期震荡:模型参数初始化为随机值,直接高LR会导致不稳定更新。
- 稳定Adam优化器:Adam的动量估计在初始阶段不准确,需要渐进调整。
实现示例
def get_lr(step, warmup_steps, d_model):
# 1. 预热阶段:线性增长
if step < warmup_steps:
return base_lr * (step / warmup_steps)
# 2. 衰减阶段:反平方根衰减
scale = (warmup_steps ** 0.5) * min(
step ** (-0.5),
step * (warmup_steps ** (-1.5))
return base_lr * scale
科学设置 Warmup 的黄金法则
- N N N = 总样本数(条)
- B B B = 每次 forward 的原始 batch(每卡)
- A A A = 梯度累积步数(
accum_grad
) - E E E = epoch 数
- W W W = 卡数(
world_size
,单卡 = 1)
- 每 epoch 的 forward 次数(向上取整):
I = ⌈ N / ( B × W ) ⌉ I = \lceil N / (B \times W) \rceil I=⌈N/(B×W)⌉
- 每 epoch 的 optimizer 更新次数(每 A A A 次 forward 做一次 update,向上取整):
S = ⌈ I / A ⌉ S = \lceil I / A \rceil S=⌈I/A⌉
- 总 optimizer step(也就是 scheduler 用的 total steps):
T = S × E T = S \times E T=S×E
- 推荐的 warmup 步数:
warmup = { max { ⌈ 0.10 × T ⌉ , 10 } , T < 4000 clamp ( ⌊ 0.05 × T ⌉ , 4000 , 20000 ) , T ≥ 4000 \text{warmup} = \begin{cases} \max\{\lceil 0.10 \times T \rceil, 10\}, & T < 4000\\[4pt] \operatorname{clamp}(\, \lfloor 0.05 \times T \rceil,\; 4000,\; 20000 \,), & T \ge 4000 \end{cases} warmup={max{⌈0.10×T⌉,10},clamp(⌊0.05×T⌉,4000,20000),T<4000T≥4000
并且最终确保 warmup ≤ T − 1 \text{warmup} \le T-1 warmup≤T−1。
解释:小训练用 10%,大训练用 5%,并在 4k–20k 之间限制
直观例子
假设 B = 16 , A = 8 , E = 120 , W = 1 B=16, A=8, E=120, W=1 B=16,A=8,E=120,W=1:
- 若 N = 100,000 N = 100{,}000 N=100,000:
- I = ⌈ 100000 / 16 ⌉ = 6250 I=\lceil100000/16\rceil=6250 I=⌈100000/16⌉=6250
- S = ⌈ 6250 / 8 ⌉ = 782 S=\lceil6250/8\rceil=782 S=⌈6250/8⌉=782
- T = 782 × 120 = 93,840 T=782\times120=93{,}840 T=782×120=93,840
- warmup ≈ ⌊ 0.05 × 93840 ⌉ = 4,692 \lfloor0.05\times93840\rceil=4{,}692 ⌊0.05×93840⌉=4,692(取 4000–20000 区间内 → 4692)
- 若 N = 1,000,000 N = 1{,}000{,}000 N=1,000,000:
- I = ⌈ 1000000 / 16 ⌉ = 62500 I=\lceil1000000/16\rceil=62500 I=⌈1000000/16⌉=62500
- S = ⌈ 62500 / 8 ⌉ = 7813 S=\lceil62500/8\rceil=7813 S=⌈62500/8⌉=7813
- T = 7813 × 120 = 937,560 T=7813\times120=937{,}560 T=7813×120=937,560
- 0.05×T ≫ 20000 → clamp → 20000
- 若 N = 100 N = 100 N=100(极小样本、仅作示例):
- I = ⌈ 100 / 16 ⌉ = 7 I=\lceil100/16\rceil=7 I=⌈100/16⌉=7
- S = ⌈ 7 / 8 ⌉ = 1 S=\lceil7/8\rceil=1 S=⌈7/8⌉=1
- T = 1 × 120 = 120 T=1\times120=120 T=1×120=120
- 因为 T < 4000 T<4000 T<4000,warmup = max(ceil(0.1×120),10) = 12
import math
def suggest_warmup(N,B,A,E,W=1):
I = math.ceil(N / (B*W))
S = math.ceil(I / A)
T = S * E
if T < 4000:
w = max(math.ceil(0.10*T), 10)
else:
w = round(0.05*T)
w = max(4000, min(w, 20000))
w = min(w, T-1)
return {"iters_per_epoch":I, "opt_steps_per_epoch":S, "total_steps":T, "warmup":w}
print(suggest_warmup(403733, 16, 8, 120, 4))
# {'iters_per_epoch': 6309, 'opt_steps_per_epoch': 789, 'total_steps': 94680, 'warmup': 4734}
print(suggest_warmup(403733, 16, 4, 100, 4))
# {'iters_per_epoch': 6309, 'opt_steps_per_epoch': 1578, 'total_steps': 157800, 'warmup': 7890}