Speculative Decoding: Faster Inference
How speculative decoding uses a small draft model to propose tokens that a large model verifies in parallel, achieving 2-3x speedups with identical output distribution.
The Autoregressive Bottleneck
Standard LLM generation is inherently sequential: generate token 1, then token 2, then token 3, each requiring a full forward pass. The main bottleneck is memory bandwidth ā loading all model weights from GPU memory for each token, even though compute is underutilized.
For a 70B parameter model, generating 1 token requires loading ~140GB of weights from GPU memory. The GPU's matrix multiplication units sit mostly idle waiting for data.
Key insight: Verifying whether a proposed token sequence is correct is faster than generating it from scratch, because verification can process multiple positions in parallel.
How Speculative Decoding Works
- A small, fast "draft" model proposes K tokens (e.g., K=4)
- The large "target" model verifies all K proposed tokens in a single forward pass (parallel)
- Tokens that match the target's distribution are accepted; the first rejection is corrected
- Repeat
Draft model generates: "warfarin inhibits VKOR" (4 tokens, fast)
Target model checks all 4 simultaneously (1 forward pass, not 4)
If target agrees with first 3 but not 4th:
- Accept "warfarin inhibits VK"
- Resample 4th token from target's distribution
- Net result: 3 tokens accepted for the cost of ~1.25 target forward passesCrucial property: speculative decoding produces exactly the same output distribution as the target model alone. It is a mathematically exact optimization, not an approximation.
The Acceptance Criterion
The acceptance/rejection is based on the ratio of target and draft probabilities:
import torch
import torch.nn.functional as F
def speculative_sample_token(
draft_prob: float, # p_draft(token | context)
target_prob: float, # p_target(token | context)
) -> bool:
"""Accept or reject a draft token."""
# Accept with probability min(1, p_target / p_draft)
# If target probability >= draft probability: always accept
# If draft overestimated: accept with probability target/draft
acceptance_prob = min(1.0, target_prob / draft_prob)
return torch.rand(1).item() < acceptance_prob
def speculative_generate_step(
target_model,
draft_model,
input_ids: torch.Tensor,
k: int = 4, # Number of draft tokens to propose
) -> torch.Tensor:
"""
One step of speculative decoding:
- Draft model proposes k tokens
- Target model verifies all k+1 positions in one pass
- Accept valid tokens, correct the first mismatch
"""
# Step 1: Draft model proposes k tokens autoregressively
draft_tokens = []
draft_probs = []
current_ids = input_ids
for _ in range(k):
with torch.no_grad():
draft_out = draft_model(current_ids)
draft_logits = draft_out.logits[:, -1, :]
draft_prob_dist = F.softmax(draft_logits, dim=-1)
# Sample next token from draft distribution
next_token = torch.multinomial(draft_prob_dist, 1)
draft_tokens.append(next_token)
draft_probs.append(draft_prob_dist[0, next_token.item()].item())
current_ids = torch.cat([current_ids, next_token], dim=-1)
# Step 2: Target model verifies all k draft tokens in one forward pass
candidate_ids = current_ids # original + k draft tokens
with torch.no_grad():
target_out = target_model(candidate_ids)
# Get target probabilities at each draft position
target_logits = target_out.logits[:, -(k+1):-1, :] # k positions
# Step 3: Accept/reject each draft token
accepted = []
for i in range(k):
target_prob_dist = F.softmax(target_logits[:, i, :], dim=-1)
target_prob = target_prob_dist[0, draft_tokens[i].item()].item()
draft_prob = draft_probs[i]
# Acceptance criterion
if torch.rand(1).item() < min(1.0, target_prob / draft_prob):
accepted.append(draft_tokens[i])
else:
# Rejection: resample from corrected distribution
# p_corrected = max(0, p_target - p_draft) normalized
correction = F.relu(target_prob_dist - F.softmax(target_logits[:, i, :], dim=-1))
correction = correction / correction.sum()
corrected_token = torch.multinomial(correction, 1)
accepted.append(corrected_token)
break # Stop after first rejection
# If all k tokens accepted, also sample the (k+1)th from target
if len(accepted) == k:
final_target_logits = target_out.logits[:, -1, :]
final_token = torch.multinomial(F.softmax(final_target_logits, dim=-1), 1)
accepted.append(final_token)
return torch.cat(accepted, dim=-1)Expected Speedup
The speedup depends on the "acceptance rate" α ā the average fraction of draft tokens accepted:
def expected_speedup(
acceptance_rate: float, # 0.0 to 1.0
k: int, # Draft tokens proposed per step
target_cost: float = 1.0, # Cost of one target forward pass
draft_cost: float = 0.1, # Cost of k draft forward passes (small model)
) -> float:
"""
Expected tokens generated per unit compute.
"""
# Expected tokens accepted per step: geometric series
expected_tokens = (1 - acceptance_rate ** (k + 1)) / (1 - acceptance_rate)
# Total cost per step: k draft passes + 1 target pass
total_cost = k * draft_cost + target_cost
# Speedup vs naive (1 token per target pass)
naive_tokens_per_cost = 1.0 / target_cost
spec_tokens_per_cost = expected_tokens / total_cost
return spec_tokens_per_cost / naive_tokens_per_cost
# At 80% acceptance rate with k=4 draft tokens and 10Ć cost ratio
speedup = expected_speedup(acceptance_rate=0.8, k=4, target_cost=1.0, draft_cost=0.1)
print(f"Expected speedup: {speedup:.2f}Ć")
# Typically 2-3Ć for well-matched draft/target pairsAcceptance rate depends on how well the draft model's distribution matches the target's. Draft models are typically:
- A smaller version of the target (e.g., 7B draft for 70B target)
- The same model with fewer layers (speculative decoding with early exit)
- A specialized distilled model
Using Speculative Decoding with HuggingFace
from transformers import AutoModelForCausalLM, AutoTokenizer
# Target: large model
target_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-70B-Instruct",
torch_dtype=torch.float16,
device_map="auto",
)
# Draft: smaller model (same architecture family)
draft_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-8B-Instruct",
torch_dtype=torch.float16,
device_map="cuda:0",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-70B-Instruct")
inputs = tokenizer("Explain warfarin's mechanism:", return_tensors="pt").to("cuda")
# HuggingFace generate() supports speculative decoding natively
output = target_model.generate(
**inputs,
assistant_model=draft_model, # Specify the draft model
max_new_tokens=200,
do_sample=True,
temperature=0.7,
)
print(tokenizer.decode(output[0], skip_special_tokens=True))Draft Model Selection
| Draft model | Target model | Typical acceptance rate | |---|---|---| | LLaMA-3-8B | LLaMA-3-70B | 75ā85% | | Tiny-LLaMA-1.1B | LLaMA-2-7B | 60ā75% | | Distilled model | Teacher model | 85ā95% | | Same model (layers 1-N/2) | Full model | 70ā80% |
Domain matters: If the target is fine-tuned for a specific domain (medical, legal, code), the draft model should also be fine-tuned on that domain. A general draft + domain-specific target will have low acceptance rates.
Self-Speculative Decoding
A draft model isn't always needed. Self-speculative decoding uses the target model's early layers as the draft:
# Early exit speculation:
# - Run only layers 1-16 to get a "draft" distribution (cheap)
# - Run all 32 layers to get the "target" distribution (expensive)
# - Apply the same acceptance criterion
# Medusa (Cai et al., 2024): Add extra "medusa heads" to the model
# Each head predicts a token K steps ahead from a single forward pass
# Verify all heads simultaneously ā no separate draft model neededWhen Speculative Decoding Helps
Good fit (high acceptance rate):
- Factual question answering (predictable tokens)
- Code completion (syntactically constrained)
- Greedy/low-temperature sampling
Poor fit (low acceptance rate):
- Creative writing (high entropy, many valid tokens)
- High temperature sampling
- Very different draft and target distributions
In practice, speculative decoding provides 2ā3Ć speedup for standard instruction-following tasks with good draft models, with no quality degradation. It's enabled by default in vLLM and TGI for serving large models.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.