一文读懂:用PyTorch从零搭建一个Transformer模型

发布于:2025-08-30 ⋅ 阅读:(16) ⋅ 点赞:(0)

  2017年,Vaswani 等人在论文《Attention Is All You Need》中提出了 Transformer 架构,这可以说是自然语言处理领域的一次“范式转移”。具体可详见之前写的一篇文章:深入解析Transformer架构

  从那以后,无论是BERT、GPT,还是后来的大模型,几乎都建立在它的基础之上。说实话,刚接触这个结构的时候我也觉得有点抽象,尤其是多头注意力和位置编码这些设计,乍一看不太直观。但当你亲手实现一遍,很多疑惑就会慢慢解开。

  今天这篇文章,我们就从最基础的模块开始,一步步用 PyTorch 实现一个完整的 Transformer 模型。我们不调用现成的 nn.Transformer,而是自己动手写每一个组件:位置编码、注意力机制、编码器、解码器……最后还会跑一个简单的训练流程。目的不是为了替代现有的高效实现,而是帮你真正“看见”模型内部是怎么运作的。

  整个过程我会尽量讲清楚每一步的设计思路,代码也会配上详细的注释。如果你正在学习深度学习或者准备面试,相信这套“手搓”流程会对理解有很大帮助。

完整代码实现

  我们可以参考如下Transformer架构图。

Image

  下面是我们将要实现的完整代码。我会在关键部分插入解释,帮助你理解每个模块的作用和背后的逻辑。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
from torch.utils.data import Dataset, DataLoader

# 固定随机种子,确保每次运行结果一致
torch.manual_seed(42)

小贴士:做实验时固定随机种子是个好习惯,不然你改了个小地方,结果波动很大,容易怀疑人生。

1. 位置编码(Positional Encoding)

  Transformer 没有像RNN那样的时序结构,所以它不知道词的位置顺序。为了解决这个问题,作者引入了位置编码,把位置信息加到词向量里。

  这里用的是正弦和余弦函数交替的形式,好处是能表达相对位置关系,而且即使遇到比训练更长的序列,也能外推。

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维用sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维用cos
        
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)  # 不参与梯度更新
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

个人体会:刚开始看公式时总觉得复杂,但其实本质就是构造一个固定的“位置模板”,然后加到每个样本上。这种设计既简单又有效,是Transformer里让我印象很深的一个巧思。

2. 缩放点积注意力(Scaled Dot-Product Attention)

  这是整个模型的核心。注意力机制的本质是:给定一组键值对,通过查询来决定应该关注哪些值。

公式是这样的:
 

其中除以  是为了避免点积过大导致 softmax 梯度太小。

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)  # 把无效位置设为负无穷
        
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        output = torch.matmul(attn_weights, value)
        
        return output, attn_weights

注意点:这里的 mask 很关键,后面我们会用它来屏蔽填充符(padding)和防止解码器偷看未来信息。

3. 多头注意力(Multi-Head Attention)

  单头注意力只能关注一种模式,而多头允许模型在不同子空间里并行地学习多种表示。你可以把它理解为“多个专家投票”。

  实现上,就是把输入投影到多个头,分别做注意力,最后再拼起来。

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % n_heads == 0, "d_model必须能被n_heads整除"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(dropout)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        Q = self.w_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        attn_output, attn_weights = self.attention(Q, K, V, mask)
        
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.w_o(attn_output)
        return output

经验分享view 和 transpose 这些操作初学容易出错,建议打印中间张量的 shape 来调试。比如 (batch, seq_len, d_model) 变成 (batch, n_heads, seq_len, d_k) 的过程要理清楚。

4. 前馈网络(Position-wise Feed-Forward)

  这个模块比较简单,对序列中每一个位置独立地应用相同的两层全连接网络。

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.linear2(self.dropout(self.relu(self.linear1(x))))

虽然叫“前馈”,但它在整个结构中起到了非线性变换和特征增强的作用,不可或缺。

5. 编码器层(Encoder Layer)

  每个编码器层包含两个子层:多头自注意力 + 前馈网络。每个子层都有残差连接和层归一化(LayerNorm),这是稳定训练的关键。

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # 自注意力 + 残差 + 归一化
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 前馈网络 + 残差 + 归一化
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

小提醒:残差连接放在归一化之前还是之后?原始论文是“post-norm”,但后来很多工作发现“pre-norm”更稳定。这里我们按原始结构实现。

6. 解码器层(Decoder Layer)

  解码器比编码器多了一个“编码器-解码器注意力”模块,用来融合源端的信息。

  另外,自注意力部分要加上序列掩码,防止当前位置看到后面的词。

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, src_mask, tgt_mask):
        # 掩码自注意力
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 编码器-解码器注意力
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # 前馈网络
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

7. 编码器 & 解码器

  把多个层堆起来,再加上词嵌入和位置编码,就构成了完整的编码器和解码器。

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout=0.1):
        super(Encoder, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

注意这里有个小技巧:词嵌入乘以 ,是为了让后续的位置编码不会“淹没”原始信号。

解码器结构类似,就不重复贴了。

8. 完整的 Transformer 模型

  把编码器、解码器和输出层组合起来:

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout=0.1):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout)
        self.decoder = Decoder(tgt_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout)
        self.linear = nn.Linear(d_model, tgt_vocab_size)
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        enc_output = self.encoder(src, src_mask)
        dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        output = self.linear(dec_output)
        return output

9. 掩码的生成

  掩码是训练过程中的一个重要细节:

  • • 填充掩码(Padding Mask):忽略 <pad> 标记

  • • 序列掩码(Look-ahead Mask):防止解码器看到未来信息

def create_mask(src, tgt, pad_idx):
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
    
    tgt_pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
    tgt_len = tgt.size(1)
    tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len))).bool()
    tgt_mask = tgt_pad_mask & tgt_sub_mask
    
    return src_mask, tgt_mask

这个 torch.tril 生成下三角矩阵的操作,是实现“只能看到前面”的关键。

10. 数据处理与训练流程

  我们定义一个简单的数据集类,并用 DataLoader 批量加载数据。

训练时,目标序列要错开一位:输入是 <s> 你 好,输出是 你 好 </s>

损失函数使用交叉熵,并忽略填充位置。

def train_transformer(model, dataloader, optimizer, criterion, pad_idx, device, n_epochs):
    model.train()
    for epoch in range(n_epochs):
        total_loss = 0
        for batch_idx, (src, tgt) in enumerate(dataloader):
            src, tgt = src.to(device), tgt.to(device)
            src_mask, tgt_mask = create_mask(src, tgt, pad_idx)
            src_mask, tgt_mask = src_mask.to(device), tgt_mask.to(device)

            optimizer.zero_grad()
            output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :, :-1, :-1])
            
            loss = criterion(
                output.contiguous().view(-1, output.size(-1)),
                tgt[:, 1:].contiguous().view(-1)
            )
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1}/{n_epochs}, Average Loss: {avg_loss:.4f}')

  最后在 main() 函数中配置超参数、初始化模型并开始训练。

关键模块再梳理

模块

作用

位置编码

给词向量注入位置信息,弥补无时序结构的缺陷

缩放点积注意力

实现“查询-键-值”机制,动态分配关注权重

多头注意力

多角度捕捉不同语义模式,提升表达能力

残差连接 + LayerNorm

缓解深层网络梯度问题,加速收敛

掩码机制

控制信息流动,保证训练合理性

  这些设计看似独立,实则环环相扣。比如没有残差连接,6层以上的Transformer几乎训不动;没有位置编码,模型就失去了顺序感知能力。

一点总结

  这篇文章我们从零实现了一个标准的 Transformer 模型。虽然用的是人工构造的小数据集,无法真正完成翻译任务,但整个流程涵盖了:

  • • 模型结构搭建

  • • 数据预处理

  • • 掩码机制

  • • 训练逻辑

  我已经尽可能让代码简洁明了,方便你理解和修改。如果你打算进一步扩展,可以考虑加入:

  • • 学习率调度器(如 NoamOpt

  • • 梯度裁剪

  • • Beam Search 解码

  • • 更真实的双语数据集(如 WMT)

  说实话,当我第一次跑通这个模型时,心里还挺激动的。不是因为结果多好,而是终于把那些公式和结构图变成了实实在在能运行的代码。这种“亲手造出来”的感觉,是读论文很难替代的。

附注:完整代码已测试通过,可在 CPU/GPU 上运行。如需进一步优化或扩展,请根据实际任务调整超参数和数据流程。


网站公告

今日签到

点亮在社区的每一天
去签到