线性注意力(Linear Attention)
线性注意力是一种改进的注意力机制,旨在解决传统自注意力(Self-Attention)在处理长序列时计算和内存复杂度过高的问题。传统自注意力的计算复杂度是 O(N2)O(N^2)O(N2),而线性注意力通过一系列数学技巧将其降低到 O(N)O(N)O(N),更适合处理长文本或高分辨率图像。
🔍 一、什么是线性注意力?为什么是线性?
在标准自注意力中,核心操作如下:
Attention(Q,K,V)=softmax(QKTd)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V Attention(Q,K,V)=softmax(dQKT)V
其中:
- Q,K,VQ, K, VQ,K,V 是查询(Query)、键(Key)、值(Value)矩阵,维度都是 N×dN \times dN×d
- 这个公式涉及到计算 QKTQK^TQKT,它是一个 N×NN \times NN×N 的矩阵 —— 所以计算复杂度是 O(N2)O(N^2)O(N2)。
而线性注意力的基本想法是:将 softmax 或其他非线性函数 近似替换成内积可以拆分的形式,也就是说将其写成:
Attention(Q,K,V)≈(ϕ(Q)(ϕ(K)TV)) \text{Attention}(Q, K, V) \approx (\phi(Q) (\phi(K)^T V)) Attention(Q,K,V)≈(ϕ(Q)(ϕ(K)TV))
其中 ϕ(⋅)\phi(\cdot)ϕ(⋅) 是一个正定映射函数(如 ReLU、ELU、exp、kernel trick等),可以将原本的 QKTQK^TQKT 操作变成:
- 先计算 KTVK^T VKTV,维度是 d×dd \times dd×d,成本是 O(Nd2)O(Nd^2)O(Nd2)
- 然后再乘上 ϕ(Q)\phi(Q)ϕ(Q),成本是 O(Nd2)O(Nd^2)O(Nd2)
从而整体注意力过程从原本的 O(N2d)O(N^2d)O(N2d) 降为 O(Nd2)O(Nd^2)O(Nd2),在 N≫dN \gg dN≫d 的时候是巨大的提升。
🧠 二、与传统注意力的区别
项目 | 传统自注意力 | 线性注意力 |
---|---|---|
复杂度 | O(N2)O(N^2)O(N2) | O(N)O(N)O(N) |
可并行性 | 差 | 好 |
对长序列的支持 | 差(容易OOM) | 强 |
精度 | 高 | 略有损失(视实现而定) |
使用softmax | 是 | 否(用可拆解核函数近似) |
✅ 三、实际应用场景
线性注意力已经应用于以下几个方向:
- 长文档生成(如 GPT-类模型压缩长上下文)
- 高分辨率图像建模(Vision Transformers)
- 语音识别与语音建模
- 嵌入设备(低功耗硬件)上运行 Transformer
典型模型包括:
- Performer
- Linformer
- Linear Transformers(本文提到的 ICML 2020 论文)
📊 四、可视化比较
标准注意力计算流程:
Q (N x d)
|
v
QK^T (N x N) <---- K (N x d)
|
v
softmax
|
v
Weighted sum with V (N x d)
线性注意力计算流程:
φ(Q) (N x d)
|
v
φ(K)^T V (d x d) <--- φ(K) (N x d), V (N x d)
|
v
Final output = φ(Q) × (φ(K)^T V)
不再显式构造 N×NN \times NN×N 的中间矩阵,节省空间和计算。
🔧 五、Python代码对比(PyTorch)
✅ 标准注意力代码(简化版):
import torch
import torch.nn.functional as F
def standard_attention(Q, K, V):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / d_k**0.5
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, V)
✅ 线性注意力代码(基于ELU核):
def elu_feature_map(x):
return F.elu(x) + 1 # 保证非负
def linear_attention(Q, K, V):
Q_prime = elu_feature_map(Q) # φ(Q)
K_prime = elu_feature_map(K) # φ(K)
KV = torch.einsum('nld,nle->lde', K_prime, V) # φ(K)^T V
Z = 1 / (torch.einsum('nld,ld->nl', Q_prime, K_prime.sum(dim=0)) + 1e-6) # 正则项
output = torch.einsum('nld,lde,nl->nle', Q_prime, KV, Z) # 最终输出
return output
🧾 六、总结
- 线性注意力本质上是通过数学变换将注意力矩阵“内积核”结构分解,避免显式计算 N×NN \times NN×N 的矩阵。
- 它能显著提高模型效率和可扩展性,特别适用于长序列任务。
- 虽然可能略微损失精度,但对于许多工程实际来说是可以接受的。