KV Cache: Accelerating Autoregressive Inference
How the key-value cache eliminates redundant attention computation during text generation. Understand cache structure, memory cost, and when caching breaks down.
The Problem: Recomputing Past Tokens
Autoregressive generation produces one token at a time. To generate token N, the model runs a forward pass over all N-1 previous tokens plus the new token. Without caching, generating a 1000-token response requires:
Token 1: attend over 1 token ā 1 attention computation
Token 2: attend over 2 tokens ā 2 attention computations
...
Token 1000: attend over 1000 tokens ā 1000 attention computationsTotal attention operations: 1 + 2 + ... + 1000 = 500,500 ā an O(n²) problem.
The key insight: for any token position j at step N, the Keys and Values it contributes are exactly the same at step N+1. They depend only on the token at position j, which never changes. Only the new token (position N) adds new K and V vectors.
How KV Cache Works
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class CachedMultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(
self,
x: torch.Tensor, # (batch, seq_len, d_model) ā only NEW tokens
past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
batch, seq_len, _ = x.shape
# Project new tokens only
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Concatenate with cached keys and values from previous steps
if past_key_values is not None:
past_k, past_v = past_key_values
k = torch.cat([past_k, k], dim=2) # (batch, heads, past+new, head_dim)
v = torch.cat([past_v, v], dim=2)
# Scaled dot product attention over full (past + new) keys
scale = self.head_dim ** -0.5
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# Reshape and project
attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
output = self.out_proj(attn_output)
# Return output AND updated cache
new_cache = (k, v)
return output, new_cacheCache During Generation
def generate_with_cache(
model,
prompt_ids: torch.Tensor, # (batch, prompt_len)
max_new_tokens: int = 100,
) -> torch.Tensor:
"""Autoregressive generation with KV cache."""
generated = prompt_ids
# First forward pass: process full prompt, build initial cache
with torch.no_grad():
# Process the full prompt
outputs = model(prompt_ids, use_cache=True)
past_key_values = outputs.past_key_values # One (K, V) pair per layer
next_token_logits = outputs.logits[:, -1, :] # Last position
for _ in range(max_new_tokens):
# Sample next token
next_token = next_token_logits.argmax(dim=-1, keepdim=True) # (batch, 1)
generated = torch.cat([generated, next_token], dim=-1)
if next_token.item() == model.config.eos_token_id:
break
# Forward pass with only the new token ā cache handles the rest
with torch.no_grad():
outputs = model(
next_token,
past_key_values=past_key_values, # Reuse cached K/V
use_cache=True,
)
past_key_values = outputs.past_key_values # Updated cache
next_token_logits = outputs.logits[:, -1, :]
return generated
# With HuggingFace:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")
inputs = tokenizer("Warfarin is", return_tensors="pt")
output = model.generate(
**inputs,
max_new_tokens=100,
use_cache=True, # Enabled by default
)Memory Cost of KV Cache
Each layer stores K and V matrices. The memory cost per token is:
Memory per token = 2 Ć num_layers Ć num_heads Ć head_dim Ć dtype_bytesFor LLaMA-3-8B:
num_layers = 32num_heads = 32head_dim = 128dtype = float16 = 2 bytes
num_layers = 32
num_heads = 32
head_dim = 128
dtype_bytes = 2 # float16
bytes_per_token = 2 * num_layers * num_heads * head_dim * dtype_bytes
kb_per_token = bytes_per_token / 1024
print(f"KV cache per token: {bytes_per_token} bytes = {kb_per_token:.1f} KB")
# Output: 524,288 bytes = 512.0 KB per token
# For 4096-token context:
context_len = 4096
total_kv_gb = (bytes_per_token * context_len) / (1024 ** 3)
print(f"KV cache for {context_len} tokens: {total_kv_gb:.2f} GB")
# Output: KV cache for 4096 tokens: 2.00 GBFor a 128k context window (Claude), the KV cache alone is enormous ā requiring specialized solutions like quantized KV cache, page attention, or offloading to CPU.
Grouped Query Attention (GQA)
Most modern models reduce KV cache memory by sharing K and V heads across multiple Q heads:
# Standard MHA: num_kv_heads = num_q_heads
# GQA: num_kv_heads < num_q_heads (e.g., 8 KV heads for 32 Q heads)
# MQA: num_kv_heads = 1 (extreme sharing)
# LLaMA-3-8B uses GQA with 32 Q heads and 8 KV heads
# KV cache memory is reduced by 32/8 = 4Ć
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
super().__init__()
assert num_q_heads % num_kv_heads == 0
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.groups = num_q_heads // num_kv_heads
self.head_dim = d_model // num_q_heads
self.q_proj = nn.Linear(d_model, num_q_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, past_key_values=None):
batch, seq_len, _ = x.shape
q = self.q_proj(x).view(batch, seq_len, self.num_q_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
if past_key_values is not None:
past_k, past_v = past_key_values
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
# Repeat KV heads to match Q heads for attention computation
k = k.repeat_interleave(self.groups, dim=1) # (batch, num_q_heads, seq, head_dim)
v = v.repeat_interleave(self.groups, dim=1)
scale = self.head_dim ** -0.5
attn = F.softmax(torch.matmul(q, k.transpose(-2, -1)) * scale, dim=-1)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.out_proj(out), (k[:, ::self.groups], v[:, ::self.groups])Paged Attention and vLLM
Serving LLMs in production requires handling variable-length requests with different KV cache sizes. PagedAttention (from vLLM) manages KV cache like virtual memory ā allocated in fixed-size pages (blocks), preventing waste from internal fragmentation:
# Using vLLM for production serving (handles KV cache automatically)
from vllm import LLM, SamplingParams
llm = LLM(model="meta-llama/Llama-3-8B", dtype="float16")
params = SamplingParams(temperature=0.8, max_tokens=200)
prompts = [
"Explain warfarin dosing in 3 sentences.",
"What is the mechanism of action of metformin?",
]
outputs = llm.generate(prompts, params)
for output in outputs:
print(output.outputs[0].text)vLLM achieves 10ā24Ć higher throughput vs naive HuggingFace generation, primarily through PagedAttention and continuous batching.
When KV Cache Breaks Down
| Scenario | Problem | Solution | |---|---|---| | Very long context (128k+) | Cache exceeds GPU memory | Quantize cache to int8, offload to CPU | | Large batch sizes | Memory Ć batch_size | Reduce batch size or use MQA/GQA | | Streaming with many users | Each user needs their own cache | PagedAttention (vLLM) | | System prompt reuse | Recomputing same prompt each call | Prompt caching (prefix caching) | | Fine-tuning with cache | Cache invalidated by gradient updates | Cache only valid during inference |
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.