Deep Learning for AI Interviews · Lesson 40 of 56
Softmax for Multi-Class Output
The Softmax Function
Softmax converts a vector of logits to a probability distribution:
softmax(z)_i = exp(z_i) / Σ_j exp(z_j)
Properties:
- All outputs in (0, 1)
- Outputs sum to exactly 1
- Preserves relative ordering (argmax is unchanged)
- Amplifies differences: the largest logit gets disproportionately large probability
Example (5-class clinical severity score):
logits = [2.1, 0.5, -1.2, 0.8, 1.5]
probs = [0.60, 0.12, 0.02, 0.16, 0.33] → sums to 1.0Softmax in Practice
import torch
import torch.nn as nn
import numpy as np
# Manual softmax
def softmax_manual(z: torch.Tensor) -> torch.Tensor:
exp_z = torch.exp(z - z.max()) # subtract max for numerical stability
return exp_z / exp_z.sum()
# Example: 5-class severity classification
logits = torch.tensor([2.1, 0.5, -1.2, 0.8, 1.5])
probs_manual = softmax_manual(logits)
probs_torch = torch.softmax(logits, dim=-1)
print(f"Logits: {logits.numpy()}")
print(f"Manual: {probs_manual.numpy().round(4)}")
print(f"Torch: {probs_torch.numpy().round(4)}")
print(f"Sum: {probs_torch.sum().item():.6f}") # should be 1.0
# Batch softmax (most common in practice)
batch_logits = torch.randn(32, 5) # (batch=32, n_classes=5)
batch_probs = torch.softmax(batch_logits, dim=-1) # dim=-1 = last dim
print(f"\nBatch softmax: {batch_probs.shape}, row sums: {batch_probs.sum(dim=-1).mean():.6f}")Numerical Stability
import torch
# Naive softmax overflows for large logits
def softmax_naive(z: torch.Tensor) -> torch.Tensor:
exp_z = torch.exp(z) # ← exp(1000) = inf
return exp_z / exp_z.sum()
z_large = torch.tensor([1000.0, 900.0, 800.0])
# Naive: NaN due to inf/inf
result_naive = softmax_naive(z_large)
print(f"Naive softmax: {result_naive}") # tensor([nan, nan, nan])
# Stable: subtract max before exponentiation
# softmax(z - max) = softmax(z) because the max term cancels
def softmax_stable(z: torch.Tensor) -> torch.Tensor:
z_shifted = z - z.max() # max becomes 0, others negative
exp_z = torch.exp(z_shifted) # exp(0) = 1, others < 1: no overflow
return exp_z / exp_z.sum()
result_stable = softmax_stable(z_large)
print(f"Stable softmax: {result_stable}") # tensor([1, 0, 0])
# PyTorch's torch.softmax is already stable
result_torch = torch.softmax(z_large, dim=-1)
print(f"PyTorch softmax: {result_torch}") # same as stable
# Always use CrossEntropyLoss (not NLLLoss + log_softmax manually)
# CrossEntropyLoss uses log-sum-exp trick for numerical stabilityCrossEntropyLoss vs Manual
import torch
import torch.nn as nn
# In PyTorch:
# CrossEntropyLoss = log_softmax + NLLLoss in one numerically stable operation
# Takes RAW LOGITS as input — do NOT apply softmax first
model_output = torch.randn(16, 5) # (batch, n_classes) logits
labels = torch.randint(0, 5, (16,)) # class indices, not one-hot
# Correct: CrossEntropyLoss with raw logits
criterion = nn.CrossEntropyLoss()
loss_correct = criterion(model_output, labels)
# Also correct (but don't do it — less stable and slower)
log_probs = torch.log_softmax(model_output, dim=-1)
loss_manual = nn.NLLLoss()(log_probs, labels)
print(f"CrossEntropyLoss: {loss_correct.item():.4f}")
print(f"LogSoftmax+NLL: {loss_manual.item():.4f}") # same
# WRONG (very common bug): applying softmax before CrossEntropyLoss
probs_wrong = torch.softmax(model_output, dim=-1)
# loss_wrong = criterion(probs_wrong, labels) # double softmax in the loss! Wrong.Temperature Scaling
import torch
# Softmax Temperature:
# softmax(z / T)_i = exp(z_i / T) / Σ_j exp(z_j / T)
#
# T = 1: standard softmax
# T > 1: "softer" distribution — less confident, more uniform
# T < 1: "harder" distribution — more confident, more peaked
# T → ∞: uniform distribution (1/K for each class)
# T → 0: one-hot (argmax)
logits = torch.tensor([2.0, 1.0, 0.5, -0.5, -1.0])
for T in [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]:
probs = torch.softmax(logits / T, dim=-1)
entropy = -(probs * torch.log(probs + 1e-8)).sum().item()
print(f"T={T:>5.1f}: probs=[{', '.join(f'{p:.3f}' for p in probs.tolist())}], entropy={entropy:.3f}")
# Use cases for temperature:
# 1. Calibration: scale up T if model is overconfident (ECE is high)
# 2. Soft labels for knowledge distillation: T > 1 on teacher logits
# 3. Language model sampling: T controls creativity vs coherence
# 4. Contrastive learning: temperature controls separation marginAttention Softmax
import torch
import math
def scaled_dot_product_attention(
Q: torch.Tensor, # (batch, heads, seq_q, d_k)
K: torch.Tensor, # (batch, heads, seq_k, d_k)
V: torch.Tensor, # (batch, heads, seq_k, d_v)
mask: torch.Tensor = None,
) -> torch.Tensor:
d_k = Q.shape[-1]
# Similarity scores: (batch, heads, seq_q, seq_k)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) # scale by √d_k
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
# Softmax converts scores to attention weights (sum to 1 across seq_k)
attn_weights = torch.softmax(scores, dim=-1)
# Weighted sum of values
return attn_weights @ V
batch, heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch, heads, seq_len, d_k)
K = torch.randn(batch, heads, seq_len, d_k)
V = torch.randn(batch, heads, seq_len, d_k)
out = scaled_dot_product_attention(Q, K, V)
print(f"Attention output: {out.shape}") # (2, 8, 10, 64)
# The √d_k scaling prevents softmax from saturating when d_k is large
# Without scaling: Q·K.T has variance d_k → softmax inputs are large → very peaked distribution
# With scaling: variance = 1 → reasonable softmax distributionInterview Answer
"Softmax converts a vector of raw logits to a probability distribution: softmax(z)_i = exp(z_i) / Σ_j exp(z_j). All outputs are in (0,1) and sum to exactly 1. It's used at the output of multi-class classifiers and inside attention mechanisms. Numerical stability: subtract the maximum logit before exponentiation to prevent overflow (exp(large_number) = inf). In PyTorch, never apply softmax before CrossEntropyLoss — the loss does log_softmax internally in a numerically stable way using the log-sum-exp trick; applying softmax first causes a double-softmax bug. Temperature scaling divides logits before softmax: T > 1 produces a flatter, less confident distribution (used for calibration and knowledge distillation); T < 1 makes it more peaked. In attention, scores are scaled by √d_k before softmax to prevent saturation — without scaling, dot products grow with dimension and softmax becomes too sharp, producing near one-hot attention weights that can't attend to multiple positions."