大模型常用位置编码方式

发布于:2025-05-14 ⋅ 阅读:(16) ⋅ 点赞:(0)

深度学习中常见的位置编码方式及其Python实现:


一、固定位置编码(Sinusoidal Positional Encoding)
原理
通过不同频率的正弦和余弦函数生成位置编码,使模型能够捕捉绝对位置和相对位置信息。公式为:

公式标准数学表达
P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d model ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d model ) \begin{aligned} PE_{(pos,2i)} &= \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \\ PE_{(pos,2i+1)} &= \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \end{aligned} PE(pos,2i)PE(pos,2i+1)=sin(100002i/dmodelpos)=cos(100002i/dmodelpos)

公式解析

  1. 变量定义
    pos:token在序列中的绝对位置(从0开始计数)

    i:位置编码向量的维度索引(范围:0 ≤ i < d_model/2)

    d_model:模型嵌入维度(如Transformer默认的512)

  2. 核心设计
    • 交替使用正弦/余弦:偶数维度用正弦函数,奇数维度用余弦函数,形成周期性编码。

    • 频率衰减特性:维度越高(i增大),分母指数项 2 i d model \frac{2i}{d_{\text{model}}} dmodel2i越大,导致频率 1000 0 − 2 i / d model 10000^{-2i/d_{\text{model}}} 100002i/dmodel越小,编码的周期性波长越长。

    • 位置唯一性:每个位置pos的编码向量唯一,且相邻位置的编码差异与相对距离成比例。

  3. 数学特性
    • 相对位置捕捉:通过三角恒等式,任意两个位置的编码内积仅与相对距离pos_i - pos_j相关,隐含相对位置信息。

    • 外推能力:周期性设计使模型能处理超过训练时最大长度的序列。

关键参数作用

参数 作用 示例值(以BERT-base为例)
d_model 定义编码维度,影响模型容量 768
10000 控制频率衰减速度,值越大高频分量衰减越快 固定超参数
pos 序列位置索引 输入序列的第0/1/2…位

  • Python实现(方式一)
import numpy as np
import matplotlib.pyplot as plt

def sinusoidal_position_encoding(max_len, d_model):
    pe = np.zeros((max_len, d_model))
    position = np.arange(max_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return pe

# 示例
max_len, d_model = 50, 64
pe = sinusoidal_position_encoding(max_len, d_model)

# 可视化
plt.imshow(pe, cmap='viridis', aspect='auto')
plt.title("Sinusoidal Position Encoding")
plt.colorbar()
plt.show()

输出示例:生成一个形状为 (50, 64) 的编码矩阵,低频维度变化平缓,高频维度变化剧烈。

  • Pytorch实现
import torch
import torch.nn as nn

class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512):
        super().__init__()
        # 初始化位置编码矩阵为可训练参数 (1, max_len, d_model)
        self.pe = nn.Parameter(torch.empty(1, max_len, d_model))
        # 正态分布初始化(标准差0.02,与Transformer常规初始化一致)
        nn.init.normal_(self.pe, mean=0.0, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        输入x形状: [batch_size, seq_len, d_model]
        输出形状: [batch_size, seq_len, d_model]
        """
        seq_len = x.size(1)
        # 取前seq_len个位置编码(避免越界)
        position_emb = self.pe[:, :seq_len, :]
        # 将位置编码与输入相加
        return x + position_emb

二、可学习位置编码(Learnable Positional Encoding)
原理
将位置编码作为可训练参数,通过嵌入层动态学习每个位置的表示。

Python实现(PyTorch)

import torch
import torch.nn as nn

class LearnablePositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pe = nn.Embedding(max_len, d_model)
    
    def forward(self, x):
        batch_size, seq_len = x.size(0), x.size(1)
        positions = torch.arange(seq_len, device=x.device).expand(batch_size, seq_len)
        return x + self.pe(positions)

# 示例
d_model, max_len = 64, 50
inputs = torch.randn(32, max_len, d_model)  # 模拟输入 (batch_size=32, seq_len=50)
pe_layer = LearnablePositionalEncoding(max_len, d_model)
output = pe_layer(inputs)
print("Encoded shape:", output.shape)  # 输出:torch.Size([32, 50, 64])

优势:灵活性高,适合特定任务;缺点:依赖预定义的最大序列长度。


三、相对位置编码(Relative Positional Encoding)
原理
关注序列元素之间的相对位置差异,常用于长序列建模。

Python实现

class RelativePositionalEncoding(nn.Module):
    def __init__(self, max_rel_pos, d_model):
        super().__init__()
        self.emb = nn.Embedding(2 * max_rel_pos + 1, d_model)
    
    def forward(self, seq_len):
        # 生成相对位置索引矩阵(对称)
        rel_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
        rel_pos = torch.clamp(rel_pos + seq_len - 1, 0, 2 * seq_len - 2)
        return self.emb(rel_pos)

# 示例
d_model, max_rel_pos = 64, 10
rel_pe = RelativePositionalEncoding(max_rel_pos, d_model)
rel_enc = rel_pe(seq_len=5)
print("Relative encoding shape:", rel_enc.shape)  # 输出:torch.Size([5, 5, 64])

应用场景:Transformer-XL、音乐生成等长序列任务。


四、旋转位置编码(Rotary Positional Encoding, RoPE)
原理
通过旋转矩阵将绝对位置信息融入注意力计算,保持相对位置的线性性质。
旋转矩阵公式的标准数学表达式及解析:

标准数学公式
R θ , m = [ cos ⁡ ( m θ ) − sin ⁡ ( m θ ) sin ⁡ ( m θ ) cos ⁡ ( m θ ) ] , q ′ = R θ , m q , k ′ = R θ , n k R_{\theta,m} = \begin{bmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{bmatrix}, \quad q' = R_{\theta,m}q, \quad k' = R_{\theta,n}k Rθ,m=[cos(mθ)sin(mθ)sin(mθ)cos(mθ)],q=Rθ,mq,k=Rθ,nk


Python实现

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_emb(q, k, freq):
    cos, sin = freq.cos(), freq.sin()
    q_rot = q * cos + rotate_half(q) * sin
    k_rot = k * cos + rotate_half(k) * sin
    return q_rot, k_rot

# 示例
d_model, seq_len = 64, 50
q = torch.randn(1, seq_len, d_model)
k = torch.randn(1, seq_len, d_model)
freq = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000) / d_model))
q_rot, k_rot = apply_rotary_emb(q, k, freq)
print("Rotated shapes:", q_rot.shape, k_rot.shape)  # 输出:torch.Size([1, 50, 64])

优势:支持任意长度外推,广泛用于LLaMA、ChatGLM等大模型。


总结与选择建议

方法 适用场景 优点 缺点
固定位置编码 通用NLP任务 确定性,无需训练 无法自适应长序列
可学习位置编码 短序列任务 灵活性高 依赖预定义长度,泛化性差
相对位置编码 长文本生成、音乐建模 捕捉相对位置关系 计算复杂度较高
旋转位置编码 大语言模型(LLaMA等) 支持外推,数学性质优雅 实现较复杂

网站公告

今日签到

点亮在社区的每一天
去签到