October 30, 2024

Wenhao Chai

<aside> 🔖 In this blog, we introduce Jacobi Decoding, a parallel decoding algorithm for LLMs and its connection to the diffusion process in terms of high-level concepts.

</aside>

The decoding process of an autoregressive LLM involves the model generating text by predicting each token step-by-step from left to right. The core idea is to predict the next token based on the content generated so far, so each generated token updates the input, allowing the model to recursively use the previous output as the new input for the next prediction. Various decoding methods, such as greedy search and beam search, have been explored, with more detailed explanations available here. However, most of these methods aim to balance the stability and diversity of LLM outputs, sometimes introducing additional inference overhead. Moreover, since these decoding methods are inherently sequential, each decoding step does not leverage the parallel processing capabilities of modern GPUs, often resulting in low GPU utilization. This poses challenges for many real-world LLM applications that prioritize rapid response times, such as video understanding. In this blog, we will first introduce Jacobi Decoding, a parallel decoding method for autoregressive LLMs that aims to achieve faster decoding speeds with minimal performance drop.

1. Jacobi Decoding

Let's start by defining some basic mathematical notation. For a general autoregressive LLM, we have:

$$ y_i=\argmax_y p(y|y_{;i},x) $$

Here, $x$ represents the input prompt, often referred to as the prefix, while $y_{:i}$ denotes the tokens that have already been predicted, and $y_i$ is the next token to be predicted. In this case, we simply select the token with the highest probability as the prediction for the next token.

The Jacobi iteration is a classic nonlinear system solver that, when applied to LLM decoding, can support generating multiple tokens at once rather than the traditional autoregressive one-at-a-time approach. Specifically, we first define the objective function:

$$ f(y_i,y_{:i},x)\coloneqq y_i -\argmax_y p(y|y_{:i},x)

$$

We require that, for each position $i$ in the decoded sequence, the condition $f(y_i,y_{:i},x)=0$ holds. Simply put, this objective is equivalent to autoregressive decoding and serves as the upper bound for Jacobi Decoding. Assuming that the number of tokens decoded in one iteration is $n$, the decoding process is equivalent to iteratively solving the following system of equations:

$$ \begin{cases} y_1^{(j+1)}&=\argmax_y p(y|x) \\ y_2^{(j+1)}&=\argmax_y p(y|y_1^{(j)},x) \\ y_3^{(j+1)}&=\argmax_y p(y|y_{:3}^{(j)},x) \\ &\vdots \\ y_n^{(j+1)}&=\argmax_y p(y|y_{:n}^{(j)},x) \end{cases} $$

Here, the superscript $^{(j)}$ denotes the iteration count. It's evident that each equation corresponds to standard autoregressive decoding. However, the difference lies in the dependency: while autoregressive decoding relies on precise preceding tokens, Jacobi Decoding solves for all tokens simultaneously, based on an imperfect initial guess. As illustrated in the following diagram, starting from a set of random guesses, the predicted sequence continually refines itself. Each iteration produces a more accurate sequence, which in turn yields even better predictions in the next iteration, ultimately converging to the autoregressive decoding result. The driving signal toward accuracy in this process is actually provided by the prefix $x$.

An example of Jacobi Decoding: from bottom to top are the results of each iteration. Image source.

An example of Jacobi Decoding: from bottom to top are the results of each iteration. Image source.