【模型细节】MHSA:多头自注意力 (Multi-head Self Attention) 详细解释,使用 PyTorch代码示例说明

发布于:2025-07-31 ⋅ 阅读:(23) ⋅ 点赞:(0)

MHSA:使用 PyTorch 实现的多头自注意力 (Multi-head Self Attention) 代码示例,包含详细注释说明:

  1. 线性投影
    通过三个线性层分别生成查询(Q)、键(K)、值(V)矩阵:
    Q=Wq⋅x,K=Wk⋅x,V=Wv⋅xQ = W_q·x, \quad K = W_k·x, \quad V = W_v·xQ=Wqx,K=Wkx,V=Wvx

  2. 分割多头
    将每个矩阵分割为 hhh 个头部:
    Q→[Q1,Q2,...,Qh],每个Qi∈Rdk\text{Q} \rightarrow [Q_1, Q_2, ..., Q_h], \quad \text{每个} Q_i \in \mathbb{R}^{d_k}Q[Q1,Q2,...,Qh],每个QiRdk

  3. 计算注意力分数
    对每个头部计算缩放点积注意力:
    Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi\text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_iAttention(Qi,Ki,Vi)=softmax(dk QiKiT)Vi

  4. 合并多头
    拼接所有头部的输出并通过线性层:
    MultiHead=Wo⋅[head1;head2;...;headh]\text{MultiHead} = W_o·[\text{head}_1; \text{head}_2; ... ; \text{head}_h]MultiHead=Wo[head1;head2;...;headh]

数学原理:

多头注意力允许模型同时关注不同表示子空间的信息:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个头的计算为:
headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)

以下是一个使用 PyTorch 实现的多头自注意力 (Multi-head Self Attention) 代码示例,包含详细注释说明:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        """
        embed_dim: 输入向量维度
        num_heads: 注意力头的数量
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # 每个头的维度
        
        # 检查维度是否可整除
        assert self.head_dim * num_heads == embed_dim
        
        # 定义线性变换层
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.fc_out = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        """
        x: 输入张量,形状为 (batch_size, seq_len, embed_dim)
        """
        batch_size = x.shape[0]  #[4,10,512]
        
        # 1. 线性投影
        Q = self.query(x)  # (batch_size, seq_len, embed_dim) #[4,10,512]
        K = self.key(x)    # (batch_size, seq_len, embed_dim) #[4,10,512]
        V = self.value(x)  # (batch_size, seq_len, embed_dim) #[4,10,512]
        
        # 2. 分割多头
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  #[4,8,10,64]
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  #[4,8,10,64]
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  #[4,8,10,64]
        # 现在形状: (batch_size, num_heads, seq_len, head_dim)
        
        # 3. 计算注意力分数
        # 计算 Q·K^T / sqrt(d_k)
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5) 
         #[4,8,10,64]* #[4,8,64,10] = [4,8,10,10]
        # 形状: (batch_size, num_heads, seq_len, seq_len)
        
        # 4. 应用softmax获取注意力权重
        attention = F.softmax(energy, dim=-1)
        # 形状: (batch_size, num_heads, seq_len, seq_len)
        
        # 5. 计算加权和
        out = torch.matmul(attention, V)
         #[4,8,10,10]* [4,8,10,64] = [4,8,10,64]
        # 形状: (batch_size, num_heads, seq_len, head_dim)
        
        # 6. 合并多头
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.view(batch_size, -1, self.embed_dim)
        # 形状: (batch_size, seq_len, embed_dim)
        
        # 7. 最终线性变换
        out = self.fc_out(out)
        
        return out

# 使用示例
if __name__ == "__main__":
    # 参数设置
    embed_dim = 512  # 输入维度
    num_heads = 8    # 注意力头数
    seq_len = 10     # 序列长度
    batch_size = 4   # 批大小
    
    # 创建多头注意力模块
    mha = MultiHeadAttention(embed_dim, num_heads)
    
    # 生成模拟输入数据
    input_data = torch.randn(batch_size, seq_len, embed_dim)
    
    # 前向传播
    output = mha(input_data)
    
    print("输入形状:", input_data.shape)
    print("输出形状:", output.shape)

输出示例:

输入形状: torch.Size([4, 10, 512])
输出形状: torch.Size([4, 10, 512])

此实现保持了输入输出维度一致,可直接集成到Transformer等架构中。


网站公告

今日签到

点亮在社区的每一天
去签到