Learnixo
Back to blog
AI Systemsintermediate

Speculative Decoding: Faster Inference

How speculative decoding uses a small draft model to propose tokens that a large model verifies in parallel, achieving 2-3x speedups with identical output distribution.

Asma Hafeez KhanMay 16, 20266 min read
TransformersSpeculative DecodingInferenceOptimization
Share:š•

The Autoregressive Bottleneck

Standard LLM generation is inherently sequential: generate token 1, then token 2, then token 3, each requiring a full forward pass. The main bottleneck is memory bandwidth — loading all model weights from GPU memory for each token, even though compute is underutilized.

For a 70B parameter model, generating 1 token requires loading ~140GB of weights from GPU memory. The GPU's matrix multiplication units sit mostly idle waiting for data.

Key insight: Verifying whether a proposed token sequence is correct is faster than generating it from scratch, because verification can process multiple positions in parallel.


How Speculative Decoding Works

  1. A small, fast "draft" model proposes K tokens (e.g., K=4)
  2. The large "target" model verifies all K proposed tokens in a single forward pass (parallel)
  3. Tokens that match the target's distribution are accepted; the first rejection is corrected
  4. Repeat
Draft model generates: "warfarin inhibits VKOR"    (4 tokens, fast)
Target model checks all 4 simultaneously          (1 forward pass, not 4)
If target agrees with first 3 but not 4th:
  - Accept "warfarin inhibits VK"
  - Resample 4th token from target's distribution
  - Net result: 3 tokens accepted for the cost of ~1.25 target forward passes

Crucial property: speculative decoding produces exactly the same output distribution as the target model alone. It is a mathematically exact optimization, not an approximation.


The Acceptance Criterion

The acceptance/rejection is based on the ratio of target and draft probabilities:

Python
import torch
import torch.nn.functional as F

def speculative_sample_token(
    draft_prob: float,    # p_draft(token | context)
    target_prob: float,   # p_target(token | context)
) -> bool:
    """Accept or reject a draft token."""
    # Accept with probability min(1, p_target / p_draft)
    # If target probability >= draft probability: always accept
    # If draft overestimated: accept with probability target/draft
    acceptance_prob = min(1.0, target_prob / draft_prob)
    return torch.rand(1).item() < acceptance_prob

def speculative_generate_step(
    target_model,
    draft_model,
    input_ids: torch.Tensor,
    k: int = 4,            # Number of draft tokens to propose
) -> torch.Tensor:
    """
    One step of speculative decoding:
    - Draft model proposes k tokens
    - Target model verifies all k+1 positions in one pass
    - Accept valid tokens, correct the first mismatch
    """
    # Step 1: Draft model proposes k tokens autoregressively
    draft_tokens = []
    draft_probs = []
    current_ids = input_ids

    for _ in range(k):
        with torch.no_grad():
            draft_out = draft_model(current_ids)
            draft_logits = draft_out.logits[:, -1, :]
            draft_prob_dist = F.softmax(draft_logits, dim=-1)

            # Sample next token from draft distribution
            next_token = torch.multinomial(draft_prob_dist, 1)
            draft_tokens.append(next_token)
            draft_probs.append(draft_prob_dist[0, next_token.item()].item())
            current_ids = torch.cat([current_ids, next_token], dim=-1)

    # Step 2: Target model verifies all k draft tokens in one forward pass
    candidate_ids = current_ids  # original + k draft tokens
    with torch.no_grad():
        target_out = target_model(candidate_ids)
        # Get target probabilities at each draft position
        target_logits = target_out.logits[:, -(k+1):-1, :]  # k positions

    # Step 3: Accept/reject each draft token
    accepted = []
    for i in range(k):
        target_prob_dist = F.softmax(target_logits[:, i, :], dim=-1)
        target_prob = target_prob_dist[0, draft_tokens[i].item()].item()
        draft_prob = draft_probs[i]

        # Acceptance criterion
        if torch.rand(1).item() < min(1.0, target_prob / draft_prob):
            accepted.append(draft_tokens[i])
        else:
            # Rejection: resample from corrected distribution
            # p_corrected = max(0, p_target - p_draft) normalized
            correction = F.relu(target_prob_dist - F.softmax(target_logits[:, i, :], dim=-1))
            correction = correction / correction.sum()
            corrected_token = torch.multinomial(correction, 1)
            accepted.append(corrected_token)
            break  # Stop after first rejection

    # If all k tokens accepted, also sample the (k+1)th from target
    if len(accepted) == k:
        final_target_logits = target_out.logits[:, -1, :]
        final_token = torch.multinomial(F.softmax(final_target_logits, dim=-1), 1)
        accepted.append(final_token)

    return torch.cat(accepted, dim=-1)

Expected Speedup

The speedup depends on the "acceptance rate" α — the average fraction of draft tokens accepted:

Python
def expected_speedup(
    acceptance_rate: float,  # 0.0 to 1.0
    k: int,                  # Draft tokens proposed per step
    target_cost: float = 1.0,  # Cost of one target forward pass
    draft_cost: float = 0.1,   # Cost of k draft forward passes (small model)
) -> float:
    """
    Expected tokens generated per unit compute.
    """
    # Expected tokens accepted per step: geometric series
    expected_tokens = (1 - acceptance_rate ** (k + 1)) / (1 - acceptance_rate)

    # Total cost per step: k draft passes + 1 target pass
    total_cost = k * draft_cost + target_cost

    # Speedup vs naive (1 token per target pass)
    naive_tokens_per_cost = 1.0 / target_cost
    spec_tokens_per_cost = expected_tokens / total_cost

    return spec_tokens_per_cost / naive_tokens_per_cost

# At 80% acceptance rate with k=4 draft tokens and 10Ɨ cost ratio
speedup = expected_speedup(acceptance_rate=0.8, k=4, target_cost=1.0, draft_cost=0.1)
print(f"Expected speedup: {speedup:.2f}Ɨ")
# Typically 2-3Ɨ for well-matched draft/target pairs

Acceptance rate depends on how well the draft model's distribution matches the target's. Draft models are typically:

  • A smaller version of the target (e.g., 7B draft for 70B target)
  • The same model with fewer layers (speculative decoding with early exit)
  • A specialized distilled model

Using Speculative Decoding with HuggingFace

Python
from transformers import AutoModelForCausalLM, AutoTokenizer

# Target: large model
target_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3-70B-Instruct",
    torch_dtype=torch.float16,
    device_map="auto",
)

# Draft: smaller model (same architecture family)
draft_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3-8B-Instruct",
    torch_dtype=torch.float16,
    device_map="cuda:0",
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-70B-Instruct")

inputs = tokenizer("Explain warfarin's mechanism:", return_tensors="pt").to("cuda")

# HuggingFace generate() supports speculative decoding natively
output = target_model.generate(
    **inputs,
    assistant_model=draft_model,   # Specify the draft model
    max_new_tokens=200,
    do_sample=True,
    temperature=0.7,
)

print(tokenizer.decode(output[0], skip_special_tokens=True))

Draft Model Selection

| Draft model | Target model | Typical acceptance rate | |---|---|---| | LLaMA-3-8B | LLaMA-3-70B | 75–85% | | Tiny-LLaMA-1.1B | LLaMA-2-7B | 60–75% | | Distilled model | Teacher model | 85–95% | | Same model (layers 1-N/2) | Full model | 70–80% |

Domain matters: If the target is fine-tuned for a specific domain (medical, legal, code), the draft model should also be fine-tuned on that domain. A general draft + domain-specific target will have low acceptance rates.


Self-Speculative Decoding

A draft model isn't always needed. Self-speculative decoding uses the target model's early layers as the draft:

Python
# Early exit speculation:
# - Run only layers 1-16 to get a "draft" distribution (cheap)
# - Run all 32 layers to get the "target" distribution (expensive)
# - Apply the same acceptance criterion

# Medusa (Cai et al., 2024): Add extra "medusa heads" to the model
# Each head predicts a token K steps ahead from a single forward pass
# Verify all heads simultaneously — no separate draft model needed

When Speculative Decoding Helps

Good fit (high acceptance rate):

  • Factual question answering (predictable tokens)
  • Code completion (syntactically constrained)
  • Greedy/low-temperature sampling

Poor fit (low acceptance rate):

  • Creative writing (high entropy, many valid tokens)
  • High temperature sampling
  • Very different draft and target distributions

In practice, speculative decoding provides 2–3Ɨ speedup for standard instruction-following tasks with good draft models, with no quality degradation. It's enabled by default in vLLM and TGI for serving large models.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.