Qwen3_moe模型代码解析

发布于:2025-09-01 ⋅ 阅读:(18) ⋅ 点赞:(0)

Qwen3_moe模型代码解析

1) 顶层:Qwen3MoeModel.forward(Embedding → 多层解码器 → RMSNorm)

输入
input_ids (B,S) 或 inputs_embeds (B,S,H)
attention_mask (B,S) 可选
position_ids (1,B,S) 可选
past_key_values (Cache) 可选
只提供其一?
(input_ids XOR inputs_embeds)
若 inputs_embeds 为 None:
embed_tokens(input_ids) → inputs_embeds (B,S,H)
抛错: 必须且只能指定其一
若 use_cache and past is None:
past_key_values = DynamicCache(...)
若 cache_position 为 None:
根据 past 长度构造 cache_position (S,)
若 position_ids 为 None:
position_ids = cache_position.unsqueeze(0) → (1,B,S)
是否 sliding_window?
create_causal_mask(...) → causal_mask (B,1,S,S)
create_sliding_window_causal_mask(...) → causal_mask (B,1,S,S)
Qwen3MoeRotaryEmbedding:
(cos,sin) = rotary_emb(inputs_embeds, position_ids)
for layer in 1..L:
Qwen3MoeDecoderLayer.forward(...)
最终 RMSNorm → last_hidden_state (B,S,H)
返回 MoeModelOutputWithPast:
last_hidden_state, past_key_values

要点

  • 位置编码:MoE 是 1D RoPE(position_ids 形状 (1,B,S))。
  • 掩码:根据配置可能是标准因果或滑窗因果。
  • Cache:默认用 DynamicCache(用于增量解码)。

2) 解码器层:Qwen3MoeDecoderLayer.forward(Self-Attn → 残差 → MoE/MLP → 残差)

输入 hidden_states (B,S,H)
input_layernorm (RMSNorm)
Qwen3MoeAttention.forward(...)
使用 (cos,sin) 施加 RoPE
残差: hidden = residual + attn_out
post_attention_layernorm (RMSNorm)
本层是否 MoE 层?
(num_experts>0 且 (layer_idx+1) % decoder_sparse_step == 0 且不在 mlp_only_layers)
Qwen3MoeSparseMoeBlock.forward(hidden) → (B,S,H), router_logits
Qwen3MoeMLP.forward(hidden) → (B,S,H)
若返回为 tuple, 取 hidden 并保留 router_logits(外部收集)
残差: hidden = residual + mlp_out 或 moe_out
输出 (B,S,H) 返回上一层

要点

  • MoE 层与普通 MLP 层互斥;MoE 层额外产生 router_logits(被上层输出记录用于负载均衡辅助损失)。

3) 注意力:Qwen3MoeAttention.forward(Q/K/V → RoPE → 注意力 → 合并头)

hidden_states (B,S,H)
q_proj: (B,S,H) → (B,S,n_heads*head_dim)
k_proj: (B,S,H) → (B,S,n_kv*head_dim)
v_proj: (B,S,H) → (B,S,n_kv*head_dim)
reshape/view → (B,S,n_heads,head_dim) → transpose → Q (B,n_heads,S,d) → q_norm(按头维)
reshape/view → (B,S,n_kv,head_dim) → transpose → K (B,n_kv,S,d) → k_norm(按头维)
reshape/view → (B,S,n_kv,head_dim) → transpose → V (B,n_kv,S,d)
apply_rotary_pos_emb(Q,K, cos,sin) → Q',K'
选择注意力实现: eager/SDPA/FA2/…
eager 路径内部:repeat_kv(K/V) 到 n_heads;(Q·K^T)/sqrt(d) + mask → softmax → dropout → softmax·V
(B,n_heads,S,d) → 合并头 → o_proj → (B,S,H)

要点

  • 与 Qwen2 系列类似,GQA:n_heads = num_attention_headsn_kv = num_key_value_heads,通过 repeat_kv 对齐。
  • 层内对 Q/K 施加 RMSNorm(按头维) 再做 RoPE(这是 Qwen3 MoE 和一些实现的一个小差别)。

4) 稀疏 MoE:Qwen3MoeSparseMoeBlock.forward(Gating → Top-k 路由 → 专家并行 → 汇聚)

输入 hidden_states (B,S,H)
展平 tokens: (B*S,H) 便于逐 token 路由
gate: Linear(H→n_experts) → router_logits (B*S,n_experts)
softmax → routing_weights (B*S,n_experts)
topk: 选出每 token 的 top_k 专家
得到 selected_experts (B*S,top_k)
和对应权重 routing_weights_topk (B*S,top_k)
norm_topk_prob ?
routing_weights_topk /= sum(topk)
保持原权重
构造 expert_mask (n_experts,top_k,B*S) 便于按专家分组索引
对每个被命中的专家 e:
取出该专家命中的 tokens 切片 → expert_e(hidden) → MLP_e 输出 × 对应权重
index_add_ 汇聚回 final_hidden_states (B*S,H)
reshape 回 (B,S,H) 并返回
同时返回 router_logits 供上层辅助损失

要点

  • 该实现是token-level routing;每 token 选 top-k 专家;支持 norm_topk_prob 对 top-k 权重归一化。
  • 通过 index_add_ 将各专家输出按原 token 位置汇聚。

5) RoPE:Qwen3MoeRotaryEmbedding.forward(1D 位置 → cos/sin)

输入 x: (B,S,H) 仅用于 dtype/device 对齐
position_ids (1,B,S)
根据 rope_scaling 选择 init 函数
得到 inv_freq (d/2) 与 attention_scaling
freqs = (inv_freq @ position_ids)ᵀ → (B,S,d/2)
emb = concat(freqs,freqs) → (B,S,d)
cos = cos(emb)*attention_scaling;
sin = sin(emb)*attention_scaling
返回 cos,sin (与 Q/K 广播匹配)

要点

  • 与标准 1D RoPE 一致(没有多模态 3D 拆段)。
  • dynamic_rope_update 允许动态扩展(取决于 rope_scaling 策略)。

6) 语言建模头:Qwen3MoeForCausalLM.forward(LM Head & Router Loss)

调用内部 Qwen3MoeModel.forward(...) → last_hidden_state (B,S,H)
lm_head: Linear(H→V) 只对末 K 步算 logits (B,K,V)
是否提供 labels?
交叉熵损失 (masked)
不计算主损失
是否 output_router_logits ?
load_balancing_loss_func(router_logits)
得到 aux_loss 并按系数加到 loss
不计算 aux_loss
返回 MoeCausalLMOutputWithPast:
loss/aux_loss/logits/past_key_values/...

要点

  • logits_to_keep:只在末 K 个时间步计算 lm_head,显存友好。
  • aux_loss负载均衡损失,鼓励专家使用更均匀。

7) 掩码构造与缓存(顶层)

inputs_embeds (B,S,H), attention_mask (B,S) 可选
sliding_window 为 None?
create_causal_mask → (B,1,S,S) 上三角 -inf
create_sliding_window_causal_mask(W) → (B,1,S,S) 仅保留最近 W 个可见
传入各层 self_attn

8) 形状清单(常用变量)

  • B: batch size;S: 当前序列长度;H: hidden_size;V: vocab_size

  • 头部:n_heads = num_attention_headsn_kv = num_key_value_headsd = head_dim = H / n_heads

  • 注意力中:

    • Q: (B,n_heads,S,d);K/V: (B,n_kv,S,d)repeat_kv(B,n_heads,S,d)
    • 权重 (B,n_heads,S,S);输出 (B,S,H)

9) 常见坑与对策

  1. inputs 选择(input_ids is None) XOR (inputs_embeds is not None) 必须成立,否则抛错。
  2. RoPE 维度position_ids 必须 (1,B,S);若用 cache,需要正确设置 cache_position 使位置连续。
  3. 滑窗注意力:窗口 W 太小会影响长程依赖;太大则近似全因果。确保与训练/推理对齐。
  4. MoE 路由num_experts_per_tok (top_k) 影响吞吐与均衡;norm_topk_prob=True 时要注意与训练策略匹配。
  5. 负载均衡损失output_router_logits=True 时才会收集所有层的 router_logits;注意与 attention_mask 一起计算避免 padding 干扰。
  6. 精度:注意力 softmax 强制 float32 再 cast 回来,避免数值不稳。

10) 端到端数字化算例(便于核对)

假设:

  • B=2, S=128, H=4096, V=151936n_heads=32 → d=128n_kv=8 → num_key_value_groups=4
  • sliding_window=None(全因果);use_cache=True 首次前向 past=None
  • 第 4 层是 MoE 层:num_experts=8, top_k=2, moe_intermediate_size=11008,其它层为密集 MLP:intermediate_size=11008

流程

  1. inputs_embeds = embed_tokens(input_ids)(2,128,4096)

  2. cache_position = [0..127]position_ids=(1,2,128)

  3. causal_mask=(2,1,128,128) 上三角 -inf

  4. 进入第 1 层:

    • Q/K/V 线性:(2,128,4096) → Q:(2,128,4096) K/V:(2,128,1024)
    • 视图→Q:(2,32,128,128)K/V:(2,8,128,128)repeat_kv(2,32,128,128)
    • RoPE:apply_rotary_pos_emb
    • 注意力:权重 (2,32,128,128) → 输出 (2,32,128,128) → 合并头 (2,128,4096)
    • 残差 + MLP (SwiGLU):(2,128,4096)(2,128,11008)(2,128,4096)
  5. 第 4 层(MoE):

    • Gate:(B*S,H)=(256,4096) → (256,8) → softmax → top2
    • 对被命中专家的 tokens 送入各自 MLP_e(N_e,4096)→(N_e,11008)→(N_e,4096),乘以各 token 对应权重
    • index_add_ 汇聚回 (256,4096) → reshape (2,128,4096)
    • 残差
    • router_logits 记录(供 loss)
  6. L 层结束 → RMSNormlast_hidden_state=(2,128,4096)

  7. lm_head:(4096→V) 只算末 K=32 步 → logits=(2,32,V)

  8. 如有 labels:交叉熵 + 若 output_router_logits=True 再加 aux_loss(乘以 router_aux_loss_coef)。


网站公告

今日签到

点亮在社区的每一天
去签到