Warmup_steps 设置经验

发布于:2025-08-17 ⋅ 阅读:(12) ⋅ 点赞:(0)

什么是 Warmup?

Warmup 是一种学习率调度策略,在训练初期逐步增加学习率(LR),而不是直接使用目标学习率。它解决了两个关键问题:

  • 避免早期震荡:模型参数初始化为随机值,直接高LR会导致不稳定更新。
  • 稳定Adam优化器:Adam的动量估计在初始阶段不准确,需要渐进调整。

实现示例

初始 LR
线性增长阶段
达到 warmup_steps?
反平方根衰减阶段
def get_lr(step, warmup_steps, d_model):
    # 1. 预热阶段:线性增长
    if step < warmup_steps:
        return base_lr * (step / warmup_steps)
    
    # 2. 衰减阶段:反平方根衰减
    scale = (warmup_steps ** 0.5) * min(
        step ** (-0.5), 
        step * (warmup_steps ** (-1.5))
    return base_lr * scale

科学设置 Warmup 的黄金法则

  • N N N = 总样本数(条)
  • B B B = 每次 forward 的原始 batch(每卡)
  • A A A = 梯度累积步数(accum_grad
  • E E E = epoch 数
  • W W W = 卡数(world_size,单卡 = 1)
  1. 每 epoch 的 forward 次数(向上取整):

I = ⌈ N / ( B × W ) ⌉ I = \lceil N / (B \times W) \rceil I=N/(B×W)⌉

  1. 每 epoch 的 optimizer 更新次数(每 A A A 次 forward 做一次 update,向上取整):

S = ⌈ I / A ⌉ S = \lceil I / A \rceil S=I/A

  1. 总 optimizer step(也就是 scheduler 用的 total steps):

T = S × E T = S \times E T=S×E

  1. 推荐的 warmup 步数:

warmup = { max ⁡ { ⌈ 0.10 × T ⌉ , 10 } , T < 4000 clamp ⁡ (   ⌊ 0.05 × T ⌉ ,    4000 ,    20000   ) , T ≥ 4000 \text{warmup} = \begin{cases} \max\{\lceil 0.10 \times T \rceil, 10\}, & T < 4000\\[4pt] \operatorname{clamp}(\, \lfloor 0.05 \times T \rceil,\; 4000,\; 20000 \,), & T \ge 4000 \end{cases} warmup={max{⌈0.10×T,10},clamp(0.05×T,4000,20000),T<4000T4000

并且最终确保 warmup ≤ T − 1 \text{warmup} \le T-1 warmupT1

解释:小训练用 10%,大训练用 5%,并在 4k–20k 之间限制


直观例子

假设 B = 16 , A = 8 , E = 120 , W = 1 B=16, A=8, E=120, W=1 B=16,A=8,E=120,W=1

  1. N = 100,000 N = 100{,}000 N=100,000
  • I = ⌈ 100000 / 16 ⌉ = 6250 I=\lceil100000/16\rceil=6250 I=100000/16=6250
  • S = ⌈ 6250 / 8 ⌉ = 782 S=\lceil6250/8\rceil=782 S=6250/8=782
  • T = 782 × 120 = 93,840 T=782\times120=93{,}840 T=782×120=93,840
  • warmup ≈ ⌊ 0.05 × 93840 ⌉ = 4,692 \lfloor0.05\times93840\rceil=4{,}692 0.05×93840=4,692(取 4000–20000 区间内 → 4692
  1. N = 1,000,000 N = 1{,}000{,}000 N=1,000,000
  • I = ⌈ 1000000 / 16 ⌉ = 62500 I=\lceil1000000/16\rceil=62500 I=1000000/16=62500
  • S = ⌈ 62500 / 8 ⌉ = 7813 S=\lceil62500/8\rceil=7813 S=62500/8=7813
  • T = 7813 × 120 = 937,560 T=7813\times120=937{,}560 T=7813×120=937,560
  • 0.05×T ≫ 20000 → clamp → 20000
  1. N = 100 N = 100 N=100(极小样本、仅作示例):
  • I = ⌈ 100 / 16 ⌉ = 7 I=\lceil100/16\rceil=7 I=100/16=7
  • S = ⌈ 7 / 8 ⌉ = 1 S=\lceil7/8\rceil=1 S=7/8=1
  • T = 1 × 120 = 120 T=1\times120=120 T=1×120=120
  • 因为 T < 4000 T<4000 T<4000,warmup = max(ceil(0.1×120),10) = 12

import math

def suggest_warmup(N,B,A,E,W=1):
    I = math.ceil(N / (B*W))
    S = math.ceil(I / A)
    T = S * E
    if T < 4000:
        w = max(math.ceil(0.10*T), 10)
    else:
        w = round(0.05*T)
        w = max(4000, min(w, 20000))
    w = min(w, T-1)
    return {"iters_per_epoch":I, "opt_steps_per_epoch":S, "total_steps":T, "warmup":w}

print(suggest_warmup(403733, 16, 8, 120, 4))
# {'iters_per_epoch': 6309, 'opt_steps_per_epoch': 789, 'total_steps': 94680, 'warmup': 4734}

print(suggest_warmup(403733, 16, 4, 100, 4))
# {'iters_per_epoch': 6309, 'opt_steps_per_epoch': 1578, 'total_steps': 157800, 'warmup': 7890}

网站公告

今日签到

点亮在社区的每一天
去签到