MHSA:使用 PyTorch 实现的多头自注意力 (Multi-head Self Attention) 代码示例,包含详细注释说明:
线性投影
通过三个线性层分别生成查询(Q)、键(K)、值(V)矩阵:
Q=Wq⋅x,K=Wk⋅x,V=Wv⋅xQ = W_q·x, \quad K = W_k·x, \quad V = W_v·xQ=Wq⋅x,K=Wk⋅x,V=Wv⋅x分割多头
将每个矩阵分割为 hhh 个头部:
Q→[Q1,Q2,...,Qh],每个Qi∈Rdk\text{Q} \rightarrow [Q_1, Q_2, ..., Q_h], \quad \text{每个} Q_i \in \mathbb{R}^{d_k}Q→[Q1,Q2,...,Qh],每个Qi∈Rdk计算注意力分数
对每个头部计算缩放点积注意力:
Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi\text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_iAttention(Qi,Ki,Vi)=softmax(dkQiKiT)Vi合并多头
拼接所有头部的输出并通过线性层:
MultiHead=Wo⋅[head1;head2;...;headh]\text{MultiHead} = W_o·[\text{head}_1; \text{head}_2; ... ; \text{head}_h]MultiHead=Wo⋅[head1;head2;...;headh]
数学原理:
多头注意力允许模型同时关注不同表示子空间的信息:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个头的计算为:
headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)
以下是一个使用 PyTorch 实现的多头自注意力 (Multi-head Self Attention) 代码示例,包含详细注释说明:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
"""
embed_dim: 输入向量维度
num_heads: 注意力头的数量
"""
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads # 每个头的维度
# 检查维度是否可整除
assert self.head_dim * num_heads == embed_dim
# 定义线性变换层
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.fc_out = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
"""
x: 输入张量,形状为 (batch_size, seq_len, embed_dim)
"""
batch_size = x.shape[0] #[4,10,512]
# 1. 线性投影
Q = self.query(x) # (batch_size, seq_len, embed_dim) #[4,10,512]
K = self.key(x) # (batch_size, seq_len, embed_dim) #[4,10,512]
V = self.value(x) # (batch_size, seq_len, embed_dim) #[4,10,512]
# 2. 分割多头
Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) #[4,8,10,64]
K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) #[4,8,10,64]
V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) #[4,8,10,64]
# 现在形状: (batch_size, num_heads, seq_len, head_dim)
# 3. 计算注意力分数
# 计算 Q·K^T / sqrt(d_k)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5)
#[4,8,10,64]* #[4,8,64,10] = [4,8,10,10]
# 形状: (batch_size, num_heads, seq_len, seq_len)
# 4. 应用softmax获取注意力权重
attention = F.softmax(energy, dim=-1)
# 形状: (batch_size, num_heads, seq_len, seq_len)
# 5. 计算加权和
out = torch.matmul(attention, V)
#[4,8,10,10]* [4,8,10,64] = [4,8,10,64]
# 形状: (batch_size, num_heads, seq_len, head_dim)
# 6. 合并多头
out = out.permute(0, 2, 1, 3).contiguous()
out = out.view(batch_size, -1, self.embed_dim)
# 形状: (batch_size, seq_len, embed_dim)
# 7. 最终线性变换
out = self.fc_out(out)
return out
# 使用示例
if __name__ == "__main__":
# 参数设置
embed_dim = 512 # 输入维度
num_heads = 8 # 注意力头数
seq_len = 10 # 序列长度
batch_size = 4 # 批大小
# 创建多头注意力模块
mha = MultiHeadAttention(embed_dim, num_heads)
# 生成模拟输入数据
input_data = torch.randn(batch_size, seq_len, embed_dim)
# 前向传播
output = mha(input_data)
print("输入形状:", input_data.shape)
print("输出形状:", output.shape)
输出示例:
输入形状: torch.Size([4, 10, 512])
输出形状: torch.Size([4, 10, 512])
此实现保持了输入输出维度一致,可直接集成到Transformer等架构中。