LLMs Deep Dive · Lesson 7 of 24
KV Cache: How Inference Gets Fast
Why the KV Cache Exists
During autoregressive generation, each new token requires computing attention over all previous tokens. Without caching, the K and V vectors for all past tokens are recomputed at every step:
Without KV cache — generating token t:
Compute Q, K, V for ALL tokens 0..t
Cost: O(t) per token → O(n²) total for n tokens
With KV cache — generating token t:
Q, K, V for tokens 0..t-1 are stored from previous steps
Compute only Q, K, V for token t
Cost: O(1) per token → O(n) total (plus O(n) memory)The KV cache is the dominant memory consumer during inference at long sequences.
What Is Cached
class KVCache:
def __init__(self, num_layers, num_kv_heads, max_seq_len, head_dim, dtype=torch.float16):
self.k_cache = torch.zeros(num_layers, num_kv_heads, max_seq_len, head_dim, dtype=dtype)
self.v_cache = torch.zeros(num_layers, num_kv_heads, max_seq_len, head_dim, dtype=dtype)
self.seq_len = 0
def update(self, layer_idx, k_new, v_new):
t = k_new.shape[2] # new tokens
self.k_cache[layer_idx, :, self.seq_len:self.seq_len+t] = k_new
self.v_cache[layer_idx, :, self.seq_len:self.seq_len+t] = v_new
def get(self, layer_idx):
return (
self.k_cache[layer_idx, :, :self.seq_len],
self.v_cache[layer_idx, :, :self.seq_len]
)KV Cache Memory Calculation
For LLaMA 2 7B:
layers = 32
num_kv_heads = 32 (no GQA in 7B)
head_dim = 128
dtype = float16 (2 bytes)
KV cache per token:
= 2 (K and V) × 32 layers × 32 heads × 128 dim × 2 bytes
= 2 × 32 × 32 × 128 × 2 = 524,288 bytes ≈ 0.5 MB per token
For max_seq_len = 4096 and batch_size = 1:
Total = 4096 × 0.5 MB = 2 GB
For batch_size = 32 with 4K context:
Total = 32 × 2 GB = 64 GB
(LLaMA 7B weights ≈ 14GB in fp16, so KV cache > model weights at large batch)GQA Reduction
LLaMA 2 70B uses 8 KV heads instead of 64 Q heads:
LLaMA 2 70B KV cache per token:
2 × 80 layers × 8 KV heads × 128 dim × 2 bytes
= 2 × 80 × 8 × 128 × 2 = 327,680 bytes ≈ 0.31 MB
vs. MHA hypothetical (64 KV heads):
2 × 80 × 64 × 128 × 2 = 2,621,440 bytes ≈ 2.5 MB
GQA gives 8× KV cache reduction for 70B.KV Cache Quantisation
Store K/V in int8 instead of float16:
fp16 → int8: 2 bytes → 1 byte = 2× memory reduction
Quantisation of KV cache:
Compute attention in fp16/bf16 normally
Before storing: quantise K, V to int8 (scale + zero-point per head)
At attention time: dequantise back to fp16
Quality impact: minimal for most tasks (0.1-0.5% degradation)
Memory saving: 2× — for a 32-batch × 4K context setup, 32GB → 16GB
Libraries: llm.int8() (bitsandbytes), AWQ, GPTQ all support KV quantisationPaged Attention (vLLM)
Standard KV cache pre-allocates max_seq_len × batch_size memory, even if requests are much shorter:
Problem: 32 concurrent requests, max 4K tokens each
Pre-allocated: 32 × 4K = 128K token slots × 0.5MB = 64GB
Actual usage: average request is 500 tokens = 8GB
→ 56GB wasted, limits throughput
vLLM paged attention:
KV cache is divided into fixed-size PAGES (e.g., 16 tokens per page)
Pages are allocated on demand as the sequence grows
Pages for different requests can interleave in memory
Enables 2-4× throughput improvement over naive pre-allocationStreaming / Eviction for Long Contexts
For sequences longer than the available KV cache:
StreamingLLM (Xiao et al., 2023):
Keep the first k "sink" tokens (always high attention)
Keep a sliding window of recent tokens
Evict all tokens outside this window
Evicted tokens are lost — model cannot attend to them.
Used for: infinite-length streaming (always remember start + recent)
H2O (Heavy Hitter Oracle):
Track which past tokens have historically high attention scores
Preferentially keep those — "heavy hitters"
Better than uniform eviction for most tasksInterview Answer
"The KV cache stores the key and value vectors for all past tokens, so each new generation step only needs to compute attention for the new token — reducing time complexity from O(n²) to O(n) at the cost of O(n) memory. Memory cost is 2 × layers × KV heads × head_dim × 2 bytes per token; for a 7B model with 4K context and batch 32, this exceeds model weights. GQA reduces KV heads (8 instead of 32 for 70B → 8× savings). vLLM's paged attention allocates cache dynamically per request, eliminating over-allocation waste. KV quantisation to int8 halves the memory cost with minimal quality loss."