KV Cache in Large Language Models
How autoregressive transformers avoid recomputing attention — the mechanics, memory cost, and the optimizations (GQA, PagedAttention, quantization) that make long-context inference viable.
1. Introduction
When an LLM generates text token by token, naively running the transformer on the entire sequence for each new token is wasteful — most of the work has already been done. The Key-Value (KV) cache stores the intermediate K and V tensors from each attention layer so they can be reused across decoding steps.
This single optimization is the difference between an LLM that takes minutes per token and one that runs in real time. It is also the dominant memory cost during inference — often larger than the model weights themselves at long context lengths.
O(seq_len) memory per layer per head.
2. Self-Attention Recap
For each input token x, an attention layer projects it into three vectors:
- Q (query) — what this token is looking for
- K (key) — what each token offers as a "match target"
- V (value) — the content returned when a query matches a key
In a decoder-only LLM, attention is causal: token t can only attend to tokens 0..t. This is the property that makes KV caching possible — past keys and values never change.
3. The Problem: Recomputation
During autoregressive generation, the model produces one token at a time. Without a cache, generating token t+1 requires running the full forward pass over all t+1 tokens — including computing K and V for tokens 0..t that we already computed in previous steps.
O(n²) total attention compute and O(n²) projection compute on K and V — most of it duplicating earlier work.
4. The Solution: KV Cache
Because attention is causal and weights are fixed, the K and V vectors for token t never change after step t. We can compute them once and append to a cache. Each subsequent step:
- Computes Q, K, V only for the new token (1 row of work).
- Appends the new K and V to the cache.
- Computes attention as
softmax(Qnew · Kcachedᵀ) · Vcached.
O(n) projections to O(1) projections. The attention matmul itself is still O(n) per step (one query against n keys), but it is the cheaper of the two.
5. Memory Layout
The KV cache is a pair of 4-D tensors (one for K, one for V) stored per transformer layer:
The full cache for the model is this tensor pair replicated across num_layers. Each new token appends one slice along the seq_len axis at every layer.
6. Prefill vs Decode
Inference splits into two phases that have very different performance profiles:
| Phase | Input | Compute | Bottleneck |
|---|---|---|---|
| Prefill | Full prompt (n tokens) at once | n × full attention (parallel) | Compute (FLOPs) |
| Decode | 1 new token per step | 1 × cached attention | Memory bandwidth (KV reads) |
7. Memory Cost Analysis
The factor of 2 is because we store both K and V. Plugging in concrete numbers:
Example: Llama-2 7B (FP16)
num_layers = 32num_heads = 32head_dim = 128bytes_per_elem = 2(FP16)
| Context length | KV cache (1 sequence) | Notes |
|---|---|---|
| 2,048 | ~1 GB | Comfortable on consumer GPU |
| 4,096 | ~2 GB | Significant fraction of VRAM |
| 16,384 | ~8 GB | Larger than the 7B model weights |
| 32,768 | ~16 GB | Single sequence fills a 24 GB GPU |
8. MHA, GQA, MQA — Shrinking the Cache
The biggest architectural lever to reduce KV cache size is sharing K and V across query heads. There are three common patterns:
- Multi-Head Attention (MHA) — every query head has its own K, V head. Largest cache.
- Grouped-Query Attention (GQA) — Q heads share K, V in groups. Compromise. Used by Llama 2/3.
- Multi-Query Attention (MQA) — all Q heads share a single K, V head. Smallest cache, some quality loss.
| Variant | KV heads | Cache size | Quality | Used by |
|---|---|---|---|---|
| MHA | = num Q heads | 1.0× | Best | GPT-2/3, original Transformer |
| GQA | groups (e.g. 8) | ~⅛–¼× | Near-MHA | Llama 2/3, Mistral |
| MQA | 1 | 1/n× | Slight degradation | PaLM, Falcon |
9. PagedAttention (vLLM)
Traditional KV caches are stored as contiguous tensors per sequence. This causes severe fragmentation when serving many concurrent requests with varying lengths — a 4096-slot buffer for a sequence that only generates 200 tokens wastes ~95% of its memory.
PagedAttention (introduced by vLLM) borrows the idea of virtual memory from operating systems. The KV cache is split into fixed-size blocks (e.g. 16 tokens each), and a per-sequence block table maps logical positions to physical blocks anywhere in GPU memory.
10. Other Optimizations
Quantization
Store K and V in INT8, INT4, or FP8 instead of FP16. INT8 halves the cache; INT4 quarters it. Quality cost is usually small if done per-channel or per-token.
Sliding-Window Attention
Bound the attention window to the most recent w tokens (e.g. Mistral uses w=4096). The cache size is capped at w regardless of total sequence length, at the cost of losing direct attention to earlier tokens.
Cache Eviction (H₂O, StreamingLLM)
Keep only the most attention-relevant tokens or a "sink" of early tokens plus a sliding window. Lets very long sequences run with bounded cache.
Prefix Caching
When many requests share a common prefix (system prompt, few-shot examples), compute the prefix's KV cache once and reuse it. Massive savings for chat APIs with templated prompts.
Offloading
Spill cold portions of the cache to CPU RAM or NVMe and stream back on demand. Trades latency for capacity.
| Optimization | Cache reduction | Quality impact |
|---|---|---|
| GQA (groups=8) | ~4–8× | Negligible |
| MQA | up to num_heads× | Small |
| INT8 quantization | 2× | Small |
| INT4 quantization | 4× | Moderate |
| Sliding window | cap at w | Lose long-range attention |
| PagedAttention | removes fragmentation | None |
| Prefix caching | per-prefix amortization | None |
11. Reference Implementation (PyTorch sketch)
A minimal causal self-attention layer with KV cache append:
class CachedAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.out = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, kv_cache=None):
# x: [batch, new_tokens, d_model]
B, T, _ = x.shape
qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
# each: [B, T, n_heads, head_dim] -> [B, n_heads, T, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if kv_cache is not None:
past_k, past_v = kv_cache
k = torch.cat([past_k, k], dim=2) # append along seq dim
v = torch.cat([past_v, v], dim=2)
new_cache = (k, v)
# attention against full (cached + new) keys/values
scores = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)
# causal mask only needed during prefill; during decode T=1
attn = scores.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, T, -1)
return self.out(out), new_cache
Generation loop:
cache = None
tokens = prompt_tokens
for _ in range(max_new_tokens):
x = embed(tokens if cache is None else tokens[:, -1:])
logits, cache = model(x, kv_cache=cache)
next_tok = logits[:, -1].argmax(-1, keepdim=True)
tokens = torch.cat([tokens, next_tok], dim=1)
Note how after the first call (prefill), every subsequent call passes only the single new token — the cache provides the rest.
12. Summary
- What: KV cache stores past K and V tensors so attention does not recompute them at every decode step.
- Why it works: Causal masking guarantees past K, V do not change as new tokens arrive.
- What it costs: Memory grows as
2 × layers × heads × head_dim × seq_len × bytes. At long context this dominates over model weights. - Phases: Prefill is compute-bound; decode is memory-bound, gated by KV cache bandwidth.
- Make it cheaper: GQA/MQA shrink the cache by sharing K,V across heads; quantization halves or quarters the bytes; PagedAttention removes fragmentation across concurrent requests.
- Make it longer: Sliding windows, eviction (H₂O, StreamingLLM), CPU/NVMe offloading.
- Make it free: Prefix caching reuses KV across requests that share a prompt prefix.