所有关于 assignment1 的代码已开源在:
https://github.com/ACEEE-1222/Standford-CS336-Assignment-1
如果对你有帮助的话,记得顺手点个star喔!
作业要求见 https://github.com/stanford-cs336/assignment1-basics
作业1后半段要求从零实现一个基于Transformer的语言模型(LM)——这是理解现代大语言模型(LLM)内部机制的关键实践。
本文将详细拆解Transformer语言模型的完整实现过程,涵盖多头注意力、旋转位置编码(RoPE)、RMS归一化、SwiGLU前馈网络等核心组件,同时讲解自定义AdamW优化器、学习率调度、批量数据处理等训练辅助工具,帮助读者掌握从模型构建到训练落地的全流程。
一、项目概述
本次作业的核心目标是搭建一个仅含解码器的Transformer语言模型(类似GPT结构),使其具备预测序列中下一个token的能力。模型采用了当前主流的设计方案(如预归一化、RoPE、RMSNorm)以兼顾效率与性能,同时配套实现了完整的训练流水线,可直接在文本数据上进行优化。
模型与训练框架的核心特点:
- 基于解码器的Transformer结构,包含多头自注意力机制
- 采用旋转位置编码(RoPE),增强模型对序列位置信息的捕捉能力
- 使用RMSNorm替代传统LayerNorm,提升训练稳定性
- 前馈网络中引入SwiGLU激活函数,性能优于ReLU等传统激活
- 自定义AdamW优化器、余弦学习率调度器与梯度裁剪模块
- 支持训练 checkpoint 保存与加载,便于中断后续训
二、核心组件实现:transformer.py
transformer.py
文件包含了Transformer语言模型的所有核心模块,各组件设计遵循模块化原则,既便于调试,也为后续扩展预留了空间。
2.1 基础通用层
这类层是模型的"基础工具",在多个模块中被复用,负责实现最基本的张量变换操作。
2.1.1 无偏置线性层(Linear)
简化版的线性变换层,移除了偏置项(bias),并采用类Xavier初始化(截断正态分布)保证训练稳定性。前向传播通过einops
库实现清晰的张量维度映射,避免手动reshape导致的维度混乱。
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, device=None, dtype=None):
super().__init__()
# 定义权重参数:形状为 (输出维度, 输入维度)
self.weight = nn.Parameter(torch.empty((out_features, in_features), device=device, dtype=dtype))
# 类Xavier初始化:标准差 = sqrt(2/(输入维度 + 输出维度)),避免梯度消失/爆炸
std = (2 / (in_features + out_features)) ** 0.5
nn.init.trunc_normal_(self.weight, std=std, a=-3*std, b=3*std)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 前向计算:y = x @ W^T(输入形状 ..., in_features → 输出形状 ..., out_features)
return einsum(x, self.weight, "... in_features, out_features in_features -> ... out_features")
2.1.2 词嵌入层(Embedding)
将离散的token ID映射为连续的稠密向量,嵌入维度(embedding_dim)与模型隐藏层维度(d_model)保持一致。权重同样采用截断正态分布初始化,确保初始嵌入向量的分布合理性。
class Embedding(nn.Module):
def __init__(
self,
num_embeddings: int, # 词汇表大小(即总token数)
embedding_dim: int, # 嵌入向量维度(需等于d_model)
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.vocab_size = num_embeddings
self.d_model = embedding_dim
# 嵌入权重矩阵:形状为 (词汇表大小, 嵌入维度)
self.weight = nn.Parameter(torch.empty((self.vocab_size, self.d_model), device=device, dtype=dtype))
nn.init.trunc_normal_(self.weight, std=1, a=-3, b=3)
def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
# 输入:(batch_size, seq_len) → 输出:(batch_size, seq_len, embedding_dim)
return self.weight[token_ids] # 通过索引直接获取对应token的嵌入向量
2.1.3 RMS归一化层(RMSNorm)
相比传统LayerNorm,RMSNorm移除了均值中心化步骤,仅对输入的均方根(RMS)进行归一化,在减少计算量的同时提升训练稳定性,是LLaMA、GPT-4等模型的默认归一化方案。
class RMSNorm(nn.Module):
def __init__(
self,
d_model: int, # 输入维度(需等于模型隐藏层维度)
eps: float = 1e-5, # 防止分母为0的微小值
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.d_model = d_model
self.eps = eps
# 可学习的缩放参数:形状为 (d_model,),初始化为1(不改变归一化结果)
self.weight = nn.Parameter(torch.ones(self.d_model, device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 输入:(batch_size, seq_len, d_model) → 输出:同输入形状
in_dtype = x.dtype # 保存输入数据类型,避免精度损失
x = x.to(dtype=torch.float32) # 转为float32计算,提升数值稳定性
# 1. 计算最后一维的均方根(RMS)
rms = (x.pow(2).mean(-1, keepdim=True) + self.eps) ** 0.5
# 2. 归一化 + 应用缩放参数
out = x / rms * self.weight
return out.to(dtype=in_dtype) # 恢复原数据类型
2.2 激活函数与前馈网络
前馈网络(FFN)是Transformer中负责"特征转换"的核心模块,而SwiGLU则是当前性能最优的激活函数之一,两者结合可显著提升模型的表达能力。
2.2.1 SwiGLU激活函数
SwiGLU是GLU(Gated Linear Unit)的变体,通过Sigmoid门控对线性变换结果进行筛选,相比ReLU能更好地捕捉特征间的非线性关系,同时避免梯度消失问题。
class SwiGLU(nn.Module):
def __init__(
self,
d_model: int, # 输入维度(模型隐藏层维度)
d_ff: int, # 前馈网络中间层维度(通常为d_model的4倍)
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
# 定义三个线性层:W1/W3用于生成门控与候选特征,W2用于输出投影
self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype)
self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype)
self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype)
# 辅助函数:Sigmoid线性单元(SiLU)
def _silu(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
# 辅助函数:门控线性单元(GLU)
def _glu(self, x: torch.Tensor) -> torch.Tensor:
return self._silu(self.w1(x)) * self.w3(x) # SiLU门控 × W3线性变换结果
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 输入:(batch_size, seq_len, d_model) → 输出:同输入形状
return self.w2(self._glu(x)) # 门控结果通过W2投影回d_model维度
2.3 位置编码:旋转位置编码(RoPE)
Transformer本身不具备位置感知能力,需通过位置编码注入序列顺序信息。RoPE通过旋转矩阵将位置信息编码到token的嵌入向量中,且支持长度外推(对长于训练序列的文本仍有效),是当前主流的位置编码方案。
class ROPE(nn.Module):
def __init__(
self,
theta: float, # RoPE基础频率(通常设为10000)
d_k: int, # 注意力头维度(需为偶数,因按奇偶维度分组旋转)
max_seq_len: int, # 支持的最大序列长度
device: torch.device | None = None
):
super().__init__()
self.theta = theta
self.d_k = d_k
self.max_seq_len = max_seq_len
self.device = device
# 预计算cos/sin缓存(仅在首次初始化时计算,避免重复计算)
if not hasattr(self, "cos_cached") or not hasattr(self, "sin_cached"):
# 1. 计算频率矩阵:shape (d_k//2,)
freqs_d = 1 / (theta ** (torch.arange(0, d_k, 2, device=device).float() / d_k))
# 2. 计算位置矩阵:shape (max_seq_len,)
pos_i = torch.arange(max_seq_len, device=device).float()
# 3. 频率-位置外积:shape (max_seq_len, d_k//2)
freqs = einsum(freqs_d, pos_i, "d_half, max_seq_len -> max_seq_len d_half")
# 预计算cos和sin值(后续直接索引使用)
self.register_buffer("cos_cached", torch.cos(freqs), persistent=False)
self.register_buffer("sin_cached", torch.sin(freqs), persistent=False)
def forward(
self,
x: torch.Tensor, # 输入:(..., seq_len, d_k)
token_positions: torch.Tensor # 位置索引:(..., seq_len)
) -> torch.Tensor:
# 1. 按最后一维的奇偶索引分组(d_k需为偶数)
x_odd = x[..., 1::2] # 奇数维度:索引1,3,5...
x_even = x[..., ::2] # 偶数维度:索引0,2,4...
# 2. 获取当前序列长度对应的cos/sin值
cos = self.cos_cached[token_positions] # (..., seq_len, d_k//2)
sin = self.sin_cached[token_positions] # (..., seq_len, d_k//2)
# 3. 应用旋转公式:将位置信息融入向量
out1 = cos * x_even - sin * x_odd # 偶数维度旋转结果
out2 = sin * x_even + cos * x_odd # 奇数维度旋转结果
# 4. 重组维度:将奇偶分组合并回原d_k维度
out = torch.stack([out1, out2], dim=-1).flatten(-2) # (..., seq_len, d_k)
return out
2.4 注意力机制:多头自注意力(Multi-Head Self-Attention)
注意力机制是Transformer的核心,负责捕捉序列内token间的依赖关系。多头注意力通过将隐藏层向量拆分到多个"头"中,并行计算注意力,可捕捉不同维度的依赖信息。
2.4.1 缩放点积注意力(辅助函数)
基础注意力计算模块,通过"查询(Q)-键(K)-值(V)"机制计算注意力权重,并引入缩放因子(√d_k)避免注意力分数过大导致的Softmax饱和问题。
def scaled_dot_product_attention(
query: torch.Tensor, # Q:(batch_size, ..., seq_len_q, d_k)
key: torch.Tensor, # K:(batch_size, ..., seq_len_k, d_k)
value: torch.Tensor, # V:(batch_size, ..., seq_len_k, d_v)
mask: torch.Tensor = None # 掩码:(seq_len_q, seq_len_k),True表示可关注
) -> torch.Tensor:
d_k = query.shape[-1]
# 1. 计算Q与K的点积(注意力分数),并除以√d_k缩放
attention_scores = einsum(query, key, "... seq_len_q d_k, ... seq_len_k d_k -> ... seq_len_q seq_len_k") / (d_k ** 0.5)
# 2. 应用掩码(如因果掩码,避免关注未来token)
if mask is not None:
attention_scores = attention_scores.masked_fill(~mask, float('-inf')) # 掩码位置设为-∞,Softmax后权重为0
# 3. Softmax归一化得到注意力权重,再与V加权求和
attention_weights = softmax(attention_scores, dim=-1)
output = einsum(attention_weights, value, "... seq_len_q seq_len_k, ... seq_len_k d_v -> ... seq_len_q d_v")
return output
# 辅助函数:自定义Softmax(避免数值溢出)
def softmax(x: torch.tensor, dim: int):
x_max = torch.max(x, dim=dim, keepdim=True).values # 减去最大值,防止指数爆炸
x_exp = torch.exp(x - x_max)
sum_exp = torch.sum(x_exp, dim=dim, keepdim=True)
return x_exp / sum_exp
2.4.2 多头自注意力模块
将输入向量通过线性层投影为Q、K、V,拆分到多个头中并行计算注意力,最后将所有头的结果拼接并投影回模型隐藏层维度(d_model)。
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int,
theta: float | None = None,
max_seq_len: int | None = None,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度(d_model需能被num_heads整除)
self.d_v = d_model // num_heads # 简化设计:V的维度与Q/K一致
# 1. Q/K/V投影层:将d_model映射为 num_heads × d_k(或d_v)
self.q_proj = Linear(d_model, num_heads * self.d_k)
self.k_proj = Linear(d_model, num_heads * self.d_k)
self.v_proj = Linear(d_model, num_heads * self.d_v)
# 2. 输出投影层:将多个头的结果拼接后映射回d_model
self.output_proj = Linear(num_heads * self.d_v, d_model)
# 3. 若传入theta和max_seq_len,初始化RoPE模块
if theta is not None and max_seq_len is not None:
self.rope = ROPE(theta, self.d_k, max_seq_len)
def forward(self, x: torch.Tensor,
mask: torch.Tensor | None = None,
token_positions: torch.Tensor | None = None) -> torch.Tensor:
# 输入:(batch_size, seq_len, d_model)
*batch_dims, seq_len, _ = x.shape # 提取批次维度(如batch_size)
# 1. Q/K/V投影与多头拆分
x_q = self.q_proj(x) # (batch_size, seq_len, num_heads×d_k)
x_k = self.k_proj(x) # (batch_size, seq_len, num_heads×d_k)
x_v = self.v_proj(x) # (batch_size, seq_len, num_heads×d_v)
# 拆分多头:(batch_size, num_heads, seq_len, d_k)
x_q = rearrange(x_q, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k",
num_heads=self.num_heads, d_k=self.d_k)
x_k = rearrange(x_k, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k",
num_heads=self.num_heads, d_k=self.d_k)
x_v = rearrange(x_v, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v",
num_heads=self.num_heads, d_v=self.d_v)
# 2. 应用RoPE(若已初始化)
if hasattr(self, "rope"):
# 若未指定token_positions,默认按0~seq_len-1顺序编码
if token_positions is None:
token_positions = torch.arange(seq_len, device=x.device)
# # 扩展token_positions维度以匹配输入批次维度
for _ in range(len(batch_dims)):
token_positions = token_positions.unsqueeze(0)
# 对Q和K应用RoPE(V无需旋转)
x_q = self.rope(x_q, token_positions)
x_k = self.rope(x_k, token_positions)
# 3. 生成掩码(默认使用因果掩码,防止关注未来token)
if mask is None:
# 因果掩码:上三角为False(不可关注),下三角及对角线为True(可关注)
mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device))
# 扩展掩码维度以匹配批次和头数维度
for _ in range(len(batch_dims) + 1): # +1 是为了适配num_heads维度
mask = mask.unsqueeze(0)
else:
# 扩展用户提供的掩码维度
for _ in range(len(batch_dims) + 1):
mask = mask.unsqueeze(0)
# 4. 计算缩放点积注意力
attn_output = scaled_dot_product_attention(x_q, x_k, x_v, mask)
# 5. 拼接多头结果并投影
attn_output = rearrange(attn_output, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)",
num_heads=self.num_heads, d_v=self.d_v)
output = self.output_proj(attn_output) # 投影回d_model维度
return output
### 2.5 Transformer块(TransformerBlock)
单个Transformer块是模型的基本重复单元,由"多头自注意力+前馈网络"组成,并采用**预归一化**(Pre-normalization)设计——在注意力和前馈网络前应用归一化,而非之后,这已被证明能显著提升训练稳定性。
```python
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int,
theta: float | None = None,
max_seq_len: int | None = None,):
super().__init__()
# 1. 归一化层(预归一化设计)
self.ln1 = RMSNorm(d_model) # 注意力层前的归一化
self.ln2 = RMSNorm(d_model) # 前馈网络前的归一化
# 2. 前馈网络
self.ffn = SwiGLU(d_model, d_ff)
# 3. 多头自注意力(若指定theta和max_seq_len,则启用RoPE)
if theta is not None and max_seq_len is not None:
self.attn = MultiHeadSelfAttention(d_model, num_heads, theta, max_seq_len)
else:
self.attn = MultiHeadSelfAttention(d_model, num_heads)
def forward(self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
token_positions: torch.Tensor | None = None) -> torch.Tensor:
# 残差连接 + 注意力:x = x + Attention(LN(x))
x = x + self.attn(self.ln1(x), mask=mask, token_positions=token_positions)
# 残差连接 + 前馈网络:x = x + FFN(LN(x))
x = x + self.ffn(self.ln2(x))
return x
2.6 完整语言模型(TransformerLM)
将所有组件组合成最终的Transformer语言模型,包含词嵌入层、多个Transformer块、最终归一化层和语言模型头(LM Head)。
class TransformerLM(nn.Module):
def __init__(self, vocab_size: int,
context_length: int, # 上下文窗口长度(最大序列长度)
num_layers: int, # Transformer块数量
d_model: int, # 模型隐藏层维度
num_heads: int, # 注意力头数
d_ff: int, # 前馈网络中间层维度
theta: float | None = None,): # RoPE的基础频率
super().__init__()
self.vocab_size = vocab_size
self.context_length = context_length
# 1. 词嵌入层:将token ID映射为d_model维度向量
self.token_embeddings = Embedding(vocab_size, d_model)
# 2. 堆叠多个Transformer块
self.layers = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, theta, context_length)
for _ in range(num_layers)
])
# 3. 最终归一化层
self.ln_final = RMSNorm(d_model)
# 4. 语言模型头:将d_model维度映射到词汇表大小(输出logits)
self.lm_head = Linear(d_model, vocab_size)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# 输入:(batch_size, seq_len) → 输出:(batch_size, seq_len, vocab_size)
# 1. 词嵌入
x = self.token_embeddings(inputs) # (batch_size, seq_len, d_model)
# 2. 经过所有Transformer块
for layer in self.layers:
x = layer(x) # 每层输出仍为 (batch_size, seq_len, d_model)
# 3. 最终归一化 + 映射到词汇表
x = self.ln_final(x)
logits = self.lm_head(x) # (batch_size, seq_len, vocab_size)
return logits
三、训练工具实现:train.py
train.py
包含了模型训练所需的核心工具,包括损失函数、优化器、学习率调度、数据批处理、梯度裁剪和模型 checkpoint 管理等。
3.1 损失函数:交叉熵(Cross-Entropy)
语言模型的核心任务是预测下一个token,因此采用交叉熵损失函数,衡量预测分布与真实token的差距。实现中加入了数值稳定技巧(减去最大值),避免指数运算溢出。
def cross_entropy(inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
计算批次上的平均交叉熵损失
输入:
inputs: (batch_size, ..., vocab_size) → 未归一化的logits
targets: (batch_size, ...) → 真实token的索引
输出:
标量张量 → 批次平均损失
"""
batch_size = inputs.shape[0]
# 数值稳定:减去最大值,避免exp(x)溢出
o_max = torch.max(inputs, dim=-1, keepdim=True).values
o = inputs - o_max
# 获取目标token对应的logit
target_logits = o[torch.arange(batch_size), targets]
# 计算log(sum(exp(o)))
logsumexp = torch.log(torch.sum(torch.exp(o), dim=-1))
# 单个样本损失:-target_logit + logsumexp
loss = -target_logits + logsumexp
# 返回批次平均值
return loss.mean(dim=0)
3.2 优化器:AdamW
AdamW是Adam优化器的改进版,在Adam的基础上分离了权重衰减(Weight Decay)与梯度更新,有效提升模型泛化能力,是当前训练Transformer的主流优化器。
class AdamW(torch.optim.Optimizer):
def __init__(
self,
params,
lr=1e-3, # 初始学习率
betas=(0.9, 0.999), # 动量参数(一阶矩和二阶矩的指数衰减率)
eps=1e-8, # 防止分母为0的微小值
weight_decay=0.01 # 权重衰减系数
):
# 参数合法性检查
if lr < 0.0:
raise ValueError(f"无效学习率: {lr}")
if eps < 0.0:
raise ValueError(f"无效epsilon值: {eps}")
if weight_decay < 0.0:
raise ValueError(f"无效weight_decay值: {weight_decay}")
if not 0.0 <= betas[0] < 1.0 or not 0.0 <= betas[1] < 1.0:
raise ValueError(f"无效betas参数: {betas}")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
def step(self, closure: Callable | None = None):
# 可选:计算闭包(用于某些特殊场景)
loss = None if closure is None else closure()
# 遍历参数组(支持不同参数使用不同学习率)
for group in self.param_groups:
lr = group["lr"]
beta1, beta2 = group["betas"]
eps = group["eps"]
weight_decay = group["weight_decay"]
# 遍历参数
for p in group["params"]:
if p.grad is None:
continue # 无梯度的参数跳过
grad = p.grad.data # 梯度
state = self.state[p] # 获取参数的状态字典
# 初始化状态(首次更新时)
t = state.get("t", 1) # 迭代次数(初始为1)
m = state.get("m", torch.zeros_like(grad)) # 一阶矩估计(动量)
v = state.get("v", torch.zeros_like(grad)) # 二阶矩估计
# 更新一阶矩和二阶矩(带偏差校正)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * grad ** 2
# 学习率校正(偏差校正)
lr_t = lr * (1 - beta2 ** t) ** 0.5 / (1 - beta1 ** t)
# 参数更新:先减去梯度项,再应用权重衰减
p.data = p.data - lr_t * m / (v ** 0.5 + eps)
p.data = p.data - lr * weight_decay * p.data # 权重衰减独立于梯度
# 更新状态
state["t"] = t + 1
state["m"] = m
state["v"] = v
return loss
3.3 学习率调度:余弦退火(Cosine Schedule)
学习率调度对Transformer训练至关重要。我们实现了带预热的余弦退火调度:初始阶段线性提升学习率(预热),避免训练初期大学习率导致的不稳定性;随后按余弦曲线衰减至最小值,帮助模型收敛到更优解。
def lr_cosine_schedule(t: int, lr_max: float, lr_min: float, warmup_iters: int, cosine_cycle_iters: int):
"""
带预热的余弦退火学习率调度
参数:
t: 当前迭代次数
lr_max: 最大学习率(预热结束时达到)
lr_min: 最小学习率(余弦衰减的下限)
warmup_iters: 预热迭代次数
cosine_cycle_iters: 余弦衰减的总迭代次数(含预热)
返回:
当前迭代的学习率
"""
if t < warmup_iters:
# 预热阶段:线性增长
lr = t / warmup_iters * lr_max
elif t < cosine_cycle_iters:
# 余弦衰减阶段:从lr_max平滑衰减到lr_min
# 计算当前相位(0到π之间)
phase = (t - warmup_iters) / (cosine_cycle_iters - warmup_iters) * math.pi
lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(phase))
else:
# 衰减结束:固定为lr_min
lr = lr_min
return lr
3.4 梯度裁剪(Gradient Clipping)
训练过程中,梯度可能因某些异常样本或大学习率而剧烈波动(梯度爆炸),导致模型不稳定。梯度裁剪通过限制梯度的L2范数,将梯度控制在合理范围内。
def gradient_clipping(parameters: Iterable[torch.nn.Parameter], max_l2_norm: float, eps: float = 1e-6):
"""
裁剪梯度的L2范数以防止梯度爆炸
参数:
parameters: 需裁剪梯度的参数列表
max_l2_norm: 梯度L2范数的上限
eps: 防止分母为0的微小值
"""
# 收集所有有梯度的参数
grads = [p.grad for p in parameters if p.grad is not None]
if not grads:
return # 无梯度可裁剪
# 计算所有梯度的总L2范数
l2_norm = 0.0
for g in grads:
l2_norm += torch.sum(g ** 2) # 累加平方和
l2_norm = torch.sqrt(l2_norm) # 开平方得L2范数
# 计算裁剪系数(若范数超过上限,则按比例缩小)
clip_coef = min(1.0, max_l2_norm / (l2_norm + eps))
# 应用裁剪
for g in grads:
g *= clip_coef
3.5 数据批处理:get_batch
从文本数据集中采样训练批次,生成输入序列(x
)和对应的标签序列(y
)——对于语言模型,y
是 x
向右偏移一位的序列(即预测下一个token)。
def get_batch(
dataset: npt.NDArray, batch_size: int, context_length: int, device: str
) -> tuple[torch.Tensor, torch.Tensor]:
"""
从数据集中采样批次数据
参数:
dataset: 1D numpy数组,存储文本的token ID序列
batch_size: 批次大小
context_length: 每个样本的序列长度
device: 数据存放的设备(如'cpu'或'cuda')
返回:
x: (batch_size, context_length) → 输入序列
y: (batch_size, context_length) → 标签序列(x向右偏移一位)
"""
# 计算最大起始索引(确保序列不越界)
max_start = len(dataset) - context_length - 1
if max_start <= 0:
raise ValueError("数据集长度小于指定的context_length")
# 随机采样batch_size个起始索引
starts = np.random.randint(0, max_start + 1, size=batch_size)
x_batch = []
y_batch = []
for s in starts:
# 截取序列:[s, s+context_length+1)
seq = dataset[s : s + context_length + 1]
x_batch.append(seq[:-1]) # 输入:前context_length个token
y_batch.append(seq[1:]) # 标签:后context_length个token(偏移一位)
# 转换为PyTorch张量并移动到指定设备
x = torch.tensor(x_batch, dtype=torch.long, device=device)
y = torch.tensor(y_batch, dtype=torch.long, device=device)
return x, y
3.6 模型Checkpoint管理
训练大型模型时,需要定期保存训练状态(模型参数、优化器状态、当前迭代次数),以便中断后恢复训练或后续评估。
def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
iteration: int,
out: str | os.PathLike | BinaryIO | IO[bytes],
):
"""保存训练状态到文件"""
checkpoint = {
'model_state': model.state_dict(), # 模型参数
'optimizer_state': optimizer.state_dict(), # 优化器状态
'iteration': iteration, # 当前迭代次数
}
torch.save(checkpoint, out)
def load_checkpoint(
src: str | os.PathLike | BinaryIO | IO[bytes],
model: torch.nn.Module,
optimizer: torch.optim.Optimizer
) -> int:
"""从文件加载训练状态并恢复模型和优化器"""
checkpoint = torch.load(src)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
return checkpoint['iteration'] # 返回保存时的迭代次数
四、总结与扩展
通过本次作业,我们从零实现了一个完整的Transformer语言模型,涵盖了现代LLM的核心组件:
- 采用预归一化设计的Transformer块,提升训练稳定性
- 旋转位置编码(RoPE)解决位置信息编码问题,支持长度外推
- 多头自注意力机制捕捉token间的依赖关系
- SwiGLU激活函数增强前馈网络的表达能力
- 配套实现了AdamW优化器、余弦学习率调度等训练工具
另外,adapter.py里的调用就不再一个个放了,考虑到上次有人在评论区里问,我放个例子吧:
def run_transformer_lm(
vocab_size: int,
context_length: int,
d_model: int,
num_layers: int,
num_heads: int,
d_ff: int,
rope_theta: float,
weights: dict[str, Tensor],
in_indices: Int[Tensor, " batch_size sequence_length"],
) -> Float[Tensor, " batch_size sequence_length vocab_size"]:
"""Given the weights of a Transformer language model and input indices,
return the output of running a forward pass on the input indices.
This function should use RoPE.
Args:
vocab_size (int): The number of unique items in the output vocabulary to be predicted.
context_length (int): The maximum number of tokens to process at once.
d_model (int): The dimensionality of the model embeddings and sublayer outputs.
num_layers (int): The number of Transformer layers to use.
num_heads (int): Number of heads to use in multi-headed attention. `d_model` must be
evenly divisible by `num_heads`.
d_ff (int): Dimensionality of the feed-forward inner layer (section 3.3).
rope_theta (float): The RoPE $\Theta$ parameter.
weights (dict[str, Tensor]):
State dict of our reference implementation. {num_layers} refers to an
integer between `0` and `num_layers - 1` (the layer index).
The keys of this dictionary are:
- `token_embeddings.weight`
Token embedding matrix. Shape is (vocab_size, d_model).
- `layers.{num_layers}.attn.q_proj.weight`
The query projections for all `num_heads` attention heads.
Shape is (num_heads * (d_model / num_heads), d_model).
The rows are ordered by matrices of shape (num_heads, d_k),
so `attn.q_proj.weight == torch.cat([q_heads.0.weight, ..., q_heads.N.weight], dim=0)`.
- `layers.{num_layers}.attn.k_proj.weight`
The key projections for all `num_heads` attention heads.
Shape is (num_heads * (d_model / num_heads), d_model).
The rows are ordered by matrices of shape (num_heads, d_k),
so `attn.k_proj.weight == torch.cat([k_heads.0.weight, ..., k_heads.N.weight], dim=0)`.
- `layers.{num_layers}.attn.v_proj.weight`
The value projections for all `num_heads` attention heads.
Shape is (num_heads * (d_model / num_heads), d_model).
The rows are ordered by matrices of shape (num_heads, d_v),
so `attn.v_proj.weight == torch.cat([v_heads.0.weight, ..., v_heads.N.weight], dim=0)`.
- `layers.{num_layers}.attn.output_proj.weight`
Weight of the multi-head self-attention output projection
Shape is ((d_model / num_heads) * num_heads, d_model).
- `layers.{num_layers}.ln1.weight`
Weights of affine transform for the first RMSNorm
applied in the transformer block.
Shape is (d_model,).
- `layers.{num_layers}.ffn.w1.weight`
Weight of the first linear transformation in the FFN.
Shape is (d_model, d_ff).
- `layers.{num_layers}.ffn.w2.weight`
Weight of the second linear transformation in the FFN.
Shape is (d_ff, d_model).
- `layers.{num_layers}.ffn.w3.weight`
Weight of the third linear transformation in the FFN.
Shape is (d_model, d_ff).
- `layers.{num_layers}.ln2.weight`
Weights of affine transform for the second RMSNorm
applied in the transformer block.
Shape is (d_model,).
- `ln_final.weight`
Weights of affine transform for RMSNorm applied to the output of the final transformer block.
Shape is (d_model, ).
- `lm_head.weight`
Weights of the language model output embedding.
Shape is (vocab_size, d_model).
in_indices (Int[Tensor, "batch_size sequence_length"]) Tensor with input indices to run the language model on. Shape is (batch_size, sequence_length), where
`sequence_length` is at most `context_length`.
Returns:
Float[Tensor, "batch_size sequence_length vocab_size"]: Tensor with the predicted unnormalized
next-word distribution for each token.
"""
model = TransformerLM(vocab_size,context_length, num_layers,d_model,
num_heads, d_ff, rope_theta).to(device)
model.load_state_dict(weights)
return model(in_indices.to(device=device))