Context Window: Limits, Tradeoffs, and Extensions
Why context windows are limited, the quadratic attention bottleneck, how modern models extend context, and practical strategies for working within limits.
What is the Context Window?
The context window is the maximum number of tokens a transformer can process in a single forward pass. Everything the model can "see" at inference time ā the system prompt, conversation history, documents, and the prompt ā must fit within it.
Modern context windows:
- GPT-4o: 128k tokens
- Claude 3.5 Sonnet: 200k tokens
- Gemini 1.5 Pro: 1M tokens
- LLaMA-3-8B (base): 8k tokens
The Quadratic Bottleneck
Standard self-attention has O(n²) time and memory complexity in sequence length n. For each of the n tokens, attention computes a weighted sum over all n tokens ā n Ć n attention weights:
import torch
def attention_memory_estimate(seq_len: int, num_heads: int, dtype_bytes: int = 2) -> dict:
"""Estimate attention matrix memory in GB."""
# Attention weights: (batch=1, heads, seq, seq)
attn_weights_elements = num_heads * seq_len * seq_len
attn_weights_bytes = attn_weights_elements * dtype_bytes
attn_weights_gb = attn_weights_bytes / (1024 ** 3)
return {
"seq_len": seq_len,
"attn_matrix_gb": attn_weights_gb,
"attn_matrix_elements": attn_weights_elements,
}
for seq_len in [1_000, 4_096, 32_000, 128_000]:
est = attention_memory_estimate(seq_len, num_heads=32)
print(f"seq={seq_len:>7}: attention matrix = {est['attn_matrix_gb']:.2f} GB")
# seq= 1_000: attention matrix = 0.00 GB
# seq= 4_096: attention matrix = 0.03 GB
# seq= 32_000: attention matrix = 2.00 GB
# seq=128_000: attention matrix = 32.00 GB128k context with 32 heads at float16 requires 32GB just for the attention weight matrices ā before any other computation.
Positional Encoding and Context Limits
Learned absolute position embeddings (BERT, GPT-2) fail on sequences longer than they were trained on ā there are no learned vectors for positions beyond max_seq_len. This is one reason modern models switched to RoPE and ALiBi.
RoPE extrapolation: RoPE encodes relative positions, but the model still sees absolute position values during training. At inference with sequences longer than training length, the absolute frequency values are out-of-distribution. Solutions:
- Dynamic NTK scaling: Adjust the base frequency of RoPE to scale with context length
- YaRN: A more sophisticated RoPE scaling method used by Mistral and others
def apply_ntk_scaling(base: float, scale_factor: float, head_dim: int) -> float:
"""Compute adjusted RoPE base for length extrapolation."""
# Increases the base to make frequencies lower (longer wavelengths)
# Allows the model to handle longer sequences
return base * (scale_factor ** (head_dim / (head_dim - 2)))
original_base = 10000.0
scale_factor = 4.0 # Supporting 4Ć longer context
head_dim = 128
new_base = apply_ntk_scaling(original_base, scale_factor, head_dim)
print(f"Adjusted RoPE base: {new_base:.0f}") # ~500,000Flash Attention: Memory-Efficient Long Context
Flash Attention (Dao et al., 2022) reformulates attention computation to avoid materializing the full nĆn attention matrix in HBM (GPU memory). Instead, it computes attention in tiles using SRAM:
Standard attention: O(n²) HBM reads/writes
Flash Attention: O(n) HBM reads/writes (recomputes tiles on the fly)Flash Attention achieves the same mathematical result with:
- Up to 10Ć less memory for the attention computation
- 2ā4Ć faster wall-clock time
- Linear (not quadratic) memory scaling in sequence length
# PyTorch 2.0+ includes Flash Attention via scaled_dot_product_attention
import torch
import torch.nn.functional as F
def attention_flash(q, k, v, causal=True):
"""Uses Flash Attention when available via PyTorch SDPA."""
# q, k, v: (batch, heads, seq, head_dim)
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
return F.scaled_dot_product_attention(q, k, v, is_causal=causal)Sparse Attention Patterns
For very long sequences, sparse attention restricts each token to attend to only a subset of other tokens:
Local windowed attention: each token attends to the W nearest tokens only ā O(n Ć W) instead of O(n²)
Sliding window: Mistral 7B uses a window of 4096 tokens per layer
Global tokens: A few designated tokens attend to everything; all others attend only locallydef local_window_attention(q, k, v, window_size: int):
"""Each token attends to only window_size past tokens."""
seq_len = q.shape[2]
output = torch.zeros_like(q)
for i in range(seq_len):
start = max(0, i - window_size + 1)
# Only attend to tokens in [start, i]
k_window = k[:, :, start:i+1, :]
v_window = v[:, :, start:i+1, :]
q_i = q[:, :, i:i+1, :]
scale = q.shape[-1] ** -0.5
attn = F.softmax(torch.matmul(q_i, k_window.transpose(-2, -1)) * scale, dim=-1)
output[:, :, i:i+1, :] = torch.matmul(attn, v_window)
return outputMistral 7B combines sliding window attention with a rolling KV cache ā only the last 4096 K/V pairs are kept per layer, regardless of total generation length.
Practical: Working Within Context Limits
Chunking long documents
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")
def chunk_document(
text: str,
max_chunk_tokens: int = 3_000,
overlap_tokens: int = 200,
) -> list[str]:
"""Split document into chunks that fit within context."""
tokens = tokenizer.encode(text)
chunks = []
start = 0
while start < len(tokens):
end = min(start + max_chunk_tokens, len(tokens))
chunk_tokens = tokens[start:end]
chunks.append(tokenizer.decode(chunk_tokens))
if end >= len(tokens):
break
start = end - overlap_tokens # Overlap to preserve context at boundaries
return chunks
# For a 50k-token document with an 8k context model:
chunks = chunk_document(long_document, max_chunk_tokens=6_000)
print(f"Split into {len(chunks)} chunks")Tracking token budget at runtime
import tiktoken
enc = tiktoken.encoding_for_model("gpt-4o")
class ContextBudgetManager:
def __init__(self, max_tokens: int = 120_000):
self.max_tokens = max_tokens
self.system_reserve = 2_000 # Reserve for system prompt overhead
def available(self, system_prompt: str, conversation: list[dict]) -> int:
used = len(enc.encode(system_prompt))
for msg in conversation:
used += len(enc.encode(msg["content"])) + 4 # 4 tokens overhead per message
return self.max_tokens - self.system_reserve - used
def trim_conversation(self, system_prompt: str, conversation: list[dict], keep_last_n: int = 10) -> list[dict]:
"""Remove oldest messages until within budget."""
while len(conversation) > keep_last_n:
available = self.available(system_prompt, conversation)
if available > 1_000: # Enough buffer
break
conversation = conversation[1:] # Drop oldest
return conversationContext Window vs Effective Context
Even when a model has a 128k context window, it doesn't uniformly use all of it. Research (the "lost in the middle" phenomenon) shows that:
- Models attend well to content at the beginning of context
- Models attend well to content at the end of context
- Content in the middle of a very long context is often effectively ignored
Practical implication: Put the most important instructions at the beginning (system prompt) and most relevant documents at the end (immediately before the query). Avoid burying critical information in the middle of a long context.
def order_context_strategically(system_prompt, documents, query):
"""Order context to minimize lost-in-the-middle effect."""
# Most important documents go at the END (closest to query)
# Sort by relevance; most relevant last
sorted_docs = sorted(documents, key=lambda d: d["relevance_score"])
context_parts = [
system_prompt,
*[doc["text"] for doc in sorted_docs], # Least relevant first
f"\nQuestion: {query}",
]
return "\n\n".join(context_parts)Context Window Summary
| Model | Context | Attention | Extension Method | |---|---|---|---| | BERT-base | 512 | Standard | Learned absolute positions | | GPT-2 | 1024 | Standard | Learned absolute positions | | LLaMA-2-7B | 4096 | GQA | RoPE | | Mistral-7B | 32k | Sliding window | RoPE + SWA | | LLaMA-3-8B | 8k (128k with tuning) | GQA | RoPE | | Claude 3.5 | 200k | Unknown | Unknown | | Gemini 1.5 | 1M | Ring attention | Unknown |
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.