Speculative Decoding
How speculative decoding uses a small draft model to speed up generation from a large model, the acceptance criterion, and the latency gains achievable in practice.
The Problem: Autoregressive Generation Is Slow
Standard LLM generation is serial โ one token per forward pass:
Generating 100 tokens from a 70B model:
100 forward passes
Each pass: O(n ร model_size) computation
On an A100: ~10-15 ms per token
Total: 1.0โ1.5 seconds for 100 tokens
The bottleneck: we're not compute-bound, we're MEMORY-bound
The model weights (140GB at fp16) must be read from HBM each step
HBM bandwidth: ~2TB/s โ reading 140GB takes ~70ms just for weight reads
For batch_size=1, the GPU compute units sit mostly idle
between memory reads (arithmetic intensity is too low)Speculative decoding exploits this memory-bound regime.
Core Idea
Use a small, fast draft model to speculatively generate k tokens. Verify all k tokens in a single forward pass of the large model:
Step 1: Draft model generates k tokens speculatively
M_small generates: [tok_1, tok_2, tok_3, tok_4] (k=4, very fast)
Step 2: Large model runs ONE forward pass on the original + all k draft tokens
M_large processes [context, tok_1, tok_2, tok_3, tok_4] simultaneously
Gets probability distributions at each position: P_large(ยท|context), P_large(ยท|+tok_1), ...
Step 3: Accept or reject each draft token using rejection sampling
Compare P_large(tok_i | ...) vs P_draft(tok_i | ...)
Accept tok_i with probability min(1, P_large/P_draft)
Stop at the first rejection โ resample from M_large at that position
Step 4: Get at least one new token (if all draft tokens rejected, still
get one from M_large's distribution)Why This Works
A single forward pass of the large model at batch size 1 processes k+1 tokens in the same time as processing 1 token (memory-bound, not compute-bound):
Memory cost of single forward pass:
Reads all model weights once regardless of sequence length (up to a point)
Processing 1 token: read all weights โ ~70ms
Processing k+1 tokens: read all weights โ ~70-75ms (barely more)
So if we verify k tokens in the same time as 1 token:
Average tokens per forward pass = E[accepted tokens] + 1
Speedup โ E[accepted tokens] + 1
If k=4 and average 3 tokens are accepted:
Speedup โ 3+1 = 4ร latency reduction at batch_size=1Implementation
import torch
import torch.nn.functional as F
def speculative_decode(
large_model, draft_model, input_ids,
num_draft_tokens: int = 4,
max_new_tokens: int = 100,
temperature: float = 1.0
):
generated = input_ids.clone()
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# 1. Draft k tokens
draft_ids = generated.clone()
draft_probs = []
for _ in range(num_draft_tokens):
logits = draft_model(draft_ids)[:, -1, :]
probs = F.softmax(logits / temperature, dim=-1)
next_tok = torch.multinomial(probs, 1)
draft_probs.append(probs.gather(-1, next_tok))
draft_ids = torch.cat([draft_ids, next_tok], dim=1)
draft_tokens = draft_ids[:, generated.shape[1]:] # k new tokens
# 2. Verify with large model (one forward pass)
verify_ids = torch.cat([generated, draft_tokens], dim=1)
with torch.no_grad():
large_logits = large_model(verify_ids)
large_probs_all = F.softmax(large_logits / temperature, dim=-1)
# 3. Accept/reject
n_accepted = 0
for i in range(num_draft_tokens):
pos = generated.shape[1] + i - 1
tok = draft_tokens[:, i]
p_large = large_probs_all[:, pos, :].gather(-1, tok.unsqueeze(-1)).squeeze()
p_draft = draft_probs[i].squeeze()
u = torch.rand_like(p_large)
if u <= p_large / p_draft:
n_accepted += 1
else:
break
# Accept n_accepted draft tokens + 1 from large model
accepted = draft_tokens[:, :n_accepted]
bonus_pos = generated.shape[1] + n_accepted - 1
bonus_probs = large_probs_all[:, bonus_pos, :]
bonus_tok = torch.multinomial(bonus_probs, 1)
generated = torch.cat([generated, accepted, bonus_tok], dim=1)
return generatedDraft Model Choices
Large model โ Draft model pairs:
LLaMA 2 70B โ LLaMA 2 7B
~10ร parameter difference, same tokeniser required
Acceptance rate: ~65-80% on typical prompts
GPT-4 โ GPT-3.5 (speculation hypothesis โ not confirmed public)
Self-speculative decoding (Draft layer skipping):
Use the same model with fewer layers for drafting
No separate model needed โ saves memory
Medusa heads:
Add k separate prediction heads to the large model
Each head predicts k steps ahead simultaneously
Fine-tuned to match the model's distribution
Speedup without a separate draft modelPractical Speedup
Latency improvement (batch_size=1, A100):
Without speculative: 10-15 ms/token for 70B
With speculative (k=4, 7B draft): 3-5 ms/token effective rate
Speedup: 2-4ร depending on prompt and acceptance rate
No quality degradation:
The acceptance criterion (rejection sampling) ensures
the output distribution exactly matches the large model
Speculative decoding is mathematically equivalent to
sampling directly from the large model
Limitations:
Requires a compatible draft model (same tokeniser)
Memory overhead for keeping two models loaded
Speedup decreases at large batch sizes (compute-bound regime)Interview Answer
"Speculative decoding speeds up LLM inference by using a small draft model to generate k tokens speculatively, then verifying all k tokens with a single forward pass of the large model. The speedup comes from a key observation: for batch_size=1, a large model is memory-bound, so processing k+1 tokens takes roughly the same time as processing 1 token (weights are read once regardless). Accepted tokens are accepted via rejection sampling โ preserving the exact output distribution of the large model. In practice, 2-4ร latency improvements are achievable at batch_size=1 with a 10ร smaller draft model."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.