【论文笔记】OctoThinker:突破 Llama 推理瓶颈的中期训练范式

发布于:2025-07-06 ⋅ 阅读:(13) ⋅ 点赞:(0)


原文链接:《OctoThinker: Mid-training Incentivizes Reinforcement Learning Scaling》

在这里插入图片描述

原文摘要:不同的基础语言模型家族,例如 Llama 和 Qwen,在使用强化学习 (RL) 进行后训练时,尤其是在推理密集型任务上,表现出不同的行为特性。是什么使得基础语言模型适合强化学习? 更深入地了解这个问题对于开发下一代可扩展 RL 的基础模型至关重要。本研究中,我们探讨了中期训练策略如何塑造 RL 动态,聚焦于两个代表性模型家族:Qwen 与 Llama。研究发现:
(1)高质量的数学语料库,例如 MegaMath-Web-Pro,显著提升了基础模型及 RL 性能,而现有替代品(例如 FineMath-4plus)则未能达到同等效果;
(2)进一步加入问答风格数据,特别是长链式思维 (CoT) 推理示例,能增强 RL 效果,且指令数据进一步释放了这一潜力;
(3)虽然长 CoT 提升了推理深度,但也可能导致模型回答冗长及 RL 训练不稳定,这凸显了数据格式化的重要性
(4)中期训练的规模扩展持续带来更强的下游 RL 性能
基于这些洞见,我们提出了一种两阶段中期训练策略——“稳定后衰减 (Stable-then-Decay)”,即基础模型先以恒定学习率训练 2000 亿 token,随后在三个 CoT 重点分支上以学习率衰减方式训练 200 亿 token。由此诞生了 OctoThinker 模型家族,展现了优异的 RL 兼容性,并缩小了与更 RL 友好模型家族(如 Qwen)的性能差距。我们期望本工作能为 RL 时代的基础模型预训练策略提供指导。为支持进一步研究,我们开源了模型及一个精选的超过 700 亿 token 的数学推理密集型语料库(即 MegaMath-Web-Pro-Max)。

一、核心要点

1.1 核心问题:基础模型的RL适应性差异

大型语言模型(LLMs)通过思维链(CoT) 结合强化学习(RL) 在复杂推理任务(如数学竞赛题)上取得突破,但不同基础模型对 RL 训练的适应性差异显著。

  • Qwen 系列模型:对 RL 扩展适应性强,性能提升显著。
  • Llama 系列模型:在 RL 训练中易出现过早预测答案输出重复问题,难以复现 R1-Zero 的成功。

核心科学问题

  • 基础模型的哪些特性导致其对 RL 训练敏感性的差异?
  • Mid-training 能否作为可控干预手段,弥合不同基座在 RL 中的表现鸿沟?

1.2 关键发现:语料质量的决定性作用

  • 数学预训练语料质量是决定 RL 性能的关键因素:
    • 高质量数学语料(如 MegaMath-Web-Pro)显著优于普通语料(如 FineMath-4plus)。
  • 数据混合策略影响 RL 效果:
    • 加入 QA 风格数据和少量指令跟随数据可进一步提升性能。
    • 长思维链(Long CoT)数据会导致 RL 训练不稳定。(解决方案:优化 RL 提示(prompt)设计;采用渐进式最大响应长度调度器稳定训练。
  • 训练规模的重要性
    • 扩大中期训练规模可提升下游 RL 性能。
    • 但这一提升无法通过标准评估指标直接反映在中期训练基础模型,突显预训练与 RL 能力评估的鸿沟。

1.3 研究目标:构建RL友好的Llama基础模型

探索能否通过大规模、以推理为中心的中期训练,将 Llama 转化为适应 RL 扩展的基础模型,弥合其与 Qwen 等模型的差距。

1.4 研究框架图示:从问题到解决方案

中期阶段I:基座构建
中期阶段II:分支精炼
强化学习阶段
基础模型差异
Qwen系列
Llama系列
适应RL训练 → 性能提升
早答/重复问题 → RL失败
关键因素: 预训练语料质量 + 数据混合策略 + 训练规模
解决方案
构建高质量数学语料
MegaMath-Web-Pro-Max
优化数据混合
(QA + 指令数据)
设计稳定RL训练策略
扩大中期训练规模
(→100B tokens)
目标: 将Llama转化为RL友好的基础模型
OctoThinker-Base-Stable
OctoThinker-Base
OctoThinker-Zero

二、初步实验与分析

2.1 实验目标:验证RL适应性差异

从数学推理的视角,发现两个主要模型家族——Qwen 和 Llama——在强化学习(RL)动态中存在关键差异。

2.2 RL训练配置:奖励机制与超参设置

  • 框架与算法:基于 verl 框架,使用 GRPO 算法进行 RL 实验。
  • 训练数据集:采用 MATH8K 数据集,因其适中的难度和简洁的结构。
  • 训练参数
    • 全局训练批量大小:128
    • 每个查询的 rollout 响应数量:16
    • PPO 的小批量大小:64
    • 采样温度:1.0
    • 最大输出长度:4096 tokens
    • 学习率:1 × 10⁻⁶
    • KL 损失系数:0
    • 采样与梯度更新的比例:2(用于稳定 RL 训练)
  • 提示模板:采用简单的提示模板“问题:{} 答案:{}”来格式化训练样本。
  • 模型选择:采用 Llama-3.2-3B-BaseQwen2.5-3B-Base 进行 R1-Zero 风格的 RL 训练,选择它们是因为模型大小适中。

2.3 评估体系:数学推理基准

  • 评估方式
    • 基础语言模型:采用少样本提示(few-shot prompting)进行评估。
    • RL 微调模型:采用零样本(zero-shot)进行评估。
  • 评估指标任务
    • 分析 RL 动态:GSM8K、MATH500、OlympiadBench、AMC23
    • 基础模型性能:MATH、SAT-MATH、MathQA、MMLU-STEM、OCW Course、MAWPS、SVAMP、ASDiv、TabMWP

2.4 实验发现:Qwen与Llama的RL表现对比

在这里插入图片描述

  • Qwen2.5-3B-Base

    • 在整个训练过程中,正确响应的长度稳定且合理地增加。
    • 表现出良好的训练动态,响应长度的增加是渐进和可控的。
    • 在多个基准测试中表现出显著的性能提升,显示出其对 RL 训练的良好适应性。
  • Llama-3.2-3B-Base

    • 表现出异常行为,平均响应长度急剧增加,达到最大响应长度(4096 tokens)。
    • 输出通常以特定的格式开始(例如 “\boxed: {}”),然后是非常明显地重复,直到达到最大响应长度。
    • 在 GSM8K 上仅经历了微小的增益,甚至出现了倒退。
  • 归因:两个模型的预训练的潜在差异可能影响了它们在 RL 训练中的表现。

三、中期训练干预的探索

什么是中期训练(Mid-training)?
中期训练是一个中期阶段,其计算和数据需求介于预训练和后训练之间。它旨在实现特定目标,例如领域和语言扩展、长上下文扩展、提高数据质量、利用大规模合成数据,以及为后训练做准备等——通过显著改变数据质量和分布或修改模型架构以提高推理效率。

3.1 实验目标:探索中期训练优化路径

旨在通过一系列对照实验,深入探究在大型语言模型(LLM)的中期训练阶段,不同因素对后续强化学习(RL)性能的具体影响。

3.2 实验设计:多变量控制

  • 系统地考察多个关键变量,包括:
    1. 数学网络语料库的数据质量
    2. 是否引入问答(QA)格式的数据
    3. QA数据本身的特性(例如,思维链的长度)
    4. 中期训练数据中是否包含通用的指令跟随数据
    5. 预训练阶段所使用的token预算规模
  • 通过对这些因素进行精细化的控制和对比分析,期望能够更深刻地理解预训练数据与策略对模型在RL阶段行为模式和最终性能的内在联系。

在这里插入图片描述

3.3 实验设置:语料组合与训练规模控制

中期训练设置

  • 基础模型:默认使用Llama-3.2-3B-Base模型
  • 训练token预算:200亿
  • 学习率调度器:余弦学习率调度器,无预热
    • 峰值学习率:3e-5
    • 最小学习率:峰值学习率的十分之一
  • 序列长度:8,192个token
  • 训练批量大小:400万token
  • 训练框架:Nanotron框架

RL 设置(同2.2)

数据集
为了支持可控的中期训练实验,使用了一系列多样化的数据集,涵盖了数学网络文档、问答格式数据以及通用的指令跟随数据:

数据集名称 类型 Token数量 (B)
FineMath-4plus 数学网络文档 9.57
MegaMath-Web-Pro 数学网络文档 13.00
MegaMath-Web-Pro-Max 数学网络文档 (本研究构建) 73.80
MegaMath-QA 问答 (短思维链 CoT) 5.94
OpenR1-Math-220K 问答 (长思维链 CoT) 1.05
TULU3-sft 通用指令跟随 0.01
WildChat 通用指令跟随 0.29
UltraChat-220K 通用指令跟随 0.51

标准化的数据处理方式:

  • 对于OpenR1数据集,将其中的问题与包含在<think></think>标签内的思考过程通过换行符进行拼接。
  • 对于通用的指令跟随数据集,仅保留了高质量的对话,并将对话格式化为“User:{} Assistant:{}”的形式。

MegaMath-Web-Pro-Max 的整理

  1. 数据筛选与优化

    • 使用高效分类器从 MegaMath-Web 语料库中召回相关的数学文档。
    • 利用 Llama-3.1-70B-instruct 模型对文档进行标注和评分。
    • 评分标准:0到5分,低于3分为负样本,3分及以上为正样本。
  2. 分类器选择

    • 选择 fasttext 模型作为分类器,因其高效特性。
    • 数据预处理对召回性能至关重要,包括文本转换为小写、过滤过长单词、移除换行符和多余字符。
  3. 召回阈值

    • 召回阈值控制数据数量和质量之间的权衡。
    • 较高的阈值(例如0.9)会得到更高质量的数据,但保留的token数量会更少。
    • 选择0.4的阈值以平衡数据质量和数量。
  4. 数据精炼

    • 使用 Llama-3.1-70B-instruct 模型,并借鉴 MegaMath-Web-Pro 的提示对文本进行精炼。
    • 形成 MegaMath-Web-Pro-Max 数据集,其 token 数量约为 MegaMath-Web-Pro 的 5.5 倍。

3.4 实验发现:语料质量与RL性能的相关性

  1. 高质量数学语料库的重要性:研究发现,像 MegaMath-Web-Pro 这样的高质量数学语料库,相较于 FineMath-4plus 等现有替代方案,能显著提升基础模型和 RL 性能。
    在这里插入图片描述

  2. QA 格式数据与指令数据的增益:在高质量数学预训练语料库基础上,加入 QA 样式数据(尤其是长链推理示例)可增强 RL 效果,但收益取决于其与下游任务的分布差距。而少量指令数据的引入能进一步释放 QA 数据潜力,并减轻长链思考引起的 RL 训练崩溃。
    在这里插入图片描述

  3. 长链推理的双刃剑效应:长链推理虽能提升推理深度,但也可能引发模型响应冗长及 RL 训练不稳定问题(体现在突然的性能下降)。解决训练不稳定问题:

    • 设计指令增强提示模板,抑制重复输出
      在这里插入图片描述

    • 设置渐进最大响应长度调度器,按照训练进度解决长链推理引发的训练不稳定
      在这里插入图片描述

  4. 中等训练规模扩展的效益:增加中等训练数据量可带来更强劲的下游 RL 性能,即使基础模型评估中未明显体现这些增益。这表明,中等训练阶段的扩展对于提升模型的最终 RL 表现具有重要意义。
    在这里插入图片描述

影响因素 正向案例 负向案例 背后的洞察
数学语料质量 MegaMath-Web-Pro FineMath-4plus 低质量语料导致RL响应崩溃
QA数据分布 竞赛题库型推理数据 网页爬取短短思维链推理数据 分布对齐下游任务稳定RL训练
指令数据注入 添加1% GPT-4级指令跟随数据 纯数学语料(加问答数据)训练 解锁模型指令跟随能力
中期训练规模 100B token扩展训练 ≤20B token小规模训练 隐性提升RL潜力,超越基座模型评估指标

四、Llama 模型的 RL 能力优化

能否通过扩大中期训练规模,将 Llama 转变为适合强化学习扩展的基础模型?

采用 稳定后衰减 (stable-then-decay) 两阶段中期训练策略来实现:

4.1 第一阶段:构建强推理基座(200B tokens)

使用恒定学习率对 Llama 模型进行 200B tokens 训练,主要依赖高质量预训练语料库(如 MegaMath-Web-Pro 和 DCLM-Baselines),辅以少量合成数据,构建稳固的推理基础。这一阶段的目标是使模型在大规模数据上逐步提升推理能力,为后续的 RL 训练打下坚实基础,产出:OctoThinker-Base-Stable 系列基模型

4.2 第二阶段:分支专业化训练(20B tokens)

学习率衰减(余弦衰减至初始 LR 的 10%),引入不同数据混合(短链推理、长链推理及其混合),训练三个分支模型,塑造多样化模型行为。这一阶段旨在通过数据多样性和学习率调整,进一步提升模型的推理能力和适应性。

  • 稳定后衰减的设置提供了灵活性,衰减阶段可以在任何时候开始,能独立于固定时间表选择检查点。
  • 第二阶段降低学习率会放大注入数据的效果,有助于更有效地塑造模型行为。
  • 衰减阶段通常较短,因此这种方法通常也降低了总体训练成本

三大推理分支:

分支类型 数据组成 推理特性
短链分支 30% 竞赛短推理QA + 55% 数学语料 + 10% 指令跟随数据 + 5% 通用语料 高效步骤式解题
长链分支 30% 反思型长推理QA + 55% 数学语料 + 10% 指令跟随数据 + 5% 通用语料 深度反思,自我纠错
混合分支 15% 短链 + 15% 长链 + 55% 数学语料 + 10% 指令跟随数据 + 5% 通用语料 平衡效率与深度

由此产生的模型家族命名为 OctoThinker

“Octo”源自“octopus”(章鱼),象征基础模型家族分支成用不同策略训练的变体。“Thinker”反映了模型的最后阶段——强化学习——在那里它被训练去思考和推理,表现出频繁的自我反思和强大的推理能力,其灵感来自章鱼的多臂结构,反映了它的多个分支。

4.3 OctoThinker 性能表现

OctoThinker 基座模型

经两阶段中等训练后的 OctoThinker 基础模型系列,在数学推理基准测试中表现出色,相较于原始 Llama 基础模型,在所有模型尺寸上均实现了 10%-20% 的显著性能提升,为 RL 扩展奠定了坚实基础。例如,在 GSM8K 和 MATH500 等基准测试中,OctoThinker 基座模型的准确率和推理深度均有明显提升。

OctoThinker-Zero 家族
进一步对 OctoThinker 基础模型进行 RL 训练后,生成 OctoThinker-Zero 家族(包括短链、混合链和长链推理分支)。

OctoThinker vs Qwen2.5
比较三个 3B 规模的基础模型:Llama-3.2-3B-Base、OctoThinker-Long-3B-Base 和 Qwen2.5-3B-Base。

在这里插入图片描述
在这里插入图片描述

结果表明:在强化学习阶段,OctoThinker-Long-3B 始终优于原始的 Llama-3.2-3B 模型,它达到了与 Qwen2.5-3B 相当的性能,而混合和短分支则略逊一筹,尤其是在具有挑战性的基准测试中。

总体而言,这些结果突显了我们中期训练策略引入的实质性增益,并证实 OctoThinker 有效地缩小了性能差距,提升了 Llama-3.2 模型在数学推理任务中的竞争力。

五、总结与展望

这项工作研究了为什么 Llama 和 Qwen 等基础模型在强化学习推理中表现出不同的行为,并证明了中期训练可以发挥决定性作用。研究结果表明,高质量的、推理密集型的语料库——尤其是像 MegaMath-Web-Pro 这样的语料库——可以显著提高强化学习的稳定性和有效性。基于这些见解,引入了一种两阶段的中期训练策略,将 Llama 转变为更具强化学习可扩展性的基础模型。由此产生的 OctoThinker 模型在数学推理任务中取得了强劲的性能,缩小了与 RL 友好模型家族之间的差距。

未来工作将积极探索更多内容,包括:
(1) 进一步精炼数学预训练语料库以增强中期训练效果
(2) 采用开放配方设计无需从强大长链推理模型蒸馏的 RL 友好型基础模型
(3) 将问答格式和内容分离,以更好地理解它们各自的贡献;
(4) 拓展 OctoThinker 家族,增加如工具集成推理等新分支,以期为预训练与强化学习的交互机制提供更深入洞见。


网站公告

今日签到

点亮在社区的每一天
去签到