论文阅读:speculative decoding

发布于:2025-06-14 ⋅ 阅读:(21) ⋅ 点赞:(0)

Fast Inference from Transformers via Speculative Decoding

论文地址:https://arxiv.org/pdf/2211.17192

speculative sampling

为了从分布 p ( x ) p(x) p(x) 中采样,我们实际上是从分布 q ( x ) q(x) q(x) 中采样 x x x,如果 q ( x ) ≤ p ( x ) q(x) \leq p(x) q(x)p(x),则保留该样本;如果 q ( x ) > p ( x ) q(x) > p(x) q(x)>p(x),则以概率 1 − p ( x ) q ( x ) 1 - \frac{p(x)}{q(x)} 1q(x)p(x) 拒绝该样本,并重新从调整后的分布 p ′ ( x ) = norm ( max ⁡ ( 0 , p ( x ) − q ( x ) ) ) p'(x) = \text{norm}(\max(0, p(x)-q(x))) p(x)=norm(max(0,p(x)q(x))) 中采样。对于任何分布 p ( x ) p(x) p(x) q ( x ) q(x) q(x),以及以此方式采样的 x x x,确实有 x ∼ p ( x ) x \sim p(x) xp(x)

给定通过在条件前缀上运行 M q M_q Mq 获得的分布 q ( x ) q(x) q(x),我们可以采样一个标记 x 1 ∼ q ( x ) x_1 \sim q(x) x1q(x)。然后,我们通过在前缀上运行 M p M_p Mp 来计算分布 p ( x ) p(x) p(x),同时并行地推测性地计算下一个标记 x 2 x_2 x2 的分布,即在前缀上追加 x 1 x_1 x1 后运行 M p M_p Mp。一旦两项计算都完成,我们就按上述方式处理:如果 x 1 x_1 x1 被拒绝,我们丢弃 x 2 x_2 x2 的计算,并从调整后的分布中重新采样 x 1 x_1 x1;如果 x 1 x_1 x1 被接受,我们就保留两个标记。算法 1 将这一想法推广为一次采样 1 到 γ + 1 \gamma + 1 γ+1 个标记。
运行算法

分析

有几个证明需要注意一下:

单次算法期望能生成的token
  1. 单次算法期望能生成的token数量服从几何分布,但是求和项是有限制的,这里推导下​

  2. ​接受率β的定义​
    设目标模型分布为 p(x),草稿模型分布为 q(x)。草稿模型生成的单个token被目标模型接受的概率为:

β = ∑ x min ⁡ ( q ( x ) , p ( x ) ) \beta = \sum_x \min\left(q(x), p(x)\right) β=xmin(q(x),p(x))

  1. ​拒绝率α的定义​

α = 1 − β = 1 − ∑ x min ⁡ ( p ( x ) , q ( x ) ) x \alpha = 1 - \beta = 1 - \sum_x \min(p(x), q(x)) x α=1β=1xmin(p(x),q(x))x

  • 假设每个token的接受事件独立且同分布(i.i.d.),草稿模型一次生成 K 个token:

  • ​首次拒绝发生在位置 r​ 的概率为:

    P ( r ) = ( 1 − β ) β r − 1 ( 1 ≤ r ≤ K ) P(r) = (1-\beta) \beta^{r-1} \quad (1 \leq r \leq K) P(r)=(1β)βr1(1rK)

    所有token均被接受​​ 的概率为: β K \beta^K βK

  • 综上期望能生成的token数量为:

    γ = ∑ r = 1 K r ⋅ P ( r ) ⏟ 拒绝前生成的token + K ⋅ β K ⏟ 全接受时生成K个token \gamma = \underbrace{\sum_{r=1}^K r \cdot P(r)}_{\text{拒绝前生成的token}} + \underbrace{K \cdot \beta^K}_{\text{全接受时生成K个token}} γ=拒绝前生成的token r=1KrP(r)+全接受时生成Ktoken KβK

代入 P ( r ) P(r) P(r) 后展开:

γ = ∑ r = 1 K r ⋅ ( 1 − β ) β r − 1 + K β K \gamma = \sum_{r=1}^K r \cdot (1-\beta) \beta^{r-1} + K \beta^K γ=r=1Kr(1β)βr1+KβK

  1. 几何级数求和​

几何级数求和公式为:

∑ r = 1 K r β r − 1 \sum_{r=1}^K r \beta^{r-1} r=1Krβr1 求和处理:

  • ​令 S = ∑ r = 1 K β r − 1 S = \sum_{r=1}^K \beta^{r-1} S=r=1Kβr1​:

S = 1 + β + β 2 + ⋯ + β K − 1 = 1 − β K 1 − β S = 1 + \beta + \beta^2 + \cdots + \beta^{K-1} = \frac{1-\beta^K}{1-\beta} S=1+β+β2++βK1=1β1βK

  • ​对 S S S 求导​​:

∑ r = 1 K r β r − 1 = d d β ( ∑ r = 0 K β r ) = d d β ( 1 − β K + 1 1 − β ) = 1 − ( K + 1 ) β K + K β K + 1 ( 1 − β ) 2 \sum_{r=1}^K r \beta^{r-1} = \frac{d}{d\beta} \left( \sum_{r=0}^K \beta^r \right) = \frac{d}{d\beta} \left( \frac{1-\beta^{K+1}}{1-\beta} \right) = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} r=1Krβr1=dβd(r=0Kβr)=dβd(1β1βK+1)=(1β)21(K+1)βK+KβK+1

  • ​代入γ表达式​​:

γ = ( 1 − β ) ⋅ 1 − ( K + 1 ) β K + K β K + 1 ( 1 − β ) 2 + K β K = 1 − ( K + 1 ) β K + K β K + 1 1 − β + K β K \gamma = (1-\beta) \cdot \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} + K\beta^K = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{1-\beta} + K\beta^K γ=(1β)(1β)21(K+1)βK+KβK+1+KβK=1β1(K+1)βK+KβK+1+KβK

  • 化简​​:

γ = 1 − β K 1 − β \gamma = \frac{1 - \beta^K}{1-\beta} γ=1β1βK

​物理意义​​:

  • K → ∞ K \to \infty K时, γ → 1 1 − β = 1 α \gamma \to \frac{1}{1-\beta} = \frac{1}{\alpha} γ1β1=α1(理想无限长草稿)。
  • 例如 β \beta β = 0.8` 时, γ max = 5 \gamma_{\text{max}} = 5 γmax=5,即平均每次生成5个token。

得证

Walltime的时间优化

​定理 3.8​​:算法 1 在总运行时间上的预期改进因子为
‘ 1 − α γ + 1 ( 1 − α ) ( γ c + 1 ) ‘ `\frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)}` (1α)(γc+1)1αγ+1

​证明​​:
记运行目标模型 M p M_p Mp​单步​​的成本为 T T T
算法 1 的​​单次运行成本​​为 T c γ + T Tc\gamma + T Tcγ+T(其中 c γ T c\gamma T cγT用于运行近似模型 M q M_q Mq γ \gamma γ 次, T T T 用于运行 M p M_p Mp 一次)。
根据单次算法期望能生成的token算法推导,单次运行​​平均生成 token 数量​​为 1 − α γ + 1 1 − α \dfrac{1 - \alpha^{\gamma + 1}}{1 - \alpha} 1α1αγ+1
因此,使用算法 1 生成单个 token 的​​总体预期成本​​为:
( c γ + 1 ) ( 1 − α ) 1 − α γ + 1 T ‘ \frac{(c\gamma + 1)(1 - \alpha)}{1 - \alpha^{\gamma + 1}}T` 1αγ+1(cγ+1)(1α)T
由于标准解码算法生成单个 token 的成本为 T
比较可得上述改进因子。∎
(注:符号 “∎” 表示证明结束)


关键术语说明:

英文术语 中文翻译 符号 含义
walltime 总运行时间 - 算法从启动到结束的时钟时间
expected improvement factor 预期改进因子 - 优化后时间开销的缩减比例
cost per step 单步成本 T T T 目标模型 M p M_p Mp 推理一个 token 的时间
approximation model 近似模型 M q M_q Mq 快速但低精度的草稿模型
tokens 标记(Token) - 模型生成的基本文本单位
rejection rate 拒绝率 α \alpha α 草稿模型 M q M_q Mq 的 token 被目标模型 M p M_p Mp 拒绝的概率
γ \gamma γ 生成长度 γ \gamma γ 草稿模型单次运行的 token 生成数
cost ratio 成本比 c c c M q M_q Mq M p M_p Mp 的单步时间比值( 0 < c < 1 0 < c < 1 0<c<1

公式解析:

  1. ​改进因子​
    1 − α γ + 1 ( 1 − α ) ( γ c + 1 ) \frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)} (1α)(γc+1)1αγ+1
  • ​分子​ 1 − α γ + 1 1 - \alpha^{\gamma+1} 1αγ+1:草稿模型连续生成 \gamma 个 token 均未被拒绝的概率补偿
  • ​分母​ ( 1 − α ) (1-\alpha) (1α):单 token 接受率, γ c + 1 \gamma c + 1 γc+1:草稿+验证的总时间成本

该值 ​​>1​​ 时表示加速,值越大加速效果越显著

  1. ​单 token 成本公式​
    ( c γ + 1 ) ( 1 − α ) 1 − α γ + 1 T \frac{(c\gamma+1)(1-\alpha)}{1-\alpha^{\gamma+1}}T 1αγ+1(cγ+1)(1α)T
  • ​分子​ ( c γ + 1 ) ( 1 − α ) T (c\gamma+1)(1-\alpha)T (cγ+1)(1α)T:草稿生成+验证的实际计算量
  • ​分母​ 1 − α γ + 1 1-\alpha^{\gamma+1} 1αγ+1:有效 token 产出的概率加权
操作数计算

操作数的计算量也是类似的,直接贴结论了

( 1 − α ) ( γ c ^ + γ + 1 ) 1 − α γ + 1 \frac{(1-\alpha)(\gamma \hat{c}+\gamma+1)}{1-\alpha^{\gamma+1}} 1αγ+1(1α)(γc^+γ+1)

采样和原分布的等价性证明

参考https://arxiv.org/pdf/2302.01318
其中需要一步代换证明下面两个公式等价:

原始公式

第一个公式:
= 1 − ∑ x ′ min ⁡ ( p ( x ′ ) , q ( x ′ ) ) =1-\sum_{x^{\prime}}\min\left(p\left(x^{\prime}\right),q\left(x^{\prime}\right)\right) =1xmin(p(x),q(x))

第二个公式:
= ∑ x ′ max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) =\sum_{x^{\prime}}\max\left(0,q\left(x^{\prime}\right)-p\left(x^{\prime}\right)\right) =xmax(0,q(x)p(x))

推导步骤

步骤 1: 应用 min 函数的恒等式

对于任何两个实数 a a a b b b,都存在以下恒等关系:
min ⁡ ( a , b ) = a − max ⁡ ( 0 , a − b ) \min(a,b) = a - \max(0, a - b) min(a,b)=amax(0,ab)

b = p ( x ′ ) b = p(x') b=p(x) a = q ( x ′ ) a = q(x') a=q(x),得到:
min ⁡ ( p ( x ′ ) , q ( x ′ ) ) = q ( x ′ ) − max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) \min(p(x'),q(x')) = q(x') - \max(0, q(x') - p(x')) min(p(x),q(x))=q(x)max(0,q(x)p(x))

步骤 2: 代入第一个公式

将恒等式代入原始公式:
1 − ∑ x ′ min ⁡ ( p ( x ′ ) , q ( x ′ ) ) = 1 − ∑ x ′ [ q ( x ′ ) − max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) ] \begin{aligned} &1 - \sum_{x^{\prime}} \min(p(x'),q(x')) \\ &= 1 - \sum_{x^{\prime}} \left[ q(x') - \max(0, q(x') - p(x')) \right] \end{aligned} 1xmin(p(x),q(x))=1x[q(x)max(0,q(x)p(x))]

步骤 3: 拆分求和运算

将求和符号分配到表达式内部:
= 1 − [ ∑ x ′ p ( x ′ ) − ∑ x ′ max ⁡ ( 0 , p ( x ′ ) − q ( x ′ ) ) ] = 1 - \left[ \sum_{x^{\prime}} p(x') - \sum_{x^{\prime}} \max(0, p(x') - q(x')) \right] =1[xp(x)xmax(0,p(x)q(x))]
= 1 − ∑ x ′ q ( x ′ ) + ∑ x ′ max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) = 1 - \sum_{x^{\prime}} q(x') + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =1xq(x)+xmax(0,q(x)p(x))

步骤 4: 应用概率分布性质

因为 p p p q q q 都是概率分布函数,满足:
∑ x ′ p ( x ′ ) = 1 和 ∑ x ′ q ( x ′ ) = 1 \sum_{x^{\prime}} p(x') = 1 \quad \text{和} \quad \sum_{x^{\prime}} q(x') = 1 xp(x)=1xq(x)=1

代入表达式:
= 1 − 1 + ∑ x ′ max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) = 1 - 1 + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =11+xmax(0,q(x)p(x))
= ∑ x ′ max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) = \sum_{x^{\prime}} \max(0, q(x') - p(x')) =xmax(0,q(x)p(x))

得证

Reference

https://arxiv.org/pdf/2211.17192