Learnixo
Back to blog
AI Systemsintermediate

Speculative Decoding

How speculative decoding uses a small draft model to speed up generation from a large model, the acceptance criterion, and the latency gains achievable in practice.

Asma Hafeez KhanMay 16, 20265 min read
LLMsSpeculative DecodingInferenceLatencyInterview
Share:๐•

The Problem: Autoregressive Generation Is Slow

Standard LLM generation is serial โ€” one token per forward pass:

Generating 100 tokens from a 70B model:
  100 forward passes
  Each pass: O(n ร— model_size) computation
  On an A100: ~10-15 ms per token
  Total: 1.0โ€“1.5 seconds for 100 tokens

The bottleneck: we're not compute-bound, we're MEMORY-bound
  The model weights (140GB at fp16) must be read from HBM each step
  HBM bandwidth: ~2TB/s โ†’ reading 140GB takes ~70ms just for weight reads
  
  For batch_size=1, the GPU compute units sit mostly idle
  between memory reads (arithmetic intensity is too low)

Speculative decoding exploits this memory-bound regime.


Core Idea

Use a small, fast draft model to speculatively generate k tokens. Verify all k tokens in a single forward pass of the large model:

Step 1: Draft model generates k tokens speculatively
  M_small generates: [tok_1, tok_2, tok_3, tok_4]  (k=4, very fast)

Step 2: Large model runs ONE forward pass on the original + all k draft tokens
  M_large processes [context, tok_1, tok_2, tok_3, tok_4] simultaneously
  Gets probability distributions at each position: P_large(ยท|context), P_large(ยท|+tok_1), ...

Step 3: Accept or reject each draft token using rejection sampling
  Compare P_large(tok_i | ...) vs P_draft(tok_i | ...)
  Accept tok_i with probability min(1, P_large/P_draft)
  Stop at the first rejection โ€” resample from M_large at that position

Step 4: Get at least one new token (if all draft tokens rejected, still
        get one from M_large's distribution)

Why This Works

A single forward pass of the large model at batch size 1 processes k+1 tokens in the same time as processing 1 token (memory-bound, not compute-bound):

Memory cost of single forward pass:
  Reads all model weights once regardless of sequence length (up to a point)
  Processing 1 token: read all weights โ†’ ~70ms
  Processing k+1 tokens: read all weights โ†’ ~70-75ms (barely more)

So if we verify k tokens in the same time as 1 token:
  Average tokens per forward pass = E[accepted tokens] + 1
  Speedup โ‰ˆ E[accepted tokens] + 1

If k=4 and average 3 tokens are accepted:
  Speedup โ‰ˆ 3+1 = 4ร— latency reduction at batch_size=1

Implementation

Python
import torch
import torch.nn.functional as F

def speculative_decode(
    large_model, draft_model, input_ids,
    num_draft_tokens: int = 4,
    max_new_tokens: int = 100,
    temperature: float = 1.0
):
    generated = input_ids.clone()

    while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
        # 1. Draft k tokens
        draft_ids = generated.clone()
        draft_probs = []
        for _ in range(num_draft_tokens):
            logits = draft_model(draft_ids)[:, -1, :]
            probs = F.softmax(logits / temperature, dim=-1)
            next_tok = torch.multinomial(probs, 1)
            draft_probs.append(probs.gather(-1, next_tok))
            draft_ids = torch.cat([draft_ids, next_tok], dim=1)

        draft_tokens = draft_ids[:, generated.shape[1]:]  # k new tokens

        # 2. Verify with large model (one forward pass)
        verify_ids = torch.cat([generated, draft_tokens], dim=1)
        with torch.no_grad():
            large_logits = large_model(verify_ids)
        large_probs_all = F.softmax(large_logits / temperature, dim=-1)

        # 3. Accept/reject
        n_accepted = 0
        for i in range(num_draft_tokens):
            pos = generated.shape[1] + i - 1
            tok = draft_tokens[:, i]
            p_large = large_probs_all[:, pos, :].gather(-1, tok.unsqueeze(-1)).squeeze()
            p_draft = draft_probs[i].squeeze()
            u = torch.rand_like(p_large)
            if u <= p_large / p_draft:
                n_accepted += 1
            else:
                break

        # Accept n_accepted draft tokens + 1 from large model
        accepted = draft_tokens[:, :n_accepted]
        bonus_pos = generated.shape[1] + n_accepted - 1
        bonus_probs = large_probs_all[:, bonus_pos, :]
        bonus_tok = torch.multinomial(bonus_probs, 1)
        generated = torch.cat([generated, accepted, bonus_tok], dim=1)

    return generated

Draft Model Choices

Large model โ†’ Draft model pairs:

LLaMA 2 70B โ†’ LLaMA 2 7B
  ~10ร— parameter difference, same tokeniser required
  Acceptance rate: ~65-80% on typical prompts

GPT-4 โ†’ GPT-3.5 (speculation hypothesis โ€” not confirmed public)

Self-speculative decoding (Draft layer skipping):
  Use the same model with fewer layers for drafting
  No separate model needed โ€” saves memory

Medusa heads:
  Add k separate prediction heads to the large model
  Each head predicts k steps ahead simultaneously
  Fine-tuned to match the model's distribution
  Speedup without a separate draft model

Practical Speedup

Latency improvement (batch_size=1, A100):
  Without speculative: 10-15 ms/token for 70B
  With speculative (k=4, 7B draft): 3-5 ms/token effective rate
  Speedup: 2-4ร— depending on prompt and acceptance rate

No quality degradation:
  The acceptance criterion (rejection sampling) ensures
  the output distribution exactly matches the large model
  Speculative decoding is mathematically equivalent to
  sampling directly from the large model

Limitations:
  Requires a compatible draft model (same tokeniser)
  Memory overhead for keeping two models loaded
  Speedup decreases at large batch sizes (compute-bound regime)

Interview Answer

"Speculative decoding speeds up LLM inference by using a small draft model to generate k tokens speculatively, then verifying all k tokens with a single forward pass of the large model. The speedup comes from a key observation: for batch_size=1, a large model is memory-bound, so processing k+1 tokens takes roughly the same time as processing 1 token (weights are read once regardless). Accepted tokens are accepted via rejection sampling โ€” preserving the exact output distribution of the large model. In practice, 2-4ร— latency improvements are achievable at batch_size=1 with a 10ร— smaller draft model."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:๐•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.