大模型Decoder-Only深入解析

发布于:2025-07-05 ⋅ 阅读:(21) ⋅ 点赞:(0)

Decoder-Only整体结构

我们以模型Llama-3.1-8B-Instruct为例,打印其结构如下(后面会慢慢解析每一部分,莫慌):

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (qkv_proj): QKVParallelLinear(in_features=4096, output_features=6144, bias=False, tp_size=1, gather_output=False)
          (o_proj): RowParallelLinear(input_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True)
          (rotary_emb): Llama3RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=131072, base=500000.0, is_neox_style=True)
          (attn): RadixAttention()
        )
        (mlp): LlamaMLP(
          (gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=28672, bias=False, tp_size=1, gather_output=False)
          (down_proj): RowParallelLinear(input_features=14336, output_features=4096, bias=False, tp_size=1, reduce_results=True)
          (act_fn): SiluAndMul()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
  (logits_processor): LogitsProcessor()
  (pooler): Pooler()
)

Decoder-Only处理流程

我们以Llama-3.1-8B-Instruct模型为例,结合一个具体的聊天对话场景,详细说明Decoder-Only模型的处理流程,从用户输入到最终输出回答。整个过程会逐步拆解,并标注每个步骤的输入输出形状(假设batch_size=1,seq_len=10,hidden_dim=4096,词表大小=128000)。

1. 用户输入与聊天模板处理

场景:用户问:“如何做西红柿炒鸡蛋?”
模型需求:需要根据历史对话和当前问题生成回答。

聊天模板处理
  • 输入文本text:原始用户输入(如“如何做西红柿炒鸡蛋?”)
  • 模板化prompt:模型需要将输入包装成特定格式的prompt,例如:
    [系统指令]:你是一个烹饪助手,请回答以下问题。
    [用户]:如何做西红柿炒鸡蛋?
    [助手]:
    
  • 作用:模板化prompt让模型明确任务目标(如回答问题),并模拟对话上下文。

输入输出形状

  • 输入文本长度:假设为10个字符(实际长度取决于具体输入)。
  • 模板化后的prompt长度:假设为30个字符(包含系统指令、用户问题和占位符)。

2. Tokenizer处理:从prompt到input_ids

步骤

  1. Tokenization:将模板化prompt拆分为模型能理解的Token(如“西红柿”→“西红柿”,“炒”→“炒”)。
  2. 映射到input_ids:每个Token被映射为对应的ID(例如,“西红柿”→1234,“炒”→5678)。

示例
假设模板化Prompt被拆分为10个Token,其input_ids为:

[101, 1234, 5678, 8901, 2345, 6789, 102, 3456, 7890, 102]

(其中101和102是特殊标记,如<BOS><EOS>,表示开始和结束)

输入输出形状

  • input_ids的形状为 (batch_size, seq_len)(1, 10)
  • attention_mask(可选)的形状为 (1, 10),标记哪些位置是有效Token(1)或填充(0)。

3. 嵌入层:input_ids → hidden_states

步骤

  1. Token Embedding:将input_ids映射为高维向量(如4096维)。
  2. Positional Encoding:添加位置信息,让模型知道每个Token在序列中的位置。

示例

  • input_ids [101, 1234, 5678, ...] → 隐藏状态 hidden_states 的形状为 (1, 10, 4096)
  • 每个Token对应的向量包含其语义和位置信息(例如,“西红柿”对应的食物相关特征,以及它在句子中的位置)。

输入输出形状

  • hidden_states 的形状为 (batch_size, seq_len, hidden_dim)(1, 10, 4096)

4. Decoder Block处理:逐层计算

核心流程

  1. Masked Self-Attention(带掩码的自注意力)

    • 每个Token只能看到自己及之前的Token(防止“偷看”未来内容)。
    • 例如,在生成“西红柿炒鸡蛋”时,模型会先处理“西红柿”,再处理“炒”,确保生成逻辑连贯。
  2. 前馈网络(FFN)

    • 对每个Token的隐藏状态进行非线性变换,增强表达能力。

示例

  • 假设模型有32层Decoder Block,每层都会更新 hidden_states
  • 最终的 hidden_states 保留了完整的上下文信息(如“西红柿炒鸡蛋”的步骤描述)。

输入输出形状

  • 每层Decoder Block的输入输出形状不变,仍为 (1, 10, 4096)

5. LM Head:从hidden_states到下一个词

步骤

  1. 线性层:将最后一个Token的隐藏状态(形状为 (1, 10, 4096))映射到词表维度(128000)。
    • 例如,对最后一个位置(seq_len=9)的隐藏状态取值:hidden_states[:, 9, :] → 形状 (1, 4096)
  2. Softmax:将输出转换为概率分布(每个词的概率)。

示例

  • 假设模型预测下一个词是“步骤一”,其ID为9876,则概率分布中9876的值最高。

输入输出形状

  • 线性层输出形状:(1, 128000)
  • 概率分布形状:(1, 128000)

6. 采样策略:从概率分布到下一个词

方法

  • Top-k采样:从概率最高的前k个词(如k=50)中随机选一个。
  • Greedy Search:直接选概率最高的词(如“步骤一”)。

示例

  • 模型选择“步骤一”作为下一个词,并将其ID(9876)添加到 input_ids 中。
  • 新的 input_ids 变为:[101, 1234, 5678, ..., 9876](长度+1)。

输入输出形状

  • 新的 input_ids 形状为 (1, 11)

7. 迭代生成:重复步骤3-6直到完成

流程

  1. 将新的 input_idshidden_states 送回Decoder Block。
  2. 重复计算,逐步生成完整回答(如“步骤一:热锅凉油…”)。
  3. 直到生成终止标记(如<EOS>)或达到最大长度(如2048 Token)。

示例

  • 生成完整回答后,input_ids 的长度可能变为200(假设生成190个新Token)。
  • 最终的 input_ids 包含原始Prompt和生成的回答。

8. Tokenizer反向处理:从input_ids到用户文本

步骤

  1. 将生成的 input_ids(含prompt和回答)截取回答部分(去掉prompt)。
  2. 使用Tokenizer将 input_ids 转换回自然语言文本(如“步骤一:热锅凉油…”)。

输入输出形状

  • 截取后的 input_ids 形状为 (1, 190)
  • 最终输出文本长度取决于生成内容(如“步骤一:热锅凉油…”)

总结流程图

用户输入 → 模板化Prompt → Tokenizer → input_ids (1,10)  
          → 嵌入层 → hidden_states (1,10,4096)  
          → Decoder Block ×32 → hidden_states (1,10,4096)  
          → LM Head → 概率分布 (1,128000)  
          → 采样 → 新input_ids (1,11)  
          → 重复生成 → input_ids (1,200)  
          → Tokenizer反向 → 用户文本

LlamaForCausalLM结构分析

以模型Llama-3.1-8B-Instruct为例,将一部分子结构信息折叠起来,将显示如下:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(...)
    )
    (norm): RMSNorm()
  )
  (lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
  (logits_processor): LogitsProcessor()
  (pooler): Pooler()
)

可以看到LlamaForCausalLM主要由几个关键部分组成:model, lm_head, logits_processorpooler。这几个组件作用各不相同,我们现在来介绍一下他们。

1. model:核心解码器结构

(1) embed_tokens:词嵌入层
  • 作用:将输入的Token ID(如“西红柿”→ID=1234)映射为4096维的向量,表示Token的语义和位置信息。
  • 技术细节
    • 使用VocabParallelEmbedding(并行词嵌入,仅需了解,无需深入),支持分布式训练。
    • 词表大小为128256,覆盖多语言和特殊符号(如<BOS><EOS>)。
  • 输入输出形状
    • 输入:(batch_size, seq_len)(1, 10)(假设输入10个Token)
    • 输出:(batch_size, seq_len, hidden_dim)(1, 10, 4096)
(2) layers:32层Decoder Block
  • 核心结构
    • 多头注意力(MHA):通过Grouped-Query Attention (GQA) 提高推理效率(Llama 3.1新增)。
      • 查询(Q)、键(K)、值(V)的维度:d_model=4096num_heads=32head_dim=128
      • GQA机制:将K/V头数减少为num_key_value_heads=8,降低计算开销。
    • 前馈网络(MLP):使用SwiGLU激活函数(Sigmoid + Gated Linear Unit),替代传统ReLU。
      • 输入:4096维 → 中间层:11008维 → 输出:4096维。
    • 归一化:每层使用RMSNorm(均方根归一化),稳定训练并加速收敛。
  • 输入输出形状
    • 每层输入/输出:(1, 10, 4096)(与输入形状一致)
(3) norm:最终归一化层
  • 作用:对32层Decoder Block的输出进行最后一次归一化,确保数值稳定性。
  • 技术细节
    • 使用RMSNorm,无需计算均值,直接对向量的模长标准化。
    • 公式:hidden_states = hidden_states / sqrt(variance + ε),其中ε=1e-6

2. lm_head:语言模型头部

  • 作用:将最终的隐藏状态(hidden_dim=4096)映射为词表大小(vocab_size=128256)的概率分布,预测下一个词。
  • 技术细节
    • 使用ParallelLMHead(并行线性层),加速大规模词表的计算。
    • 参数量:4096 × 128256 ≈ 5.16B(占模型总参数量的约76%)。
  • 输入输出形状
    • 输入:(1, 4096)(取最后一个位置的隐藏状态)
    • 输出:(1, 128256)(每个词的概率值)

3. logits_processor:概率分布处理器

  • 作用:对lm_head输出的概率分布进行后处理,控制生成策略。
  • 常用功能
    • 温度调节(Temperature):降低温度(<1)使输出更确定,升高温度(>1)增加多样性。
    • Top-k/Top-p采样:从概率最高的k个词或累积概率达p的词中随机选择,平衡质量和多样性。
    • 重复惩罚(Repetition Penalty):抑制重复生成相同词(如避免“西红柿西红柿”)。
  • 输入输出形状
    • 输入:(1, 128256)(原始概率分布)
    • 输出:(1, 128256)(处理后的概率分布)

4. pooler:池化层

  • 作用:将整个序列的隐藏状态压缩为固定长度的向量,用于下游任务(如分类、相似度计算)。
  • 技术细节
    • 默认取第一个Token(如<BOS>)的隐藏状态作为全局表示。
    • 或使用平均池化/最大池化,但Llama 3.1通常直接取<BOS>
  • 输入输出形状
    • 输入:(1, 10, 4096)(全序列隐藏状态)
    • 输出:(1, 4096)(固定长度的全局向量)

总结:组件协同工作流程

  1. 输入处理:用户输入文本 → 模板化Prompt → embed_tokens(1, 10, 4096)
  2. 特征提取:32层Decoder Block → hidden_states(1, 10, 4096)
  3. 归一化norm → 稳定输出
  4. 生成预测
    • lm_head(1, 128256) 概率分布
    • logits_processor → 调整概率分布
    • 采样生成下一个词 → 更新 input_ids
  5. 迭代生成:重复步骤1-4,直到生成终止标记(<EOS>)或达到最大长度。
  6. 任务适配pooler 提取全局向量 → 用于分类、相似度等任务。
  • model:像一个厨师,逐步处理食材(Token)并调整火候(注意力机制)。
  • lm_head:厨师的“味觉”,决定下一步该加什么调料(预测下一个词)。
  • logits_processor:厨房的“规则制定者”,确保菜谱不重复且口味可控。
  • pooler:食客的“总结笔记”,用一句话概括整道菜的风味(全局语义)。