Attention Backend 技术背景
注意力(Attention)机制在深度学习中扮演着关键角色,它帮助模型在处理序列数据时,有选择地关注输入中的重要信息。然而,传统的注意力计算往往受到内存访问和算力分配的双重制约,导致在大型模型或长序列场景下效率低下。为了解决这一瓶颈,研究者们提出了“Attention Backend”这一概念,即在计算框架之外,专门设计高效的底层实现和调度策略,以充分利用硬件资源,实现更快、更省内存的注意力运算。
Attention Backend 的出现意义在于,它将注意力计算从通用的算子调用中剥离出来,通过针对性地优化内存布局、并行调度和数值稳定性等关键环节,大幅提高推理和训练的性能。例如,FlashAttention 通过重设计内存访问路径和分块处理,将计算复杂度和显存占用同时降低;Triton Backend 则利用自定义 CUDA 内核与编译时调度,进一步压缩执行时间。这些底层创新让模型能够在相同硬件上承载更长的序列、更大的批次,从而加速实际应用中的部署速度与响应效率。
总的来说,Attention Backend 就像是一条高效的高速公路,承载着注意力机制在硬件层面的高速运转。它不仅推动了大规模预训练模型在推理阶段的普及,也为未来更复杂的序列处理任务提供了坚实的性能保障。
1. FlashInfer Backend
FlashInfer 最初是华盛顿大学 Paul G. Allen 计算机科学院、卡耐基梅隆大学及陈天奇的创业公司 OctoAI 共同发起的合作研究项目,旨在创建一个灵活的大语言模型(LLM)推理内核库,提供 LLM GPU 内核的高性能实现,如 FlashAttention、SparseAttention、PageAttention、Sampling 等。
核心技术特点:
- 分页 KV 缓存(Paged KV Cache):
- 将 Key-Value 缓存(KV Cache)划分为固定大小的块(类似操作系统中的分页机制),按需加载和存储。
- 减少内存碎片化,提升显存利用率。
- Radix Tree 前缀匹配:
- 通过 Radix Tree 结构快速匹配前缀相同的序列,共享缓存块,减少重复计算。
- 适用于批量推理(如多个请求共享相同前缀的情况)。
- Block-Sparse 稀疏化:
- 将 KV Cache 表示为稀疏矩阵,仅存储非零元素,降低显存占用。
- Wrapper 模式:
- FlashInfer 将注意力计算分为 Decode Wrapper 和 Prefill Wrapper 两种模式:
- Decode Wrapper:用于生成阶段的单步推理(如解码新 token)。
- Prefill Wrapper:用于预填充阶段(如处理完整输入序列)。
- 通过分步配置(
init_forward_metadata
)和执行(forward
)分离,提升工程灵活性。
- FlashInfer 将注意力计算分为 Decode Wrapper 和 Prefill Wrapper 两种模式:
适用场景:
- LLM 推理服务:适合大规模并发请求场景(如聊天机器人、API 服务)。
- 长文本处理:通过分页和稀疏化技术,支持更长的上下文长度。
优势:
- 显存占用降低 40% 以上。
- 推理延迟降低 50% 以上(相比传统注意力实现)。
- 支持灵活的硬件定制(如通过 JIT 编译优化)。
2. Triton Backend
Triton Backend是专门为高性能计算设计的一种后端实现,它通过Triton语言(一种针对GPU编程的低级语言)和编译器优化,能够更精细地控制硬件资源(如显存和计算单元),从而加速注意力机制的计算过程。与传统的PyTorch或TensorFlow实现不同,Triton允许开发者直接编写高度定制化的内核代码,针对注意力计算中的矩阵乘法、softmax操作等关键步骤进行优化,减少不必要的内存拷贝和计算冗余,尤其适合处理Transformer模型中长序列的注意力操作。
3. FA3 (FlashAttention v3)
FA3是FlashAttention的第三个版本,继承了前两个版本的设计理念,并对性能进行了进一步的优化。FA3相对于FlashInfer在算法和硬件优化上有了更多的创新,特别是在大规模训练时,能够显著提升吞吐量。
FlashAttention-3(FA3)作为最新一代注意力机制优化技术,在FlashAttention-1(FA1)和FlashAttention-2(FA2)的基础上进行了多项关键改进,显著提升了计算效率、硬件利用率和低精度性能。
推理框架中的Atttention Backend
如下代码为SGLang推理框架里面投机采样技术的Attention Backend选择代码,值得参考:
# 初始化目标模型和草稿模型各自的Attention Backend
# 草稿模型和目标模型使用不同的Attention Backend
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
if self.server_args.attention_backend == "flashinfer":
if not global_server_args_dict["use_mla_backend"]:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
FlashInferMultiStepDraftBackend,
)
print("self.draft_attn_backend = FlashInferMultiStepDraftBackend")
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
print("self.draft_extend_attn_backend = FlashInferAttnBackend")
self.draft_extend_attn_backend = FlashInferAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
FlashInferMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import (
TritonAttnBackend,
TritonMultiStepDraftBackend,
)
self.draft_attn_backend = TritonMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = TritonAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
FlashAttentionMultiStepBackend,
)
self.draft_attn_backend = FlashAttentionMultiStepBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = FlashAttentionBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import (
FlashMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.has_prefill_wrapper_verify = False
else:
raise ValueError(
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
)
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend