主流的Attention Backend介绍

发布于:2025-06-29 ⋅ 阅读:(15) ⋅ 点赞:(0)

Attention Backend 技术背景

注意力(Attention)机制在深度学习中扮演着关键角色,它帮助模型在处理序列数据时,有选择地关注输入中的重要信息。然而,传统的注意力计算往往受到内存访问和算力分配的双重制约,导致在大型模型或长序列场景下效率低下。为了解决这一瓶颈,研究者们提出了“Attention Backend”这一概念,即在计算框架之外,专门设计高效的底层实现和调度策略,以充分利用硬件资源,实现更快、更省内存的注意力运算。

Attention Backend 的出现意义在于,它将注意力计算从通用的算子调用中剥离出来,通过针对性地优化内存布局、并行调度和数值稳定性等关键环节,大幅提高推理和训练的性能。例如,FlashAttention 通过重设计内存访问路径和分块处理,将计算复杂度和显存占用同时降低;Triton Backend 则利用自定义 CUDA 内核与编译时调度,进一步压缩执行时间。这些底层创新让模型能够在相同硬件上承载更长的序列、更大的批次,从而加速实际应用中的部署速度与响应效率。

总的来说,Attention Backend 就像是一条高效的高速公路,承载着注意力机制在硬件层面的高速运转。它不仅推动了大规模预训练模型在推理阶段的普及,也为未来更复杂的序列处理任务提供了坚实的性能保障。

1. FlashInfer Backend

FlashInfer 最初是华盛顿大学 Paul G. Allen 计算机科学院、卡耐基梅隆大学及陈天奇的创业公司 OctoAI 共同发起的合作研究项目,旨在创建一个灵活的大语言模型(LLM)推理内核库,提供 LLM GPU 内核的高性能实现,如 FlashAttention、SparseAttention、PageAttention、Sampling 等。

核心技术特点:
  • 分页 KV 缓存(Paged KV Cache)
    • 将 Key-Value 缓存(KV Cache)划分为固定大小的块(类似操作系统中的分页机制),按需加载和存储。
    • 减少内存碎片化,提升显存利用率。
  • Radix Tree 前缀匹配
    • 通过 Radix Tree 结构快速匹配前缀相同的序列,共享缓存块,减少重复计算。
    • 适用于批量推理(如多个请求共享相同前缀的情况)。
  • Block-Sparse 稀疏化
    • 将 KV Cache 表示为稀疏矩阵,仅存储非零元素,降低显存占用。
  • Wrapper 模式
    • FlashInfer 将注意力计算分为 Decode WrapperPrefill Wrapper 两种模式:
      • Decode Wrapper:用于生成阶段的单步推理(如解码新 token)。
      • Prefill Wrapper:用于预填充阶段(如处理完整输入序列)。
    • 通过分步配置(init_forward_metadata)和执行(forward)分离,提升工程灵活性。
适用场景:
  • LLM 推理服务:适合大规模并发请求场景(如聊天机器人、API 服务)。
  • 长文本处理:通过分页和稀疏化技术,支持更长的上下文长度。
优势:
  • 显存占用降低 40% 以上。
  • 推理延迟降低 50% 以上(相比传统注意力实现)。
  • 支持灵活的硬件定制(如通过 JIT 编译优化)。

2. Triton Backend

Triton Backend是专门为高性能计算设计的一种后端实现,它通过Triton语言(一种针对GPU编程的低级语言)和编译器优化,能够更精细地控制硬件资源(如显存和计算单元),从而加速注意力机制的计算过程。与传统的PyTorch或TensorFlow实现不同,Triton允许开发者直接编写高度定制化的内核代码,针对注意力计算中的矩阵乘法、softmax操作等关键步骤进行优化,减少不必要的内存拷贝和计算冗余,尤其适合处理Transformer模型中长序列的注意力操作。

3. FA3 (FlashAttention v3)

FA3是FlashAttention的第三个版本,继承了前两个版本的设计理念,并对性能进行了进一步的优化。FA3相对于FlashInfer在算法和硬件优化上有了更多的创新,特别是在大规模训练时,能够显著提升吞吐量。

FlashAttention-3(FA3)作为最新一代注意力机制优化技术,在FlashAttention-1(FA1)和FlashAttention-2(FA2)的基础上进行了多项关键改进,显著提升了计算效率、硬件利用率和低精度性能。

推理框架中的Atttention Backend

如下代码为SGLang推理框架里面投机采样技术的Attention Backend选择代码,值得参考:

    # 初始化目标模型和草稿模型各自的Attention Backend
    # 草稿模型和目标模型使用不同的Attention Backend
    def init_attention_backend(self):
        # Create multi-step attn backends and cuda graph runners
        if self.server_args.attention_backend == "flashinfer":
            if not global_server_args_dict["use_mla_backend"]:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferAttnBackend,
                    FlashInferMultiStepDraftBackend,
                )

                print("self.draft_attn_backend = FlashInferMultiStepDraftBackend")
                self.draft_attn_backend = FlashInferMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
                print("self.draft_extend_attn_backend = FlashInferAttnBackend")
                self.draft_extend_attn_backend = FlashInferAttnBackend(
                    self.draft_model_runner,
                    skip_prefill=False,
                )
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAAttnBackend,
                    FlashInferMLAMultiStepDraftBackend,
                )

                self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
                self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
                    self.draft_model_runner,
                    skip_prefill=False,
                )
            self.has_prefill_wrapper_verify = True
        elif self.server_args.attention_backend == "triton":
            from sglang.srt.layers.attention.triton_backend import (
                TritonAttnBackend,
                TritonMultiStepDraftBackend,
            )

            self.draft_attn_backend = TritonMultiStepDraftBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
            self.draft_extend_attn_backend = TritonAttnBackend(
                self.draft_model_runner,
                skip_prefill=False,
            )
            self.has_prefill_wrapper_verify = False
        elif self.server_args.attention_backend == "fa3":
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionBackend,
                FlashAttentionMultiStepBackend,
            )

            self.draft_attn_backend = FlashAttentionMultiStepBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
            self.draft_extend_attn_backend = FlashAttentionBackend(
                self.draft_model_runner,
                skip_prefill=False,
            )
            self.has_prefill_wrapper_verify = False
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import (
                FlashMLAMultiStepDraftBackend,
            )

            self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
            self.draft_extend_attn_backend = None
            self.has_prefill_wrapper_verify = False
        else:
            raise ValueError(
                f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
            )

        self.draft_model_runner.draft_attn_backend = self.draft_attn_backend

网站公告

今日签到

点亮在社区的每一天
去签到