Learnixo
Back to blog
AI Systemsintermediate

Implement Softmax and Temperature Scaling

Implement softmax from scratch, handle numerical stability, and understand temperature scaling. See how softmax converts logits to probabilities in LLM token sampling.

Asma Hafeez KhanMay 16, 20265 min read
Live CodingSoftmaxTemperaturePython
Share:š•

What Softmax Does

Softmax converts a vector of raw scores (logits) into a probability distribution that sums to 1. It's used everywhere in LLMs: attention weights, next-token probabilities, classification heads.

softmax(x_i) = exp(x_i) / sum(exp(x_j) for all j)

Basic Implementation

Python
import math

def softmax_naive(logits: list[float]) -> list[float]:
    """Simple softmax — numerically unstable for large values."""
    exp_values = [math.exp(x) for x in logits]
    total = sum(exp_values)
    return [e / total for e in exp_values]

# Test
logits = [2.0, 1.0, 0.1]
probs = softmax_naive(logits)
print(probs)          # [0.659, 0.242, 0.099]
print(sum(probs))     # 1.0

Numerical Stability Problem

When logits are large (e.g., 1000), exp(1000) overflows to infinity:

Python
logits_large = [1000.0, 1001.0, 1002.0]
try:
    softmax_naive(logits_large)
except OverflowError:
    print("Overflow!")

Fix: Subtract the maximum value before computing exp. This doesn't change the output (numerically equivalent) but prevents overflow.

softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
Python
def softmax(logits: list[float]) -> list[float]:
    """Numerically stable softmax."""
    max_val = max(logits)
    exp_shifted = [math.exp(x - max_val) for x in logits]
    total = sum(exp_shifted)
    return [e / total for e in exp_shifted]

# Works for any scale
logits_large = [1000.0, 1001.0, 1002.0]
probs = softmax(logits_large)
print(probs)   # [0.090, 0.245, 0.665] — same proportional answer
print(sum(probs))  # 1.0

# NumPy version
import numpy as np

def softmax_np(logits: np.ndarray) -> np.ndarray:
    """Vectorized numerically stable softmax."""
    shifted = logits - np.max(logits)
    exp_vals = np.exp(shifted)
    return exp_vals / exp_vals.sum()

Temperature Scaling

Temperature controls the sharpness of the probability distribution:

softmax_T(x_i) = softmax(x_i / T)
  • T = 1.0: Standard softmax
  • T approaching 0: Distribution concentrates on the max (greedy / deterministic)
  • T greater than 1: Distribution becomes more uniform (more random)
Python
def softmax_with_temperature(logits: np.ndarray, temperature: float) -> np.ndarray:
    """Apply temperature scaling before softmax."""
    if temperature <= 0:
        raise ValueError("Temperature must be positive")

    scaled = logits / temperature
    return softmax_np(scaled)

logits = np.array([2.0, 1.0, 0.1])

print("T=0.1 (very sharp):")
print(softmax_with_temperature(logits, 0.1).round(4))
# [0.9999  0.0001  0.0000] — almost always picks index 0

print("T=1.0 (standard):")
print(softmax_with_temperature(logits, 1.0).round(4))
# [0.6590  0.2424  0.0986]

print("T=2.0 (flatter):")
print(softmax_with_temperature(logits, 2.0).round(4))
# [0.4750  0.3543  0.1707] — more uniform

Token Sampling with Temperature

In an LLM, logits are computed for every token in the vocabulary (50,000+ tokens). Softmax converts these to probabilities, then sampling picks the next token:

Python
import numpy as np

def sample_next_token(
    logits: np.ndarray,   # (vocab_size,) raw scores
    temperature: float = 1.0,
    top_k: int | None = None,
    top_p: float | None = None,
) -> int:
    """Sample next token from logits with optional top-k/top-p filtering."""

    # Temperature scaling
    scaled_logits = logits / max(temperature, 1e-8)

    # Top-k filtering: keep only top k logits
    if top_k is not None and top_k > 0:
        top_k_threshold = np.sort(scaled_logits)[-top_k]
        scaled_logits = np.where(scaled_logits >= top_k_threshold, scaled_logits, -np.inf)

    # Softmax to probabilities
    probs = softmax_np(scaled_logits)

    # Top-p (nucleus) filtering: keep tokens covering top p probability mass
    if top_p is not None and 0 < top_p < 1.0:
        sorted_probs = np.sort(probs)[::-1]
        cumulative = np.cumsum(sorted_probs)
        cutoff_idx = np.searchsorted(cumulative, top_p) + 1
        cutoff_prob = sorted_probs[cutoff_idx - 1]
        probs = np.where(probs >= cutoff_prob, probs, 0.0)
        probs /= probs.sum()  # Renormalize

    # Sample
    return int(np.random.choice(len(probs), p=probs))

# Simulate vocabulary sampling
vocab_size = 50257  # GPT-2 vocab
np.random.seed(42)

# Simulate logits — a few tokens have high scores
logits = np.random.randn(vocab_size) * 2
top_token = np.argmax(logits)
logits[top_token] += 5  # Make one token clearly preferred

# Greedy (temperature near 0)
greedy_token = sample_next_token(logits, temperature=0.01)
print(f"Greedy: token {greedy_token}")  # Always picks top_token

# Sampling with temperature
for _ in range(5):
    token = sample_next_token(logits, temperature=0.7, top_k=50)
    print(f"Sampled: token {token}")

Log-Softmax for Training

During training, models use log-softmax + NLLLoss for numerical stability:

Python
def log_softmax(logits: np.ndarray) -> np.ndarray:
    """Compute log(softmax(x)) with better numerical stability."""
    shifted = logits - np.max(logits)
    return shifted - np.log(np.sum(np.exp(shifted)))

# Cross-entropy loss = negative log probability of correct token
def cross_entropy_loss(logits: np.ndarray, target_idx: int) -> float:
    log_probs = log_softmax(logits)
    return -log_probs[target_idx]  # Negative log likelihood

Log-softmax avoids computing exp then log — it's more numerically stable and cheaper for training.


Interview Questions

Q: Why subtract max before softmax?

exp(x - max) can never overflow because x - max ≤ 0, so exp(x - max) ≤ 1. The denominator is at least 1 (the max term contributes exp(0) = 1). This is mathematically equivalent to the original but numerically safe.

Q: What does temperature 0 mean in practice?

Temperature 0 makes the distribution infinitely sharp — all probability mass concentrates on the highest-logit token. Division by 0 is avoided in practice by using a very small epsilon (0.001). This is equivalent to greedy decoding.

Q: Why is top-p (nucleus sampling) often preferred over top-k?

Top-k always keeps exactly k tokens regardless of how confident the model is. If the model is very confident (one token has 99% probability), top-k=50 still samples from 50 tokens, introducing noise. Top-p adapts to the distribution — when the model is confident, it samples from fewer tokens; when uncertain, it samples from more.

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.