LLM笔记 - 简单认识Attention机制

发布于:2024-08-10 ⋅ 阅读:(76) ⋅ 点赞:(0)

谈及LLM,都在讨论它的多头自注意力MHA,那么,注意力机制究竟是什么,本文尝试着给出一个从开发者角度的科普解释。

Attention机制简介

Attention机制最初是在机器翻译任务中提出的,它的主要思想是让模型在预测下一个词时不仅关注当前输入,还能关注输入序列中的其他位置。Attention机制可以帮助模型更好地捕捉输入序列中不同部分之间的关系,提高模型的表现。

Key-Value Pair

在Attention机制中,通常会涉及到三个重要的概念:Query(查询)Key(键),和Value(值)。这些概念可以帮助模型计算不同位置之间的相关性。

  1. Query: 查询向量,表示当前需要计算Attention的词。
  2. Key: 键向量,表示输入序列中每个词的位置。
  3. Value: 值向量,表示输入序列中每个词的实际内容。

Attention机制的计算过程主要包括以下几个步骤:

  1. 计算Query和Key的点积(Dot Product),得到Attention Score。
  2. 对Attention Score进行归一化处理,通常使用Softmax函数。
  3. 使用归一化的Attention Score对Value进行加权求和,得到最终的Attention输出。

Attention公式

假设有一个Query向量( q ),Key向量( k ),和Value向量( v ),那么Attention的计算可以表示为:
Attention ( q , k , v ) = softmax ( q ⋅ k T d k ) ⋅ v \text{Attention}(q, k, v) = \text{softmax}\left(\frac{q \cdot k^T}{\sqrt{d_k}}\right) \cdot v Attention(q,k,v)=softmax(dk qkT)v

其中, d k d_k dk是Key向量的维度,用于缩放点积结果。

PyTorch中的简单实现

下面是一个使用PyTorch实现Attention机制的简单例子:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleAttention(nn.Module):
    def __init__(self, d_model):
        super(SimpleAttention, self).__init__()
        self.d_model = d_model
        
        # Linear layers for projecting Query, Key, and Value
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

    def forward(self, q, k, v):
        # Project the Query, Key, and Value
        q = self.query(q)
        k = self.key(k)
        v = self.value(v)
        
        # Calculate the dot products between Query and Key
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        
        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)
        
        # Weighted sum of the Value vectors
        attn_output = torch.matmul(attn_weights, v)
        
        return attn_output, attn_weights

# Example usage
d_model = 64
batch_size = 1
seq_len = 10

# Create dummy data for Query, Key, and Value
q = torch.rand((batch_size, seq_len, d_model))
k = torch.rand((batch_size, seq_len, d_model))
v = torch.rand((batch_size, seq_len, d_model))

# Initialize the attention mechanism
attention = SimpleAttention(d_model)

# Apply attention
attn_output, attn_weights = attention(q, k, v)

print("Attention output:", attn_output)
print("Attention weights:", attn_weights)

解释代码

  1. 定义SimpleAttention类:这是一个继承自nn.Module的自定义Attention模块。它包含三个线性层,用于对Query、Key和Value进行线性变换。
  2. forward方法:这是PyTorch模型的前向传播函数,计算Attention的输出。
  3. 创建示例数据:生成一些随机的Query、Key和Value向量。
  4. 应用Attention机制:初始化并应用自定义的Attention模块,计算Attention输出和权重。

通过这种方式,我们可以实现一个基本的Attention机制,并在实际任务中应用它。


网站公告

今日签到

点亮在社区的每一天
去签到