注意力机制的使用说明01

发布于:2025-07-26 ⋅ 阅读:(10) ⋅ 点赞:(0)

多头注意力机制(MHA)使用精要

核心作用: 捕捉序列数据的全局依赖关系,让每个时间点都能关注到所有其他时间点。

关键参数 (__init__)
  1. embed_dim: 特征维度 (C)。必须与输入到MHA层的数据的特征维度完全一致。

  2. num_heads: 头的数量embed_dim 必须能被 num_heads 整除。

  3. batch_first=True: 务必设为 True。这规定了MHA期望的输入格式为 (N, L, C)

实现蓝图 (forward pass)

在卷积网络(输入为 (N, C, L))中使用MHA,遵循以下三步即可:

  1. 格式转换 (Permute In):

    • x = x.permute(0, 2, 1)

    • 目的:将 (N, C, L) 转换为MHA期望的 (N, L, C)

  2. 应用注意力块 (Attention Block):

    • attn_out, _ = self.mha(x, x, x)

    • x = self.norm(x + attn_out)

    • 目的:执行自注意力计算,并用残差连接和层归一化稳定训练。

  3. 格式恢复 (Permute Back):

    • x = x.permute(0, 2, 1)

    • 目的:将 (N, L, C) 转换回 (N, C, L),以适配后续的卷积层。

黄金法则: MHA的 embed_dim 参数值,必须等于你的数据在进入MHA模块时的特征维度(通道数C),而不是最原始信号的维度。

import torch
import torch.nn as nn

class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(AttentionBlock, self).__init__()
        # 确保 embed_dim 能被 num_heads 整除
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) 必须能被 num_heads ({num_heads}) 整除。")

        self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x 的输入格式应为 (N, C, L),这是CNN的典型输出格式
        N, C, L = x.shape

        # --- 配方第1步: 格式准备 ---
        # (N, C, L) -> (N, L, C)
        x_permuted = x.permute(0, 2, 1)

        # --- 配方第2步: 自注意力计算 ---
        attn_output, _ = self.mha(x_permuted, x_permuted, x_permuted)

        # --- 配方第3步: 稳定与融合 ---
        # 残差连接 + 层归一化
        x_stabilized = self.norm(x_permuted + attn_output)

        # --- 配方第4步: 格式恢复 ---
        # (N, L, C) -> (N, C, L)
        final_output = x_stabilized.permute(0, 2, 1)

        return final_output

# --- 使用示例 ---
# 假设我们有一个来自CNN的输出
cnn_output = torch.randn(32, 64, 1024) # (N, C, L)

# 创建并使用注意力块
attention_block = AttentionBlock(embed_dim=64, num_heads=8)
processed_output = attention_block(cnn_output)

print(f"输入形状: {cnn_output.shape}")
print(f"输出形状: {processed_output.shape}") # 输出形状应与输入完全相同


网站公告

今日签到

点亮在社区的每一天
去签到