Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference

Jiaming Tang*, Yilong Zhao*, Kan Zhu, Guangxuan Xiao, Baris Kasikci, Song Han
Massachusetts Institute of Technology, Shanghai Jiao Tong University, University of Washington, NVIDIA
(* indicates equal contribution)

News

Waiting for more news.

Awards

No items found.

Competition Awards

No items found.

Abstract

As the demand for long-context large language models (LLMs) increases, models with context windows of up to 128k or 1M tokens are becoming increasingly prevalent. However, long- context LLM inference is challenging since the inference speed decreases significantly as the sequence length grows. This slowdown is primarily caused by loading a large KV cache during self-attention. Previous works have shown that a small portion of critical tokens will dominate the attention outcomes. However, we observe the criticality of a token highly depends on the query. To this end, we propose Quest, a query-aware token criticality estimation algorithm. Quest keeps track of the minimal and maximal Key values in KV cache pages and estimates the criticality of a given page using Query vectors. By only loading the Top-K critical KV cache pages for attention, Quest significantly speeds up self-attention without sacrificing accuracy. We show that Quest can achieve up to 7.03× self-attention speedup, which reduces inference latency by 2.23× while performing well on tasks with long dependencies with negligible accuracy loss.

The Limits of Previous Methods

  • Many previous efforts have been dedicated to compressing the size of the KV cache to accelerate attention and reduce memory usage.
  • These methods decide which parts of the KV cache to discard based on historical information or current states, but discarded tokens might be important for future tokens, which may cause the loss of important information.
  • The criticality of the tokens is dynamic and highly dependent on the query vector Q.
  • Example: the token ‘B’ is critical to the current query ‘is’. Thus, it has a high attention score. However, before the final token ‘is’, ‘B’ is not critical for any previous query and has very low attention scores.

Using Query-aware Sparsity in Attention

Key Idea: preserve all KV cache, and significantly accelerate inference by reducing the memory movement from the entire KV cache to selected constant K pages.

  • Our insight is that in order not to miss critical tokens, we should select pages containing the token with the highest attention weights.
  • However, for an efficient selection of pages, we should calculate an approximate attention score following this insight.
  • We found that the upper bound attention weights within a page can be used to approximate the highest attention in the page.

Results

Accuracy of Needle-in-a-Haystack Benchmark

  • (i) Results of 10k length passkey retrieval test on LongChat-7b-v1.5-32k.
  • (ii) Results of 100k length passkey retrieval test on Yarn-Llama-2-7b-128k.
  • Quest can achieve nearly perfect accuracy with a KV cache of 64 and 1024 tokens, which is about 1% of the total sequence length, demonstrating that Quest can effectively preserve the model’s ability to handle long-dependency tasks.

Accuracy of LongBench Tasks

  • We evaluate LongChat-7b-v1.5-32k across a wide range of long-context datasets,
  • Quest with a budget of 2k tokens can achieve comparable performance as the model with full KV cache, while other baselines still exhibit a notable gap from full cache performance even with a larger budget.
  • Single-document QA: NarrativeQA, Qasper, MultiFieldQA; multi-document QA: HotpotQA; summarization: GovReport; few-shot learning: TriviaQA.

Effiency Comparison

  • For all sequence lengths, Quest significantly outperforms FlashInfer. Increasing the sequence lengths only slightly changes the latency of Quest.
  • Quest speedup end-to-end inference by 2.23× with sequence length 30K, token budget 2048, 4-bit weight quantization.

Video

Citation

@inproceedings{tang2024quest,
 title={{QUEST: Query-Aware Sparsity for Efficient Long-Context LLM Inference}},
 author={Jiaming Tang and Yilong Zhao and Kan Zhu and Guangxuan Xiao and Baris Kasikci and Song Han},
 booktitle={Proceedings of the International Conference on Machine Learning (ICML)},
 year={2024},
}

Media

No media articles found.

Acknowledgment

We thank Zihao Ye for his insightful discussion, feedback, and useful advice on algorithm design and FlashInfer integration. This work was supported by generous gifts from Intel, Google, and the PRISM Research Center, a JUMP Center cosponsored by SRC and DARPA.

Team Members