Learnixo
Back to blog
AI Systemsintermediate

KV Cache: Accelerating Autoregressive Inference

How the key-value cache eliminates redundant attention computation during text generation. Understand cache structure, memory cost, and when caching breaks down.

Asma Hafeez KhanMay 16, 20266 min read
TransformersKV CacheInferenceOptimization
Share:š•

The Problem: Recomputing Past Tokens

Autoregressive generation produces one token at a time. To generate token N, the model runs a forward pass over all N-1 previous tokens plus the new token. Without caching, generating a 1000-token response requires:

Token 1: attend over 1 token       → 1 attention computation
Token 2: attend over 2 tokens      → 2 attention computations
...
Token 1000: attend over 1000 tokens → 1000 attention computations

Total attention operations: 1 + 2 + ... + 1000 = 500,500 — an O(n²) problem.

The key insight: for any token position j at step N, the Keys and Values it contributes are exactly the same at step N+1. They depend only on the token at position j, which never changes. Only the new token (position N) adds new K and V vectors.


How KV Cache Works

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

class CachedMultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(
        self,
        x: torch.Tensor,                # (batch, seq_len, d_model) — only NEW tokens
        past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        batch, seq_len, _ = x.shape

        # Project new tokens only
        q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Concatenate with cached keys and values from previous steps
        if past_key_values is not None:
            past_k, past_v = past_key_values
            k = torch.cat([past_k, k], dim=2)  # (batch, heads, past+new, head_dim)
            v = torch.cat([past_v, v], dim=2)

        # Scaled dot product attention over full (past + new) keys
        scale = self.head_dim ** -0.5
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        # Reshape and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        output = self.out_proj(attn_output)

        # Return output AND updated cache
        new_cache = (k, v)
        return output, new_cache

Cache During Generation

Python
def generate_with_cache(
    model,
    prompt_ids: torch.Tensor,  # (batch, prompt_len)
    max_new_tokens: int = 100,
) -> torch.Tensor:
    """Autoregressive generation with KV cache."""
    generated = prompt_ids

    # First forward pass: process full prompt, build initial cache
    with torch.no_grad():
        # Process the full prompt
        outputs = model(prompt_ids, use_cache=True)
        past_key_values = outputs.past_key_values  # One (K, V) pair per layer
        next_token_logits = outputs.logits[:, -1, :]  # Last position

    for _ in range(max_new_tokens):
        # Sample next token
        next_token = next_token_logits.argmax(dim=-1, keepdim=True)  # (batch, 1)
        generated = torch.cat([generated, next_token], dim=-1)

        if next_token.item() == model.config.eos_token_id:
            break

        # Forward pass with only the new token — cache handles the rest
        with torch.no_grad():
            outputs = model(
                next_token,
                past_key_values=past_key_values,  # Reuse cached K/V
                use_cache=True,
            )
            past_key_values = outputs.past_key_values  # Updated cache
            next_token_logits = outputs.logits[:, -1, :]

    return generated

# With HuggingFace:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")

inputs = tokenizer("Warfarin is", return_tensors="pt")
output = model.generate(
    **inputs,
    max_new_tokens=100,
    use_cache=True,  # Enabled by default
)

Memory Cost of KV Cache

Each layer stores K and V matrices. The memory cost per token is:

Memory per token = 2 Ɨ num_layers Ɨ num_heads Ɨ head_dim Ɨ dtype_bytes

For LLaMA-3-8B:

  • num_layers = 32
  • num_heads = 32
  • head_dim = 128
  • dtype = float16 = 2 bytes
Python
num_layers = 32
num_heads = 32
head_dim = 128
dtype_bytes = 2  # float16

bytes_per_token = 2 * num_layers * num_heads * head_dim * dtype_bytes
kb_per_token = bytes_per_token / 1024
print(f"KV cache per token: {bytes_per_token} bytes = {kb_per_token:.1f} KB")
# Output: 524,288 bytes = 512.0 KB per token

# For 4096-token context:
context_len = 4096
total_kv_gb = (bytes_per_token * context_len) / (1024 ** 3)
print(f"KV cache for {context_len} tokens: {total_kv_gb:.2f} GB")
# Output: KV cache for 4096 tokens: 2.00 GB

For a 128k context window (Claude), the KV cache alone is enormous — requiring specialized solutions like quantized KV cache, page attention, or offloading to CPU.


Grouped Query Attention (GQA)

Most modern models reduce KV cache memory by sharing K and V heads across multiple Q heads:

Python
# Standard MHA: num_kv_heads = num_q_heads
# GQA: num_kv_heads < num_q_heads (e.g., 8 KV heads for 32 Q heads)
# MQA: num_kv_heads = 1 (extreme sharing)

# LLaMA-3-8B uses GQA with 32 Q heads and 8 KV heads
# KV cache memory is reduced by 32/8 = 4Ɨ

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.groups = num_q_heads // num_kv_heads
        self.head_dim = d_model // num_q_heads

        self.q_proj = nn.Linear(d_model, num_q_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, past_key_values=None):
        batch, seq_len, _ = x.shape

        q = self.q_proj(x).view(batch, seq_len, self.num_q_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)

        if past_key_values is not None:
            past_k, past_v = past_key_values
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        # Repeat KV heads to match Q heads for attention computation
        k = k.repeat_interleave(self.groups, dim=1)  # (batch, num_q_heads, seq, head_dim)
        v = v.repeat_interleave(self.groups, dim=1)

        scale = self.head_dim ** -0.5
        attn = F.softmax(torch.matmul(q, k.transpose(-2, -1)) * scale, dim=-1)
        out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.out_proj(out), (k[:, ::self.groups], v[:, ::self.groups])

Paged Attention and vLLM

Serving LLMs in production requires handling variable-length requests with different KV cache sizes. PagedAttention (from vLLM) manages KV cache like virtual memory — allocated in fixed-size pages (blocks), preventing waste from internal fragmentation:

Python
# Using vLLM for production serving (handles KV cache automatically)
from vllm import LLM, SamplingParams

llm = LLM(model="meta-llama/Llama-3-8B", dtype="float16")
params = SamplingParams(temperature=0.8, max_tokens=200)

prompts = [
    "Explain warfarin dosing in 3 sentences.",
    "What is the mechanism of action of metformin?",
]

outputs = llm.generate(prompts, params)
for output in outputs:
    print(output.outputs[0].text)

vLLM achieves 10–24Ɨ higher throughput vs naive HuggingFace generation, primarily through PagedAttention and continuous batching.


When KV Cache Breaks Down

| Scenario | Problem | Solution | |---|---|---| | Very long context (128k+) | Cache exceeds GPU memory | Quantize cache to int8, offload to CPU | | Large batch sizes | Memory Ɨ batch_size | Reduce batch size or use MQA/GQA | | Streaming with many users | Each user needs their own cache | PagedAttention (vLLM) | | System prompt reuse | Recomputing same prompt each call | Prompt caching (prefix caching) | | Fine-tuning with cache | Cache invalidated by gradient updates | Cache only valid during inference |

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.