Stanford CS336 | Assignment 1 - Transformer Language Model Architecture

发布于:2025-09-14 ⋅ 阅读:(25) ⋅ 点赞:(0)

所有关于 assignment1 的代码已开源在:
https://github.com/ACEEE-1222/Standford-CS336-Assignment-1
如果对你有帮助的话,记得顺手点个star喔!

作业要求见 https://github.com/stanford-cs336/assignment1-basics

作业1后半段要求从零实现一个基于Transformer的语言模型(LM)——这是理解现代大语言模型(LLM)内部机制的关键实践。

本文将详细拆解Transformer语言模型的完整实现过程,涵盖多头注意力、旋转位置编码(RoPE)、RMS归一化、SwiGLU前馈网络等核心组件,同时讲解自定义AdamW优化器、学习率调度、批量数据处理等训练辅助工具,帮助读者掌握从模型构建到训练落地的全流程。

一、项目概述

本次作业的核心目标是搭建一个仅含解码器的Transformer语言模型(类似GPT结构),使其具备预测序列中下一个token的能力。模型采用了当前主流的设计方案(如预归一化、RoPE、RMSNorm)以兼顾效率与性能,同时配套实现了完整的训练流水线,可直接在文本数据上进行优化。

模型与训练框架的核心特点:

  • 基于解码器的Transformer结构,包含多头自注意力机制
  • 采用旋转位置编码(RoPE),增强模型对序列位置信息的捕捉能力
  • 使用RMSNorm替代传统LayerNorm,提升训练稳定性
  • 前馈网络中引入SwiGLU激活函数,性能优于ReLU等传统激活
  • 自定义AdamW优化器、余弦学习率调度器与梯度裁剪模块
  • 支持训练 checkpoint 保存与加载,便于中断后续训

二、核心组件实现:transformer.py

transformer.py 文件包含了Transformer语言模型的所有核心模块,各组件设计遵循模块化原则,既便于调试,也为后续扩展预留了空间。

2.1 基础通用层

这类层是模型的"基础工具",在多个模块中被复用,负责实现最基本的张量变换操作。

2.1.1 无偏置线性层(Linear)

简化版的线性变换层,移除了偏置项(bias),并采用类Xavier初始化(截断正态分布)保证训练稳定性。前向传播通过einops库实现清晰的张量维度映射,避免手动reshape导致的维度混乱。

class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, device=None, dtype=None):
        super().__init__()
        # 定义权重参数:形状为 (输出维度, 输入维度)
        self.weight = nn.Parameter(torch.empty((out_features, in_features), device=device, dtype=dtype))
        # 类Xavier初始化:标准差 = sqrt(2/(输入维度 + 输出维度)),避免梯度消失/爆炸
        std = (2 / (in_features + out_features)) ** 0.5
        nn.init.trunc_normal_(self.weight, std=std, a=-3*std, b=3*std)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 前向计算:y = x @ W^T(输入形状 ..., in_features → 输出形状 ..., out_features)
        return einsum(x, self.weight, "... in_features, out_features in_features -> ... out_features")
2.1.2 词嵌入层(Embedding)

将离散的token ID映射为连续的稠密向量,嵌入维度(embedding_dim)与模型隐藏层维度(d_model)保持一致。权重同样采用截断正态分布初始化,确保初始嵌入向量的分布合理性。

class Embedding(nn.Module):
    def __init__(
        self,
        num_embeddings: int,  # 词汇表大小(即总token数)
        embedding_dim: int,   # 嵌入向量维度(需等于d_model)
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        self.vocab_size = num_embeddings
        self.d_model = embedding_dim
        # 嵌入权重矩阵:形状为 (词汇表大小, 嵌入维度)
        self.weight = nn.Parameter(torch.empty((self.vocab_size, self.d_model), device=device, dtype=dtype))
        nn.init.trunc_normal_(self.weight, std=1, a=-3, b=3)

    def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
        # 输入:(batch_size, seq_len) → 输出:(batch_size, seq_len, embedding_dim)
        return self.weight[token_ids]  # 通过索引直接获取对应token的嵌入向量
2.1.3 RMS归一化层(RMSNorm)

相比传统LayerNorm,RMSNorm移除了均值中心化步骤,仅对输入的均方根(RMS)进行归一化,在减少计算量的同时提升训练稳定性,是LLaMA、GPT-4等模型的默认归一化方案。

class RMSNorm(nn.Module):
    def __init__(
        self,
        d_model: int,          # 输入维度(需等于模型隐藏层维度)
        eps: float = 1e-5,     # 防止分母为0的微小值
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        # 可学习的缩放参数:形状为 (d_model,),初始化为1(不改变归一化结果)
        self.weight = nn.Parameter(torch.ones(self.d_model, device=device, dtype=dtype))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 输入:(batch_size, seq_len, d_model) → 输出:同输入形状
        in_dtype = x.dtype  # 保存输入数据类型,避免精度损失
        x = x.to(dtype=torch.float32)  # 转为float32计算,提升数值稳定性
        
        # 1. 计算最后一维的均方根(RMS)
        rms = (x.pow(2).mean(-1, keepdim=True) + self.eps) ** 0.5
        # 2. 归一化 + 应用缩放参数
        out = x / rms * self.weight
        
        return out.to(dtype=in_dtype)  # 恢复原数据类型

2.2 激活函数与前馈网络

前馈网络(FFN)是Transformer中负责"特征转换"的核心模块,而SwiGLU则是当前性能最优的激活函数之一,两者结合可显著提升模型的表达能力。

2.2.1 SwiGLU激活函数

SwiGLU是GLU(Gated Linear Unit)的变体,通过Sigmoid门控对线性变换结果进行筛选,相比ReLU能更好地捕捉特征间的非线性关系,同时避免梯度消失问题。

class SwiGLU(nn.Module):
    def __init__(
        self,
        d_model: int,    # 输入维度(模型隐藏层维度)
        d_ff: int,       # 前馈网络中间层维度(通常为d_model的4倍)
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        # 定义三个线性层:W1/W3用于生成门控与候选特征,W2用于输出投影
        self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype)
        self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype)
        self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype)

    # 辅助函数:Sigmoid线性单元(SiLU)
    def _silu(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.sigmoid(x)

    # 辅助函数:门控线性单元(GLU)
    def _glu(self, x: torch.Tensor) -> torch.Tensor:
        return self._silu(self.w1(x)) * self.w3(x)  # SiLU门控 × W3线性变换结果

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 输入:(batch_size, seq_len, d_model) → 输出:同输入形状
        return self.w2(self._glu(x))  # 门控结果通过W2投影回d_model维度

2.3 位置编码:旋转位置编码(RoPE)

Transformer本身不具备位置感知能力,需通过位置编码注入序列顺序信息。RoPE通过旋转矩阵将位置信息编码到token的嵌入向量中,且支持长度外推(对长于训练序列的文本仍有效),是当前主流的位置编码方案。

class ROPE(nn.Module):
    def __init__(
        self,
        theta: float,       # RoPE基础频率(通常设为10000)
        d_k: int,           # 注意力头维度(需为偶数,因按奇偶维度分组旋转)
        max_seq_len: int,   # 支持的最大序列长度
        device: torch.device | None = None
    ):
        super().__init__()
        self.theta = theta
        self.d_k = d_k
        self.max_seq_len = max_seq_len
        self.device = device

        # 预计算cos/sin缓存(仅在首次初始化时计算,避免重复计算)
        if not hasattr(self, "cos_cached") or not hasattr(self, "sin_cached"):
            # 1. 计算频率矩阵:shape (d_k//2,)
            freqs_d = 1 / (theta ** (torch.arange(0, d_k, 2, device=device).float() / d_k))
            # 2. 计算位置矩阵:shape (max_seq_len,)
            pos_i = torch.arange(max_seq_len, device=device).float()
            # 3. 频率-位置外积:shape (max_seq_len, d_k//2)
            freqs = einsum(freqs_d, pos_i, "d_half, max_seq_len -> max_seq_len d_half")

            # 预计算cos和sin值(后续直接索引使用)
            self.register_buffer("cos_cached", torch.cos(freqs), persistent=False)
            self.register_buffer("sin_cached", torch.sin(freqs), persistent=False)

    def forward(
        self,
        x: torch.Tensor,                # 输入:(..., seq_len, d_k)
        token_positions: torch.Tensor   # 位置索引:(..., seq_len)
    ) -> torch.Tensor:
        # 1. 按最后一维的奇偶索引分组(d_k需为偶数)
        x_odd = x[..., 1::2]  # 奇数维度:索引1,3,5...
        x_even = x[..., ::2]  # 偶数维度:索引0,2,4...

        # 2. 获取当前序列长度对应的cos/sin值
        cos = self.cos_cached[token_positions]  # (..., seq_len, d_k//2)
        sin = self.sin_cached[token_positions]  # (..., seq_len, d_k//2)

        # 3. 应用旋转公式:将位置信息融入向量
        out1 = cos * x_even - sin * x_odd  # 偶数维度旋转结果
        out2 = sin * x_even + cos * x_odd  # 奇数维度旋转结果

        # 4. 重组维度:将奇偶分组合并回原d_k维度
        out = torch.stack([out1, out2], dim=-1).flatten(-2)  # (..., seq_len, d_k)
        return out

2.4 注意力机制:多头自注意力(Multi-Head Self-Attention)

注意力机制是Transformer的核心,负责捕捉序列内token间的依赖关系。多头注意力通过将隐藏层向量拆分到多个"头"中,并行计算注意力,可捕捉不同维度的依赖信息。

2.4.1 缩放点积注意力(辅助函数)

基础注意力计算模块,通过"查询(Q)-键(K)-值(V)"机制计算注意力权重,并引入缩放因子(√d_k)避免注意力分数过大导致的Softmax饱和问题。

def scaled_dot_product_attention(
    query: torch.Tensor,  # Q:(batch_size, ..., seq_len_q, d_k)
    key: torch.Tensor,    # K:(batch_size, ..., seq_len_k, d_k)
    value: torch.Tensor,  # V:(batch_size, ..., seq_len_k, d_v)
    mask: torch.Tensor = None  # 掩码:(seq_len_q, seq_len_k),True表示可关注
) -> torch.Tensor:
    d_k = query.shape[-1]
    # 1. 计算Q与K的点积(注意力分数),并除以√d_k缩放
    attention_scores = einsum(query, key, "... seq_len_q d_k, ... seq_len_k d_k -> ... seq_len_q seq_len_k") / (d_k ** 0.5)
    
    # 2. 应用掩码(如因果掩码,避免关注未来token)
    if mask is not None:
        attention_scores = attention_scores.masked_fill(~mask, float('-inf'))  # 掩码位置设为-∞,Softmax后权重为0
    
    # 3. Softmax归一化得到注意力权重,再与V加权求和
    attention_weights = softmax(attention_scores, dim=-1)
    output = einsum(attention_weights, value, "... seq_len_q seq_len_k, ... seq_len_k d_v -> ... seq_len_q d_v")
    
    return output

# 辅助函数:自定义Softmax(避免数值溢出)
def softmax(x: torch.tensor, dim: int):
    x_max = torch.max(x, dim=dim, keepdim=True).values  # 减去最大值,防止指数爆炸
    x_exp = torch.exp(x - x_max)
    sum_exp = torch.sum(x_exp, dim=dim, keepdim=True)
    return x_exp / sum_exp
2.4.2 多头自注意力模块

将输入向量通过线性层投影为Q、K、V,拆分到多个头中并行计算注意力,最后将所有头的结果拼接并投影回模型隐藏层维度(d_model)。

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, 
                theta: float | None = None,
                max_seq_len: int | None = None,
                ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度(d_model需能被num_heads整除)
        self.d_v = d_model // num_heads  # 简化设计:V的维度与Q/K一致

        # 1. Q/K/V投影层:将d_model映射为 num_heads × d_k(或d_v)
        self.q_proj = Linear(d_model, num_heads * self.d_k)
        self.k_proj = Linear(d_model, num_heads * self.d_k)
        self.v_proj = Linear(d_model, num_heads * self.d_v)
        # 2. 输出投影层:将多个头的结果拼接后映射回d_model
        self.output_proj = Linear(num_heads * self.d_v, d_model)

        # 3. 若传入theta和max_seq_len,初始化RoPE模块
        if theta is not None and max_seq_len is not None:
            self.rope = ROPE(theta, self.d_k, max_seq_len)

    def forward(self, x: torch.Tensor, 
                mask: torch.Tensor | None = None,
                token_positions: torch.Tensor | None = None) -> torch.Tensor:
        # 输入:(batch_size, seq_len, d_model)
        *batch_dims, seq_len, _ = x.shape  # 提取批次维度(如batch_size)

        # 1. Q/K/V投影与多头拆分
        x_q = self.q_proj(x)  # (batch_size, seq_len, num_heads×d_k)
        x_k = self.k_proj(x)  # (batch_size, seq_len, num_heads×d_k)
        x_v = self.v_proj(x)  # (batch_size, seq_len, num_heads×d_v)
        
        # 拆分多头:(batch_size, num_heads, seq_len, d_k)
        x_q = rearrange(x_q, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", 
                        num_heads=self.num_heads, d_k=self.d_k)
        x_k = rearrange(x_k, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", 
                        num_heads=self.num_heads, d_k=self.d_k)
        x_v = rearrange(x_v, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", 
                        num_heads=self.num_heads, d_v=self.d_v)

        # 2. 应用RoPE(若已初始化)
        if hasattr(self, "rope"):
            # 若未指定token_positions,默认按0~seq_len-1顺序编码
            if token_positions is None:
                token_positions = torch.arange(seq_len, device=x.device)
                #                # 扩展token_positions维度以匹配输入批次维度
                for _ in range(len(batch_dims)):
                    token_positions = token_positions.unsqueeze(0)
            # 对Q和K应用RoPE(V无需旋转)
            x_q = self.rope(x_q, token_positions)
            x_k = self.rope(x_k, token_positions)

        # 3. 生成掩码(默认使用因果掩码,防止关注未来token)
        if mask is None:
            # 因果掩码:上三角为False(不可关注),下三角及对角线为True(可关注)
            mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device))
            # 扩展掩码维度以匹配批次和头数维度
            for _ in range(len(batch_dims) + 1):  # +1 是为了适配num_heads维度
                mask = mask.unsqueeze(0)
        else:
            # 扩展用户提供的掩码维度
            for _ in range(len(batch_dims) + 1):
                mask = mask.unsqueeze(0)

        # 4. 计算缩放点积注意力
        attn_output = scaled_dot_product_attention(x_q, x_k, x_v, mask)

        # 5. 拼接多头结果并投影
        attn_output = rearrange(attn_output, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)",
                              num_heads=self.num_heads, d_v=self.d_v)
        output = self.output_proj(attn_output)  # 投影回d_model维度

        return output


### 2.5 Transformer块(TransformerBlock)
单个Transformer块是模型的基本重复单元,由"多头自注意力+前馈网络"组成,并采用**预归一化**(Pre-normalization)设计——在注意力和前馈网络前应用归一化,而非之后,这已被证明能显著提升训练稳定性。

```python
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int,
                 theta: float | None = None,
                 max_seq_len: int | None = None,):
        super().__init__()
        # 1. 归一化层(预归一化设计)
        self.ln1 = RMSNorm(d_model)  # 注意力层前的归一化
        self.ln2 = RMSNorm(d_model)  # 前馈网络前的归一化
        
        # 2. 前馈网络
        self.ffn = SwiGLU(d_model, d_ff)
        
        # 3. 多头自注意力(若指定theta和max_seq_len,则启用RoPE)
        if theta is not None and max_seq_len is not None:
            self.attn = MultiHeadSelfAttention(d_model, num_heads, theta, max_seq_len)
        else:
            self.attn = MultiHeadSelfAttention(d_model, num_heads)
    
    def forward(self, 
            x: torch.Tensor, 
            mask: torch.Tensor | None = None, 
            token_positions: torch.Tensor | None = None) -> torch.Tensor:
        # 残差连接 + 注意力:x = x + Attention(LN(x))
        x = x + self.attn(self.ln1(x), mask=mask, token_positions=token_positions)
        # 残差连接 + 前馈网络:x = x + FFN(LN(x))
        x = x + self.ffn(self.ln2(x))
        return x

2.6 完整语言模型(TransformerLM)

将所有组件组合成最终的Transformer语言模型,包含词嵌入层、多个Transformer块、最终归一化层和语言模型头(LM Head)。

class TransformerLM(nn.Module):
    def __init__(self, vocab_size: int,
                 context_length: int,  # 上下文窗口长度(最大序列长度)
                 num_layers: int,       # Transformer块数量
                 d_model: int,          # 模型隐藏层维度
                 num_heads: int,        # 注意力头数
                 d_ff: int,             # 前馈网络中间层维度
                 theta: float | None = None,):  # RoPE的基础频率
        super().__init__()

        self.vocab_size = vocab_size
        self.context_length = context_length
        
        # 1. 词嵌入层:将token ID映射为d_model维度向量
        self.token_embeddings = Embedding(vocab_size, d_model)
        
        # 2. 堆叠多个Transformer块
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, theta, context_length)
            for _ in range(num_layers)
        ])
        
        # 3. 最终归一化层
        self.ln_final = RMSNorm(d_model)
        
        # 4. 语言模型头:将d_model维度映射到词汇表大小(输出logits)
        self.lm_head = Linear(d_model, vocab_size)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        # 输入:(batch_size, seq_len) → 输出:(batch_size, seq_len, vocab_size)
        # 1. 词嵌入
        x = self.token_embeddings(inputs)  # (batch_size, seq_len, d_model)
        
        # 2. 经过所有Transformer块
        for layer in self.layers:
            x = layer(x)  # 每层输出仍为 (batch_size, seq_len, d_model)
        
        # 3. 最终归一化 + 映射到词汇表
        x = self.ln_final(x)
        logits = self.lm_head(x)  # (batch_size, seq_len, vocab_size)
        
        return logits

三、训练工具实现:train.py

train.py 包含了模型训练所需的核心工具,包括损失函数、优化器、学习率调度、数据批处理、梯度裁剪和模型 checkpoint 管理等。

3.1 损失函数:交叉熵(Cross-Entropy)

语言模型的核心任务是预测下一个token,因此采用交叉熵损失函数,衡量预测分布与真实token的差距。实现中加入了数值稳定技巧(减去最大值),避免指数运算溢出。

def cross_entropy(inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    计算批次上的平均交叉熵损失
    输入:
        inputs: (batch_size, ..., vocab_size) → 未归一化的logits
        targets: (batch_size, ...) → 真实token的索引
    输出:
        标量张量 → 批次平均损失
    """
    batch_size = inputs.shape[0]
    # 数值稳定:减去最大值,避免exp(x)溢出
    o_max = torch.max(inputs, dim=-1, keepdim=True).values
    o = inputs - o_max
    # 获取目标token对应的logit
    target_logits = o[torch.arange(batch_size), targets]
    # 计算log(sum(exp(o)))
    logsumexp = torch.log(torch.sum(torch.exp(o), dim=-1))
    # 单个样本损失:-target_logit + logsumexp
    loss = -target_logits + logsumexp
    # 返回批次平均值
    return loss.mean(dim=0)

3.2 优化器:AdamW

AdamW是Adam优化器的改进版,在Adam的基础上分离了权重衰减(Weight Decay)与梯度更新,有效提升模型泛化能力,是当前训练Transformer的主流优化器。

class AdamW(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,          # 初始学习率
        betas=(0.9, 0.999),  # 动量参数(一阶矩和二阶矩的指数衰减率)
        eps=1e-8,         # 防止分母为0的微小值
        weight_decay=0.01  # 权重衰减系数
    ):
        # 参数合法性检查
        if lr < 0.0:
            raise ValueError(f"无效学习率: {lr}")
        if eps < 0.0:
            raise ValueError(f"无效epsilon值: {eps}")
        if weight_decay < 0.0:
            raise ValueError(f"无效weight_decay值: {weight_decay}")
        if not 0.0 <= betas[0] < 1.0 or not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"无效betas参数: {betas}")

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

    def step(self, closure: Callable | None = None):
        # 可选:计算闭包(用于某些特殊场景)
        loss = None if closure is None else closure()
        
        # 遍历参数组(支持不同参数使用不同学习率)
        for group in self.param_groups:
            lr = group["lr"]
            beta1, beta2 = group["betas"]
            eps = group["eps"]
            weight_decay = group["weight_decay"]

            # 遍历参数
            for p in group["params"]:
                if p.grad is None:
                    continue  # 无梯度的参数跳过
                grad = p.grad.data  # 梯度
                state = self.state[p]  # 获取参数的状态字典

                # 初始化状态(首次更新时)
                t = state.get("t", 1)  # 迭代次数(初始为1)
                m = state.get("m", torch.zeros_like(grad))  # 一阶矩估计(动量)
                v = state.get("v", torch.zeros_like(grad))  # 二阶矩估计

                # 更新一阶矩和二阶矩(带偏差校正)
                m = beta1 * m + (1 - beta1) * grad
                v = beta2 * v + (1 - beta2) * grad ** 2

                # 学习率校正(偏差校正)
                lr_t = lr * (1 - beta2 ** t) ** 0.5 / (1 - beta1 ** t)

                # 参数更新:先减去梯度项,再应用权重衰减
                p.data = p.data - lr_t * m / (v ** 0.5 + eps)
                p.data = p.data - lr * weight_decay * p.data  # 权重衰减独立于梯度

                # 更新状态
                state["t"] = t + 1
                state["m"] = m
                state["v"] = v

        return loss

3.3 学习率调度:余弦退火(Cosine Schedule)

学习率调度对Transformer训练至关重要。我们实现了带预热的余弦退火调度:初始阶段线性提升学习率(预热),避免训练初期大学习率导致的不稳定性;随后按余弦曲线衰减至最小值,帮助模型收敛到更优解。

def lr_cosine_schedule(t: int, lr_max: float, lr_min: float, warmup_iters: int, cosine_cycle_iters: int):
    """
    带预热的余弦退火学习率调度
    参数:
        t: 当前迭代次数
        lr_max: 最大学习率(预热结束时达到)
        lr_min: 最小学习率(余弦衰减的下限)
        warmup_iters: 预热迭代次数
        cosine_cycle_iters: 余弦衰减的总迭代次数(含预热)
    返回:
        当前迭代的学习率
    """
    if t < warmup_iters:
        # 预热阶段:线性增长
        lr = t / warmup_iters * lr_max
    elif t < cosine_cycle_iters:
        # 余弦衰减阶段:从lr_max平滑衰减到lr_min
        # 计算当前相位(0到π之间)
        phase = (t - warmup_iters) / (cosine_cycle_iters - warmup_iters) * math.pi
        lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(phase))
    else:
        # 衰减结束:固定为lr_min
        lr = lr_min
    return lr

3.4 梯度裁剪(Gradient Clipping)

训练过程中,梯度可能因某些异常样本或大学习率而剧烈波动(梯度爆炸),导致模型不稳定。梯度裁剪通过限制梯度的L2范数,将梯度控制在合理范围内。

def gradient_clipping(parameters: Iterable[torch.nn.Parameter], max_l2_norm: float, eps: float = 1e-6):
    """
    裁剪梯度的L2范数以防止梯度爆炸
    参数:
        parameters: 需裁剪梯度的参数列表
        max_l2_norm: 梯度L2范数的上限
        eps: 防止分母为0的微小值
    """
    # 收集所有有梯度的参数
    grads = [p.grad for p in parameters if p.grad is not None]
    if not grads:
        return  # 无梯度可裁剪
    
    # 计算所有梯度的总L2范数
    l2_norm = 0.0
    for g in grads:
        l2_norm += torch.sum(g ** 2)  # 累加平方和
    l2_norm = torch.sqrt(l2_norm)  # 开平方得L2范数
    
    # 计算裁剪系数(若范数超过上限,则按比例缩小)
    clip_coef = min(1.0, max_l2_norm / (l2_norm + eps))
    
    # 应用裁剪
    for g in grads:
        g *= clip_coef

3.5 数据批处理:get_batch

从文本数据集中采样训练批次,生成输入序列(x)和对应的标签序列(y)——对于语言模型,yx 向右偏移一位的序列(即预测下一个token)。

def get_batch(
    dataset: npt.NDArray, batch_size: int, context_length: int, device: str
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    从数据集中采样批次数据
    参数:
        dataset: 1D numpy数组,存储文本的token ID序列
        batch_size: 批次大小
        context_length: 每个样本的序列长度
        device: 数据存放的设备(如'cpu'或'cuda')
    返回:
        x: (batch_size, context_length) → 输入序列
        y: (batch_size, context_length) → 标签序列(x向右偏移一位)
    """
    # 计算最大起始索引(确保序列不越界)
    max_start = len(dataset) - context_length - 1
    if max_start <= 0:
        raise ValueError("数据集长度小于指定的context_length")

    # 随机采样batch_size个起始索引
    starts = np.random.randint(0, max_start + 1, size=batch_size)

    x_batch = []
    y_batch = []
    for s in starts:
        # 截取序列:[s, s+context_length+1)
        seq = dataset[s : s + context_length + 1]
        x_batch.append(seq[:-1])  # 输入:前context_length个token
        y_batch.append(seq[1:])   # 标签:后context_length个token(偏移一位)

    # 转换为PyTorch张量并移动到指定设备
    x = torch.tensor(x_batch, dtype=torch.long, device=device)
    y = torch.tensor(y_batch, dtype=torch.long, device=device)
    
    return x, y

3.6 模型Checkpoint管理

训练大型模型时,需要定期保存训练状态(模型参数、优化器状态、当前迭代次数),以便中断后恢复训练或后续评估。

def save_checkpoint(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    iteration: int,
    out: str | os.PathLike | BinaryIO | IO[bytes],
):
    """保存训练状态到文件"""
    checkpoint = {
        'model_state': model.state_dict(),      # 模型参数
        'optimizer_state': optimizer.state_dict(),  # 优化器状态
        'iteration': iteration,                # 当前迭代次数
    }
    torch.save(checkpoint, out)

def load_checkpoint(
    src: str | os.PathLike | BinaryIO | IO[bytes],
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer
) -> int:
    """从文件加载训练状态并恢复模型和优化器"""
    checkpoint = torch.load(src)
    model.load_state_dict(checkpoint['model_state'])
    optimizer.load_state_dict(checkpoint['optimizer_state'])
    return checkpoint['iteration']  # 返回保存时的迭代次数

四、总结与扩展

通过本次作业,我们从零实现了一个完整的Transformer语言模型,涵盖了现代LLM的核心组件:

  • 采用预归一化设计的Transformer块,提升训练稳定性
  • 旋转位置编码(RoPE)解决位置信息编码问题,支持长度外推
  • 多头自注意力机制捕捉token间的依赖关系
  • SwiGLU激活函数增强前馈网络的表达能力
  • 配套实现了AdamW优化器、余弦学习率调度等训练工具

另外,adapter.py里的调用就不再一个个放了,考虑到上次有人在评论区里问,我放个例子吧:

def run_transformer_lm(
    vocab_size: int,
    context_length: int,
    d_model: int,
    num_layers: int,
    num_heads: int,
    d_ff: int,
    rope_theta: float,
    weights: dict[str, Tensor],
    in_indices: Int[Tensor, " batch_size sequence_length"],
) -> Float[Tensor, " batch_size sequence_length vocab_size"]:
    """Given the weights of a Transformer language model and input indices,
    return the output of running a forward pass on the input indices.

    This function should use RoPE.

    Args:
        vocab_size (int): The number of unique items in the output vocabulary to be predicted.
        context_length (int): The maximum number of tokens to process at once.
        d_model (int): The dimensionality of the model embeddings and sublayer outputs.
        num_layers (int): The number of Transformer layers to use.
        num_heads (int): Number of heads to use in multi-headed attention. `d_model` must be
            evenly divisible by `num_heads`.
        d_ff (int): Dimensionality of the feed-forward inner layer (section 3.3).
        rope_theta (float): The RoPE $\Theta$ parameter.
        weights (dict[str, Tensor]): 
            State dict of our reference implementation. {num_layers} refers to an
            integer between `0` and `num_layers - 1` (the layer index).
            The keys of this dictionary are:
            - `token_embeddings.weight`
                Token embedding matrix. Shape is (vocab_size, d_model).
            - `layers.{num_layers}.attn.q_proj.weight`
                The query projections for all `num_heads` attention heads.
                Shape is (num_heads * (d_model / num_heads), d_model).
                The rows are ordered by matrices of shape (num_heads, d_k),
                so `attn.q_proj.weight == torch.cat([q_heads.0.weight, ..., q_heads.N.weight], dim=0)`.
            - `layers.{num_layers}.attn.k_proj.weight`
                The key projections for all `num_heads` attention heads.
                Shape is (num_heads * (d_model / num_heads), d_model).
                The rows are ordered by matrices of shape (num_heads, d_k),
                so `attn.k_proj.weight == torch.cat([k_heads.0.weight, ..., k_heads.N.weight], dim=0)`.
            - `layers.{num_layers}.attn.v_proj.weight`
                The value projections for all `num_heads` attention heads.
                Shape is (num_heads * (d_model / num_heads), d_model).
                The rows are ordered by matrices of shape (num_heads, d_v),
                so `attn.v_proj.weight == torch.cat([v_heads.0.weight, ..., v_heads.N.weight], dim=0)`.
            - `layers.{num_layers}.attn.output_proj.weight`
                Weight of the multi-head self-attention output projection
                Shape is ((d_model / num_heads) * num_heads, d_model).
            - `layers.{num_layers}.ln1.weight`
                Weights of affine transform for the first RMSNorm
                applied in the transformer block.
                Shape is (d_model,).
            - `layers.{num_layers}.ffn.w1.weight`
                Weight of the first linear transformation in the FFN.
                Shape is (d_model, d_ff).
            - `layers.{num_layers}.ffn.w2.weight`
                Weight of the second linear transformation in the FFN.
                Shape is (d_ff, d_model).
            - `layers.{num_layers}.ffn.w3.weight`
                Weight of the third linear transformation in the FFN.
                Shape is (d_model, d_ff).
            - `layers.{num_layers}.ln2.weight`
                Weights of affine transform for the second RMSNorm
                applied in the transformer block.
                Shape is (d_model,).
            - `ln_final.weight`
                Weights of affine transform for RMSNorm applied to the output of the final transformer block.
                Shape is (d_model, ).
            - `lm_head.weight`
                Weights of the language model output embedding.
                Shape is (vocab_size, d_model).
        in_indices (Int[Tensor, "batch_size sequence_length"]) Tensor with input indices to run the language model on. Shape is (batch_size, sequence_length), where
            `sequence_length` is at most `context_length`.

    Returns:
        Float[Tensor, "batch_size sequence_length vocab_size"]: Tensor with the predicted unnormalized
        next-word distribution for each token.
    """
    model = TransformerLM(vocab_size,context_length, num_layers,d_model, 
                          num_heads, d_ff, rope_theta).to(device)
    model.load_state_dict(weights)

    return model(in_indices.to(device=device))