Learnixo

Live Coding Interview Prep · Lesson 10 of 16

Implement Softmax: Numerical Stability Trick

What Softmax Does

Softmax converts a vector of raw scores (logits) into a probability distribution that sums to 1. It's used everywhere in LLMs: attention weights, next-token probabilities, classification heads.

softmax(x_i) = exp(x_i) / sum(exp(x_j) for all j)

Basic Implementation

Python
import math

def softmax_naive(logits: list[float]) -> list[float]:
    """Simple softmax — numerically unstable for large values."""
    exp_values = [math.exp(x) for x in logits]
    total = sum(exp_values)
    return [e / total for e in exp_values]

# Test
logits = [2.0, 1.0, 0.1]
probs = softmax_naive(logits)
print(probs)          # [0.659, 0.242, 0.099]
print(sum(probs))     # 1.0

Numerical Stability Problem

When logits are large (e.g., 1000), exp(1000) overflows to infinity:

Python
logits_large = [1000.0, 1001.0, 1002.0]
try:
    softmax_naive(logits_large)
except OverflowError:
    print("Overflow!")

Fix: Subtract the maximum value before computing exp. This doesn't change the output (numerically equivalent) but prevents overflow.

softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
Python
def softmax(logits: list[float]) -> list[float]:
    """Numerically stable softmax."""
    max_val = max(logits)
    exp_shifted = [math.exp(x - max_val) for x in logits]
    total = sum(exp_shifted)
    return [e / total for e in exp_shifted]

# Works for any scale
logits_large = [1000.0, 1001.0, 1002.0]
probs = softmax(logits_large)
print(probs)   # [0.090, 0.245, 0.665]  same proportional answer
print(sum(probs))  # 1.0

# NumPy version
import numpy as np

def softmax_np(logits: np.ndarray) -> np.ndarray:
    """Vectorized numerically stable softmax."""
    shifted = logits - np.max(logits)
    exp_vals = np.exp(shifted)
    return exp_vals / exp_vals.sum()

Temperature Scaling

Temperature controls the sharpness of the probability distribution:

softmax_T(x_i) = softmax(x_i / T)
  • T = 1.0: Standard softmax
  • T approaching 0: Distribution concentrates on the max (greedy / deterministic)
  • T greater than 1: Distribution becomes more uniform (more random)
Python
def softmax_with_temperature(logits: np.ndarray, temperature: float) -> np.ndarray:
    """Apply temperature scaling before softmax."""
    if temperature <= 0:
        raise ValueError("Temperature must be positive")

    scaled = logits / temperature
    return softmax_np(scaled)

logits = np.array([2.0, 1.0, 0.1])

print("T=0.1 (very sharp):")
print(softmax_with_temperature(logits, 0.1).round(4))
# [0.9999  0.0001  0.0000]  almost always picks index 0

print("T=1.0 (standard):")
print(softmax_with_temperature(logits, 1.0).round(4))
# [0.6590  0.2424  0.0986]

print("T=2.0 (flatter):")
print(softmax_with_temperature(logits, 2.0).round(4))
# [0.4750  0.3543  0.1707]  more uniform

Token Sampling with Temperature

In an LLM, logits are computed for every token in the vocabulary (50,000+ tokens). Softmax converts these to probabilities, then sampling picks the next token:

Python
import numpy as np

def sample_next_token(
    logits: np.ndarray,   # (vocab_size,) raw scores
    temperature: float = 1.0,
    top_k: int | None = None,
    top_p: float | None = None,
) -> int:
    """Sample next token from logits with optional top-k/top-p filtering."""

    # Temperature scaling
    scaled_logits = logits / max(temperature, 1e-8)

    # Top-k filtering: keep only top k logits
    if top_k is not None and top_k > 0:
        top_k_threshold = np.sort(scaled_logits)[-top_k]
        scaled_logits = np.where(scaled_logits >= top_k_threshold, scaled_logits, -np.inf)

    # Softmax to probabilities
    probs = softmax_np(scaled_logits)

    # Top-p (nucleus) filtering: keep tokens covering top p probability mass
    if top_p is not None and 0 < top_p < 1.0:
        sorted_probs = np.sort(probs)[::-1]
        cumulative = np.cumsum(sorted_probs)
        cutoff_idx = np.searchsorted(cumulative, top_p) + 1
        cutoff_prob = sorted_probs[cutoff_idx - 1]
        probs = np.where(probs >= cutoff_prob, probs, 0.0)
        probs /= probs.sum()  # Renormalize

    # Sample
    return int(np.random.choice(len(probs), p=probs))

# Simulate vocabulary sampling
vocab_size = 50257  # GPT-2 vocab
np.random.seed(42)

# Simulate logits  a few tokens have high scores
logits = np.random.randn(vocab_size) * 2
top_token = np.argmax(logits)
logits[top_token] += 5  # Make one token clearly preferred

# Greedy (temperature near 0)
greedy_token = sample_next_token(logits, temperature=0.01)
print(f"Greedy: token {greedy_token}")  # Always picks top_token

# Sampling with temperature
for _ in range(5):
    token = sample_next_token(logits, temperature=0.7, top_k=50)
    print(f"Sampled: token {token}")

Log-Softmax for Training

During training, models use log-softmax + NLLLoss for numerical stability:

Python
def log_softmax(logits: np.ndarray) -> np.ndarray:
    """Compute log(softmax(x)) with better numerical stability."""
    shifted = logits - np.max(logits)
    return shifted - np.log(np.sum(np.exp(shifted)))

# Cross-entropy loss = negative log probability of correct token
def cross_entropy_loss(logits: np.ndarray, target_idx: int) -> float:
    log_probs = log_softmax(logits)
    return -log_probs[target_idx]  # Negative log likelihood

Log-softmax avoids computing exp then log — it's more numerically stable and cheaper for training.


Interview Questions

Q: Why subtract max before softmax?

exp(x - max) can never overflow because x - max ≤ 0, so exp(x - max) ≤ 1. The denominator is at least 1 (the max term contributes exp(0) = 1). This is mathematically equivalent to the original but numerically safe.

Q: What does temperature 0 mean in practice?

Temperature 0 makes the distribution infinitely sharp — all probability mass concentrates on the highest-logit token. Division by 0 is avoided in practice by using a very small epsilon (0.001). This is equivalent to greedy decoding.

Q: Why is top-p (nucleus sampling) often preferred over top-k?

Top-k always keeps exactly k tokens regardless of how confident the model is. If the model is very confident (one token has 99% probability), top-k=50 still samples from 50 tokens, introducing noise. Top-p adapts to the distribution — when the model is confident, it samples from fewer tokens; when uncertain, it samples from more.