《DEEPSEEK原生应用与智能体开发实践 图书》【摘要 书评 试读】- 京东图书
在自回归生成过程中,每一次推理步骤仅生成一个token,随后将这个新生成的token拼接到当前的输入序列末尾。紧接着,基于更新后的序列,模型进行下一次推理,如此循环往复,直至生成特定的结束标志(如eos,即end of sentence)或达到预设的最大生成长度。
这种逐步生成的方式使得自回归模型能够灵活地处理长文本生成任务。通过逐步构建序列,模型能够考虑之前生成的上下文信息,从而生成更加连贯和符合逻辑的文本。然而,随着生成序列的不断增长,计算量和内存消耗也会相应增加,这对模型的推理效率和性能提出了挑战。
8.1.1 自回归模型的计算量
自回归生成模型是一种生成式模型,它逐个生成序列中的元素(通常是token),每次生成都依赖于之前已经生成的元素。这种依赖关系使得模型能够捕捉序列中的上下文信息,从而生成连贯的文本。
在自回归生成模型的推理过程中,模型需要逐步生成序列中的每个token。假设我们有一个前缀序列,其长度为P,模型从这个前缀开始生成新的序列。随着推理的进行,生成的序列长度逐渐增加,假设当前生成的序列长度为L,前缀+推理示意如图8-1所示。
图8-1 前缀+推理
则此时模型在推断完成时候,总的计算量近似表示为:
在推理过程中,模型需要计算每个新生成的token的概率分布。这个计算量取决于当前已生成的序列长度(即前缀长度P加上已生成的长度L)。
8.1.2 自回归模型的缓存优化
在上一章我们完成了一个基于注意力的自回归模型的设计,整体模型结构如图8-2所示。
首先,我们的输入是一个序列。这个序列的长度是可变的,并且会加上前次推理生成的token(在图中以深色部分表示)。这些输入通过自回归模型的Embedding权重矩阵进行映射,这两个矩阵的作用是将input_ids映射到高维空间,从而得到hidden_state张量。这个张量包含了输入序列在高维空间中的表示。
图8-2 基于注意力的自回归模型架构
接着,hidden_state张量通过模型的线性变换模块注意力进行处理。这个模块的作用是将hidden_state的维度提升3倍,然后将其分割成查询(Query)、键(Key)和值(Value)三个部分。这三个部分在后续的注意力计算中起着关键作用。
随后,Q、K和V被进一步分割成多个head,这是多头注意力机制的一部分。每个head分别进行注意力计算,即计算Q和K的点积,然后除以K的平方根得到注意力权重,这些权重再与Value相乘得到加权和。多个head的结果拼接起来后,通过另一个线性变换模块进行处理,以恢复hidden_state的原始维度或进行其他变换。
在得到新的hidden_state后,我们进行残差连接,即将新计算的hidden_state与之前的hidden_state相加。这一步有助于缓解梯度消失或梯度爆炸的问题,提高模型的训练稳定性。
接着,残差连接后的hidden_state通过前馈层FFN模块进行处理。前馈层是多个线性变换和激活函数的组合,用于进一步提取特征。处理完后,我们再次进行残差连接,得到更新后的hidden_state。
最后,更新后的hidden_state通过lm_head模块生成logits,即预测token的概率分布。lm_head模块实际上是一个线性映射,将hidden_state的维度从d_model变换到vocab_size(词汇表大小)。这样,我们就可以根据logits得到下一个token的预测结果。
值得注意的是,图8-2中张量里面的深色条带一开始表示的是输入序列的最后一个token。随着前向计算的进行,它逐渐变成了下一个token的概率分布,也就是logits计算矩阵的最后一行。而logits前面的行在推理阶段通常是没有意义的,因为它们代表的是之前已经生成的token的概率分布。
因此,我们不禁思考是否可以只计算最后一行以省略其他行的计算量。通过分析每个模块对最后一行的依赖关系,我们发现lm_head、mlp、layer_norm以及前面的线性变换模块的输出都只与 hidden_state的最后一行相关。这意味着理论上我们可以只计算最后一行来减少计算量。然而,在实际实现中还需要考虑其他因素,如内存访问模式和并行计算效率等。
如图8-3所示,Attention计算过程中,Q与K计算结果只影响Attention_score(注意力得分)的最后一行,但与全部的值(V)相关。而score则与查询(Q)的最后一行相关,并与键(K)的全部行相关。由此可以得出,Attention机制的最后一行与查询(Q)的最后一行、完整的键(K)和值(V)相关。这一结论非常重要,因为它揭示了为什么我们选择使用KV Cache而不是QKV Cache。
我们继续探讨,图8-3展示了注意力核心计算在使用计算不同状态的比较。
图8-3 注意力的核心计算(图像颜色参看配套资源中的相关文件)
从上图可以看到,此时注意力中的输入只有上一次推理生成的token,而不是整个 prompt 序列。在进行注意力计算之前,需要拼接完整的键(K)和值(V),因此需要将这两个量缓存起来,并在每次推理时复用。