从零构建TransformerP2-新闻分类Demo

发布于:2025-08-09 ⋅ 阅读:(16) ⋅ 点赞:(0)

欢迎来到啾啾的博客🐱。
记录学习点滴。分享工作思考和实用技巧,偶尔也分享一些杂谈💬。
有很多很多不足的地方,欢迎评论交流,感谢您的阅读和评论😄。

引言

AI使用声明:在内容整理、结构优化和语言表达的过程中,我使用了人工智能(AI)工具作为辅助。

如果以LLM应用工程师为目标,其实我们并不需要熟练掌握PyTorch,熟练掌握Transformer,但是我们必须对这两者与其背后的信息有基本的了解诶,进而更好的团队协作,以及微调模型。

本篇是一个完整的从0开始构建Transformer的Demo。

代码由QWen3-Coder生成,可以运行调试。

1 一个完整的Transformer模型

![[从零构建TransformerP2-新闻分类Demo.png]]

2 需要准备的“工具包”

工具 作用
nn.Embedding 词嵌入
nn.Linear 投影层
F.softmax, F.relu 激活函数
torch.matmul 矩阵乘法(注意力核心)
mask(triu, masked_fill) 实现因果注意力
LayerNorm, Dropout 稳定训练
nn.ModuleList 堆叠多层
DataLoader 批量加载数据

3 Demo

"""  
基于Transformer的新闻分类模型  
严格按照设计流程实现,每个组件都有明确设计依据  
"""  
  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import math  
from torch.utils.data import Dataset, DataLoader  
from typing import Dict, List, Optional, Tuple  
  
  
# ==============================================  
# 第一部分:基础组件设计(根据设计决策选择)  
# ==============================================  
  
class TokenEmbedding(nn.Module):  
    """  
    词嵌入层:将输入的词ID映射为密集向量表示  
  
    设计依据:  
    - 文本任务需要词嵌入表示语义  
    - 乘以sqrt(d_model)稳定初始化方差(原论文做法)  
    """  
    def __init__(self, vocab_size: int, d_model: int):  
        super().__init__()  
        self.embedding = nn.Embedding(vocab_size, d_model)  
        self.d_model = d_model  
  
    def forward(self, x: torch.Tensor) -> torch.Tensor:  
        """  
        前向传播  
  
        参数:  
            x: 输入词ID张量,形状为(batch_size, seq_len)  
  
        返回:  
            嵌入后的张量,形状为(batch_size, seq_len, d_model)  
        """        # 原论文建议乘以sqrt(d_model)来稳定方差  
        return self.embedding(x) * math.sqrt(self.d_model)  
  
  
class PositionalEncoding(nn.Module):  
    """  
    位置编码:为输入序列添加位置信息  
  
    设计依据:  
    - Transformer没有顺序感知能力,必须添加位置信息  
    - 选择可学习位置编码(更灵活,适合变长序列)  
    """  
    def __init__(self, d_model: int, max_len: int = 512):  
        super().__init__()  
        self.pos_embedding = nn.Embedding(max_len, d_model)  
  
    def forward(self, x: torch.Tensor) -> torch.Tensor:  
        """  
        前向传播  
  
        参数:  
            x: 输入张量,形状为(batch_size, seq_len, d_model)  
  
        返回:  
            添加位置编码后的张量  
        """        batch_size, seq_len = x.size(0), x.size(1)  
        # 生成位置ID: [0, 1, 2, ..., seq_len-1]  
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)  
        return x + self.pos_embedding(positions)  
  
  
class MultiHeadAttention(nn.Module):  
    """  
    多头注意力机制  
  
    设计依据:  
    - 需要建模词与词之间的关系(自注意力)  
    - 多头机制允许模型在不同子空间关注不同关系  
    """  
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):  
        super().__init__()  
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"  
  
        self.d_model = d_model  
        self.num_heads = num_heads  
        self.d_k = d_model // num_heads  
  
        # 线性变换层  
        self.W_q = nn.Linear(d_model, d_model)  
        self.W_k = nn.Linear(d_model, d_model)  
        self.W_v = nn.Linear(d_model, d_model)  
        self.W_o = nn.Linear(d_model, d_model)  
  
        self.dropout = nn.Dropout(dropout)  
  
    def scaled_dot_product_attention(  
            self,  
            q: torch.Tensor,  
            k: torch.Tensor,  
            v: torch.Tensor,  
            mask: Optional[torch.Tensor] = None  
    ) -> Tuple[torch.Tensor, torch.Tensor]:  
        """  
        缩放点积注意力  
  
        参数:  
            q: 查询张量,形状为(batch_size, num_heads, seq_len, d_k)  
            k: 键张量,形状为(batch_size, num_heads, seq_len, d_k)  
            v: 值张量,形状为(batch_size, num_heads, seq_len, d_k)  
            mask: 注意力掩码,用于屏蔽padding或未来位置  
  
        返回:  
            attention_output: 注意力输出  
            attention_weights: 注意力权重(可用于可视化)  
        """        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)  
  
        if mask is not None:  
            # 将mask为0的位置设为极小值,使softmax后为0  
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)  
  
        attn_probs = F.softmax(attn_scores, dim=-1)  
        attn_probs = self.dropout(attn_probs)  
        output = torch.matmul(attn_probs, v)  
  
        return output, attn_probs  
  
    def split_heads(self, x: torch.Tensor) -> torch.Tensor:  
        """将输入拆分为多个头"""  
        batch_size = x.size(0)  
        x = x.view(batch_size, -1, self.num_heads, self.d_k)  
        return x.transpose(1, 2)  # (batch_size, num_heads, seq_len, d_k)  
  
    def combine_heads(self, x: torch.Tensor) -> torch.Tensor:  
        """将多个头合并回原始形状"""  
        batch_size = x.size(0)  
        x = x.transpose(1, 2).contiguous()  
        return x.view(batch_size, -1, self.d_model)  
  
    def forward(  
            self,  
            q: torch.Tensor,  
            k: torch.Tensor,  
            v: torch.Tensor,  
            mask: Optional[torch.Tensor] = None  
    ) -> torch.Tensor:  
        """  
        前向传播  
  
        参数:  
            q, k, v: 查询、键、值张量,形状为(batch_size, seq_len, d_model)  
            mask: 注意力掩码  
  
        返回:  
            多头注意力输出,形状为(batch_size, seq_len, d_model)  
        """        q = self.split_heads(self.W_q(q))  
        k = self.split_heads(self.W_k(k))  
        v = self.split_heads(self.W_v(v))  
  
        attn_output, _ = self.scaled_dot_product_attention(q, k, v, mask)  
        output = self.W_o(self.combine_heads(attn_output))  
        return output  
  
  
class FeedForward(nn.Module):  
    """  
    前馈神经网络  
  
    设计依据:  
    - 每个位置独立处理,增强模型表示能力  
    - 通常d_ff = 4 * d_model(原论文比例)  
    """  
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):  
        super().__init__()  
        self.fc1 = nn.Linear(d_model, d_ff)  
        self.fc2 = nn.Linear(d_ff, d_model)  
        self.dropout = nn.Dropout(dropout)  
  
    def forward(self, x: torch.Tensor) -> torch.Tensor:  
        x = F.gelu(self.fc1(x))  
        x = self.dropout(x)  
        x = self.fc2(x)  
        return x  
  
  
class EncoderLayer(nn.Module):  
    """  
    编码器层  
  
    设计依据:  
    - 新闻分类需要双向上下文理解  
    - 残差连接和层归一化提升训练稳定性  
    """  
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):  
        super().__init__()  
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)  
        self.ffn = FeedForward(d_model, d_ff, dropout)  
  
        self.norm1 = nn.LayerNorm(d_model)  
        self.norm2 = nn.LayerNorm(d_model)  
        self.dropout = nn.Dropout(dropout)  
  
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:  
        # 自注意力 + 残差连接 + 层归一化  
        attn_output = self.self_attn(x, x, x, mask)  
        x = self.norm1(x + self.dropout(attn_output))  
  
        # 前馈网络 + 残差连接 + 层归一化  
        ffn_output = self.ffn(x)  
        x = self.norm2(x + self.dropout(ffn_output))  
  
        return x  
  
  
# ==============================================  
# 第二部分:完整模型组装(根据设计决策)  
# ==============================================  
  
class NewsClassifier(nn.Module):  
    """  
    新闻分类Transformer模型  
    设计决策回顾:  
    - 任务类型:文本分类(Encoder-only)  
    - 输入:新闻文本序列  
    - 输出:新闻类别(体育、科技、娱乐等)  
    - 架构选择:Encoder-only(无需生成能力)  
    - 输入表示:Token Embedding + 可学习位置编码  
    - 输出头:[CLS] token + 分类层  
    """    def __init__(  
            self,  
            vocab_size: int,  
            d_model: int = 768,  
            num_heads: int = 12,  
            num_layers: int = 6,  
            d_ff: int = 3072,  
            num_classes: int = 10,  
            max_len: int = 512,  
            dropout: float = 0.1  
    ):  
        """  
        参数:  
            vocab_size: 词汇表大小  
            d_model: 模型维度(默认768,与BERT-base一致)  
            num_heads: 注意力头数(默认12,与BERT-base一致)  
            num_layers: 编码器层数(默认6,平衡性能与计算成本)  
            d_ff: FFN隐藏层维度(默认3072=4*d_model)  
            num_classes: 分类类别数  
            max_len: 最大序列长度  
            dropout: dropout概率  
        """        super().__init__()  
        self.d_model = d_model  
        # 1. 特殊token(设计依据:BERT-style分类需要[CLS])  
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))  
        # 2. 词嵌入层  
        self.token_embedding = TokenEmbedding(vocab_size, d_model)  
        # 3. 位置编码(设计依据:选择可学习位置编码)  
        self.pos_encoding = PositionalEncoding(d_model, max_len)  
        # 4. 编码器层堆叠  
        self.encoder_layers = nn.ModuleList([  
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)  
        ])  
        # 5. 分类头(设计依据:使用[CLS] token进行分类)  
        self.classifier = nn.Sequential(  
            nn.Linear(d_model, d_model),  
            nn.GELU(),  
            nn.Linear(d_model, num_classes)  
        )  
        self.dropout = nn.Dropout(dropout)  
        # 权重初始化(设计依据:稳定训练)  
        self._init_weights()  
  
    def _init_weights(self):  
        """初始化模型权重"""  
        for module in self.modules():  
            if isinstance(module, nn.Linear):  
                nn.init.xavier_uniform_(module.weight)  
                if module.bias is not None:  
                    nn.init.zeros_(module.bias)  
            elif isinstance(module, nn.Embedding):  
                nn.init.normal_(module.weight, mean=0.0, std=0.02)  
            elif isinstance(module, nn.LayerNorm):  
                nn.init.ones_(module.weight)  
                nn.init.zeros_(module.bias)  
  
    def add_cls_token(self, x: torch.Tensor) -> torch.Tensor:  
        """  
        在序列开头添加[CLS] token  
        设计依据:BERT-style分类使用[CLS]聚合全局信息  
        参数:  
            x: 输入张量,形状为(batch_size, seq_len, d_model)  
        返回:  
            添加[CLS]后的张量,形状为(batch_size, seq_len+1, d_model)  
        """        batch_size = x.size(0)  
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  
        return torch.cat((cls_tokens, x), dim=1)  
  
    def create_padding_mask(self, input_ids: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:  
        """  
        创建padding掩码  
        设计依据:处理变长序列,忽略padding位置  
        参数:  
            input_ids: 输入ID张量,形状为(batch_size, seq_len)  
            pad_idx: padding token的ID  
        返回:  
            掩码张量,形状为(batch_size, 1, 1, seq_len)  
            True表示有效位置,False表示padding位置 (BoolTensor)        """        # 创建布尔掩码,非pad为True  
        mask = (input_ids != pad_idx).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)  
        return mask.bool() # 确保返回的是布尔类型  
  
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:  
        """  
        前向传播  
        参数:  
            input_ids: 输入词ID,形状为(batch_size, original_seq_len)  
            attention_mask: 可选的注意力掩码,形状为(batch_size, original_seq_len)。  
                           1.0 表示有效位置,0.0 表示padding位置。  
                           如果提供,应为浮点类型 (如 torch.float) 或布尔类型。  
                           如果为 None,则根据 input_ids 自动创建。  
        返回:  
            分类logits,形状为(batch_size, num_classes)  
        """        batch_size, original_seq_len = input_ids.size()  
  
        # 1. 词嵌入  
        x = self.token_embedding(input_ids)  # (batch_size, original_seq_len, d_model)  
  
        # 2. 添加[CLS] token  
        x = self.add_cls_token(x)  # (batch_size, original_seq_len + 1, d_model)  
        new_seq_len = x.size(1) # 获取添加[CLS]后的序列长度  
  
        # 3. 位置编码  
        x = self.pos_encoding(x)  
        x = self.dropout(x)  
  
        # 4. 准备注意力掩码 (用于屏蔽padding)  
        if attention_mask is not None:  
            # 如果提供了 attention_mask,确保其为四维且为布尔类型  
            # 预期输入形状: (batch_size, original_seq_len)  
            # 目标形状: (batch_size, 1, 1, original_seq_len)  
            if attention_mask.dim() == 2:  
                 # 假设非零值为有效位置  
                attention_mask_for_padding = (attention_mask != 0).unsqueeze(1).unsqueeze(2)  
            elif attention_mask.dim() == 4:  
                attention_mask_for_padding = (attention_mask.squeeze(1).squeeze(1) != 0).unsqueeze(1).unsqueeze(2)  
            else:  
                raise ValueError(f"attention_mask must be 2D or 4D, but got {attention_mask.dim()}D")  
        else:  
            # 如果没有提供,根据 input_ids 自动创建  
            # 形状: (batch_size, 1, 1, original_seq_len)  
            attention_mask_for_padding = self.create_padding_mask(input_ids)  
  
        # --- 关键修复:正确扩展 mask 以适应添加了 [CLS] token 后的新序列长度 ---        # 创建一个针对新序列长度 (new_seq_len = original_seq_len + 1) 的掩码  
        # [CLS] token (索引 0) 应该总是被 attend 到,所以我们需要扩展 mask        # 1. 初始化一个全为 True 的新掩码,形状 (batch_size, 1, 1, new_seq_len)        expanded_mask = torch.ones((batch_size, 1, 1, new_seq_len), dtype=torch.bool, device=x.device)  
  
        # 2. 将原始 padding mask 复制到新 mask 的 [1:] 位置 (跳过 [CLS])        #    原始 mask 形状: (batch_size, 1, 1, original_seq_len)  
        #    新 mask 的 [1:] 部分形状: (batch_size, 1, 1, original_seq_len)  
        expanded_mask[:, :, :, 1:] = attention_mask_for_padding  
  
        # 最终用于注意力的掩码,形状 (batch_size, 1, 1, new_seq_len)        # 在 MultiHeadAttention 中,这个掩码会被广播用于屏蔽 key (src_seq) 的 padding 位置  
        final_attention_mask = expanded_mask  
  
        # 5. 通过编码器层  
        # 将扩展后的 mask 传递给每一层,以屏蔽 padding        for layer in self.encoder_layers:  
            x = layer(x, final_attention_mask)  # 传递匹配新序列长度的 mask  
        # 6. 取[CLS] token作为句子表示  
        cls_output = x[:, 0, :]  # (batch_size, d_model)  
  
        # 7. 分类  
        logits = self.classifier(cls_output)  
        return logits  
  
  
# ==============================================  
# 第三部分:训练流程(根据设计决策)  
# ==============================================  
  
def train_news_classifier():  
    """新闻分类模型训练流程"""  
  
    # 1. 超参数设置(根据设计决策)  
    config = {  
        "vocab_size": 30000,  # 词汇表大小(设计依据:新闻领域常用词)  
        "d_model": 768,  # 模型维度(设计依据:平衡性能与计算成本)  
        "num_heads": 12,  # 注意力头数(设计依据:与d_model匹配)  
        "num_layers": 6,  # 编码器层数(设计依据:足够捕捉复杂关系)  
        "d_ff": 3072,  # FFN维度(设计依据:4*d_model)  
        "num_classes": 10,  # 分类类别数(设计依据:新闻类别数量)  
        "max_len": 512,  # 最大序列长度(设计依据:覆盖大多数新闻)  
        "dropout": 0.1,  # dropout概率(设计依据:防止过拟合)  
        "batch_size": 32,  # 批量大小(设计依据:GPU内存限制)  
        "learning_rate": 2e-5,  # 学习率(设计依据:微调预训练模型常用值)  
        "epochs": 3,  # 训练轮数(设计依据:避免过拟合)  
        "warmup_steps": 500,  # warmup步数(设计依据:稳定训练初期)  
        "weight_decay": 0.01  # 权重衰减(设计依据:正则化)  
    }  
  
    # 2. 创建模型  
    print("✅ 创建新闻分类模型...")  
    model = NewsClassifier(  
        vocab_size=config["vocab_size"],  
        d_model=config["d_model"],  
        num_heads=config["num_heads"],  
        num_layers=config["num_layers"],  
        d_ff=config["d_ff"],  
        num_classes=config["num_classes"],  
        max_len=config["max_len"],  
        dropout=config["dropout"]  
    )  
  
    # 3. 设备选择  
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
    model.to(device)  
    print(f"   模型将运行在: {device}")  
  
    # 4. 伪造数据集(实际应用中替换为真实数据)  
    class NewsDataset(Dataset):  
        def __init__(self, num_samples: int = 1000, max_len: int = 512):  
            self.num_samples = num_samples  
            self.max_len = max_len  
  
        def __len__(self):  
            return self.num_samples  
  
        def __getitem__(self, idx):  
            # 伪造新闻文本(词ID)  
            seq_len = min(500, 100 + idx % 400)  # 变长序列  
            input_ids = torch.randint(1, 30000, (seq_len,))  
  
            # 伪造类别标签(0-9)  
            label = torch.tensor(idx % 10, dtype=torch.long)  
  
            return input_ids, label  
  
    # 5. 数据加载器(处理变长序列的关键)  
    def collate_fn(batch):  
        """处理变长序列的collate函数"""  
        input_ids, labels = zip(*batch)  
  
        # 找出最大长度  
        max_len = max(len(ids) for ids in input_ids)  
  
        # padding  
        padded_ids = []  
        for ids in input_ids:  
            padding = torch.zeros(max_len - len(ids), dtype=torch.long)  
            padded_ids.append(torch.cat([ids, padding]))  
  
        input_ids = torch.stack(padded_ids)  
        labels = torch.stack(labels)  
  
        return input_ids, labels  
  
    print("✅ 创建数据集和数据加载器...")  
    train_dataset = NewsDataset(num_samples=1000)  
    train_loader = DataLoader(  
        train_dataset,  
        batch_size=config["batch_size"],  
        shuffle=True,  
        collate_fn=collate_fn  
    )  
  
    # 6. 损失函数和优化器  
    print("✅ 配置训练组件...")  
    loss_fn = nn.CrossEntropyLoss()  
    optimizer = torch.optim.AdamW(  
        model.parameters(),  
        lr=config["learning_rate"],  
        weight_decay=config["weight_decay"]  
    )  
  
    # 7. 学习率调度器(设计依据:warmup + linear decay)  
    total_steps = len(train_loader) * config["epochs"]  
    warmup_steps = config["warmup_steps"]  
  
    def lr_lambda(current_step: int):  
        if current_step < warmup_steps:  
            return float(current_step) / float(max(1, warmup_steps))  
        return max(  
            0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps))  
        )  
  
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)  
  
    # 8. 训练循环  
    print("🚀 开始训练...")  
    for epoch in range(config["epochs"]):  
        model.train()  
        total_loss = 0  
  
        for batch_idx, (input_ids, labels) in enumerate(train_loader):  
            input_ids = input_ids.to(device)  
            labels = labels.to(device)  
  
            # 前向传播  
            optimizer.zero_grad()  
            logits = model(input_ids)  
            loss = loss_fn(logits, labels)  
  
            # 反向传播  
            loss.backward()  
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 梯度裁剪  
            optimizer.step()  
            scheduler.step()  
  
            total_loss += loss.item()  
  
            # 打印进度  
            if batch_idx % 50 == 0:  
                avg_loss = total_loss / (batch_idx + 1)  
                current_lr = optimizer.param_groups[0]['lr']  
                print(f"Epoch [{epoch + 1}/{config['epochs']}] | "  
                      f"Batch [{batch_idx}/{len(train_loader)}] | "  
                      f"Loss: {avg_loss:.4f} | "                      f"LR: {current_lr:.2e}")  
  
        print(f"✅ Epoch {epoch + 1} 完成 | Average Loss: {total_loss / len(train_loader):.4f}")  
  
    # 9. 保存模型  
    torch.save(model.state_dict(), "news_classifier.pth")  
    print("💾 模型已保存至 news_classifier.pth")  
  
  
# ==============================================  
# 第四部分:推理示例  
# ==============================================  
  
def predict_news_category(text: str, model: NewsClassifier, tokenizer, device: torch.device):  
    """  
    新闻分类推理  
  
    设计依据:  
    - 使用与训练相同的预处理流程  
    - 取[CLS] token进行分类  
  
    参数:  
        text: 新闻文本  
        model: 训练好的模型  
        tokenizer: 文本分词器  
        device: 设备  
  
    返回:  
        预测类别和概率  
    """    model.eval()  
  
    # 1. 文本预处理  
    input_ids = tokenizer.encode(text, max_length=512, truncation=True, padding="max_length")  
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)  
  
    # 2. 前向传播  
    with torch.no_grad():  
        logits = model(input_ids)  
        probs = F.softmax(logits, dim=-1)  
  
    # 3. 获取结果  
    predicted_class = torch.argmax(probs, dim=-1).item()  
    confidence = probs[0, predicted_class].item()  
  
    return predicted_class, confidence  
  
  
if __name__ == "__main__":  
    # 这里只是演示结构,实际运行需要完整实现  
    print("=" * 50)  
    print("Transformer新闻分类模型设计与实现")  
    print("=" * 50)  
    print("\n本示例演示了如何根据任务需求设计并实现一个Transformer模型")  
    print("设计流程严格遵循:问题分析 → 架构选择 → 组件设计 → 训练实现")  
    print("\n关键设计决策:")  
    print("- 选择Encoder-only架构(分类任务无需生成能力)")  
    print("- 使用[CLS] token进行分类(BERT-style)")  
    print("- 可学习位置编码(更适合变长新闻文本)")  
    print("- 6层编码器(平衡性能与计算成本)")  
    print("\n要运行完整训练,请取消注释train_news_classifier()调用")  
    train_news_classifier()

网站公告

今日签到

点亮在社区的每一天
去签到