Live Coding Interview Prep · Lesson 10 of 16
Implement Softmax: Numerical Stability Trick
What Softmax Does
Softmax converts a vector of raw scores (logits) into a probability distribution that sums to 1. It's used everywhere in LLMs: attention weights, next-token probabilities, classification heads.
softmax(x_i) = exp(x_i) / sum(exp(x_j) for all j)Basic Implementation
import math
def softmax_naive(logits: list[float]) -> list[float]:
"""Simple softmax — numerically unstable for large values."""
exp_values = [math.exp(x) for x in logits]
total = sum(exp_values)
return [e / total for e in exp_values]
# Test
logits = [2.0, 1.0, 0.1]
probs = softmax_naive(logits)
print(probs) # [0.659, 0.242, 0.099]
print(sum(probs)) # 1.0Numerical Stability Problem
When logits are large (e.g., 1000), exp(1000) overflows to infinity:
logits_large = [1000.0, 1001.0, 1002.0]
try:
softmax_naive(logits_large)
except OverflowError:
print("Overflow!")Fix: Subtract the maximum value before computing exp. This doesn't change the output (numerically equivalent) but prevents overflow.
softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))def softmax(logits: list[float]) -> list[float]:
"""Numerically stable softmax."""
max_val = max(logits)
exp_shifted = [math.exp(x - max_val) for x in logits]
total = sum(exp_shifted)
return [e / total for e in exp_shifted]
# Works for any scale
logits_large = [1000.0, 1001.0, 1002.0]
probs = softmax(logits_large)
print(probs) # [0.090, 0.245, 0.665] — same proportional answer
print(sum(probs)) # 1.0
# NumPy version
import numpy as np
def softmax_np(logits: np.ndarray) -> np.ndarray:
"""Vectorized numerically stable softmax."""
shifted = logits - np.max(logits)
exp_vals = np.exp(shifted)
return exp_vals / exp_vals.sum()Temperature Scaling
Temperature controls the sharpness of the probability distribution:
softmax_T(x_i) = softmax(x_i / T)- T = 1.0: Standard softmax
- T approaching 0: Distribution concentrates on the max (greedy / deterministic)
- T greater than 1: Distribution becomes more uniform (more random)
def softmax_with_temperature(logits: np.ndarray, temperature: float) -> np.ndarray:
"""Apply temperature scaling before softmax."""
if temperature <= 0:
raise ValueError("Temperature must be positive")
scaled = logits / temperature
return softmax_np(scaled)
logits = np.array([2.0, 1.0, 0.1])
print("T=0.1 (very sharp):")
print(softmax_with_temperature(logits, 0.1).round(4))
# [0.9999 0.0001 0.0000] — almost always picks index 0
print("T=1.0 (standard):")
print(softmax_with_temperature(logits, 1.0).round(4))
# [0.6590 0.2424 0.0986]
print("T=2.0 (flatter):")
print(softmax_with_temperature(logits, 2.0).round(4))
# [0.4750 0.3543 0.1707] — more uniformToken Sampling with Temperature
In an LLM, logits are computed for every token in the vocabulary (50,000+ tokens). Softmax converts these to probabilities, then sampling picks the next token:
import numpy as np
def sample_next_token(
logits: np.ndarray, # (vocab_size,) raw scores
temperature: float = 1.0,
top_k: int | None = None,
top_p: float | None = None,
) -> int:
"""Sample next token from logits with optional top-k/top-p filtering."""
# Temperature scaling
scaled_logits = logits / max(temperature, 1e-8)
# Top-k filtering: keep only top k logits
if top_k is not None and top_k > 0:
top_k_threshold = np.sort(scaled_logits)[-top_k]
scaled_logits = np.where(scaled_logits >= top_k_threshold, scaled_logits, -np.inf)
# Softmax to probabilities
probs = softmax_np(scaled_logits)
# Top-p (nucleus) filtering: keep tokens covering top p probability mass
if top_p is not None and 0 < top_p < 1.0:
sorted_probs = np.sort(probs)[::-1]
cumulative = np.cumsum(sorted_probs)
cutoff_idx = np.searchsorted(cumulative, top_p) + 1
cutoff_prob = sorted_probs[cutoff_idx - 1]
probs = np.where(probs >= cutoff_prob, probs, 0.0)
probs /= probs.sum() # Renormalize
# Sample
return int(np.random.choice(len(probs), p=probs))
# Simulate vocabulary sampling
vocab_size = 50257 # GPT-2 vocab
np.random.seed(42)
# Simulate logits — a few tokens have high scores
logits = np.random.randn(vocab_size) * 2
top_token = np.argmax(logits)
logits[top_token] += 5 # Make one token clearly preferred
# Greedy (temperature near 0)
greedy_token = sample_next_token(logits, temperature=0.01)
print(f"Greedy: token {greedy_token}") # Always picks top_token
# Sampling with temperature
for _ in range(5):
token = sample_next_token(logits, temperature=0.7, top_k=50)
print(f"Sampled: token {token}")Log-Softmax for Training
During training, models use log-softmax + NLLLoss for numerical stability:
def log_softmax(logits: np.ndarray) -> np.ndarray:
"""Compute log(softmax(x)) with better numerical stability."""
shifted = logits - np.max(logits)
return shifted - np.log(np.sum(np.exp(shifted)))
# Cross-entropy loss = negative log probability of correct token
def cross_entropy_loss(logits: np.ndarray, target_idx: int) -> float:
log_probs = log_softmax(logits)
return -log_probs[target_idx] # Negative log likelihoodLog-softmax avoids computing exp then log — it's more numerically stable and cheaper for training.
Interview Questions
Q: Why subtract max before softmax?
exp(x - max) can never overflow because x - max ≤ 0, so exp(x - max) ≤ 1. The denominator is at least 1 (the max term contributes exp(0) = 1). This is mathematically equivalent to the original but numerically safe.
Q: What does temperature 0 mean in practice?
Temperature 0 makes the distribution infinitely sharp — all probability mass concentrates on the highest-logit token. Division by 0 is avoided in practice by using a very small epsilon (0.001). This is equivalent to greedy decoding.
Q: Why is top-p (nucleus sampling) often preferred over top-k?
Top-k always keeps exactly k tokens regardless of how confident the model is. If the model is very confident (one token has 99% probability), top-k=50 still samples from 50 tokens, introducing noise. Top-p adapts to the distribution — when the model is confident, it samples from fewer tokens; when uncertain, it samples from more.