October 30, 2024

Wenhao Chai

<aside> 🔖 在这篇博客中,我们介绍了Jacobi Decoding,一种用于大型语言模型(LLMs)的并行解码算法,并从高层概念上探讨了其与扩散过程的联系。

</aside>

自回归(autoregressive)LLM的解码是模型在生成文本时,逐步从左到右预测每一个token的过程。其核心思想是:基于已生成的内容来预测下一个token,这样一来,每生成一个token,都会更新输入,使得模型可以递归地将前面的输出作为新一轮输入进行预测。人们已经探索了很多解码方式如greedy search和beam search,更详细的介绍可以在此处找到。然而,这些解码方式大多是为了平衡LLM输出的稳定性和多样性,部分还带来了额外的inference开销。并且,由于这些解码方式本质上还是串行的,因此每个解码步骤都不会利用现代 GPU 的并行处理能力,这通常会导致 GPU 利用率低。这对许多优先考虑快速响应时间的现实世界 LLM 应用程序提出了挑战,例如视频理解。在这篇blog中,我们将首先介绍Jacobi Decoding,这是一种应用于自回归LLM的并行解码方式,目标是在尽量不损失性能的前提下实现更快的解码速度。

1. Jacobi Decoding

让我们先定义一些基本的数学符号,对于一般的自回归LLM,我们有:

$$ y_i=\argmax_y p(y|y_{;i},x) $$

其中 $x$ 是输入的prompt被称作prefix, $y_{:i}$ 是已预测的token, 而 $y_i$ 是需要预测的下一个token。在这里我们简单的取最高概率作为下一个token的预测方式。

Jacobi iteration 是一种经典的非线性系统求解器,当被用于LLM解码时,他可以支持一次性输出多个token而非自回归的每次一个。具体来说,我们首先定义优化项:

$$ f(y_i,y_{:i},x)\coloneqq y_i -\argmax_y p(y|y_{:i},x)

$$

我们要求对于带解码序列的每一项 $i$ 都有 $f(y_i,y_{:i},x)=0$,简单来说,这一优化项与自回归解码等价并将其作为jacobi decoding的upper bound。我们假设一次性解码的token数是 $n$,则解码过程等价于不断地求解方程组:

$$ \begin{cases} y_1^{(j+1)}&=\argmax_y p(y|x) \\ y_2^{(j+1)}&=\argmax_y p(y|y_1^{(j)},x) \\ y_3^{(j+1)}&=\argmax_y p(y|y_{:3}^{(j)},x) \\ &\vdots \\ y_n^{(j+1)}&=\argmax_y p(y|y_{:n}^{(j)},x) \end{cases} $$

其中上标 $^{(j)}$表示迭代的次数。不难看出其中每一项都是标准的自回归解码。而区别在于,自回归解码需要依赖于前序精确的token,而jacobi decoding实际上是同时求解所有token,它依赖于一个不完美的前序预测。如下图所示,从一组随机猜测开始,待预测的序列会不断的self-refine,而更精确的序列又会在下一次迭代中产生更精确的结果,最后收敛到自回归解码的结果。在这其中使其趋于精确的信号实际上由prefix $x$ 提供。

一个jacobi decoding的例子,从下至上分别是每一次迭代的结果,图片来源。

一个jacobi decoding的例子,从下至上分别是每一次迭代的结果,图片来源