个人在网上看到的实现好像都长得奇奇怪怪的,没有简洁的感觉,因此在这里给出一种易读的GQA实现方法:
import torch
import torch.nn as nn
import torch.nn.functional as F
class GroupedQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads, num_groups):
super().__init__()
assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
self.num_heads = num_heads
self.num_groups = num_groups
self.head_dim = embed_dim // num_heads
self.group_dim = self.num_groups * self.head_dim # Correct: num_groups * head_dim
self.scale = self.head_dim ** -0.5
# Projections
self.q_proj = nn.Linear(embed_dim, embed_dim) # Query: full embed_dim for num_heads
self.k_proj = nn.Linear(embed_dim, self.group_dim) # Key: group_dim for num_groups
self.v_proj = nn.Linear(embed_dim, self.group_dim) # Value: group_dim for num_groups
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.shape
# Project inputs to q, k, v
q = self.q_proj(x) # Shape: (batch_size, seq_len, embed_dim)
k = self.k_proj(x) # Shape: (batch_size, seq_len, group_dim)
v = self.v_proj(x) # Shape: (batch_size, seq_len, group_dim)
# Reshape query for multi-head attention
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Shape: (batch_size, num_heads, seq_len, head_dim)
# Reshape key and value for grouped attention
k = k.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)
# Shape: (batch_size, num_groups, seq_len, head_dim)
v = v.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)
# Shape: (batch_size, num_groups, seq_len, head_dim)
# Repeat k and v to match the number of query heads
heads_per_group = self.num_heads // self.num_groups
k = k.repeat_interleave(heads_per_group, dim=1)
# Shape: (batch_size, num_heads, seq_len, head_dim)
v = v.repeat_interleave(heads_per_group, dim=1)
# Shape: (batch_size, num_heads, seq_len, head_dim)
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Shape: (batch_size, num_heads, seq_len, seq_len)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v) # Shape: (batch_size, num_heads, seq_len, head_dim)
# Reshape and project output
out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
out = self.out_proj(out) # Shape: (batch_size, seq_len, embed_dim)
return out
# Test the model
embed_dim = 64
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(2, 10, embed_dim) # Input shape: (batch_size, seq_len, embed_dim)
output = model(x)
print(output.shape) # Expected output: torch.Size([2, 10, 64])
为了读懂GQA,建议读者了解一下MQA的实现,这样顺着读下来会更顺手。
一旦读懂了MQA,GQA的实现思路几乎完全一样,只是多用了一个不太常用的函数tensor.repeat_interleave。关于这个函数,直接点击链接看笔者相关文章就行了,挺好懂的。