DUOATTENTION:结合检索与流式注意力机制的高效长上下文大语言模型推理方法

发布于:2025-08-10 ⋅ 阅读:(11) ⋅ 点赞:(0)

温馨提示:
本篇文章已同步至"AI专题精讲" DUOATTENTION:结合检索与流式注意力机制的高效长上下文大语言模型推理方法

摘要

部署长上下文的大语言模型(LLMs)是必要的,但也带来了显著的计算与内存挑战。将所有注意力头的 Key 和 Value(KV)状态进行缓存会占用大量内存。现有的 KV 缓存裁剪方法要么破坏了 LLM 的长上下文能力,要么仅带来有限的效率提升。本文发现,只有部分注意力头(即“检索头”)在处理长上下文时至关重要,需要对所有 token 执行完整注意力;而其余的注意力头主要关注最近的 token 和注意力汇聚点(attention sinks),我们称之为“流式头”,它们并不需要完整注意力机制。

基于这一洞察,我们提出 DuoAttention 框架:该框架仅对检索头使用完整的 KV 缓存,而对流式头则采用轻量级、固定长度的 KV 缓存,从而在不影响模型长上下文能力的前提下,降低 LLM 的解码和预填充时的内存占用与延迟。DuoAttention 借助轻量级、基于优化的算法,并利用合成数据准确识别检索头。我们的方法在保持精度几乎无损的前提下,将长上下文推理的内存使用量最多降低至原来的 2.55×(用于 MHA 模型)和 1.67×(用于 GQA 模型),同时解码速度最多提升至 2.18× 和 1.50×,预填充速度最多提升至 1.73× 和 1.63×。值得注意的是,结合量化技术后,DuoAttention 使得 Llama-3-8B 在单个 A100 GPU 上支持 330 万 token 的上下文长度推理成为可能。代码地址详见文末链接。

1 引言

大型语言模型(LLMs)(Touvron et al., 2023a;b;OpenAI, 2023;Black et al., 2022)位于当前人工智能革命的前沿,推动了诸如多轮对话(Schulman et al., 2022;Taori et al., 2023;Chiang et al., 2023)、长文档摘要(Goyal & Durrett, 2020;Zhang et al., 2023a),以及视觉和视频理解等多模态任务(Liu et al., 2023b;Lin et al., 2023)等高级应用的发展。这些应用通常需要处理大量的上下文 token。例如,想要对整个《哈利·波特》系列小说进行摘要,大约需要处理一百万个 token。而在视觉语言模型(VLM)中,单张 224×224 的图像就对应 256 个 token(Liu et al., 2023b),而一段 3 分钟的视频(24 帧每秒)则会产生大约 110 万个 token。

在此类应用中部署 LLM 时,面临的一个关键问题是长上下文推理问题。完整的注意力机制要求每一个 token 都必须关注所有之前的 token,以获得准确的表示,导致解码延迟线性增长、预填充延迟按平方增长。与此同时,Key-Value(KV)缓存技术需要存储所有历史 token 的键值信息,使得内存使用量随上下文长度线性增长。随着序列的延长,KV 缓存占用的内存越来越多,从而对注意力机制造成了巨大计算负担。例如,在 Llama-3-8B(Dubey et al., 2024)模型结构中,若以 FP16 精度缓存 100 万 token 的 KV 信息,则至少需要 137GB 内存,远超单张 80GB GPU 的容量。此外,在如此长的上下文下进行预填充与解码,其延迟也极为显著,严重影响了 LLM 在长上下文任务中的实际可用性。
在这里插入图片描述

尽管已经有许多尝试旨在解决长上下文推理中注意力机制所面临的挑战,但显著的计算与内存问题仍然存在。一些结构上的改进方法,例如 Grouped-Query Attention(GQA)(Ainslie et al., 2023),需要模型在预训练阶段就进行改动,且无法降低计算开销。线性注意力方法(Gu & Dao, 2023;Poli et al., 2023)虽然在计算和内存上的需求较低,但在长上下文场景中的表现往往不如标准的 transformer 模型。近似注意力机制,如 H2O(Zhang et al., 2023b)、StreamingLLM(Xiao et al., 2023b)、TOVA(Oren et al., 2024)和 FastGen(Ge et al., 2024),通常会在处理长上下文时牺牲准确性,而且无法兼容诸如 GQA 等关键的 KV 缓存优化技术。KV 缓存量化方法(Liu et al., 2024;Hooper et al., 2024)虽然有所帮助,但并不能减少注意力机制的计算时间。系统层级的优化方法,包括 FlashAttention(Dao et al., 2022;Dao, 2023)、FlashDecoding(Hong et al., 2024)和 PagedAttention(Kwon et al., 2023),虽然高效,但无法减小 KV 缓存的尺寸,在长上下文下依然需要大量计算资源。

这些限制凸显了进一步发展的必要性,以便部署能够处理百万级上下文长度的模型。本文提出了一个关键观察:LLM 中的注意力头可以被划分为两种截然不同的类型,如图 1 所示:Retrieval Heads(Wu et al., 2024)和 Streaming Heads。Retrieval Heads 仅占所有注意力头的一小部分,但对于处理长上下文至关重要,需要对所有 token 执行完整注意力。而其余大部分注意力头——称为 Streaming Heads——主要关注最近的 token 和注意力汇聚点(attention sinks)(Xiao et al., 2023b),在仅包含最近 token 和注意力汇聚点的精简 KV 缓存下也能正常工作。

基于这种“检索头-流式头”二元划分,我们提出了一种通用、简洁、易集成的方法 DuoAttention,它能显著加速 LLM 的解码与预填充,并减少内存占用,尤其适用于长上下文场景。DuoAttention 的核心创新是一种轻量级的、基于优化的过程,利用合成数据精确识别无法压缩的检索头。与依赖注意力模式分析的方法(Wu et al., 2024;Ge et al., 2024;Tang et al., 2024a)不同,DuoAttention 直接通过测量 token 丢弃带来的输出偏差进行评估,从而实现更高的压缩率与更优的部署效率。

DuoAttention 的设计兼顾简洁性与效率:每一层 transformer 拥有两个 KV 缓存 —— 为关键的 Retrieval Heads 提供完整的 KV 缓存,为 Streaming Heads 提供一个仅包含注意力汇聚点与最近 token 的固定长度 KV 缓存。该设计可显著减少内存使用,并提升在 Llama-2/3 和 Mistral 等模型中的解码速度,在 MHA 模型中内存最多减少 2.55×、在 GQA 模型中减少 1.67×,同时解码速度分别提升至 2.18× 和 1.50×,预填充加速至 1.73× 和 1.63×,且在精度上与完整注意力相比仅有极小的损失。

此外,DuoAttention 可与 GQA 和量化等重要优化技术完全兼容。我们进一步展示,当 DuoAttention 与 8-bit 权重和 4-bit KV 缓存量化结合使用时,可使 Llama-3-8B 模型在单个 A100 GPU 上处理最多 330 万个上下文 token,相较于标准 FP16 全注意力部署,容量提升了 6.4 倍。DuoAttention 为将 LLM 应用于需要处理百万级上下文的场景铺平了道路。
在这里插入图片描述

2 DuoAttention

2.1 检索头与流式头

检索头(Retrieval Heads)
在基于 Transformer 的大语言模型(LLM)中,不同的注意力头通常展现出稳定且各具特色的注意力模式,反映出它们所执行的特定功能(Clark et al., 2019;Xiao et al., 2023b;Wu et al., 2024)。图 1 展示了在 Llama-2-7B-32K-Instruct 模型中,输入句子 “The best fruit is orange. What is the best fruit? Orange” 下两种注意力头的可视化效果。图左面板显示的注意力头在解码阶段强调了与当前词语相关的 token:例如,在解码第二次出现的 “best fruit” 时,该头重点关注了第一次出现的 “best fruit”,在生成第二个 “orange” 时,也显著关注了前面提到的 “orange”。我们将这类注意力头称为检索头(Retrieval Heads),它们在上下文建模中至关重要,能够捕捉语义相关的 token。如果压缩这类注意力头的 KV 缓存,就会导致重要上下文信息的丢失,因此它们必须对全部 token 执行完整注意力。

流式头(Streaming Heads)
与之相对,图中间面板所示的注意力头主要关注最近的 token 和注意力汇聚点(attention sinks,Xiao et al., 2023b),而不会特别强调上下文中早先出现的相关 token。我们将这类注意力头称为流式头(Streaming Heads)。对流式头的 KV 缓存进行压缩是可行的,因为移除其未关注的中间 token 并不会显著影响其注意力输出。因此,我们可以仅保留流式头对注意力汇聚点和最近 token 的 KV 状态,在不牺牲长上下文处理能力的前提下进行优化。

Token 剪除对检索头与流式头的影响
图右面板展示了一个初步的 passkey 检索实验,结果表明:当剪除检索头 KV 缓存中的中间 token(即用流式注意力机制替代)时,模型性能会明显下降;而对流式头剪除中间 token 则对 passkey 检索的准确率影响甚微。这一观察说明,我们可以在不牺牲模型长上下文理解能力的前提下提升计算效率:只需对流式头移除中间 token,而对检索头保留完整注意力,即可将流式头的内存需求降至 O ( 1 ) O(1) O(1),从而提升长上下文处理的效率。
在这里插入图片描述

2.2 基于优化的检索头识别

检索头的定义
第 2.1 节中我们从定性角度区分了 retrieval 头与 streaming 头,而要进行精确识别,还需要一个明确且量化的定义。本文将 “检索头(retrieval heads)” 定义为:

当其 attention 仅限于 recent token 与 attention sink 时,模型输出发生显著变化的注意力头。

我们据此标准将检索头与流式头区分开来。需要注意的是,这一定义不同于现有的一些方法(Ge et al., 2024;Wu et al., 2024;Tang et al., 2024a),后者仅依赖 attention score 来判断哪些头是 retrieval 头,这种方法忽略了以下几点:1)对特定注意力头进行 KV cache 压缩的端到端影响;2)value 状态在注意力计算中的作用;3)不同层和注意力头之间的 attention 分布差异性。相比之下,本文的方法直接衡量压缩后的输出偏差,即便 attention score 不明显,也能识别出对长上下文处理至关重要的注意力头。我们将在第 3.5 节通过消融实验对该方法进行进一步支持。

基于优化的识别方法
我们采用一种基于优化的方式识别 retrieval 头,该方法灵感来源于 CNN filter 剪枝研究(Liu et al., 2017),如图 2 所示。首先,我们为每一个 key-value(KV)头分配一个门控值 α i , j \alpha_{i,j} αi,j,该值直观地表示第 i i i 层中第 j j j 个 KV 头在处理长上下文信息时的重要性。需要注意的是,在采用 GQA(Grouped-Query Attention)的模型中,一个 KV 头可能会被多个 attention 头共享,因此我们的方法也适用于对一组 attention 头的 KV cache 统一压缩。

我们的方法直接评估仅保留 sink 和 recent token 后对每个 KV 头的输出影响。首先,为每个 KV 头初始化门控值 α i , j ∈ [ 0 , 1 ] \alpha_{i,j} \in [0, 1] αi,j[0,1],默认所有注意力头初始状态都是 retrieval 头(即 α = 1 \alpha=1 α=1)。在优化过程中,LLM 的所有原始参数保持冻结,仅优化这些门控值,因此总的可训练参数为 N × H N \times H N×H(其中 N N N 为层数, H H H 为每层的 KV 头数),从而避免对原模型性能造成干扰。

在前向传播过程中,对于每个 KV 头的输出,我们将 full attention(对所有 token 的注意)与 streaming attention(仅对 sink 和 recent token 的注意)混合,使用门控值作为加权因子,混合方式如下所示:
a t t n i , j = α i , j ⋅ f u l l _ a t t n + ( 1 − α i , j ) ⋅ s t r e a m i n g _ a t t n \mathtt { a t t n } _ { i , j } = \alpha _ { i , j } \cdot \mathtt { f u l l \_ a t t n } + \left( 1 - \alpha _ { i , j } \right) \cdot \mathtt { s t r e a m i n g \_ a t t n } attni,j=αi,jfull_attn+(1αi,j)streaming_attn
其中注意力计算被定义为:
f u 11 _ a t t n = s o f t m a x ( Q K T ⊙ M c a u s a l ) V ,     s t r e a m i n g _ a t t n = s o f t m a x ( Q K T ⊙ M s t r e a m i n g ) V , \begin{array} { r } { \mathrm { f u 1 1 \_ a t t n = s o f t m a x } ( Q K ^ { T } \odot M _ { \mathrm { c a u s a l } } ) V , \ \ \ } \\ { \mathrm { s t r e a m i n g \_ a t t n = s o f t m a x } ( Q K ^ { T } \odot M _ { \mathrm { s t r e a m i n g } } ) V , } \end{array} fu11_attn=softmax(QKTMcausal)V,   streaming_attn=softmax(QKTMstreaming)V,
其中, M _ causal M\_{\text{causal}} M_causal 是因果注意力掩码(即下三角矩阵),而 M _ streaming M\_{\text{streaming}} M_streaming 表示 Λ 形掩码(Han et al., 2023;Xiao et al., 2023b),该掩码仅关注 recent token 与起始 token。

用于识别检索头的合成数据集
然而,仅依赖自然语言建模目标是无法有效识别 retrieval 头的,因为自然文本中需要进行长距离推理的监督信号非常稀疏,大多数 token 都可以通过局部上下文推理得出。为了解决这一问题,我们设计了一个专门用于强化模型长上下文检索能力的合成数据集,从而有效识别出哪些 KV 头可以被压缩而不影响模型性能。

如图 3 所示,我们构建了一个 passkey 检索数据集(passkey-retrieval dataset):在一个非常长的上下文中,随机插入 10 个长度为 s s s 的 passkey 序列(在实验中 s = 32 s = 32 s=32),每个序列出现在随机位置。模型的任务是,在上下文末尾,能够回忆出这 10 个 passkey 序列。

在这里插入图片描述

训练与损失函数
我们优化的是蒸馏损失(distillation loss),具体为完整注意力模型的最后隐藏状态 H full H_{\text{full}} Hfull 与使用 DuoAttention 的模型输出 H mixed H_{\text{mixed}} Hmixed 之间的 L2 差异,仅在整个长度为 T T T 的输入序列中,关注最后 l l l 个 passkey token:
L d i s t i l l = 1 N ∑ i = 1 N ∑ j = T − l + 1 T ( H f u l l ( i ) [ j ] − H m i x e d ( i ) [ j ] ) 2 ( 1 ) \mathcal { L } _ { \mathrm { d i s t i l l } } = \frac { 1 } { N } \sum _ { i = 1 } ^ { N } \sum _ { j = T - l + 1 } ^ { T } ( H _ { \mathrm { f u l l } } ^ { ( i ) } [ j ] - H _ { \mathrm { m i x e d } } ^ { ( i ) } [ j ] ) ^ { 2 }\quad(1) Ldistill=N1i=1Nj=Tl+1T(Hfull(i)[j]Hmixed(i)[j])2(1)
我们的合成数据集确保每个监督信号都与最终的压缩策略相关,使得该过程在信息检索准确率方面实现无损。实验表明,该方法比单纯使用自然语言建模目标更为有效(详见第13节的消融研究)。我们采用 L1 正则项(即 Lasso(Tibshirani,1996))来促进门控值的稀疏性:
L r e g = ∑ i = 1 L ∑ j = 1 H ∣ α i , j ∣ ( 2 ) \mathcal { L } _ { \mathrm { r e g } } = \sum _ { i = 1 } ^ { L } \sum _ { j = 1 } ^ { H } \left| \alpha _ { i , j } \right|\quad(2) Lreg=i=1Lj=1Hαi,j(2)
最终的训练损失由蒸馏损失和正则化损失加权组合而成,权重超参数为 λ λ λ,在我们的实验中设为 0.05:
L = L d i s t i l l + λ L r e g . ( 3 ) \begin{array} { r } { \mathcal { L } = \mathcal { L } _ { \mathrm { d i s t i l l } } + \lambda \mathcal { L } _ { \mathrm { r e g } } . } \end{array}\quad(3) L=Ldistill+λLreg.(3)
由于可训练参数总数仅为数千个浮点数,该优化过程相当快速,通常只需约 2000 步训练即可完成。本文中的所有训练实验均可在 8 张 NVIDIA A100 GPU 服务器上完成。

2.3 使用 DuoAttention 部署大语言模型

注意力策略的二值化实现
在推理阶段,我们仅对通过训练阶段优化门控值识别出的检索头(Retrieval Heads)应用完整注意力机制(如图4所示)。我们基于一个阈值 τ \tau τ 对每个注意力头的策略进行二值化,该阈值由指定的稀疏分位数确定,用以区分检索头与流式头(Streaming Heads):
a t t n i , j = { f u l l _ a t t n i f    α i , j > τ s t r e a m i n g _ a t t n o t h e r w i s e ( 4 ) \mathsf { a t t n } _ { i , j } = \left\{ \begin{array} { l l } { \mathtt { f u l l \_ a t t n } } & { \mathrm { i f } \; \alpha _ { i , j } > \tau } \\ { \mathtt { s t r e a m i n g \_ a t t n } } & { \mathrm { o t h e r w i s e } } \end{array} \right.\quad(4) attni,j={full_attnstreaming_attnifαi,j>τotherwise(4)
在这里插入图片描述
注意力头重排序
在部署之前,我们会对模型进行预处理,按照注意力头的分类对 Query、Key、Value 投影权重的输出通道进行重排序。该重排序将检索头(Retrieval Heads)和流式头(Streaming Heads)分别聚集成两个连续的簇,这样在管理每层的 KV 缓存时,可以高效地进行切片和拼接操作,避免使用开销较大的分散(scatter)与聚集(gather)操作。

解码过程
如图5所示,在 LLM 的每一层解码过程中,我们为检索头和流式头分别分配两个 KV 缓存:检索头缓存保存所有历史的 Key 和 Value,流式头缓存仅保存注意力汇聚点和最近的 token,缓存大小保持恒定。当处理新 token 时,其对应的 Query、Key 和 Value 向量会沿注意力头维度拆分,分别计算检索头的完整注意力和流式头的流式注意力,结果随后沿注意力头维度拼接,传入输出投影层。

分块预填充
我们使用 FlashAttention-2(Dao, 2023)对检索头和流式头的 KV 缓存进行预填充。对于长上下文 LLM,分块预填充是一种常用做法(Agrawal et al., 2023;Kwon et al., 2023),即将提示文本分割成固定长度的块,依次预填充 KV 缓存。此技术显著降低峰值内存使用(参见表10),将线性层的峰值中间激活大小从序列长度降至块大小。

DuoAttention 完全兼容分块预填充,且流式头的预填充过程时间复杂度为线性,内存复杂度为常数,无需依赖特殊内核。正如图5所示,一旦一层的 KVs 计算完成,流式头的 KV 缓存即被即时修剪,只保留注意力汇聚点和最近 token。接下来的新块 token 在预填充时仅关注恒定数量的上下文 token。令序列长度为 L L L,块大小为 K K K,则流式头的预填充时间复杂度由 O ( L 2 ) O(L^2) O(L2) 优化至 O ( L K ) O(LK) O(LK),内存复杂度由 O ( L ) O(L) O(L) 降至 O ( K ) O(K) O(K)

值得注意的是,DuoAttention 的设计非常适合批量操作,这可进一步提升 LLM 在大批量请求场景下的服务效率。

3 实验

3.1 实验设置

模型、数据集与基线
我们在长上下文和短上下文基准任务上评估 DuoAttention,以验证该方法在保证模型在需要长上下文和短上下文任务中的性能的同时,显著提升效率。在长上下文评测中,我们采用 Needle-in-a-Haystack(NIAH)基准(Kamradt, 2024)和 LongBench(Bai et al., 2023)。在短上下文评测中,我们使用 MMLU(Hendrycks et al., 2021)、MBPP(Austin et al., 2021)和 MT-Bench(Zheng et al., 2023)等数据集。所用模型均为当前最先进的开源模型,包括 Llama-2-7B-chat(Touvron et al., 2023b)及其长上下文变体 Llama-2-7B-32K-Instruct(Together, 2023)、Llama-3-[8,70]B-Instruct(及其长上下文变体 Llama-3-8B-Instruct-Gradient-1048k*),以及 Mistral-7B-v0.2-Instruct(Jiang et al., 2023)。我们的方法与多种 KV 缓存压缩算法进行了对比,包括 H2O(Zhang et al., 2023b)、TOVA(Oren et al., 2024)、FastGen(Ge et al., 2024)以及 StreamingLLM(Xiao et al., 2023b)。
在这里插入图片描述
温馨提示:
阅读全文请访问"AI深语解构" DUOATTENTION:结合检索与流式注意力机制的高效长上下文大语言模型推理方法


网站公告

今日签到

点亮在社区的每一天
去签到