Implement Softmax and Temperature Scaling
Implement softmax from scratch, handle numerical stability, and understand temperature scaling. See how softmax converts logits to probabilities in LLM token sampling.
What Softmax Does
Softmax converts a vector of raw scores (logits) into a probability distribution that sums to 1. It's used everywhere in LLMs: attention weights, next-token probabilities, classification heads.
softmax(x_i) = exp(x_i) / sum(exp(x_j) for all j)Basic Implementation
import math
def softmax_naive(logits: list[float]) -> list[float]:
"""Simple softmax ā numerically unstable for large values."""
exp_values = [math.exp(x) for x in logits]
total = sum(exp_values)
return [e / total for e in exp_values]
# Test
logits = [2.0, 1.0, 0.1]
probs = softmax_naive(logits)
print(probs) # [0.659, 0.242, 0.099]
print(sum(probs)) # 1.0Numerical Stability Problem
When logits are large (e.g., 1000), exp(1000) overflows to infinity:
logits_large = [1000.0, 1001.0, 1002.0]
try:
softmax_naive(logits_large)
except OverflowError:
print("Overflow!")Fix: Subtract the maximum value before computing exp. This doesn't change the output (numerically equivalent) but prevents overflow.
softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))def softmax(logits: list[float]) -> list[float]:
"""Numerically stable softmax."""
max_val = max(logits)
exp_shifted = [math.exp(x - max_val) for x in logits]
total = sum(exp_shifted)
return [e / total for e in exp_shifted]
# Works for any scale
logits_large = [1000.0, 1001.0, 1002.0]
probs = softmax(logits_large)
print(probs) # [0.090, 0.245, 0.665] ā same proportional answer
print(sum(probs)) # 1.0
# NumPy version
import numpy as np
def softmax_np(logits: np.ndarray) -> np.ndarray:
"""Vectorized numerically stable softmax."""
shifted = logits - np.max(logits)
exp_vals = np.exp(shifted)
return exp_vals / exp_vals.sum()Temperature Scaling
Temperature controls the sharpness of the probability distribution:
softmax_T(x_i) = softmax(x_i / T)- T = 1.0: Standard softmax
- T approaching 0: Distribution concentrates on the max (greedy / deterministic)
- T greater than 1: Distribution becomes more uniform (more random)
def softmax_with_temperature(logits: np.ndarray, temperature: float) -> np.ndarray:
"""Apply temperature scaling before softmax."""
if temperature <= 0:
raise ValueError("Temperature must be positive")
scaled = logits / temperature
return softmax_np(scaled)
logits = np.array([2.0, 1.0, 0.1])
print("T=0.1 (very sharp):")
print(softmax_with_temperature(logits, 0.1).round(4))
# [0.9999 0.0001 0.0000] ā almost always picks index 0
print("T=1.0 (standard):")
print(softmax_with_temperature(logits, 1.0).round(4))
# [0.6590 0.2424 0.0986]
print("T=2.0 (flatter):")
print(softmax_with_temperature(logits, 2.0).round(4))
# [0.4750 0.3543 0.1707] ā more uniformToken Sampling with Temperature
In an LLM, logits are computed for every token in the vocabulary (50,000+ tokens). Softmax converts these to probabilities, then sampling picks the next token:
import numpy as np
def sample_next_token(
logits: np.ndarray, # (vocab_size,) raw scores
temperature: float = 1.0,
top_k: int | None = None,
top_p: float | None = None,
) -> int:
"""Sample next token from logits with optional top-k/top-p filtering."""
# Temperature scaling
scaled_logits = logits / max(temperature, 1e-8)
# Top-k filtering: keep only top k logits
if top_k is not None and top_k > 0:
top_k_threshold = np.sort(scaled_logits)[-top_k]
scaled_logits = np.where(scaled_logits >= top_k_threshold, scaled_logits, -np.inf)
# Softmax to probabilities
probs = softmax_np(scaled_logits)
# Top-p (nucleus) filtering: keep tokens covering top p probability mass
if top_p is not None and 0 < top_p < 1.0:
sorted_probs = np.sort(probs)[::-1]
cumulative = np.cumsum(sorted_probs)
cutoff_idx = np.searchsorted(cumulative, top_p) + 1
cutoff_prob = sorted_probs[cutoff_idx - 1]
probs = np.where(probs >= cutoff_prob, probs, 0.0)
probs /= probs.sum() # Renormalize
# Sample
return int(np.random.choice(len(probs), p=probs))
# Simulate vocabulary sampling
vocab_size = 50257 # GPT-2 vocab
np.random.seed(42)
# Simulate logits ā a few tokens have high scores
logits = np.random.randn(vocab_size) * 2
top_token = np.argmax(logits)
logits[top_token] += 5 # Make one token clearly preferred
# Greedy (temperature near 0)
greedy_token = sample_next_token(logits, temperature=0.01)
print(f"Greedy: token {greedy_token}") # Always picks top_token
# Sampling with temperature
for _ in range(5):
token = sample_next_token(logits, temperature=0.7, top_k=50)
print(f"Sampled: token {token}")Log-Softmax for Training
During training, models use log-softmax + NLLLoss for numerical stability:
def log_softmax(logits: np.ndarray) -> np.ndarray:
"""Compute log(softmax(x)) with better numerical stability."""
shifted = logits - np.max(logits)
return shifted - np.log(np.sum(np.exp(shifted)))
# Cross-entropy loss = negative log probability of correct token
def cross_entropy_loss(logits: np.ndarray, target_idx: int) -> float:
log_probs = log_softmax(logits)
return -log_probs[target_idx] # Negative log likelihoodLog-softmax avoids computing exp then log ā it's more numerically stable and cheaper for training.
Interview Questions
Q: Why subtract max before softmax?
exp(x - max) can never overflow because x - max ⤠0, so exp(x - max) ⤠1. The denominator is at least 1 (the max term contributes exp(0) = 1). This is mathematically equivalent to the original but numerically safe.
Q: What does temperature 0 mean in practice?
Temperature 0 makes the distribution infinitely sharp ā all probability mass concentrates on the highest-logit token. Division by 0 is avoided in practice by using a very small epsilon (0.001). This is equivalent to greedy decoding.
Q: Why is top-p (nucleus sampling) often preferred over top-k?
Top-k always keeps exactly k tokens regardless of how confident the model is. If the model is very confident (one token has 99% probability), top-k=50 still samples from 50 tokens, introducing noise. Top-p adapts to the distribution ā when the model is confident, it samples from fewer tokens; when uncertain, it samples from more.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.