当我们翻译「他爱她,所以她也爱他」这句话时,大脑会自动把两个「他」「她」对应起来;读小说时看到「这个角色」,也能立刻想起前文说的是谁。这种理解上下文关联的能力,曾是 AI 的一大难题 —— 直到自注意力(Self-attention)机制的出现,才让机器真正学会了「读懂」语言。
为什么传统 AI 做不到?
在自注意力诞生前,主流的序列模型(如 RNN、LSTM)处理文本时像读小说一样「从左到右」:先看第一个词,再看第二个,逐步推进。这种方式有个致命缺陷:如果句子很长(比如一句话有 100 个词),前面的信息会慢慢「遗忘」。
比如处理「小明告诉小红,他明天要去北京」时,当模型读到「他」这个词,可能已经记不清前面的「小明」和「小红」谁是主语。而自注意力机制的革命性在于:它能同时「看到」句子中的所有词,并计算它们之间的关联。
自注意力如何工作?用三个步骤理解核心逻辑
想象你在分析「猫追狗,它跑得很快」这句话,自注意力的工作过程就像这样:
第一步:给每个词贴「标签」
模型会给每个词生成三个向量:
Query(查询):相当于这个词的「问题」——「我想找和我相关的词」
Key(键):相当于这个词的「身份标签」——「我是谁」
Value(值):相当于这个词的「具体信息」——「我有什么内容」
比如「它」的 Query 是「谁和我有关?」,「猫」和「狗」的 Key 分别是「我是猫」「我是狗」,而Value 则是这些词对应的具体语义信息,例如「猫」的 Value 可能是关于猫的属性(如 “有柔软的毛”“喜欢抓老鼠”)、特征(如 “灵活敏捷”)等相关内容,「狗」的 Value 则包含狗的对应信息(如 “忠诚护主”“嗅觉灵敏”)。通过 Query 与 Key 的匹配,模型能够从这些 Value 中提取出与当前查询最相关的信息,从而理解上下文含义。
第二步:计算「关联度」
在自注意力机制中,计算关联度(即注意力得分)时,主要使用 Query 和 Key 进行计算,此时 Value不参与。在计算得到注意力得分并经过归一化等操作后,才会使用注意力权重与 Value 进行加权求和。因此,在补充内容时,先明确指出计算关联度时 Value 不参与,再补充后续 Value 的参与时机,从而完善该机制的计算过程。
接下来,模型会用「它」的 Query 分别和「猫」「狗」的 Key 做「匹配」,就像用钥匙试锁:匹配度越高,说明两者关联越紧密。计算结果会得到一组「注意力得分」,比如「它」和「狗」的得分是 8 分,和「猫」的得分是 2 分。需要注意的是,在这个计算关联度(获取注意力得分)的过程中,Value 并不参与计算,仅由 Query 和 Key 完成匹配操作。
这里有个小细节:得分会除以一个系数(通常是向量维度的平方根),防止数值太大导致模型「算不清」。在完成上述计算后,会通过 softmax 函数对注意力得分进行归一化,得到注意力权重,这时才会使用这些权重与 Value 进行加权求和,最终得出自注意力机制的输出结果。
第三步:加权整合信息
最后用 softmax 函数把得分转换成百分比(比如「狗」占 80%,「猫」占 20%),再按这个比例把「狗」和「猫」的 Value 信息整合起来。最终「它」的语义就变成了「80% 的狗 + 20% 的猫」—— 这就是模型判断「它」更可能指「狗」的过程。
其本质就是让每个词都能「主动选择」该重点关注哪些词。在 PyTorch 中,可以通过以下代码实现这一机制:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_size):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.values = nn.Linear(embed_size, embed_size, bias=False)
self.keys = nn.Linear(embed_size, embed_size, bias=False)
self.queries = nn.Linear(embed_size, embed_size, bias=False)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(query)
energy = torch.einsum("nqd,nkd->nqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=2)
out = torch.einsum("nql,nld->nqd", [attention, values])
out = self.fc_out(out)
return out
上述代码定义了一个SelfAttention类,通过单头自注意力机制实现核心逻辑。在这个实现中,values、keys和query经过线性变换后,通过点积计算注意力分数,再经过 softmax 得到注意力权重,最终加权求和得到输出。mask参数用于处理可变长度的输入,避免模型关注到填充的无效数据。
多头注意力:让 AI 学会「多角度思考」
人类理解语言时,会从多个角度分析关联:比如「苹果」既可能指水果,也可能指公司。自注意力的进阶版「多头注意力」就模拟了这种能力。
它会把 Q、K、V 分成多组(比如 8 组),每组单独计算注意力。在 PyTorch 中,多头注意力机制的实现代码如下:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
attn_probs = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_probs, V)
return output
def split_heads(self, x):
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
def forward(self, Q, K, V, mask=None):
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
attn_output = attn_output.transpose(1, 2).contiguous().view(-1, attn_output.size(2) * attn_output.size(3))
output = self.W_o(attn_output)
return output
# 示例使用
d_model = 512 # 输入维度
num_heads = 8 # 头数
batch_size = 16
seq_length = 32
Q = torch.rand(batch_size, seq_length, d_model)
K = torch.rand(batch_size, seq_length, d_model)
V = torch.rand(batch_size, seq_length, d_model)
attention = MultiHeadAttention(d_model, num_heads)
output = attention(Q, K, V)
print(output.shape)
就像 8 个不同的「分析师」,有的关注语法关系,有的关注语义关联,最后把所有人的结论汇总。这种机制能让模型捕捉到更丰富的上下文信息 —— 这也是 BERT、GPT 等大模型能理解复杂语言的核心原因。
自注意力的「超能力」应用
如今自注意力早已跳出 NLP 领域,在多个场景大显身手:
机器翻译:在翻译长句时,能精准对应不同语言的主谓宾
图像识别:分析图片时,会关注「天空下的人」「手里的书」等关联
语音转文字:即使说话卡顿,也能根据前后文补全内容
简单说,自注意力让 AI 从「逐字阅读」升级成了「全局理解」,就像给机器装上了「上下文雷达」。当我们惊叹于 ChatGPT 能写出流畅的文章时,背后正是这种机制在默默计算每个词之间的关联 —— 让机器终于学会了像人类一样「读懂」世界。