Autoregressive decoding in large language models (LLMs) requires \(\mathcal{O}(n)\) sequential steps for \(n\) tokens, fundamentally limiting inference throughput. Recent diffusion-based LLMs (dLLMs) enable parallel token generation through iterative denoising. However, current parallel decoding strategies rely on fixed, input-agnostic heuristics (e.g., confidence thresholds), which fail to adapt to input-specific characteristics, resulting in suboptimal speed-quality trade-offs across diverse NLP tasks. In this work, we explore a more flexible and dynamic approach to parallel decoding. We propose \(\textbf{Learning to Parallel Decode (Learn2PD)}\), a framework that trains a lightweight and adaptive filter model to predict, for each token position, whether the current prediction matches the final output. This learned filter approximates an oracle parallel decoding strategy that unmasks tokens only when correctly predicted. Importantly, the filter model is learned in a post-training manner, requiring only a small amount of computation to optimize it (minute-level GPU time). Additionally, we introduce \(\textbf{End-of-Text Prediction (EoTP)}\) to detect decoding completion at the end of sequence, avoiding redundant decoding of padding tokens. Experiments on the LLaDA benchmark demonstrate that our method achieves up to \(\textbf{22.58×}\) speedup without any performance drop, and up to \(\textbf{57.51×}\) when combined with KV-Cache.
This strategy compares the predicted tokens with the reference answer and only remasks the tokens that do not match in these comparisons.
Use a trained \(Filter f_\theta\) that simulate the Extremely Greedy Parallel strategy after each decoding step to select tokens and decide whether to remask them.
Upon detection of an \([EoT]\) token, all subsequent tokens are assigned \([EoT]\) in parallel. Since the \([EoT]\) tokens have no effect on other tokens beyond the first occurrence, we throw away those redundant \([EoT]\) tokens in the next diffusion step. When the specified output length is very long (for example, 1024), this method can significantly reduce computation by dynamically reducing the input size during the diffusion process.
Benchmark results on the LLaDA-8B-Instruct suite. Each method was evaluated using two generation lengths (256 and 1024) across four datasets. Performance is measured using three metrics: TPS (tokens/sec), speedup, and accuracy score. The highest throughput and speedup values for each configuration are highlighted in bold.
We further evaluate the compatibility of our approach with established Key-Value (KV) Cache techniques by integrating both Dual Cache and Prefix Cache strategies. Experiments are conducted on GSM8K with a generation length of 1024 tokens. As summarized in the following table, the baseline model (Learn2PD & EoTP) achieves a throughput of 12.26 TPS, a speed-up of 22.58×, and an accuracy score of 79.83. When augmented with the Dual Cache, the system attains substantially higher efficiency, reaching 31.23 TPS and a 57.51× speedup, albeit with a slight decrease in accuracy (74.00). Similarly, incorporating the Prefix Cache also brings noticeable improvements, yielding 14.79 TPS and a 27.23× acceleration while maintaining a competitive score of 77.71. These results confirm that our method is orthogonal to and fully compatible with standard KV caching mechanisms, demonstrating its ability to leverage such strategies to enhance inference efficiency.
This website is adapted from Nerfies, licensed under a Creative Commons Attribution-ShareAlike 4.0 International License. We thank the LLaDA team for giving us access to their models, and open-source projects.
Usage and License Notices: The data, code and checkpoint is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaDA, Fast-dLLM. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
Related Links: [LLaDA] [Fast-dLLM]