目录
1.2.1、局部窗口稀疏(Local Window Sparse)
1.2.2、基于内容的稀疏选择(Content-Based Sparse)
3、多查询注意力(Multi-Query Attention, MQA)
4、多头潜在注意力(Multi-Head Latent Attention)
6.1、 查询(Query)、键(Key)、值(Value)的分工
1、稀疏注意力机制(Sparse Attention)
1.1、核心问题:传统注意力的 “效率瓶颈”
传统的缩放点积注意力(Scaled Dot-Product Attention)计算复杂度是 O(n²)(n 为序列长度),当处理长序列(如文档、视频帧,n=10000 以上)时,计算量和内存占用会爆炸式增长(例如 n=10000 时,n²=1 亿,n=10 万时 n²=1 万亿),根本无法训练或推理。
稀疏注意力机制的核心是:只计算序列中 “重要的少数” 元素之间的注意力,忽略大部分无关元素,将复杂度从 O (n²) 降到 O (n) 或 O (n log n),同时保留关键信息。
核心思想:
传统注意力就像 “逐字阅读一本书”,每句话都要和其他所有句子对比,效率很低。
稀疏注意力则像 “跳读”:只关注重要的部分(如标题、图表、关键词),忽略无关内容,大幅提高阅读速度。生活化比喻:
你在图书馆找一本关于 “人工智能” 的书。
- 传统注意力:把整个图书馆的书都翻一遍,对比每本书和 “人工智能” 的关联;
- 稀疏注意力:直接去计算机科学区(局部窗口),或者只看封面带 “AI” 标签的书(内容选择),忽略其他区域。
适用场景:
长文本(如论文、小说)、长视频分析、大规模数据处理
1.2、具体稀疏策略(详细计算逻辑)
1.2.1、局部窗口稀疏(Local Window Sparse)
- 原理:每个元素只关注自身周围固定窗口内的元素(类似人类 “视野有限”)。
- 计算步骤:
① 将序列分成多个不重叠或重叠的窗口(如窗口大小为 w);
② 每个位置 i 只与 [i-w/2, i+w/2] 范围内的位置计算注意力;
③ 窗口外的位置注意力权重直接设为 0。 - 例:Longformer 模型用的 “滑动窗口 + 全局令牌”,窗口大小通常设为 512,同时对特殊令牌(如 [CLS])计算全局注意力,兼顾局部细节和全局依赖。
1.2.2、基于内容的稀疏选择(Content-Based Sparse)
- 原理:根据内容相似度动态选择少数 “相关元素”(如只关注与当前元素语义相似的 top-k 个)。
- 计算步骤:
① 对每个元素 i,计算与其他元素 j 的相似度(如);
② 只保留相似度最高的 k 个 j(k 远小于 n),其余权重设为 0;
③ 对保留的 k 个权重做 softmax 归一化。 - 例:RNN + 注意力的改进模型中,常通过这种方式减少长序列计算量。
1.2.3、块稀疏(Block Sparse)
- 原理:将序列分成若干块,只在部分块之间计算注意力(块内或跨块的稀疏交互)。
- 计算步骤:
① 序列分块:n = b×m(b 为块数,m 为块大小);
② 定义块间交互矩阵(如对角线块内计算,少数跨块计算);
③ 块内元素间计算注意力,跨块只在允许的块间计算。 - 例:BigBird 模型的 “块稀疏 + 随机稀疏 + 全局稀疏” 混合策略,既高效又保留全局依赖。
1.3、优缺点
- 优点:大幅降低长序列计算成本,可处理 10 万级长度序列;
- 缺点:稀疏模式设计依赖先验(如窗口大小、k 值),可能丢失重要依赖;实现复杂(需特殊掩码处理)。
1.4、测试代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import numpy as np
# 设置中文显示
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
# 实现一个简化版的稀疏注意力机制
class SparseAttention(nn.Module):
def __init__(self, embed_dim, num_heads, window_size=5, random_size=0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.window_size = window_size # 局部窗口大小
self.random_size = random_size # 随机选择的元素数量
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# 计算Q, K, V
# self.q_proj(x) 就像比较 “苹果” 和 “橙子” 的甜度,需要先将它们的特征(如糖分含量)转换到同一度量单位(如克 / 100g),
# 否则 “一个苹果” 和 “一个橙子” 的直接对比没有意义。
q = (self.q_proj(x) #线性投影:将输入x映射到查询空间
.view(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2))
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 创建稀疏注意力掩码
mask = torch.zeros(seq_len, seq_len, device=x.device)
# 1. 局部窗口注意力
for i in range(seq_len):
start = max(0, i - self.window_size)
end = min(seq_len, i + self.window_size + 1)
mask[i, start:end] = 1
# 2. 随机稀疏注意力(可选)
if self.random_size > 0 and self.random_size < seq_len:
for i in range(seq_len):
random_indices = torch.randperm(seq_len, device=x.device)[:self.random_size]
mask[i, random_indices] = 1
# 确保对角线始终为1(自己关注自己)
mask.fill_diagonal_(1)
# 计算注意力得分并应用掩码
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
masked_attn_scores = attn_scores.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, -1e9)
attn_weights = F.softmax(masked_attn_scores, dim=-1)
# 应用注意力权重
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.out_proj(output)
# 实现标准注意力机制作为对比
class StandardAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, _ = x.shape
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_weights = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.out_proj(output)
# 测试函数
def test_attention():
# 设置测试参数
embed_dim = 512
num_heads = 8
seq_lens = [100, 500, 1000, 2000, 3000] # 测试不同序列长度
window_sizes = [5, 10, 20, 50] # 测试不同窗口大小
# 存储结果
sparse_times = {ws: [] for ws in window_sizes}
standard_times = []
mem_usage = {ws: [] for ws in window_sizes}
for seq_len in seq_lens:
# 创建随机输入
x = torch.randn(1, seq_len, embed_dim)
# 测试标准注意力
standard_attn = StandardAttention(embed_dim, num_heads)
start_time = time.time()
with torch.no_grad():
standard_output = standard_attn(x)
standard_times.append(time.time() - start_time)
# 测试不同窗口大小的稀疏注意力
for ws in window_sizes:
sparse_attn = SparseAttention(embed_dim, num_heads, window_size=ws)
start_time = time.time()
with torch.no_grad():
sparse_output = sparse_attn(x)
sparse_times[ws].append(time.time() - start_time)
# 计算内存占用(以参数数量近似)
mem_usage[ws].append(seq_len * seq_len * ws / (seq_len * seq_len) * 100) # 稀疏度百分比
# 绘制结果
plt.figure(figsize=(12, 5))
# 绘制时间对比图
plt.subplot(1, 2, 1)
plt.plot(seq_lens, standard_times, 'o-', label='标准注意力')
for ws in window_sizes:
plt.plot(seq_lens, sparse_times[ws], 'o-', label=f'稀疏注意力 (窗口={ws})')
plt.xlabel('序列长度')
plt.ylabel('计算时间 (秒)')
plt.title('不同序列长度下的注意力计算时间')
plt.legend()
plt.grid(True)
# 绘制稀疏度对比图
plt.subplot(1, 2, 2)
for ws in window_sizes:
plt.plot(seq_lens, mem_usage[ws], 'o-', label=f'窗口={ws}')
plt.axhline(y=100, color='r', linestyle='--', label='标准注意力')
plt.xlabel('序列长度')
plt.ylabel('相对内存占用 (%)')
plt.title('不同窗口大小的稀疏度')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('attention_comparison.png')
plt.show()
# 打印一些关键结果
print("序列长度为3000时的计算时间对比:")
print(f"标准注意力: {standard_times[-1]:.4f}秒")
for ws in window_sizes:
print(f"稀疏注意力 (窗口={ws}): {sparse_times[ws][-1]:.4f}秒")
print(f" 速度提升: {standard_times[-1] / sparse_times[ws][-1]:.2f}倍")
print(f" 内存占用: {mem_usage[ws][-1]:.2f}%")
if __name__ == "__main__":
test_attention()
1.5、实验结果
左:窗口越小,计算时间越短,但可能损失部分全局信息(需在效率与性能间权衡)。
右:窗口大小直接决定了稀疏程度,窗口 = 5 的稀疏度远高于窗口 = 50。
2、FlashAttention
2.1、核心问题:传统注意力的 “内存瓶颈”
传统注意力计算时,会产生中间大矩阵(如 是 n×n 矩阵),当 n=1 万时,该矩阵占用约 400MB(float32),若 n=10 万则达 40GB,远超 GPU 显存。即使能计算,频繁的内存读写也会拖慢速度(内存带宽比计算速度慢得多)。
FlashAttention 的核心是:通过 “分块计算 + 内存高效调度”,避免存储完整中间矩阵,在有限显存内高效计算注意力,同时保持结果与传统注意力一致。
核心思想:
传统注意力计算时,会频繁在 “草稿纸”(高速内存)和 “书架”(低速内存)之间搬数据,浪费时间。
FlashAttention重新设计了 “打草稿” 的顺序,让你一次性在草稿纸上算完所有步骤,再放回书架,减少来回折腾。生活化比喻:
你要做一顿饭:
- 传统方法:每切一个菜,就把刀放回刀架,再从冰箱拿食材,切完又放回去,反复跑冰箱和操作台;
- FlashAttention 方法:一次性把所有需要的食材从冰箱拿出来放在操作台上,切完所有菜再统一收拾,减少来回跑的时间。
效果:
速度提升 2-4 倍,内存占用减少,尤其适合处理超长序列(如 10 万词的文档)。
2.2、详细计算逻辑(内存优化关键)
2.2.1、瓦片(Tiling)技术
- 将 Q、K、V 分块(如切成大小为 B 的瓦片),每次只处理一小块数据,避免完整矩阵加载。
- 例:Q∈R^(n×d),切成 Q1, Q2, ..., Qp(每块 B×d);K、V 同理切成 K1~Kp, V1~Vp。
2.2.2、分块计算注意力
- 传统注意力:
→ softmax → 与 V 相乘;
- FlashAttention 分两步:
① 计算 “块级”:对每个
,逐块计算与
的相似度(
),同时实时计算 softmax 的中间值(最大值和总和),避免存储完整
;
② 分块更新输出:用块级 softmax 结果与相乘,逐步累加得到最终输出
。
2.2.3、数值稳定性优化
- 传统 softmax 可能因数值溢出导致精度问题,FlashAttention 在分块计算时实时跟踪每块的最大值,通过 “减最大值” 避免指数爆炸,同时保留足够精度。
2.2.4、显存复用
- 中间结果(如块级 QK^T、softmax 中间值)只在寄存器 / 共享内存中临时存储,计算完立即释放,不占用全局显存。
2.3、数学公式
2.4、性能提升
- 速度:比 PyTorch 原生注意力快 2-4 倍(长序列时更明显);
- 内存:可处理 n=16 万的序列(传统注意力在 n=1 万时就会 OOM);
- 精度:通过数值优化,结果与传统注意力误差小于 1e-5。
2.5、完整代码
"""
文件名: 2.3.2
作者: 墨尘
日期: 2025/7/19
项目名: dl_env
备注:
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import numpy as np
from torch.profiler import profile, record_function, ProfilerActivity
# 设置中文显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
# 尝试导入 FlashAttention(需先安装 flash-attn 库)
try:
from flash_attn.flash_attention import FlashAttention
flash_available = True
print("FlashAttention 库已成功导入")
except ImportError:
flash_available = False
print("未找到 FlashAttention 库,请通过 'pip install flash-attn' 安装")
# --------------------------- 1. 标准注意力机制 ---------------------------
# 先计算完整的注意力权重矩阵,再一次性与 V 相乘:
class StandardAttention(nn.Module):
"""标准缩放点积注意力,用于与 FlashAttention 对比"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# 线性投影层
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, q, k, v, attn_mask=None, dropout_p=0.0):
"""
标准注意力计算流程:
1. 计算 Q、K 的点积得到注意力分数
2. 应用掩码(如果有)
3. 应用 softmax 转换为概率分布
4. 对 V 进行加权聚合
"""
batch_size, seq_len_q, _ = q.shape
seq_len_k = k.shape[1]
# 计算注意力分数
attn_scores = torch.matmul(
q.view(batch_size, seq_len_q, self.num_heads, self.head_dim),
k.view(batch_size, seq_len_k, self.num_heads, self.head_dim).transpose(1, 2)
) / (self.head_dim ** 0.5) # 缩放防止梯度消失
# 应用掩码(如果提供)
if attn_mask is not None:
attn_scores = attn_scores.masked_fill(attn_mask.unsqueeze(1) == 0, -1e9)
# 应用 softmax 和 dropout
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = F.dropout(attn_weights, p=dropout_p, training=self.training)
# 加权聚合 V
output = torch.matmul(
attn_weights,
v.view(batch_size, seq_len_k, self.num_heads, self.head_dim)
)
# 重塑并通过输出投影层
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.embed_dim)
return self.out_proj(output)
# --------------------------- 2. FlashAttention 包装器 ---------------------------
# 计算完一个分块的注意力权重后,立即与对应分块的 V 相乘并累加结果
class FlashAttentionWrapper(nn.Module):
"""FlashAttention 包装器,保持与标准注意力相同的接口"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# FlashAttention 模块
self.flash_attn = FlashAttention(causal=False) # 非因果注意力
# 线性投影层(与标准注意力一致)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, q, k, v, attn_mask=None, dropout_p=0.0):
"""
FlashAttention 前向传播:
1. 将输入投影到 Q、K、V 空间
2. 调整维度顺序以适应 FlashAttention 接口
3. 通过 FlashAttention 计算注意力
4. 重塑并通过输出投影层
"""
batch_size, seq_len_q, _ = q.shape
# 投影到 Q、K、V 空间
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# 调整维度为 (batch, seq_len, num_heads, head_dim)
q = q.view(batch_size, seq_len_q, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len_q, self.num_heads, self.head_dim) # 假设 seq_len_k == seq_len_q
v = v.view(batch_size, seq_len_q, self.num_heads, self.head_dim)
# 转换为 FlashAttention 所需的格式 (batch, seq_len, num_heads, head_dim)
q = q.transpose(1, 2) # (batch, num_heads, seq_len, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# 计算 FlashAttention
# 注意:FlashAttention 输入格式为 (batch, seq_len, num_heads, head_dim)
# 在实际使用 FlashAttention 时,分块大小(block size)通常不需要我们手动设定,
# 而是由库内部根据硬件(如 GPU 型号)和序列长度自动优化选择。
# 输出格式也相同
output, _ = self.flash_attn(
q, k, v,
dropout_p=dropout_p if self.training else 0.0
)
# 重塑并通过输出投影层
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.embed_dim)
return self.out_proj(output)
# --------------------------- 3. 测试函数 ---------------------------
def test_flash_attention():
"""测试并对比 FlashAttention 和标准注意力的性能"""
if not flash_available:
print("无法运行测试:未找到 FlashAttention 库")
return
# 设置测试参数
embed_dim = 512
num_heads = 8
head_dim = embed_dim // num_heads
batch_size = 4
seq_lens = [100, 500, 1000, 2000, 4000, 8000] # 测试不同序列长度
dropout = 0.1
# 设备选择(FlashAttention 在 GPU 上效果最佳)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 存储结果
standard_times = []
flash_times = []
speedups = []
# 创建模型
standard_attn = StandardAttention(embed_dim, num_heads).to(device)
flash_attn = FlashAttentionWrapper(embed_dim, num_heads).to(device)
# 设置为评估模式
standard_attn.eval()
flash_attn.eval()
# 预热(让 CUDA 初始化)
x = torch.randn(batch_size, 100, embed_dim, device=device)
with torch.no_grad():
_ = standard_attn(x, x, x)
_ = flash_attn(x, x, x)
# 测试不同序列长度
for seq_len in seq_lens:
print(f"\n测试序列长度: {seq_len}")
# 创建随机输入
q = torch.randn(batch_size, seq_len, embed_dim, device=device)
k = torch.randn(batch_size, seq_len, embed_dim, device=device)
v = torch.randn(batch_size, seq_len, embed_dim, device=device)
# 测试标准注意力
torch.cuda.synchronize() # 同步 GPU
start_time = time.time()
with torch.no_grad():
for _ in range(10): # 多次运行取平均
_ = standard_attn(q, k, v, dropout_p=dropout)
torch.cuda.synchronize() # 同步 GPU
standard_time = (time.time() - start_time) / 10
standard_times.append(standard_time)
print(f"标准注意力耗时: {standard_time:.6f} 秒")
# 测试 FlashAttention
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
for _ in range(10): # 多次运行取平均
_ = flash_attn(q, k, v, dropout_p=dropout)
torch.cuda.synchronize()
flash_time = (time.time() - start_time) / 10
flash_times.append(flash_time)
print(f"FlashAttention 耗时: {flash_time:.6f} 秒")
# 计算加速比
speedup = standard_time / flash_time
speedups.append(speedup)
print(f"加速比: {speedup:.2f}x")
# 使用 PyTorch Profiler 分析内存和计算量
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("standard_attention"):
_ = standard_attn(q, k, v)
print("\n标准注意力性能分析:")
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=5))
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
with record_function("flash_attention"):
_ = flash_attn(q, k, v)
print("\nFlashAttention 性能分析:")
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=5))
# 绘制性能对比图
plt.figure(figsize=(12, 5))
# 绘制时间对比图
plt.subplot(1, 2, 1)
plt.plot(seq_lens, standard_times, 'o-', label='标准注意力')
plt.plot(seq_lens, flash_times, 'o-', label='FlashAttention')
plt.xlabel('序列长度')
plt.ylabel('计算时间 (秒)')
plt.title('不同序列长度下的计算时间对比')
plt.legend()
plt.grid(True)
# 绘制加速比图
plt.subplot(1, 2, 2)
plt.plot(seq_lens, speedups, 'o-', color='green')
plt.xlabel('序列长度')
plt.ylabel('加速比 (标准/Flash)')
plt.title('FlashAttention 相对于标准注意力的加速比')
plt.grid(True)
plt.tight_layout()
plt.savefig('flash_attention_comparison.png')
plt.show()
# 打印总结
print("\n===== 性能总结 =====")
for i, seq_len in enumerate(seq_lens):
print(f"序列长度 {seq_len}:")
print(f" 标准注意力: {standard_times[i]:.6f} 秒")
print(f" FlashAttention: {flash_times[i]:.6f} 秒")
print(f" 加速比: {speedups[i]:.2f}x")
# --------------------------- 4. 主函数 ---------------------------
if __name__ == "__main__":
test_flash_attention()
3、多查询注意力(Multi-Query Attention, MQA)
3.1、核心问题:多头注意力的 “参数与推理瓶颈”
传统多头注意力(Multi-Head Attention, MHA)中,每个头有独立的 Q、K、V 投影矩阵(共 3h×d 参数,h 为头数),且推理时每个头需独立计算 K、V,导致:
- 参数多:h=16 时,K、V 投影参数是 MQA 的 16 倍;
- 推理慢:生成式模型解码时,每次需处理 h 组 K、V 缓存,内存占用大,并行效率低。
核心思想:
传统多头注意力就像 “10 个人同时查资料”,每个人都带一套完整的工具(Q、K、V),浪费资源。
多查询注意力让 10 个人共享同一套 “K 和 V 工具”,只保留各自的 “Q 工具”,既节省资源,又不影响效率。生活化比喻:
10 个学生做小组作业,需要查资料、整理笔记、写报告:
- 传统方法:每个学生都带一套完整的词典、笔记本、电脑(Q、K、V);
- MQA 方法:10 个学生共用一套词典和笔记本(K、V),但每人保留自己的电脑(Q),分工协作。
优势:
参数减少,推理速度提升(尤其适合生成式模型,如 ChatGPT),节省显存。
3.2、详细改进逻辑
MQA 的核心:多个头共享同一组 K 和 V,只保留多头 Q,大幅减少参数和计算量。
3.2.1、计算步骤对比
- 传统 MHA:
① 多头投影:;
② 每个头计算注意力:;
③ 拼接所有,投影输出。
- MQA:
① 多头 Q 投影:Q_h = Q・W_Qh(h=1..H);
② 共享 K、V 投影:K = K・W_K, V = V・W_V(仅 1 组);
③ 每个头用共享的 K、V 计算:A_h = softmax (Q_hK^T/√d) V;
④ 拼接 A_h,投影输出。
3.2.2、效率提升本质
- 参数:K、V 投影参数从 H×d² 降为 d²(减少 H 倍);
- 推理缓存:生成式模型中,K、V 缓存从 H 组降为 1 组,内存占用减为 1/H,解码速度提升(因缓存读写减少)。
3.3、与 GQA 的关系
Grouped-Query Attention(GQA)是 MQA 的折中:将 H 个头分成 G 组,每组共享 1 组 K、V(MQA 是 G=1 的特例,MHA 是 G=H 的特例)。例如 H=16, G=4,则 4 组 K、V,兼顾效率和性能。
3.4、优缺点
- 优点:参数少、推理快(尤其生成任务),适合大模型部署;
- 缺点:共享 K、V 可能损失部分表达能力(多头多样性降低),需通过调优补偿(如增加头数 H)。
- 应用:PaLM、GPT-4、LLaMA 2 等大模型广泛采用(GQA 更常见,平衡效率和性能)。
3.5、示例代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiQueryAttention(nn.Module):
"""
多查询注意力 (Multi-Query Attention, MQA) 模块
与标准多头注意力不同,MQA中所有查询头共享相同的键和值投影矩阵,
从而显著减少参数量和内存占用,同时保持模型性能。
论文参考: "Fast Transformer Decoding: One Write-Head is All You Need"
https://arxiv.org/abs/1911.02150
"""
def __init__(self,
embed_dim: int, # 输入嵌入维度
num_heads: int, # 查询头数量
head_dim: int = None, # 每个头的维度
dropout: float = 0.0, # Dropout概率
bias: bool = True, # 是否使用偏置项
):
super().__init__()
# 检查参数有效性
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads
# 确保维度匹配
assert self.head_dim * num_heads == embed_dim, "embed_dim必须能被num_heads整除"
# 查询投影: 为每个头创建独立的投影矩阵
self.q_proj = nn.Linear(embed_dim, num_heads * self.head_dim, bias=bias)
# 键和值投影: 所有头共享相同的投影矩阵
# 这是MQA与标准多头注意力的核心区别
self.k_proj = nn.Linear(embed_dim, self.head_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, self.head_dim, bias=bias)
# 输出投影
self.out_proj = nn.Linear(num_heads * self.head_dim, embed_dim, bias=bias)
# Dropout层
self.dropout = nn.Dropout(dropout)
# 缩放因子 (用于缩放点积注意力)
self.scale = self.head_dim ** -0.5
def forward(self,
query: torch.Tensor, # 查询张量 [batch_size, seq_len, embed_dim]
key: torch.Tensor, # 键张量 [batch_size, seq_len, embed_dim]
value: torch.Tensor, # 值张量 [batch_size, seq_len, embed_dim]
attn_mask: torch.Tensor = None, # 注意力掩码 [batch_size, seq_len, seq_len]
):
"""
前向传播过程
"""
batch_size, seq_len, _ = query.shape
# 1. 线性投影
# 查询投影后形状: [batch_size, seq_len, num_heads * head_dim]
q = self.q_proj(query)
# 键和值投影后形状: [batch_size, seq_len, head_dim]
k = self.k_proj(key)
v = self.v_proj(value)
# 2. 重塑查询张量为多头形式
# 形状变为: [batch_size, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
# 3. 调整维度顺序以便计算注意力分数
# 查询形状: [batch_size, num_heads, seq_len, head_dim]
q = q.transpose(1, 2)
# 键和值形状: [batch_size, seq_len, head_dim]
# 注意: 键和值不需要多头维度,所有头共享相同的键值矩阵
# 4. 计算注意力分数 (点积)
# 形状: [batch_size, num_heads, seq_len, seq_len]
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# 5. 应用注意力掩码 (如果提供)
if attn_mask is not None:
# 确保掩码维度匹配
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0).unsqueeze(1) # [1, 1, seq_len, seq_len]
elif attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
# 将掩码位置的值设为负无穷 (softmax后接近0)
attn_scores = attn_scores.masked_fill(attn_mask == 0, -1e9)
# 6. 应用softmax获取注意力权重
# 形状: [batch_size, num_heads, seq_len, seq_len]
attn_weights = F.softmax(attn_scores, dim=-1)
# 7. 应用dropout
attn_weights = self.dropout(attn_weights)
# 8. 加权聚合值
# 值形状: [batch_size, seq_len, head_dim]
# 输出形状: [batch_size, num_heads, seq_len, head_dim]
output = torch.matmul(attn_weights, v.unsqueeze(1)) # 扩展维度以匹配多头
# 9. 重塑输出并通过线性层
# 形状: [batch_size, seq_len, num_heads * head_dim]
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
# 最终输出形状: [batch_size, seq_len, embed_dim]
return self.out_proj(output)
4、多头潜在注意力(Multi-Head Latent Attention)
4.1、核心问题:传统注意力的 “显式依赖局限”
传统注意力(包括多头、稀疏版)依赖 “显式成对交互”( 与
的相似度),但:
- 长序列中,显式交互仍可能遗漏全局隐式依赖(如 “猫” 和 “狗” 的关联不通过直接相似,而通过 “动物” 这个隐概念);
- 多头注意力的 “头” 是独立的,缺乏对 “头间关联” 的建模。
多头潜在注意力的核心:引入 “潜在变量”(Latent Variable)捕捉全局隐式依赖,同时用多头机制建模不同维度的潜在结构。
核心思想:
传统注意力只关注 “表面关联”(如 “苹果” 和 “水果”),忽略 “隐藏关联”(如 “苹果” 和 “健康” 通过 “维生素” 关联)。
多头潜在注意力引入 “潜在变量”,就像在大脑中创建 “隐藏文件夹”,专门存放这些隐藏关联。生活化比喻:
你整理照片:
- 传统方法:按 “人物”“风景”“美食” 分类(显式标签);
- 多头潜在方法:除了显式分类,还创建 “隐藏文件夹”,自动关联 “运动→健康→健身房”“旅行→相机→回忆” 等隐藏关系。
作用:
捕捉更深层的语义关联,提升复杂任务(如长文本理解、跨模态推理)的效果。
4.2、详细计算逻辑
4.2.1、潜在变量的作用
- 潜在变量 z∈R^k(k 远小于 n):压缩全局信息,作为 “隐式中介” 传递序列中不直接交互的元素依赖。
- 例:z 可理解为 “全局语义向量”,每个元素既关注显式相似元素,也关注 z 包含的隐式全局信息。
4.2.2、多头潜在机制
- 每个头有独立的潜在变量 z_h(h=1..H),建模不同维度的隐式依赖;
- 计算步骤:
① 多头投影:,
,
(同 MHA);
② 显式注意力:;
③ 潜在注意力:(
通过学习捕捉全局模式);
④ 融合:(或通过门控机制融合);
⑤ 拼接多头结果,输出最终序列表示。
4.2.3、潜在变量的学习
通常通过 “重构损失” 学习:让
能辅助重构原始序列信息;
- 或结合变分推断:
服从某种分布(如高斯分布),通过 KL 散度正则化,增强泛化能力。
4.3、优缺点
- 优点:捕捉显式 + 隐式依赖,提升长序列全局建模能力;多头潜在变量增加表达多样性;
- 缺点:引入潜在变量增加模型复杂度(需学习 z_h 的先验 / 分布);训练不稳定(潜在变量难优化)。
- 应用:少样本学习、长文本理解(如文档摘要)、跨模态建模(如图文隐式关联)。
4.4、示例代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadLatentAttention(nn.Module):
"""
多头潜在注意力 (Multi-Head Latent Attention) 模块
与标准多头注意力不同,MLA引入了可学习的潜在变量 (latent variables),
这些潜在变量作为查询 (Query) 来关注输入序列,使模型能够从输入中提取
更抽象的表示。常用于变分自编码器 (VAE)、生成对抗网络 (GAN) 等生成模型。
核心思想: 使用可学习的潜在变量作为"探针",主动从输入中提取信息,而非
仅依赖输入自身的交互。
"""
def __init__(self,
embed_dim: int, # 输入嵌入维度
num_heads: int, # 注意力头数量
num_latents: int, # 潜在变量数量
latent_dim: int = None, # 潜在变量维度
dropout: float = 0.0, # Dropout概率
):
super().__init__()
# 参数校验
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_latents = num_latents
self.latent_dim = latent_dim if latent_dim is not None else embed_dim
# 确保维度可被头数整除
assert self.latent_dim % num_heads == 0, "latent_dim必须能被num_heads整除"
self.head_dim = self.latent_dim // num_heads
# 初始化可学习的潜在变量
# 形状: [num_latents, latent_dim]
self.latents = nn.Parameter(torch.randn(num_latents, self.latent_dim))
# 投影层
self.q_proj = nn.Linear(self.latent_dim, self.latent_dim) # 潜在变量投影为查询
self.k_proj = nn.Linear(embed_dim, self.latent_dim) # 输入投影为键
self.v_proj = nn.Linear(embed_dim, self.latent_dim) # 输入投影为值
self.out_proj = nn.Linear(self.latent_dim, embed_dim) # 输出投影
# Dropout和缩放因子
self.dropout = nn.Dropout(dropout)
self.scale = self.head_dim ** -0.5
def forward(self,
x: torch.Tensor, # 输入序列 [batch_size, seq_len, embed_dim]
mask: torch.Tensor = None # 可选的注意力掩码 [batch_size, seq_len]
) -> torch.Tensor:
"""
前向传播过程
"""
batch_size, seq_len, _ = x.shape
# 1. 准备查询 (Query): 从潜在变量生成
# 形状: [batch_size, num_latents, latent_dim]
q = self.q_proj(self.latents).unsqueeze(0).expand(batch_size, -1, -1)
# 2. 准备键 (Key) 和值 (Value): 从输入生成
# 形状: [batch_size, seq_len, latent_dim]
k = self.k_proj(x)
v = self.v_proj(x)
# 3. 将张量重塑为多头形式
# 形状: [batch_size, num_heads, num_latents, head_dim]
q = q.view(batch_size, self.num_latents, self.num_heads, self.head_dim).transpose(1, 2)
# 形状: [batch_size, num_heads, seq_len, head_dim]
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 4. 计算注意力分数 (点积)
# 形状: [batch_size, num_heads, num_latents, seq_len]
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# 5. 应用注意力掩码 (如果提供)
if mask is not None:
# 扩展掩码维度以匹配注意力分数
mask = mask.unsqueeze(1).unsqueeze(1) # [batch_size, 1, 1, seq_len]
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
# 6. 应用softmax获取注意力权重
# 形状: [batch_size, num_heads, num_latents, seq_len]
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 7. 加权聚合值
# 形状: [batch_size, num_heads, num_latents, head_dim]
output = torch.matmul(attn_weights, v)
# 8. 重塑并通过输出投影层
# 形状: [batch_size, num_latents, latent_dim]
output = output.transpose(1, 2).contiguous().view(batch_size, self.num_latents, self.latent_dim)
# 最终输出形状: [batch_size, num_latents, embed_dim]
return self.out_proj(output)
5、四种注意力的总结
机制 | 核心优化点 | 类比场景 | 典型优势 |
---|---|---|---|
稀疏注意力 | 减少计算量(只关注重要部分) | 跳读一本书 | 长序列处理效率提升 |
FlashAttention | 优化内存访问顺序 | 一次性准备好所有食材再做饭 | 速度快、省显存 |
多查询注意力 | 共享参数(K/V) | 小组作业共享工具 | 推理速度快、参数少 |
多头潜在注意力 | 捕捉隐藏关联 | 创建隐藏文件夹整理照片 | 深层语义理解能力更强 |
- 稀疏注意力:少看(只看关键部分)—— 像读长文章只看段落首尾句,抓重点省时间。
- FlashAttention:快算(不改逻辑只提速)—— 像用计算器算算术,和手算结果一样,但速度快 10 倍。
- 多查询注意力:共享算(共用参数)—— 像办公室共用打印机,多人用一台也不耽误事,还省成本。
- 多头潜在注意力:压缩算(先提炼核心再处理)—— 像把长视频先转成文字摘要,再根据摘要找片段,既懂全局又抓细节。
6、信息处理:分离角色
6.1、 查询(Query)、键(Key)、值(Value)的分工
- 查询(Q):表示 “当前 token 在找什么”,类似于 “问题”。
- 键(K):表示 “每个 token 有什么”,类似于 “答案的索引”。
- 值(V):表示 “每个 token 实际携带的信息”,类似于 “答案内容”。
6.2、为什么需要分离?
类比搜索引擎:
- 查询(Q):用户输入的搜索关键词(如 “深度学习”)。
- 键(K):网页的标签或索引(如标题、关键词)。
- 值(V):网页的实际内容。
搜索引擎通过比较 Q 和 K 的相似度,从 V 中提取相关信息。注意力机制同理:通过 Q 和 K 的点积计算相似度,从 V 中加权聚合信息。