Qwen3 中注意力机制实现

发布于:2025-09-13 ⋅ 阅读:(18) ⋅ 点赞:(0)

导入必要的库

import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
from dataclasses import dataclass
import typing
from transformers.utils import TransformersKwargs
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
Unpack = typing.Unpack

旋转位置编码辅助函数

def rotate_half(x):

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

键值重复函数

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: # [2, 4, 8, 64]  2 
    
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

Eager注意力前向传播函数

def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor, # [2, 8, 8, 64]
    key: torch.Tensor,  # [2, 4, 8, 64]
    value: torch.Tensor, # [2, 4, 8, 64]
    attention_mask: Optional[torch.Tensor],  # [2, 1, 8, 8]
    scaling: float, # 0.125 
    dropout: float = 0.0, 
    **kwargs: Unpack[TransformersKwargs],
):
    key_states = repeat_kv(key, module.num_key_value_groups)  # [2, 8, 8, 64]
    value_states = repeat_kv(value, module.num_key_value_groups) # [2, 8, 8, 64] 
    
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling   # [2,8, 8, 8]
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]   # [2, 8, 8,8]    
        print("causal_mask:",causal_mask.shape)  
        attn_weights = attn_weights + causal_mask  # [2, 8, 8, 8]

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    print("attn_weights:",attn_weights.shape) # [2, 8, 8, 8]
    attn_output = torch.matmul(attn_weights, value_states) #  [2, 8, 8, 8] [2, 8, 8, 64]  
    attn_output = attn_output.transpose(1, 2).contiguous()  # [2, 8, 8, 64]

    return attn_output, attn_weights

RoPE位置编码实现

def default_rope_init(config, device=None):
    """默认的RoPE初始化函数"""
    dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size # 64

    inv_freq = 1.0 / (
        config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
    )   # 10000.0 ** (torch.arange(0, 64, 2) / 64) -> 32
    print("inv_freq:",inv_freq.shape)
    return inv_freq.to(device), 1.0  # inv_freq, attention_scaling

# 注册RoPE初始化函数
ROPE_INIT_FUNCTIONS = {
    "default": default_rope_init,
}

class Qwen3MoeRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

Qwen3Moe注意力机制实现

class Qwen3MoeAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx                   # 0
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)  # 64
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads  # 2 
        self.scaling = self.head_dim**-0.5   # 0.125
        self.attention_dropout = config.attention_dropout #  0.0
        self.is_causal = True 

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )  # 512 8*64  
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )  # 512 4*64
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )  # 512 4*64
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )  # 8*64 512
        self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  
        self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  
        self.sliding_window = getattr(config, "sliding_window", None)

    def forward(
        self,
        hidden_states: torch.Tensor,    # [2, 8, 512] 
        position_embeddings: tuple[torch.Tensor, torch.Tensor], # [2 8]
        attention_mask: Optional[torch.Tensor],   # [2, 1, 8, 8]
        past_key_values: Optional = None,
        cache_position: Optional[torch.LongTensor] = None, 
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        input_shape = hidden_states.shape[:-1]   # [2, 8]
        hidden_shape = (*input_shape, -1, self.head_dim) # [2, 8, -1 , 64] 

        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
                                # [2, 8, 512]  [2, 8, 512] -> [2, 8, 8, 64] -> [2, 8, 8, 64]

        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
                                # [2, 8, 512]  [2, 8, 256] -> [2, 8, 4 ,64] -> [2, 4, 8, 64]

        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                                # [2, 8, 512]  [2, 8, 256] -> [2, 8, 4 ,64] -> [2, 4, 8, 64]

        cos, sin = position_embeddings  # [2, 8, 64], [2, 8, 64]
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 
        print("query_states:",query_states.shape) # [2, 8, 4, 64]
        if past_key_values is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward  
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window, 
            **kwargs,
        )  # [2, 8, 8, 64]           [2,8, 8, 8]

        attn_output = attn_output.reshape(*input_shape, -1).contiguous() # [2, 8, 512]
        attn_output = self.o_proj(attn_output) # [2, 8, 512]
        return attn_output, attn_weights  # [2, 8, 512]   [2,8, 8, 8]
 

模拟配置和RMSNorm实现

@dataclass
class MockConfig:
    hidden_size: int = 512
    num_attention_heads: int = 8
    num_key_value_heads: int = 4
    head_dim: int = 64
    max_position_embeddings: int = 2048
    rope_theta: float = 10000.0
    rms_norm_eps: float = 1e-6
    attention_bias: bool = False
    attention_dropout: float = 0.0
    _attn_implementation: str = "eager"

class Qwen3MoeRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):

        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

主函数:测试代码

if __name__ == "__main__":
    # Configuration
    config = MockConfig()

    attention_layer = Qwen3MoeAttention(config, layer_idx=0)

    rotary_emb = Qwen3MoeRotaryEmbedding(config)

    batch_size = 2
    seq_length = 8
    hidden_size = config.hidden_size  # 512

    hidden_states = torch.randn(batch_size, seq_length, hidden_size) # [2, 8, 512] 

    position_ids = torch.arange(seq_length).unsqueeze(0).expand(batch_size, -1)  # [2, 8]

    cos, sin = rotary_emb(hidden_states, position_ids)
    print(f"Position embeddings:")
    print(f"  - cos shape: {cos.shape}")  # [2, 8, 64]
    print(f"  - sin shape: {sin.shape}")  # [2, 8, 64]

    attention_mask = torch.tril(torch.ones(batch_size, 1, seq_length, seq_length)) # [2, 1, 8, 8]
    attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float32).min # [2, 1, 8, 8]

    attention_output, attention_weights = attention_layer(
        hidden_states=hidden_states,
        position_embeddings=(cos, sin),
        attention_mask=attention_mask  # Now providing the required argument
    )
    
    print(f"\nAttention results:")
    print(f"  - Input shape: {hidden_states.shape}")      # [2, 8, 512]
    print(f"  - Output shape: {attention_output.shape}")   # [2, 8, 512]
    print(f"  - Attention weights shape: {attention_weights.shape}")  # [2, 8, 8, 8]
inv_freq: torch.Size([32])
Position embeddings:
  - cos shape: torch.Size([2, 8, 64])
  - sin shape: torch.Size([2, 8, 64])
query_states: torch.Size([2, 8, 8, 64])
causal_mask: torch.Size([2, 1, 8, 8])
attn_weights: torch.Size([2, 8, 8, 8])

Attention results:
  - Input shape: torch.Size([2, 8, 512])
  - Output shape: torch.Size([2, 8, 512])
  - Attention weights shape: torch.Size([2, 8, 8, 8])

网站公告

今日签到

点亮在社区的每一天
去签到