模拟配置类
class MockConfig:
def __init__(self):
self.max_position_embeddings = 2048
self.rope_theta = 10000.0
self.hidden_size = 512
self.num_attention_heads = 8
self.head_dim = self.hidden_size // self.num_attention_heads
self.rope_scaling = None
Qwen3MoeRotaryEmbedding模块
import torch
import torch.nn as nn
def default_rope_init(config, device=None):
"""默认的RoPE初始化函数"""
dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size
inv_freq = 1.0 / (
config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
)
print("inv_freq:",inv_freq.shape)
return inv_freq.to(device), 1.0
ROPE_INIT_FUNCTIONS = {
"default": default_rope_init,
}
class Qwen3MoeRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: MockConfig, device=None):
super().__init__()
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
print("cos:",cos)
print("sin:",sin)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
示例
config = MockConfig()
rope = Qwen3MoeRotaryEmbedding(config)
batch_size = 2
seq_length = 8
num_heads = config.num_attention_heads
head_dim = config.head_dim
q = torch.randn(batch_size, seq_length, num_heads, head_dim)
k = torch.randn(batch_size, seq_length, num_heads, head_dim)
position_ids = torch.arange(seq_length).unsqueeze(0).expand(batch_size, -1)
cos, sin = rope(q, position_ids)
print(f"\nRoPE输出:")
print(f" - cos: {cos.shape}")
print(f" - sin: {sin.shape}")
inv_freq: torch.Size([32])
cos: tensor([[[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[ 0.5403, 0.7318, 0.8460, ..., 1.0000, 1.0000, 1.0000],
[-0.4161, 0.0709, 0.4315, ..., 1.0000, 1.0000, 1.0000],
...,
[ 0.2837, -0.8209, -0.9461, ..., 1.0000, 1.0000, 1.0000],
[ 0.9602, -0.2114, -0.9731, ..., 1.0000, 1.0000, 1.0000],
[ 0.7539, 0.5114, -0.7004, ..., 1.0000, 1.0000, 1.0000]],
[[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[ 0.5403, 0.7318, 0.8460, ..., 1.0000, 1.0000, 1.0000],
[-0.4161, 0.0709, 0.4315, ..., 1.0000, 1.0000, 1.0000],
...,
[ 0.2837, -0.8209, -0.9461, ..., 1.0000, 1.0000, 1.0000],
[ 0.9602, -0.2114, -0.9731, ..., 1.0000, 1.0000, 1.0000],
[ 0.7539, 0.5114, -0.7004, ..., 1.0000, 1.0000, 1.0000]]])
sin: tensor([[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 8.4147e-01, 6.8156e-01, 5.3317e-01, ..., 2.3714e-04,
1.7783e-04, 1.3335e-04],
[ 9.0930e-01, 9.9748e-01, 9.0213e-01, ..., 4.7427e-04,
3.5566e-04, 2.6670e-04],
...,
[-9.5892e-01, -5.7113e-01, 3.2394e-01, ..., 1.1857e-03,
8.8914e-04, 6.6676e-04],
[-2.7942e-01, -9.7740e-01, -2.3037e-01, ..., 1.4228e-03,
1.0670e-03, 8.0011e-04],
[ 6.5699e-01, -8.5931e-01, -7.1372e-01, ..., 1.6600e-03,
1.2448e-03, 9.3346e-04]],
[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 8.4147e-01, 6.8156e-01, 5.3317e-01, ..., 2.3714e-04,
1.7783e-04, 1.3335e-04],
[ 9.0930e-01, 9.9748e-01, 9.0213e-01, ..., 4.7427e-04,
3.5566e-04, 2.6670e-04],
...,
[-9.5892e-01, -5.7113e-01, 3.2394e-01, ..., 1.1857e-03,
8.8914e-04, 6.6676e-04],
[-2.7942e-01, -9.7740e-01, -2.3037e-01, ..., 1.4228e-03,
1.0670e-03, 8.0011e-04],
[ 6.5699e-01, -8.5931e-01, -7.1372e-01, ..., 1.6600e-03,
1.2448e-03, 9.3346e-04]]])
RoPE输出:
- cos: torch.Size([2, 8, 64])
- sin: torch.Size([2, 8, 64])
应用RoPE到查询和键
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
q_rotated, k_rotated = apply_rotary_pos_emb(q, k, cos, sin)
print(f"\n应用RoPE后:")
print(f" - 旋转后的查询 (q_rotated): {q_rotated.shape}")
print(f" - 旋转后的键 (k_rotated): {k_rotated.shape}")
print(f"\n=== RoPE性质验证 ===")
assert q_rotated.shape == q.shape, "查询张量形状不一致"
assert k_rotated.shape == k.shape, "键张量形状不一致"
print("✓ 查询和键张量形状保持一致")
print(f"\n=== 不同位置的RoPE值示例 ===")
print("位置0的cos值前5维:", cos[0, 0, :5].tolist())
print("位置0的sin值前5维:", sin[0, 0, :5].tolist())
print("位置3的cos值前5维:", cos[0, 3, :5].tolist())
print("位置3的sin值前5维:", sin[0, 3, :5].tolist())
print(f"\n=== 正交性验证 ===")
original_inner_prod = torch.sum(q[0, 0, 0, :] * q[0, 1, 0, :])
rotated_inner_prod = torch.sum(q_rotated[0, 0, 0, :] * q_rotated[0, 1, 0, :])
print(f"位置0和1的原始内积: {original_inner_prod:.6f}")
print(f"位置0和1的旋转后内积: {rotated_inner_prod:.6f}")
print(f"差异: {abs(original_inner_prod - rotated_inner_prod):.6f}")
print(f"\n=== 示例完成 ===")
print("RoPE模块成功处理了查询和键张量,保持了它们的形状并应用了旋转位置编码")
应用RoPE后:
- 旋转后的查询 (q_rotated): torch.Size([2, 8, 8, 64])
- 旋转后的键 (k_rotated): torch.Size([2, 8, 8, 64])
=== RoPE性质验证 ===
✓ 查询和键张量形状保持一致
=== 不同位置的RoPE值示例 ===
位置0的cos值前5维: [1.0, 1.0, 1.0, 1.0, 1.0]
位置0的sin值前5维: [0.0, 0.0, 0.0, 0.0, 0.0]
位置3的cos值前5维: [-0.9899924993515015, -0.6279267072677612, -0.11596616357564926, 0.3009673058986664, 0.5827536582946777]
位置3的sin值前5维: [0.14112000167369843, 0.7782725095748901, 0.9932531714439392, 0.9536344408988953, 0.8126488924026489]
=== 正交性验证 ===
位置0和1的原始内积: -1.464770
位置0和1的旋转后内积: -1.464770
差异: 0.000000
=== 示例完成 ===
RoPE模块成功处理了查询和键张量,保持了它们的形状并应用了旋转位置编码