DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads

Guangxuan Xiao¹, Jiaming Tang¹, Jingwei Zuo², Junxian Guo¹ ³, Shang Yang¹, Haotian Tang¹, Yao Fu⁴, Song Han¹ ⁵
¹MIT, ²Tsinghua University, ³SJTU, ⁴University of Edinburgh, ⁵NVIDIA
(* indicates equal contribution)

News

Waiting for more news.

Awards

No items found.

Competition Awards

No items found.

Abstract

Deploying long-context large language models (LLMs) is essential but poses significant computational and memory challenges. Caching all Key and Value (KV) states across all attention heads consumes substantial memory. Existing KV cache pruning methods either damage the long-context capabilities of LLMs or offer only limited efficiency improvements. In this paper, we identify that only a fraction of attention heads, a.k.a, Retrieval Heads, are critical for processing long contexts and require full attention across all tokens. In contrast, all other heads, which primarily focus on recent tokens and attention sinks, referred to as Streaming Heads, do not require full attention. Based on this insight, we introduce DuoAttention, a framework that only applies a full KV cache to retrieval heads while using a light-weight, constant-length KV cache for streaming heads, which reduces both LLM's decoding and pre-filling memory and latency without compromising its long-context abilities. DuoAttention uses a lightweight, optimization-based algorithm with synthetic data to identify retrieval heads accurately. Our method significantly reduces long-context inference memory by up to 2.55x for MHA and 1.67x for GQA models while speeding up decoding by up to 2.18x and 1.50x and accelerating pre-filling by up to 1.73x and 1.63x for MHA and GQA models, respectively, with minimal accuracy loss compared to full attention. Notably, combined with quantization, DuoAttention enables Llama-3-8B decoding with 3.3 million context length on a single A100 GPU.

1. Introduction

Large language models (LLMs) are increasingly being used for tasks requiring the processing of long contextual sequences. However, this comes with significant computational and memory challenges, particularly due to the Key-Value (KV) cache, which stores keys and values from all preceding tokens. As the context length increases, so do the memory requirements, resulting in inefficient and slow decoding. To address these issues, DuoAttention proposes a novel approach that reduces the memory footprint and latency without compromising the long-context capabilities of LLMs by differentiating between retrieval heads and streaming heads.

We introduce DuoAttention, a framework that reduces memory and latency in long-context inference by using a full KV cache only for retrieval heads and a constant-length cache for streaming heads. The framework retains performance while achieving substantial improvements in efficiency.

2. Challenges

Deploying LLMs for long-context tasks presents multiple challenges, such as:

  • Memory Consumption: A traditional LLM caches all Key and Value states across attention heads, causing memory consumption to increase linearly with the context length. For example, Llama-3-8B requires 137 GB of memory to store the KV cache for 1 million tokens.
  • Decoding and Pre-filling Latency: As the sequence length grows, the time to decode and pre-fill the context increases significantly. This latency becomes a bottleneck, particularly for tasks that involve extensive context processing, such as document summarization or dialogue systems.
  • Long-Context Inference: Current methods for pruning the KV cache either hurt the long-context processing capabilities of LLMs or provide only limited efficiency improvements.

3. Method

3.1. Retrieval and Streaming Heads

DuoAttention leverages the observation that only a fraction of attention heads, termed retrieval heads, are critical for processing long contexts, whereas streaming heads primarily focus on recent tokens. This insight allows for more efficient memory usage by maintaining a full KV cache only for retrieval heads and using a smaller cache for streaming heads.

Refer to caption

Figure 1 provides a visual representation of the different roles of retrieval and streaming heads. Retrieval heads, such as those in Layer 15, Head 12, capture contextually important tokens. Streaming heads focus on recent tokens and can tolerate a reduced KV cache without affecting performance (Figure 2).

Refer to caption

3.2. DuoAttention Framework

DuoAttention introduces a lightweight optimization-based algorithm to distinguish between retrieval and streaming heads, based on output deviations caused by compressing the KV cache. The identification process uses synthetic datasets and regularizes the gate values for each attention head, minimizing performance loss while improving efficiency.

Refer to caption

During deployment, DuoAttention applies full attention to retrieval heads and streaming attention to streaming heads (Figure 3).

3.3. Pre-filling and Decoding Efficiency

DuoAttention reduces memory usage and latency during both decoding and pre-filling stages. The memory complexity of streaming heads is reduced to a constant size (O(1)), while retrieval heads maintain a full cache to handle long-context scenarios. Figure 3 shows the chunked pre-filling process, illustrating how DuoAttention retains efficiency by limiting the KV cache for streaming heads.

4. Results

4.1. Memory and Latency Reductions

DuoAttention achieves substantial reductions in memory usage and latency compared to full attention models. The memory reductions are up to 2.55× for Multi-Head Attention (MHA) and 1.67× for Grouped-Query Attention (GQA) models. The decoding speed increases by up to 2.18× for MHA models and 1.50× for GQA models. DuoAttention's improvements scale with context length, making it ideal for long-context applications like document summarization or multi-turn dialogue systems.

Refer to caption
Refer to caption

Moreover, combining DuoAttention with quantization techniques further boosts its memory efficiency, allowing models like Llama-3-8B to handle up to 3.3 million contextual tokens on a single A100 GPU.

Refer to caption

4.2. Accuracy

Despite the efficiency gains, DuoAttention maintains high accuracy across various benchmarks. Figure 6 compares DuoAttention's performance with other methods, showing its ability to handle long-context tasks such as Needle-in-a-Haystack and LongBench. It outperforms existing techniques like H2O and StreamingLLM by providing a superior trade-off between KV cache budget and accuracy.

Refer to caption
Refer to caption

Video

Citation

@article{xiao2024duo,
title={DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads},
author={Xiao, Guangxuan and Tang, Jiaming and Zuo, Jingwei and Guo, Junxian and Yang, Shang and Tang, Haotian and Fu, Yao and Han, Song},
journal={arXiv},
year={2024}
}

Media

No media articles found.

Acknowledgment

We thank MIT-IBM Watson AI Lab, MIT and Amazon Science Hub, MIT AI Hardware Program, National Science Foundation, Hyundai and Samsung for supporting this research. We thank NVIDIA for donating the DGX server.

Team Members