Learnixo
Back to blog
AI Systemsintermediate

Context Window: Limits, Tradeoffs, and Extensions

Why context windows are limited, the quadratic attention bottleneck, how modern models extend context, and practical strategies for working within limits.

Asma Hafeez KhanMay 16, 20266 min read
TransformersContext WindowLong ContextArchitecture
Share:š•

What is the Context Window?

The context window is the maximum number of tokens a transformer can process in a single forward pass. Everything the model can "see" at inference time — the system prompt, conversation history, documents, and the prompt — must fit within it.

Modern context windows:

  • GPT-4o: 128k tokens
  • Claude 3.5 Sonnet: 200k tokens
  • Gemini 1.5 Pro: 1M tokens
  • LLaMA-3-8B (base): 8k tokens

The Quadratic Bottleneck

Standard self-attention has O(n²) time and memory complexity in sequence length n. For each of the n tokens, attention computes a weighted sum over all n tokens — n Ɨ n attention weights:

Python
import torch

def attention_memory_estimate(seq_len: int, num_heads: int, dtype_bytes: int = 2) -> dict:
    """Estimate attention matrix memory in GB."""
    # Attention weights: (batch=1, heads, seq, seq)
    attn_weights_elements = num_heads * seq_len * seq_len
    attn_weights_bytes = attn_weights_elements * dtype_bytes
    attn_weights_gb = attn_weights_bytes / (1024 ** 3)

    return {
        "seq_len": seq_len,
        "attn_matrix_gb": attn_weights_gb,
        "attn_matrix_elements": attn_weights_elements,
    }

for seq_len in [1_000, 4_096, 32_000, 128_000]:
    est = attention_memory_estimate(seq_len, num_heads=32)
    print(f"seq={seq_len:>7}: attention matrix = {est['attn_matrix_gb']:.2f} GB")

# seq=  1_000: attention matrix = 0.00 GB
# seq=  4_096: attention matrix = 0.03 GB
# seq= 32_000: attention matrix = 2.00 GB
# seq=128_000: attention matrix = 32.00 GB

128k context with 32 heads at float16 requires 32GB just for the attention weight matrices — before any other computation.


Positional Encoding and Context Limits

Learned absolute position embeddings (BERT, GPT-2) fail on sequences longer than they were trained on — there are no learned vectors for positions beyond max_seq_len. This is one reason modern models switched to RoPE and ALiBi.

RoPE extrapolation: RoPE encodes relative positions, but the model still sees absolute position values during training. At inference with sequences longer than training length, the absolute frequency values are out-of-distribution. Solutions:

  • Dynamic NTK scaling: Adjust the base frequency of RoPE to scale with context length
  • YaRN: A more sophisticated RoPE scaling method used by Mistral and others
Python
def apply_ntk_scaling(base: float, scale_factor: float, head_dim: int) -> float:
    """Compute adjusted RoPE base for length extrapolation."""
    # Increases the base to make frequencies lower (longer wavelengths)
    # Allows the model to handle longer sequences
    return base * (scale_factor ** (head_dim / (head_dim - 2)))

original_base = 10000.0
scale_factor = 4.0  # Supporting 4Ɨ longer context
head_dim = 128

new_base = apply_ntk_scaling(original_base, scale_factor, head_dim)
print(f"Adjusted RoPE base: {new_base:.0f}")  # ~500,000

Flash Attention: Memory-Efficient Long Context

Flash Attention (Dao et al., 2022) reformulates attention computation to avoid materializing the full nƗn attention matrix in HBM (GPU memory). Instead, it computes attention in tiles using SRAM:

Standard attention: O(n²) HBM reads/writes
Flash Attention: O(n) HBM reads/writes (recomputes tiles on the fly)

Flash Attention achieves the same mathematical result with:

  • Up to 10Ɨ less memory for the attention computation
  • 2–4Ɨ faster wall-clock time
  • Linear (not quadratic) memory scaling in sequence length
Python
# PyTorch 2.0+ includes Flash Attention via scaled_dot_product_attention
import torch
import torch.nn.functional as F

def attention_flash(q, k, v, causal=True):
    """Uses Flash Attention when available via PyTorch SDPA."""
    # q, k, v: (batch, heads, seq, head_dim)
    with torch.backends.cuda.sdp_kernel(
        enable_flash=True,
        enable_math=False,
        enable_mem_efficient=False
    ):
        return F.scaled_dot_product_attention(q, k, v, is_causal=causal)

Sparse Attention Patterns

For very long sequences, sparse attention restricts each token to attend to only a subset of other tokens:

Local windowed attention: each token attends to the W nearest tokens only — O(n Ɨ W) instead of O(n²)
Sliding window: Mistral 7B uses a window of 4096 tokens per layer
Global tokens: A few designated tokens attend to everything; all others attend only locally
Python
def local_window_attention(q, k, v, window_size: int):
    """Each token attends to only window_size past tokens."""
    seq_len = q.shape[2]
    output = torch.zeros_like(q)

    for i in range(seq_len):
        start = max(0, i - window_size + 1)
        # Only attend to tokens in [start, i]
        k_window = k[:, :, start:i+1, :]
        v_window = v[:, :, start:i+1, :]
        q_i = q[:, :, i:i+1, :]

        scale = q.shape[-1] ** -0.5
        attn = F.softmax(torch.matmul(q_i, k_window.transpose(-2, -1)) * scale, dim=-1)
        output[:, :, i:i+1, :] = torch.matmul(attn, v_window)

    return output

Mistral 7B combines sliding window attention with a rolling KV cache — only the last 4096 K/V pairs are kept per layer, regardless of total generation length.


Practical: Working Within Context Limits

Chunking long documents

Python
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")

def chunk_document(
    text: str,
    max_chunk_tokens: int = 3_000,
    overlap_tokens: int = 200,
) -> list[str]:
    """Split document into chunks that fit within context."""
    tokens = tokenizer.encode(text)
    chunks = []
    start = 0

    while start < len(tokens):
        end = min(start + max_chunk_tokens, len(tokens))
        chunk_tokens = tokens[start:end]
        chunks.append(tokenizer.decode(chunk_tokens))

        if end >= len(tokens):
            break
        start = end - overlap_tokens  # Overlap to preserve context at boundaries

    return chunks

# For a 50k-token document with an 8k context model:
chunks = chunk_document(long_document, max_chunk_tokens=6_000)
print(f"Split into {len(chunks)} chunks")

Tracking token budget at runtime

Python
import tiktoken

enc = tiktoken.encoding_for_model("gpt-4o")

class ContextBudgetManager:
    def __init__(self, max_tokens: int = 120_000):
        self.max_tokens = max_tokens
        self.system_reserve = 2_000  # Reserve for system prompt overhead

    def available(self, system_prompt: str, conversation: list[dict]) -> int:
        used = len(enc.encode(system_prompt))
        for msg in conversation:
            used += len(enc.encode(msg["content"])) + 4  # 4 tokens overhead per message
        return self.max_tokens - self.system_reserve - used

    def trim_conversation(self, system_prompt: str, conversation: list[dict], keep_last_n: int = 10) -> list[dict]:
        """Remove oldest messages until within budget."""
        while len(conversation) > keep_last_n:
            available = self.available(system_prompt, conversation)
            if available > 1_000:  # Enough buffer
                break
            conversation = conversation[1:]  # Drop oldest
        return conversation

Context Window vs Effective Context

Even when a model has a 128k context window, it doesn't uniformly use all of it. Research (the "lost in the middle" phenomenon) shows that:

  • Models attend well to content at the beginning of context
  • Models attend well to content at the end of context
  • Content in the middle of a very long context is often effectively ignored

Practical implication: Put the most important instructions at the beginning (system prompt) and most relevant documents at the end (immediately before the query). Avoid burying critical information in the middle of a long context.

Python
def order_context_strategically(system_prompt, documents, query):
    """Order context to minimize lost-in-the-middle effect."""
    # Most important documents go at the END (closest to query)
    # Sort by relevance; most relevant last
    sorted_docs = sorted(documents, key=lambda d: d["relevance_score"])

    context_parts = [
        system_prompt,
        *[doc["text"] for doc in sorted_docs],  # Least relevant first
        f"\nQuestion: {query}",
    ]
    return "\n\n".join(context_parts)

Context Window Summary

| Model | Context | Attention | Extension Method | |---|---|---|---| | BERT-base | 512 | Standard | Learned absolute positions | | GPT-2 | 1024 | Standard | Learned absolute positions | | LLaMA-2-7B | 4096 | GQA | RoPE | | Mistral-7B | 32k | Sliding window | RoPE + SWA | | LLaMA-3-8B | 8k (128k with tuning) | GQA | RoPE | | Claude 3.5 | 200k | Unknown | Unknown | | Gemini 1.5 | 1M | Ring attention | Unknown |

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.