Learnixo

LLMs Deep Dive · Lesson 7 of 24

KV Cache: How Inference Gets Fast

Why the KV Cache Exists

During autoregressive generation, each new token requires computing attention over all previous tokens. Without caching, the K and V vectors for all past tokens are recomputed at every step:

Without KV cache — generating token t:
  Compute Q, K, V for ALL tokens 0..t
  Cost: O(t) per token → O(n²) total for n tokens

With KV cache — generating token t:
  Q, K, V for tokens 0..t-1 are stored from previous steps
  Compute only Q, K, V for token t
  Cost: O(1) per token → O(n) total (plus O(n) memory)

The KV cache is the dominant memory consumer during inference at long sequences.


What Is Cached

Python
class KVCache:
    def __init__(self, num_layers, num_kv_heads, max_seq_len, head_dim, dtype=torch.float16):
        self.k_cache = torch.zeros(num_layers, num_kv_heads, max_seq_len, head_dim, dtype=dtype)
        self.v_cache = torch.zeros(num_layers, num_kv_heads, max_seq_len, head_dim, dtype=dtype)
        self.seq_len = 0

    def update(self, layer_idx, k_new, v_new):
        t = k_new.shape[2]  # new tokens
        self.k_cache[layer_idx, :, self.seq_len:self.seq_len+t] = k_new
        self.v_cache[layer_idx, :, self.seq_len:self.seq_len+t] = v_new

    def get(self, layer_idx):
        return (
            self.k_cache[layer_idx, :, :self.seq_len],
            self.v_cache[layer_idx, :, :self.seq_len]
        )

KV Cache Memory Calculation

For LLaMA 2 7B:
  layers = 32
  num_kv_heads = 32  (no GQA in 7B)
  head_dim = 128
  dtype = float16 (2 bytes)

KV cache per token:
  = 2 (K and V) × 32 layers × 32 heads × 128 dim × 2 bytes
  = 2 × 32 × 32 × 128 × 2 = 524,288 bytes ≈ 0.5 MB per token

For max_seq_len = 4096 and batch_size = 1:
  Total = 4096 × 0.5 MB = 2 GB

For batch_size = 32 with 4K context:
  Total = 32 × 2 GB = 64 GB
  (LLaMA 7B weights ≈ 14GB in fp16, so KV cache > model weights at large batch)

GQA Reduction

LLaMA 2 70B uses 8 KV heads instead of 64 Q heads:

LLaMA 2 70B KV cache per token:
  2 × 80 layers × 8 KV heads × 128 dim × 2 bytes
  = 2 × 80 × 8 × 128 × 2 = 327,680 bytes ≈ 0.31 MB

vs. MHA hypothetical (64 KV heads):
  2 × 80 × 64 × 128 × 2 = 2,621,440 bytes ≈ 2.5 MB

GQA gives 8× KV cache reduction for 70B.

KV Cache Quantisation

Store K/V in int8 instead of float16:

fp16 → int8: 2 bytes → 1 byte = 2× memory reduction

Quantisation of KV cache:
  Compute attention in fp16/bf16 normally
  Before storing: quantise K, V to int8 (scale + zero-point per head)
  At attention time: dequantise back to fp16

Quality impact: minimal for most tasks (0.1-0.5% degradation)
Memory saving: 2× — for a 32-batch × 4K context setup, 32GB → 16GB

Libraries: llm.int8() (bitsandbytes), AWQ, GPTQ all support KV quantisation

Paged Attention (vLLM)

Standard KV cache pre-allocates max_seq_len × batch_size memory, even if requests are much shorter:

Problem: 32 concurrent requests, max 4K tokens each
  Pre-allocated: 32 × 4K = 128K token slots × 0.5MB = 64GB
  Actual usage: average request is 500 tokens = 8GB
  → 56GB wasted, limits throughput

vLLM paged attention:
  KV cache is divided into fixed-size PAGES (e.g., 16 tokens per page)
  Pages are allocated on demand as the sequence grows
  Pages for different requests can interleave in memory
  Enables 2-4× throughput improvement over naive pre-allocation

Streaming / Eviction for Long Contexts

For sequences longer than the available KV cache:

StreamingLLM (Xiao et al., 2023):
  Keep the first k "sink" tokens (always high attention)
  Keep a sliding window of recent tokens
  Evict all tokens outside this window

Evicted tokens are lost — model cannot attend to them.
Used for: infinite-length streaming (always remember start + recent)

H2O (Heavy Hitter Oracle):
  Track which past tokens have historically high attention scores
  Preferentially keep those — "heavy hitters"
  Better than uniform eviction for most tasks

Interview Answer

"The KV cache stores the key and value vectors for all past tokens, so each new generation step only needs to compute attention for the new token — reducing time complexity from O(n²) to O(n) at the cost of O(n) memory. Memory cost is 2 × layers × KV heads × head_dim × 2 bytes per token; for a 7B model with 4K context and batch 32, this exceeds model weights. GQA reduces KV heads (8 instead of 32 for 70B → 8× savings). vLLM's paged attention allocates cache dynamically per request, eliminating over-allocation waste. KV quantisation to int8 halves the memory cost with minimal quality loss."