Deploying long-context large language models (LLMs) is essential but poses significant computational and memory challenges. Caching all Key and Value (KV) states across all attention heads consumes substantial memory. Existing KV cache pruning methods either damage the long-context capabilities of LLMs or offer only limited efficiency improvements. In this paper, we identify that only a fraction of attention heads, a.k.a, Retrieval Heads, are critical for processing long contexts and require full attention across all tokens. In contrast, all other heads, which primarily focus on recent tokens and attention sinks, referred to as Streaming Heads, do not require full attention. Based on this insight, we introduce DuoAttention, a framework that only applies a full KV cache to retrieval heads while using a light-weight, constant-length KV cache for streaming heads, which reduces both LLM's decoding and pre-filling memory and latency without compromising its long-context abilities. DuoAttention uses a lightweight, optimization-based algorithm with synthetic data to identify retrieval heads accurately. Our method significantly reduces long-context inference memory by up to 2.55x for MHA and 1.67x for GQA models while speeding up decoding by up to 2.18x and 1.50x and accelerating pre-filling by up to 1.73x and 1.63x for MHA and GQA models, respectively, with minimal accuracy loss compared to full attention. Notably, combined with quantization, DuoAttention enables Llama-3-8B decoding with 3.3 million context length on a single A100 GPU.
Large language models (LLMs) are increasingly being used for tasks requiring the processing of long contextual sequences. However, this comes with significant computational and memory challenges, particularly due to the Key-Value (KV) cache, which stores keys and values from all preceding tokens. As the context length increases, so do the memory requirements, resulting in inefficient and slow decoding. To address these issues, DuoAttention proposes a novel approach that reduces the memory footprint and latency without compromising the long-context capabilities of LLMs by differentiating between retrieval heads and streaming heads.
We introduce DuoAttention, a framework that reduces memory and latency in long-context inference by using a full KV cache only for retrieval heads and a constant-length cache for streaming heads. The framework retains performance while achieving substantial improvements in efficiency.
Deploying LLMs for long-context tasks presents multiple challenges, such as:
DuoAttention leverages the observation that only a fraction of attention heads, termed retrieval heads, are critical for processing long contexts, whereas streaming heads primarily focus on recent tokens. This insight allows for more efficient memory usage by maintaining a full KV cache only for retrieval heads and using a smaller cache for streaming heads.
Figure 1 provides a visual representation of the different roles of retrieval and streaming heads. Retrieval heads, such as those in Layer 15, Head 12, capture contextually important tokens. Streaming heads focus on recent tokens and can tolerate a reduced KV cache without affecting performance (Figure 2).
DuoAttention introduces a lightweight optimization-based algorithm to distinguish between retrieval and streaming heads, based on output deviations caused by compressing the KV cache. The identification process uses synthetic datasets and regularizes the gate values for each attention head, minimizing performance loss while improving efficiency.
During deployment, DuoAttention applies full attention to retrieval heads and streaming attention to streaming heads (Figure 3).
DuoAttention reduces memory usage and latency during both decoding and pre-filling stages. The memory complexity of streaming heads is reduced to a constant size (O(1)), while retrieval heads maintain a full cache to handle long-context scenarios. Figure 3 shows the chunked pre-filling process, illustrating how DuoAttention retains efficiency by limiting the KV cache for streaming heads.
DuoAttention achieves substantial reductions in memory usage and latency compared to full attention models. The memory reductions are up to 2.55× for Multi-Head Attention (MHA) and 1.67× for Grouped-Query Attention (GQA) models. The decoding speed increases by up to 2.18× for MHA models and 1.50× for GQA models. DuoAttention's improvements scale with context length, making it ideal for long-context applications like document summarization or multi-turn dialogue systems.
Moreover, combining DuoAttention with quantization techniques further boosts its memory efficiency, allowing models like Llama-3-8B to handle up to 3.3 million contextual tokens on a single A100 GPU.
Despite the efficiency gains, DuoAttention maintains high accuracy across various benchmarks. Figure 6 compares DuoAttention's performance with other methods, showing its ability to handle long-context tasks such as Needle-in-a-Haystack and LongBench. It outperforms existing techniques like H2O and StreamingLLM by providing a superior trade-off between KV cache budget and accuracy.
@article{xiao2024duo,
title={DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads},
author={Xiao, Guangxuan and Tang, Jiaming and Zuo, Jingwei and Guo, Junxian and Yang, Shang and Tang, Haotian and Fu, Yao and Han, Song},
journal={arXiv},
year={2024}
}
We thank MIT-IBM Watson AI Lab, MIT and Amazon Science Hub, MIT AI Hardware Program, National Science Foundation, Hyundai and Samsung for supporting this research. We thank NVIDIA for donating the DGX server.