Extending LLM Context Windows
How to extend LLMs beyond their trained context length. RoPE scaling, YaRN, LongLoRA, sliding window attention, and the engineering tradeoffs of long contexts.
The Context Limit Problem
Every LLM is trained on sequences up to a maximum length (LLaMA-3: 8K tokens, GPT-4: 128K, Gemini 1.5: 1M). Beyond this limit, the model has never seen position indices — its positional encoding scheme produces values it wasn't trained to handle.
Why this matters:
- Many real-world documents exceed 8K tokens (legal contracts, research papers, books)
- Multi-turn conversations accumulate context quickly
- RAG systems want to stuff more documents into context for better recall
Why extending is hard:
- Quadratic attention cost: doubling sequence length quadruples attention computation
- Positional encoding distributions shift outside trained range
- The model's attention patterns were learned for shorter sequences
RoPE and the Scaling Problem
RoPE (Rotary Position Embedding) encodes position by rotating query and key vectors:
import torch
import math
def precompute_rope_freqs(
dim: int,
max_seq_len: int,
base: float = 10000.0,
scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Precompute RoPE frequencies.
scaling_factor > 1 extends the effective context length.
"""
# Frequency for each dimension pair
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# Position indices (scaled for extension)
positions = torch.arange(max_seq_len).float() / scaling_factor
# Outer product: (max_seq_len, dim/2)
freqs = torch.outer(positions, theta)
# Convert to complex representation for rotation
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
return freqs_cos, freqs_sin
def apply_rope(
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
) -> torch.Tensor:
"""Apply rotary position embedding to query or key tensor."""
# x: (batch, seq_len, n_heads, head_dim)
x_r, x_i = x[..., ::2], x[..., 1::2] # Split into real and imaginary parts
# Rotate
x_rotated_r = x_r * freqs_cos - x_i * freqs_sin
x_rotated_i = x_r * freqs_sin + x_i * freqs_cos
# Interleave back
x_out = torch.stack([x_rotated_r, x_rotated_i], dim=-1).flatten(-2)
return x_outThe problem: At positions beyond the training length, torch.arange(max_seq_len) produces values the model has never been trained on. The model has learned to interpret rotated vectors at positions 0-8191 (for LLaMA-3 8K) — positions 8192+ produce angles that weren't in training.
Linear RoPE Scaling
The simplest fix: divide all position indices by a scaling factor:
# Instead of: positions = torch.arange(max_seq_len)
# Use: positions = torch.arange(max_seq_len) / scaling_factor
# For LLaMA-3 8K model wanting to handle 32K context:
# scaling_factor = 32768 / 8192 = 4.0
# At position 32768, the model sees: 32768 / 4 = 8192
# This looks like the last trained position — problem!
# The model doesn't distinguish between positions 8192 and 32768.
# This creates position aliasing: many different positions map to the same
# scaled position → the model can't distinguish their relative ordering.Linear scaling works to a limited extent (2-4× extension) but degrades for larger extensions because of this aliasing problem.
NTK-Aware Scaling
A better approach: change the base of the RoPE frequencies rather than scaling positions directly.
def rope_ntk_scaling(
dim: int,
max_seq_len: int,
original_max_seq_len: int = 8192,
base: float = 10000.0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
NTK-aware RoPE scaling: rescale base frequency to handle longer sequences.
The NTK perspective: high-frequency dimensions handle local structure,
low-frequency dimensions handle global position. Scaling the base
changes the wavelength of all frequencies proportionally.
"""
extension_ratio = max_seq_len / original_max_seq_len
# Scale the base: longer context → larger base → lower frequencies
scaled_base = base * (extension_ratio ** (dim / (dim - 2)))
theta = 1.0 / (scaled_base ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq_len).float()
freqs = torch.outer(positions, theta)
return freqs.cos(), freqs.sin()
# NTK scaling at 4× extension (8K → 32K):
# Base changes from 10000 to ~500,000
# This stretches all wavelengths so they interpolate rather than extrapolate
# Performance: much better than linear scaling, minimal fine-tuning neededYaRN: Yet Another RoPE Extenstion Method
YaRN (Peng et al., 2023) applies different scaling strategies to different frequency bands:
def rope_yarn(
dim: int,
max_seq_len: int,
original_max_seq_len: int = 8192,
base: float = 10000.0,
alpha: float = 1.0, # Low-frequency scale
beta: float = 32.0, # High-frequency boundary
) -> tuple[torch.Tensor, torch.Tensor]:
"""
YaRN: apply different scaling per frequency dimension.
- High-frequency dims (handle local structure): no scaling
- Low-frequency dims (handle global position): linear scaling
- Medium-frequency: interpolate between strategies
"""
extension_ratio = max_seq_len / original_max_seq_len
theta_original = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# Compute lambda for each dimension
# lambda determines how much scaling to apply
d = dim // 2
lambda_vals = []
for i in range(d):
wavelength = 2 * math.pi / theta_original[i].item()
r = wavelength / original_max_seq_len
if r < 1 / beta:
# High frequency: no scaling (interpolation region)
lambda_vals.append(1.0)
elif r > 1 / alpha:
# Low frequency: scale by extension_ratio (extrapolation)
lambda_vals.append(extension_ratio)
else:
# Interpolate between the two strategies
frac = (1 / r - alpha) / (beta - alpha)
lambda_vals.append(1 + (extension_ratio - 1) * frac)
lambda_tensor = torch.tensor(lambda_vals)
# Scale positions by lambda per dimension
positions = torch.arange(max_seq_len).float()
scaled_theta = theta_original / lambda_tensor
freqs = torch.outer(positions, scaled_theta)
return freqs.cos(), freqs.sin()
# YaRN with fine-tuning (a few hundred steps on long documents):
# Achieves 4-8× extension with less than 1 perplexity increase
# Used in LLaMA-3.1 (128K context) and many other modelsLongLoRA: Efficient Fine-Tuning for Long Context
Fine-tuning on actual long documents is the most reliable way to extend context — but standard full-context attention is too expensive. LongLoRA uses shifted sparse attention during fine-tuning:
# LongLoRA concept: split sequence into groups, apply attention within groups
# Shift groups between layers to allow information flow across group boundaries
def shifted_sparse_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
group_size: int = 512,
shift: bool = False,
) -> torch.Tensor:
"""
Sparse attention in groups for efficient long-context fine-tuning.
Standard attention: O(n²) for sequence length n
Grouped attention: O(n × group_size) — linear in n
"""
B, n_heads, T, head_dim = q.shape
if shift:
# Shift by half group_size to allow cross-group information flow
shift_amount = group_size // 2
q = torch.roll(q, shifts=-shift_amount, dims=2)
k = torch.roll(k, shifts=-shift_amount, dims=2)
v = torch.roll(v, shifts=-shift_amount, dims=2)
# Reshape into groups
n_groups = T // group_size
q = q.reshape(B, n_heads, n_groups, group_size, head_dim)
k = k.reshape(B, n_heads, n_groups, group_size, head_dim)
v = v.reshape(B, n_heads, n_groups, group_size, head_dim)
# Apply attention within each group
scale = head_dim ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale # (B, n_heads, n_groups, group_size, group_size)
attn = torch.softmax(attn, dim=-1)
out = attn @ v # (B, n_heads, n_groups, group_size, head_dim)
# Reshape back
out = out.reshape(B, n_heads, T, head_dim)
if shift:
out = torch.roll(out, shifts=shift_amount, dims=2)
return out
# LongLoRA training: alternate layers between shifted and unshifted groups
# Then freeze most weights, fine-tune only LoRA adapters + normalization
# Result: 100K context with 70B model fine-tuned on 8 A100s in 1-2 daysSliding Window Attention
Mistral and Phi models use sliding window attention to handle long contexts with bounded memory:
def sliding_window_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
window_size: int = 4096,
) -> torch.Tensor:
"""
Each token attends only to the most recent `window_size` tokens.
Memory and compute: O(n × window_size) instead of O(n²).
"""
B, n_heads, T, head_dim = q.shape
scale = head_dim ** -0.5
# Build sliding window mask
# Token i can attend to tokens max(0, i - window_size) through i
indices = torch.arange(T, device=q.device)
mask = (indices.unsqueeze(0) - indices.unsqueeze(1)).abs() <= window_size
mask = mask.tril() # Causal: can't attend to future
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
attn = (q @ k.transpose(-2, -1)) * scale
attn = attn.masked_fill(~mask, float('-inf'))
attn = torch.softmax(attn, dim=-1)
return attn @ v
# Limitation: information older than window_size tokens is lost
# Solution: some models use full attention in early layers + sliding window in later layers
# Mistral 7B: 4096 window, 32K RoPE base — "can see up to 32K but only attends to last 4K"Memory and Compute at Long Contexts
def estimate_inference_cost(
n_layers: int,
n_heads: int,
head_dim: int,
d_model: int,
seq_len: int,
bytes_per_element: float = 2.0, # bfloat16
) -> dict:
"""Estimate memory and compute for LLM inference at a given sequence length."""
# KV cache memory
kv_cache_bytes = (
2 # K and V
* n_layers
* n_heads
* head_dim
* seq_len
* bytes_per_element
)
# Attention compute (FLOP/s for one forward pass)
# Attention: 2 × n_heads × seq_len² × head_dim (QK^T + softmax×V)
attn_flops = 2 * n_layers * n_heads * seq_len ** 2 * head_dim
# FFN compute: 2 × 8/3 × d_model² (SwiGLU expansion)
ffn_flops = 2 * n_layers * (8 / 3) * d_model ** 2 * seq_len
return {
"kv_cache_gb": kv_cache_bytes / 1e9,
"attention_gflops": attn_flops / 1e9,
"ffn_gflops": ffn_flops / 1e9,
"attention_fraction": attn_flops / (attn_flops + ffn_flops),
}
# LLaMA-3-8B at different context lengths:
for seq_len in [4096, 16384, 65536, 131072]:
costs = estimate_inference_cost(
n_layers=32, n_heads=32, head_dim=128, d_model=4096, seq_len=seq_len
)
print(f"seq={seq_len:>7}: KV={costs['kv_cache_gb']:.1f}GB, "
f"attn={costs['attention_gflops']:.0f} GFLOP ({costs['attention_fraction']:.0%} of total)")
# Output:
# seq= 4096: KV=1.1GB, attn=17 GFLOP (8% of total)
# seq= 16384: KV=4.3GB, attn=274 GFLOP (40% of total)
# seq= 65536: KV=17.2GB, attn=4398 GFLOP (84% of total)
# seq=131072: KV=34.4GB, attn=17592 GFLOP (94% of total)
# At 128K context, attention dominates total compute.
# This is why long-context inference is much more expensive per token.Practical Context Extension Strategy
For extending a model you're deploying:
- Under 4× extension: Use NTK scaling — no fine-tuning needed, minimal quality loss
- 4-16× extension: Use YaRN with 100-500 steps of fine-tuning on long documents
- 16×+ extension: Use LongLoRA or train from scratch with a larger context window
- Inference cost: Budget for 4× more memory and 16× more attention compute at 4× context
Common pitfalls:
- Applying context extension without testing: "lost-in-the-middle" gets worse at long contexts
- Using a 128K context model on short queries: unnecessary KV cache overhead
- Extending context without enough long-document training data: the model doesn't learn to use the extra context
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.