← Back to Learning Hub

KV Cache in Large Language Models

How autoregressive transformers avoid recomputing attention — the mechanics, memory cost, and the optimizations (GQA, PagedAttention, quantization) that make long-context inference viable.

1. Introduction

When an LLM generates text token by token, naively running the transformer on the entire sequence for each new token is wasteful — most of the work has already been done. The Key-Value (KV) cache stores the intermediate K and V tensors from each attention layer so they can be reused across decoding steps.

This single optimization is the difference between an LLM that takes minutes per token and one that runs in real time. It is also the dominant memory cost during inference — often larger than the model weights themselves at long context lengths.

TL;DR: KV cache trades memory for compute. Each generated token reuses cached K/V from all prior tokens instead of recomputing them. Cost: O(seq_len) memory per layer per head.

2. Self-Attention Recap

For each input token x, an attention layer projects it into three vectors:

  • Q (query) — what this token is looking for
  • K (key) — what each token offers as a "match target"
  • V (value) — the content returned when a query matches a key
Attention(Q, K, V) = softmax(Q · Kᵀ / √dk) · V
Input X W_Q W_K W_V Q K V Attention softmax(QKᵀ/√d) V Out
Figure 1 — Q, K, V projections inside one attention layer.

In a decoder-only LLM, attention is causal: token t can only attend to tokens 0..t. This is the property that makes KV caching possible — past keys and values never change.

3. The Problem: Recomputation

During autoregressive generation, the model produces one token at a time. Without a cache, generating token t+1 requires running the full forward pass over all t+1 tokens — including computing K and V for tokens 0..t that we already computed in previous steps.

Without KV Cache (recompute everything every step): Step 1: K₁V₁ Step 2: K₁V₁ K₂V₂ Step 3: K₁V₁ K₂V₂ K₃V₃ Step 4: K₁V₁ K₂V₂ K₃V₃ K₄V₄ new compute redundant recompute Cost per step grows linearly: O(n²) total over a sequence of length n.
Figure 2 — Naive autoregressive generation recomputes all prior K,V at every step.
Cost without cache: Generating an n-token response requires O(n²) total attention compute and O(n²) projection compute on K and V — most of it duplicating earlier work.

4. The Solution: KV Cache

Because attention is causal and weights are fixed, the K and V vectors for token t never change after step t. We can compute them once and append to a cache. Each subsequent step:

  1. Computes Q, K, V only for the new token (1 row of work).
  2. Appends the new K and V to the cache.
  3. Computes attention as softmax(Qnew · Kcachedᵀ) · Vcached.
With KV Cache (only new token's K,V is computed): Step 1: K₁V₁ Step 2: K₁V₁ K₂V₂ Step 3: K₁V₁ K₂V₂ K₃V₃ Step 4: K₁V₁ K₂V₂ K₃V₃ K₄V₄ new compute read from cache (free) Per-step compute is constant: O(n) total over a sequence of length n.
Figure 3 — With caching, each step computes only one new column of K,V.
Speedup: Per-step cost drops from O(n) projections to O(1) projections. The attention matmul itself is still O(n) per step (one query against n keys), but it is the cheaper of the two.

5. Memory Layout

The KV cache is a pair of 4-D tensors (one for K, one for V) stored per transformer layer:

shape = [batch_size, num_heads, seq_len, head_dim]
seq_len → num_heads ↓ head_dim batch_size (stack of these)
Figure 4 — KV cache tensor layout per layer (×2 for K and V, ×num_layers total).

The full cache for the model is this tensor pair replicated across num_layers. Each new token appends one slice along the seq_len axis at every layer.

6. Prefill vs Decode

Inference splits into two phases that have very different performance profiles:

PhaseInputComputeBottleneck
Prefill Full prompt (n tokens) at once n × full attention (parallel) Compute (FLOPs)
Decode 1 new token per step 1 × cached attention Memory bandwidth (KV reads)
Prefill (parallel, compute-bound) Process all prompt tokens together Build initial KV cache in one forward pass Decode (sequential, memory-bound) tok 1 tok 2 tok 3 tok 4 → each step reads the entire growing cache from VRAM
Figure 5 — Prefill is one parallel pass; decode is many small sequential passes.
Why decode is memory-bound: Each decode step does very little arithmetic per byte loaded. The GPU spends most of its time streaming the KV cache from HBM rather than doing matmuls. This is why KV cache size directly limits decode throughput.

7. Memory Cost Analysis

cache_bytes = 2 × batch × seq_len × num_layers × num_heads × head_dim × bytes_per_elem

The factor of 2 is because we store both K and V. Plugging in concrete numbers:

Example: Llama-2 7B (FP16)

  • num_layers = 32
  • num_heads = 32
  • head_dim = 128
  • bytes_per_elem = 2 (FP16)
per token = 2 × 32 × 32 × 128 × 2 = 524,288 bytes ≈ 512 KB
Context lengthKV cache (1 sequence)Notes
2,048~1 GBComfortable on consumer GPU
4,096~2 GBSignificant fraction of VRAM
16,384~8 GBLarger than the 7B model weights
32,768~16 GBSingle sequence fills a 24 GB GPU
0 4 GB 8 GB 12 GB 16 GB 2K 4K 8K 16K 32K Context length (tokens) 1G 2G 4G 8G 16G KV cache scales linearly with context length
Figure 6 — Llama-2 7B FP16 KV cache size vs context length (single sequence).
Implication: At long context, KV cache — not weights — becomes the limiting factor for batch size and concurrency.

8. MHA, GQA, MQA — Shrinking the Cache

The biggest architectural lever to reduce KV cache size is sharing K and V across query heads. There are three common patterns:

  • Multi-Head Attention (MHA) — every query head has its own K, V head. Largest cache.
  • Grouped-Query Attention (GQA) — Q heads share K, V in groups. Compromise. Used by Llama 2/3.
  • Multi-Query Attention (MQA) — all Q heads share a single K, V head. Smallest cache, some quality loss.
MHA — 8 Q heads, 8 KV heads Q heads 8 KV heads → full cache GQA — 8 Q heads, 2 KV heads (groups of 4) 2 KV heads → ¼ cache MQA — 8 Q heads, 1 KV head 1 KV head → 1/8 cache
Figure 7 — Sharing K,V across query heads shrinks the cache proportionally.
VariantKV headsCache sizeQualityUsed by
MHA= num Q heads1.0×BestGPT-2/3, original Transformer
GQAgroups (e.g. 8)~⅛–¼×Near-MHALlama 2/3, Mistral
MQA11/n×Slight degradationPaLM, Falcon

9. PagedAttention (vLLM)

Traditional KV caches are stored as contiguous tensors per sequence. This causes severe fragmentation when serving many concurrent requests with varying lengths — a 4096-slot buffer for a sequence that only generates 200 tokens wastes ~95% of its memory.

PagedAttention (introduced by vLLM) borrows the idea of virtual memory from operating systems. The KV cache is split into fixed-size blocks (e.g. 16 tokens each), and a per-sequence block table maps logical positions to physical blocks anywhere in GPU memory.

Logical KV (per sequence) Seq A: block 0 block 1 block 2 Seq B: block 0 block 1 Seq C: block 0 block 1 block 2 block 3 block table (per-seq mapping) Physical GPU memory (block pool) C-0 phys 0 A-0 phys 1 B-0 phys 2 C-1 phys 3 A-2 phys 4 B-1 phys 5 A-1 phys 6 C-3 phys 7 C-2 phys 8 free Logically contiguous, physically scattered → near-zero fragmentation.
Figure 8 — PagedAttention maps logical KV blocks to scattered physical blocks.
Wins: Near-zero internal fragmentation, enables prefix sharing across requests (same prompt prefix → shared physical blocks), supports copy-on-write for parallel sampling. vLLM achieves 2–4× higher throughput than naive systems largely from this.

10. Other Optimizations

Quantization

Store K and V in INT8, INT4, or FP8 instead of FP16. INT8 halves the cache; INT4 quarters it. Quality cost is usually small if done per-channel or per-token.

Sliding-Window Attention

Bound the attention window to the most recent w tokens (e.g. Mistral uses w=4096). The cache size is capped at w regardless of total sequence length, at the cost of losing direct attention to earlier tokens.

Cache Eviction (H₂O, StreamingLLM)

Keep only the most attention-relevant tokens or a "sink" of early tokens plus a sliding window. Lets very long sequences run with bounded cache.

Prefix Caching

When many requests share a common prefix (system prompt, few-shot examples), compute the prefix's KV cache once and reuse it. Massive savings for chat APIs with templated prompts.

Offloading

Spill cold portions of the cache to CPU RAM or NVMe and stream back on demand. Trades latency for capacity.

OptimizationCache reductionQuality impact
GQA (groups=8)~4–8×Negligible
MQAup to num_heads×Small
INT8 quantizationSmall
INT4 quantizationModerate
Sliding windowcap at wLose long-range attention
PagedAttentionremoves fragmentationNone
Prefix cachingper-prefix amortizationNone

11. Reference Implementation (PyTorch sketch)

A minimal causal self-attention layer with KV cache append:

class CachedAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, kv_cache=None):
        # x: [batch, new_tokens, d_model]
        B, T, _ = x.shape
        qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        # each: [B, T, n_heads, head_dim] -> [B, n_heads, T, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        if kv_cache is not None:
            past_k, past_v = kv_cache
            k = torch.cat([past_k, k], dim=2)   # append along seq dim
            v = torch.cat([past_v, v], dim=2)
        new_cache = (k, v)

        # attention against full (cached + new) keys/values
        scores = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)
        # causal mask only needed during prefill; during decode T=1
        attn = scores.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, T, -1)
        return self.out(out), new_cache

Generation loop:

cache = None
tokens = prompt_tokens
for _ in range(max_new_tokens):
    x = embed(tokens if cache is None else tokens[:, -1:])
    logits, cache = model(x, kv_cache=cache)
    next_tok = logits[:, -1].argmax(-1, keepdim=True)
    tokens = torch.cat([tokens, next_tok], dim=1)

Note how after the first call (prefill), every subsequent call passes only the single new token — the cache provides the rest.

12. Summary

  • What: KV cache stores past K and V tensors so attention does not recompute them at every decode step.
  • Why it works: Causal masking guarantees past K, V do not change as new tokens arrive.
  • What it costs: Memory grows as 2 × layers × heads × head_dim × seq_len × bytes. At long context this dominates over model weights.
  • Phases: Prefill is compute-bound; decode is memory-bound, gated by KV cache bandwidth.
  • Make it cheaper: GQA/MQA shrink the cache by sharing K,V across heads; quantization halves or quarters the bytes; PagedAttention removes fragmentation across concurrent requests.
  • Make it longer: Sliding windows, eviction (H₂O, StreamingLLM), CPU/NVMe offloading.
  • Make it free: Prefix caching reuses KV across requests that share a prompt prefix.
Bottom line: The KV cache is the single most important inference-time data structure in modern LLMs. Almost every serving-side optimization — from GQA in model architecture to PagedAttention in the runtime — exists to shrink it, share it, or stream it.