Grouped Query Attention (GQA) PyTorch实现

发布于:2025-04-21 ⋅ 阅读:(53) ⋅ 点赞:(0)

个人在网上看到的实现好像都长得奇奇怪怪的,没有简洁的感觉,因此在这里给出一种易读的GQA实现方法:

import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups):
        super().__init__()
        assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = embed_dim // num_heads
        self.group_dim = self.num_groups * self.head_dim  # Correct: num_groups * head_dim
        self.scale = self.head_dim ** -0.5

        # Projections
        self.q_proj = nn.Linear(embed_dim, embed_dim)  # Query: full embed_dim for num_heads
        self.k_proj = nn.Linear(embed_dim, self.group_dim)  # Key: group_dim for num_groups
        self.v_proj = nn.Linear(embed_dim, self.group_dim)  # Value: group_dim for num_groups
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # Project inputs to q, k, v
        q = self.q_proj(x)  # Shape: (batch_size, seq_len, embed_dim)
        k = self.k_proj(x)  # Shape: (batch_size, seq_len, group_dim)
        v = self.v_proj(x)  # Shape: (batch_size, seq_len, group_dim)

        # Reshape query for multi-head attention
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # Shape: (batch_size, num_heads, seq_len, head_dim)

        # Reshape key and value for grouped attention
        k = k.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)
        # Shape: (batch_size, num_groups, seq_len, head_dim)
        v = v.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)
        # Shape: (batch_size, num_groups, seq_len, head_dim)

        # Repeat k and v to match the number of query heads
        heads_per_group = self.num_heads // self.num_groups
        k = k.repeat_interleave(heads_per_group, dim=1)
        # Shape: (batch_size, num_heads, seq_len, head_dim)
        v = v.repeat_interleave(heads_per_group, dim=1)
        # Shape: (batch_size, num_heads, seq_len, head_dim)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        # Shape: (batch_size, num_heads, seq_len, seq_len)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)  # Shape: (batch_size, num_heads, seq_len, head_dim)

        # Reshape and project output
        out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        out = self.out_proj(out)  # Shape: (batch_size, seq_len, embed_dim)
        return out

# Test the model
embed_dim = 64
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(2, 10, embed_dim)  # Input shape: (batch_size, seq_len, embed_dim)
output = model(x)
print(output.shape)  # Expected output: torch.Size([2, 10, 64])

为了读懂GQA,建议读者了解一下MQA的实现,这样顺着读下来会更顺手。

一旦读懂了MQA,GQA的实现思路几乎完全一样,只是多用了一个不太常用的函数tensor.repeat_interleave。关于这个函数,直接点击链接看笔者相关文章就行了,挺好懂的。