【大模型论文阅读】2503.01821_On the Power of Context-Enhanced Learning in LLMs

发布于:2025-07-26 ⋅ 阅读:(20) ⋅ 点赞:(0)

我们提出了一种适用于大型语言模型的新概念——上下文增强学习。它在基于梯度的文本学习基础上,通过在上下文中添加额外数据(不对这些数据计算自回归梯度)来增强效果。

这一设定是常规上下文学习(ICL)的基于梯度版本,在近期的一些研究中已有体现。 借助一项多步推理任务,我们在简化场景中证明:当模型具备上下文学习能力时,上下文增强学习的样本效率可能比常规学习高出指数级。从机制层面来看,我们发现上下文增强带来的优势源于更准确的梯度学习信号。我们还通过实验表明,很难检测或恢复训练过程中用于上下文的学习材料。这一点可能对数据安全及版权问题具有重要意义。

1. Introduction

预训练大型语言模型(LLMs)(Brown 等人,2020;Touvron 等人,2023;Team 等人,2023)展现出在推理时学习新内容的强大能力,例如通过上下文学习(ICL)。此外,有新证据表明,若在上下文中加入额外的辅助文本(即便不对这些辅助文本计算自回归损失),基于梯度的文本学习(例如数学问答学习)效果可得到提升(Liao 等人,2024;Zou 等人,2024;Choi 等人,2025)。研究还显示,此类策略对预训练亦有助益——在文档前添加源URL可提高模型的训练效率和记忆能力(Allen-Zhu 与 Li,2024;Gao 等人,2025)。

译者注:预训练的作用存疑

在本文中,我们旨在对这一现象进行正式研究:大型语言模型的基于梯度的学习可通过在上下文中植入额外的辅助材料得到增强,而无需对这些材料进行实际的自回归梯度更新。我们将这种学习形式称为上下文增强学习。由于用于上下文增强的材料可在训练过程中不断演变,因此这种方法自然契合课程学习的理念。

上下文增强学习直观地反映了人类的学习方式:在解决问题时,人们会参考教科书或演示示例以获取指导,但本身并不会刻意去记忆这些资源。一个类似的概念——“利用特权信息学习”(LUPI),已在核支持向量机(SVMs)(Vapnik & Vashist, 2009)和分类模型的研究中得到充分探讨。我们的研究将这一概念适配于大型语言模型,并提出了以下问题:

  1. 问题1(Q1):尽管自回归损失是基于同一组标记计算的,但上下文增强学习是否能显著优于没有额外上下文材料的常规自回归学习?如果能,我们能否从理论上描述并理解这种改进背后的机制?
  2. 问题2(Q2):模型是否需要达到特定的能力水平才能从上下文增强学习中获益?这是一个很自然的问题,因为利用上下文信息(例如上下文学习)可能需要模型具备最低限度的能力水平或模型规模(Brown 等人,2020;Wei 等人,2022)。
  3. 问题3(Q3):上下文增强学习是否是一种在学习过程中使用特权/私有信息的可行方式?在上下文中提供此类特权信息理论上可以增强模型的学习效果,而由于无需对这些特权/私有信息进行自回归梯度更新,通过API调用导致此类信息泄露的风险可能会更低。

论文概述:第2.1节正式定义了上下文增强学习。为了严谨地理解上下文增强学习的效力,第2.2节引入了一项名为“多层翻译”的多步推理任务。这是一个包含d+1d + 1d+1种语言(L1,L2,…,Ld+1L_1, L_2, \ldots, L_{d+1}L1,L2,,Ld+1)的合成场景,这些语言均基于有限字母表构建。对于每个iii,存在一个简单的短语集,用于描述如何从LiL_iLi翻译到Li+1L_{i+1}Li+1,而从L1L_1L1Ld+1L_{d+1}Ld+1的映射则是这组短语集的依次应用。

目标是学习如何将文本从L1L_1L1翻译为Ld+1L_{d+1}Ld+1,且无需明确写出中间步骤。在训练过程中,会向学习者提供这些短语集的摘录作为上下文中的辅助信息,但不会对这些标记进行自回归梯度更新。

如果我们在训练时,以短语集摘录和输入为条件,对翻译输出计算自回归损失,那么具备一定上下文学习(ICL)能力的模型或许能通过利用上下文中的短语集快速掌握翻译任务。然而,这种学习方式可能存在脆弱性,即模型会依赖上下文中的短语集摘录。通过对上下文中的短语集标记采用概率性丢弃(dropout)策略,可逐步减弱这种依赖性。直观而言,这种训练课程会促使模型不仅能读取短语集摘录,还能逐渐内化短语集的内容。久而久之,模型在从L1L_1L1Ld+1L_{d+1}Ld+1的翻译任务中,对短语集摘录的丢弃会变得更具鲁棒性,最终即便完全移除短语集摘录,也能完成翻译。

实验表明,当学习者是具备上下文学习(ICL)能力的预训练大型语言模型时,这种训练策略确实有效(但当大型语言模型不具备上下文学习能力时,该策略则会失效)。即便在训练中采用20%的丢弃率,模型在测试时无需任何短语集摘录,也能完美地将字符串从L1L_1L1翻译为Ld+1L_{d+1}Ld+1。本文其余部分的结构如下:

  • 第3节详细介绍了我们的实验以及上述概述的研究发现。实验表明,具备上下文学习(ICL)能力的模型会对上下文中提供的短语集进行直观的顺序处理,其中Transformer层以一种符合直觉的方式对应翻译的各个阶段;例如,L3→L4L_3 \to L_4L3L4 的翻译是在 L2→L3L_2 \to L_3L2L3 之后进行的(见第3.3节)。
  • 第4节表明,经过上下文增强学习后,模型的输出概率几乎不会泄露训练过程中所见过的短语集规则信息。
  • 在第5节中,我们提出了一个理论框架,该框架采用了一个替代/简化模型,该模型代表了适用于翻译任务的理想大型语言模型(见第5.1节)。这一框架表明,模型在训练时是否使用上下文中的短语集信息,会导致样本复杂度出现指数级差距(见第5.2节和第5.3节)。实验发现,上下文增强学习样本效率提升的背后机制是梯度信号的改善,这可通过梯度预测准确性来衡量(见第5.4节)。

2. Setup

2.1. Context-Enhanced Learning

XXX为所有可能文本字符串的集合,YYY为所有可能的文本分布集合。设ggg为一项语言任务,它将输入x∈Xg⊂Xx \in X_g \subset XxXgX映射到分布Y∈YY \in YYY。设fθ:X→Yf_\theta: X \to Yfθ:XY为一个通用自回归语言模型。我们对fθf_\thetafθ在任务ggg上的能力描述如下:

定义2.1(g-能力模型,非正式表述)。若语言模型fθf_\thetafθ与任务gggXgX_gXg上通过适当的度量标准衡量时足够接近,则称该语言模型fθf_\thetafθ具备完成语言任务ggg的能力(即ggg-能力模型)。

标准监督微调(SFT)旨在通过最小化监督数据集Dg={(xi,yi)}i=1ND_g = \{(x_i, y_i)\}_{i=1}^NDg={(xi,yi)}i=1N上的自回归损失ℓauto\ell_{\text{auto}}auto来构建一个具备ggg能力的模型,其中每个xi∈Xgx_i \in X_gxiXg对应的标签yiy_iyi均从g(xi)g(x_i)g(xi)中采样得到。

上下文增强学习包括通过额外的课程文本对监督过程进行强化,这些课程文本取决于任务ggg、输入xxx和训练步骤ttt。我们将这类课程文本记为CURRg(x,t)CURR_g(x, t)CURRg(x,t),其内容可以是任何形式(如有用的解释、教科书摘录、详细的示例等)。

在这里插入图片描述

在从监督数据集中抽取的样本(x, y)上,我们使用自回归损失来训练模型,该损失基于模型在以[CURR9(x,t),x][CURR₉(x, t), x][CURR9(x,t),x]为条件的情况下对y的预测。需要注意的是,不会对课程文本标记计算损失。我们将这种损失记为ℓauto(fρ([CURR9(x,t),x,y]),y)ℓₐᵤₜₒ(fᵨ([CURR₉(x, t), x, y]), y)autofρ([CURR9(x,t),x,y]),y

2.2. Multi-level Translation (MLT)

为了研究上下文增强学习的效力,我们引入了一项多步翻译任务。借助上下文中清晰明了的课程文本,这项任务很容易学习;但仅依靠输入-输出示例,则很难掌握。

这项任务的灵感来源于加密方法¹,例如费斯妥密码(Knudsen,1993)。多层翻译(MLT)任务涉及一种字符串到字符串的双射映射,该映射由2d个更简单的双射构成,每个简单双射要么是简单的移位1位操作,要么是通过双射对双字母组合(由两个字符组成的元组)进行转换。深度为d的翻译可由O(d)位来描述。但我们将证明,在统计查询学习(SQ-learning)框架下(Kearns,1998),仅通过输入-输出对来学习这项任务,其样本复杂度需要达到e^Ω(d)(参见定理5.4)。

在这里插入图片描述

具体而言,设A1,…,Ad+1A_1, \ldots, A_{d+1}A1,,Ad+1d+1d + 1d+1个字母表,每个字母表均包含nnn个字符且大小相同。对于每对连续的字母表AiA_iAiAi+1A_{i+1}Ai+1,我们设定一个短语集πi:Ai2→Ai+12\pi_i: A_i^2 \to A_{i+1}^2πi:Ai2Ai+12,这是一个从AiA_iAi中的双字符元组到Ai+1A_{i+1}Ai+1中的双字符元组的双射映射。每个短语集πi\pi_iπi可由一个二元随机矩阵Matrix(πi)\text{Matrix}(\pi_i)Matrix(πi)表示,其中规则以独热列的形式呈现(参见定义G.4)。

在这里插入图片描述

翻译过程的输入是一个偶数长度的序列,称为s1∈A1Ls_1 \in A_1^Ls1A1L,其中LLL为序列长度。翻译过程对s1s_1s1进行递归修改。对于每个i∈[d]i \in [d]i[d],序列si∈AiLs_i \in A_i^LsiAiL将通过短语集πi\pi_iπi,经由以下两个子过程转换为si+1∈Ai+1Ls_{i+1} \in A_{i+1}^Lsi+1Ai+1L

  • 循环移位:将序列si∈AiLs_i \in A_i^LsiAiL中的字符向左循环移动1位(必要时从末尾折回),得到序列s~i∈AiL\tilde{s}_i \in A_i^Ls~iAiL。形式化地说,对于每个j∈[1,L]j \in [1, L]j[1,L],有s~i,j=si,(j+1)%L\tilde{s}_{i,j} = s_{i,(j+1)\%L}s~i,j=si,(j+1)%L
  • 翻译:利用短语集πi:Ai2→Ai+12\pi_i: A_i^2 \to A_{i+1}^2πi:Ai2Ai+12,对序列s~i\tilde{s}_is~i中连续字符的双元组(双字母组合)进行翻译,生成si+1s_{i+1}si+1。即,对于每个奇数j∈[1,L]j \in [1, L]j[1,L],有(si+1,j,si+1,j+1)=πi(s~i,j,s~i,j+1)(s_{i+1,j}, s_{i+1,j+1}) = \pi_i(\tilde{s}_{i,j}, \tilde{s}_{i,j+1})(si+1,j,si+1,j+1)=πi(s~i,j,s~i,j+1)

我们将从sis_isisi+1s_{i+1}si+1的映射记为si+1=Tπi(si)s_{i+1} = T_{\pi_i}(s_i)si+1=Tπi(si)。d步翻译被定义为复合映射sd+1=Tπd∘Tπd−1∘⋯∘Tπ1(s1)s_{d+1} = T_{\pi_d} \circ T_{\pi_{d-1}} \circ \cdots \circ T_{\pi_1}(s_1)sd+1=TπdTπd1Tπ1(s1),该映射通过d个翻译步骤将输入序列s1s_1s1转换为sd+1s_{d+1}sd+1。关于d=2d=2d=2n=8n=8n=8的直观图示,请参见图1。

我们将所有层级的短语集集合记为Π={πi}i=1d\Pi=\{\pi_i\}_{i=1}^dΠ={πi}i=1dΠ\PiΠ现从输入s1s_1s1到输出sd+1s_{d+1}sd+1的映射记为MLTΠ:s1↦MLTΠ(s1):=Tπd∘⋯∘Tπ1(s1)\text{MLT}_\Pi:s_1\mapsto\text{MLT}_\Pi(s_1):=T_{\pi_d}\circ\cdots\circ T_{\pi_1}(s_1)MLTΠ:s1MLTΠ(s1):=TπdTπ1(s1)。我们用MLT(d,n)\text{MLT}(d,n)MLT(d,n)表示包含ddd个步骤且每个字母表有nnn个字符的翻译任务族。
需要注意的是,MLT(d,n)\text{MLT}(d,n)MLT(d,n)具有两个关键特性:1. 一旦短语集固定,由于循环移位和翻译操作都是可逆的(参见引理E.1),该翻译任务就在输入字符串和输出字符串之间定义了一种双射关系。2. 输出字符串中的每个字符都依赖于输入文本字符串中的2d2d2d个字符(见图1的说明),这使得仅通过输入-输出对来学习该任务变得非常困难(定理5.4)。
对于每个短语集πi\pi_iπi,我们将其文本表示STR(πi)\text{STR}(\pi_i)STR(πi)构造为“……a b -> C D;e d -> B A;……”的形式,其中列出了前一个字母表与后一个字母表中双字符元组之间的短语集规则(规则顺序无关)。此外,我们将拼接后的字符串[STR(π1),…,STR(πd)][\text{STR}(\pi_1),\ldots,\text{STR}(\pi_d)][STR(π1),,STR(πd)]记为STR(Π)\text{STR}(\Pi)STR(Π),它将用于定义课程文本。

2.3. Needed: Curriculum without Explicit CoT

为了通过(s1s_1s1MLTΠ(s1)\text{MLT}_\Pi(s_1)MLTΠ(s1))形式的输入-输出对,教模型掌握特定的翻译任务MLTΠ\text{MLT}_\PiMLTΠ,我们可以将短语集的相关部分STR(Π)\text{STR}(\Pi)STR(Π)作为课程文本融入上下文来训练模型。但在测试时,模型无法获取短语集,因此重要的是不要教它包含上下文信息的显式思维链(CoT)。(另一个需要考虑的因素是数据隐私,短语集被视为敏感信息。)然而,思维链的一个双重作用是在推理时为模型提供额外的计算支持(Goyal等人,2023),这在此处是必要的,因为该翻译任务包含ddd个阶段。为了促进这种隐性计算,我们教模型输出固定数量的 THINK 标记,这有时被称为隐性思维链或内化思维链。

3. Experiments and Observations

在本节中,我们固定一组短语集Π∗\Pi^*Π,并研究在MLTΠ∗\text{MLT}_{\Pi^*}MLTΠ任务上的上下文增强学习。
我们首先介绍具备MLT(d,n)\text{MLT}(d, n)MLT(d,n)上下文学习(ICL)能力模型的准备工作。接着,我们将介绍一种上下文增强学习课程,该课程涉及在上下文中随机丢弃短语集规则。然后,我们将提供实证证据,以证明上下文增强学习具有显著的样本效率。最后,我们将结合关于内部表征和参数演化的机制洞察,对本节内容进行总结。
我们使用Llama 3.2-3B指令微调模型(Dubey等人,2024)作为基础模型,并固定d=5d=5d=5n=8n=8n=8n=10n=10n=10。详细配置可参见附录B.5。

3.1. Experimental Setup

(i) Preparing an MLT-ICL-Capable Model:

Llama 3.2B模型具备多层翻译(MLT)任务的上下文学习(ICL)能力,因为它在训练过程中并未接触过该任务。为了让其具备我们所需的上下文学习能力,我们遵循常见的思维链(CoT)内化流程(Deng等人,2024;Pfau等人,2024;Hao等人,2024),在其他包含随机短语集的随机翻译任务上对其进行监督微调(SFT)。对于每组短语集,我们仅使用一个训练示例,以防止模型对特定短语集产生记忆。训练结束后,给定任何输入和短语集Π,该模型都能正确生成结果。关于第一阶段训练的详细信息,请参见附录B.4。

(ii) Setting up context-enhanced learning for :

我们将上述具备多层翻译(MLT(d, n))上下文学习(ICL)能力的模型作为初始化模型,并针对MLTₗ₌₊任务进行训练。监督数据集Dₗ₌₊由输入-标签对组成,其形式为(s1,[<THINK>,……,MLTl=+(s1)])(s₁,[<THINK>,……,MLTₗ₌₊(s₁)])s1[<THINK>……MLTl=+(s1)],其中s₁是从A₁中采样的随机字符串,长度在20到40之间。我们利用短语集STR(Π∗)的摘录(基于s₁选取)来定义课程文本CURRₗ₌₊(s₁, t),其中会对规则进行随机丢弃(由训练步骤t参数化)。我们对以下课程方案进行探究,并研究它们的影响。

  • No Context (vanilla SFT)::课程文本为空。
  • Fixed Dropout:一种与步骤t无关的简单策略;给定s₁,仅选取Π∗中用于s₁翻译的规则,然后随机丢弃20%的所选规则。
  • Annealing Dropout:一种更优的策略:对于s₁,从Π∗中选取必要的规则以及25%的未使用规则。对这些规则进行随机丢弃,在训练的前60%阶段,丢弃比例从0%线性增加到100%,之后保持100%。
  • No Dropout (ablation):给定s₁,在课程文本中始终提供Π∗中所有用于s₁翻译的规则。
  • Wrong Context (ablation):与退火丢弃策略类似,但课程文本中的规则是错误的。

3.2. Experiment Results

为了验证上下文增强学习在样本效率方面的优势(问题1),我们构建了包含特定数量独特样本的监督数据集DΠ∗D_{\Pi^*}DΠ,并在每个数据集上对模型进行一个轮次的训练。² 我们报告了在不依赖课程文本(100%丢弃)的情况下,模型对留存样本的最终答案标记(忽略思维标记)的下一个标记预测准确率,并将其与监督数据集的规模进行比较。为了验证适当的上下文学习(ICL)能力的必要性(问题2),我们进行了消融实验:采用退火丢弃策略,但从不具备多层翻译(MLT)上下文学习能力的30亿参数基础模型开始训练。

在这里插入图片描述

图2展示了上下文增强学习在样本效率方面的显著优势。此外,使用短语集子集且仅采用20%丢弃率训练的模型,在测试时即使课程文本被100%丢弃,仍能在留存样本上达到完美的准确率。这表明,模型能够有效利用那些在同一训练样本中未同时出现的短语集子集所包含的短语集规则。显然,模型已对短语集形成了原子化的学习,并且能够在测试时根据需要对这些规则进行组合运用。

在一项消融实验中(附录C),我们表明,上下文增强学习只会内化那些从课程文本中被丢弃后会导致训练数据损失增加的规则

实验结果可总结如下:
(i)基于具备上下文学习(ICL)能力的模型进行上下文增强学习,能显著提升训练样本效率。
(ii)短语集规则被以原子化的方式内化,且仅当缺失这些规则会导致损失增加时才会被内化。

在这里插入图片描述

图3. 具备MLT(5, 8)上下文学习(ICL)能力的模型中存在序列处理的证据。每个条目表示在特定层、特定标记位置处,扰动前后的潜在表征之间的范数。扰动上下文中较靠后的短语集会改变较深层的表征,这表明较靠后的翻译步骤发生在较深的网络层中。为了排除“受影响的深度仅依赖于扰动位置而非语义内容”这一可能性,我们进行了一组相同的实验,只是仅扰动翻译过程中未使用的规则,结果显示表征差异可忽略不计(参见图8中的对比)。
图4. 以图3中评估的具备MLT(5, 8)上下文学习(ICL)能力的模型为起点,我们固定一组短语集Π∗\Pi^*Π,采用退火丢弃课程方案在10万个样本上进行训练,得到一个具备MLTΠ∗\text{MLT}_{\Pi^*}MLTΠ能力的模型M∗M^*M。我们通过有选择地将M∗M^*M中从第aaa层到第bbb层的网络层替换为基础模型M0M_0M0的对应层,构建出“拼接模型”,并在上下文中标注了特定丢弃短语集的情况下对其进行评估。靠近对角线的亮色区域表明,补偿上下文中被丢弃短语集所需的知识可定位到M∗M^*M中一小部分网络层。特别值得注意的是,对于任意层级iii,在上下文增强学习之前,负责存储πi\pi_iπi的网络层组的结尾,与负责读取πi\pi_iπi的网络层的起始位置相匹配(参见图3)。


网站公告

今日签到

点亮在社区的每一天
去签到