Learnixo

Deep Learning for AI Interviews · Lesson 40 of 56

Softmax for Multi-Class Output

The Softmax Function

Softmax converts a vector of logits to a probability distribution:

  softmax(z)_i = exp(z_i) / Σ_j exp(z_j)

Properties:
  - All outputs in (0, 1)
  - Outputs sum to exactly 1
  - Preserves relative ordering (argmax is unchanged)
  - Amplifies differences: the largest logit gets disproportionately large probability

Example (5-class clinical severity score):
  logits = [2.1, 0.5, -1.2, 0.8, 1.5]
  probs  = [0.60, 0.12, 0.02, 0.16, 0.33] → sums to 1.0

Softmax in Practice

Python
import torch
import torch.nn as nn
import numpy as np

# Manual softmax
def softmax_manual(z: torch.Tensor) -> torch.Tensor:
    exp_z = torch.exp(z - z.max())   # subtract max for numerical stability
    return exp_z / exp_z.sum()

# Example: 5-class severity classification
logits = torch.tensor([2.1, 0.5, -1.2, 0.8, 1.5])
probs_manual = softmax_manual(logits)
probs_torch  = torch.softmax(logits, dim=-1)

print(f"Logits: {logits.numpy()}")
print(f"Manual: {probs_manual.numpy().round(4)}")
print(f"Torch:  {probs_torch.numpy().round(4)}")
print(f"Sum:    {probs_torch.sum().item():.6f}")  # should be 1.0

# Batch softmax (most common in practice)
batch_logits = torch.randn(32, 5)   # (batch=32, n_classes=5)
batch_probs  = torch.softmax(batch_logits, dim=-1)  # dim=-1 = last dim
print(f"\nBatch softmax: {batch_probs.shape}, row sums: {batch_probs.sum(dim=-1).mean():.6f}")

Numerical Stability

Python
import torch

# Naive softmax overflows for large logits
def softmax_naive(z: torch.Tensor) -> torch.Tensor:
    exp_z = torch.exp(z)    #  exp(1000) = inf
    return exp_z / exp_z.sum()

z_large = torch.tensor([1000.0, 900.0, 800.0])

# Naive: NaN due to inf/inf
result_naive = softmax_naive(z_large)
print(f"Naive softmax: {result_naive}")   # tensor([nan, nan, nan])

# Stable: subtract max before exponentiation
# softmax(z - max) = softmax(z) because the max term cancels
def softmax_stable(z: torch.Tensor) -> torch.Tensor:
    z_shifted = z - z.max()        # max becomes 0, others negative
    exp_z = torch.exp(z_shifted)   # exp(0) = 1, others < 1: no overflow
    return exp_z / exp_z.sum()

result_stable = softmax_stable(z_large)
print(f"Stable softmax: {result_stable}")   # tensor([1, 0, 0])

# PyTorch's torch.softmax is already stable
result_torch = torch.softmax(z_large, dim=-1)
print(f"PyTorch softmax: {result_torch}")   # same as stable

# Always use CrossEntropyLoss (not NLLLoss + log_softmax manually)
# CrossEntropyLoss uses log-sum-exp trick for numerical stability

CrossEntropyLoss vs Manual

Python
import torch
import torch.nn as nn

# In PyTorch:
# CrossEntropyLoss = log_softmax + NLLLoss in one numerically stable operation
# Takes RAW LOGITS as input  do NOT apply softmax first

model_output = torch.randn(16, 5)   # (batch, n_classes) logits
labels = torch.randint(0, 5, (16,))  # class indices, not one-hot

# Correct: CrossEntropyLoss with raw logits
criterion = nn.CrossEntropyLoss()
loss_correct = criterion(model_output, labels)

# Also correct (but don't do it — less stable and slower)
log_probs = torch.log_softmax(model_output, dim=-1)
loss_manual = nn.NLLLoss()(log_probs, labels)

print(f"CrossEntropyLoss: {loss_correct.item():.4f}")
print(f"LogSoftmax+NLL:   {loss_manual.item():.4f}")   # same

# WRONG (very common bug): applying softmax before CrossEntropyLoss
probs_wrong = torch.softmax(model_output, dim=-1)
# loss_wrong = criterion(probs_wrong, labels)  # double softmax in the loss! Wrong.

Temperature Scaling

Python
import torch

# Softmax Temperature:
# softmax(z / T)_i = exp(z_i / T) / Σ_j exp(z_j / T)
#
# T = 1: standard softmax
# T > 1: "softer" distribution  less confident, more uniform
# T < 1: "harder" distribution  more confident, more peaked
# T  : uniform distribution (1/K for each class)
# T  0: one-hot (argmax)

logits = torch.tensor([2.0, 1.0, 0.5, -0.5, -1.0])

for T in [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]:
    probs = torch.softmax(logits / T, dim=-1)
    entropy = -(probs * torch.log(probs + 1e-8)).sum().item()
    print(f"T={T:>5.1f}: probs=[{', '.join(f'{p:.3f}' for p in probs.tolist())}], entropy={entropy:.3f}")

# Use cases for temperature:
# 1. Calibration: scale up T if model is overconfident (ECE is high)
# 2. Soft labels for knowledge distillation: T > 1 on teacher logits
# 3. Language model sampling: T controls creativity vs coherence
# 4. Contrastive learning: temperature controls separation margin

Attention Softmax

Python
import torch
import math

def scaled_dot_product_attention(
    Q: torch.Tensor,   # (batch, heads, seq_q, d_k)
    K: torch.Tensor,   # (batch, heads, seq_k, d_k)
    V: torch.Tensor,   # (batch, heads, seq_k, d_v)
    mask: torch.Tensor = None,
) -> torch.Tensor:
    d_k = Q.shape[-1]
    
    # Similarity scores: (batch, heads, seq_q, seq_k)
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)  # scale by √d_k
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    
    # Softmax converts scores to attention weights (sum to 1 across seq_k)
    attn_weights = torch.softmax(scores, dim=-1)
    
    # Weighted sum of values
    return attn_weights @ V

batch, heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch, heads, seq_len, d_k)
K = torch.randn(batch, heads, seq_len, d_k)
V = torch.randn(batch, heads, seq_len, d_k)

out = scaled_dot_product_attention(Q, K, V)
print(f"Attention output: {out.shape}")  # (2, 8, 10, 64)

# The √d_k scaling prevents softmax from saturating when d_k is large
# Without scaling: Q·K.T has variance d_k  softmax inputs are large  very peaked distribution
# With scaling: variance = 1  reasonable softmax distribution

Interview Answer

"Softmax converts a vector of raw logits to a probability distribution: softmax(z)_i = exp(z_i) / Σ_j exp(z_j). All outputs are in (0,1) and sum to exactly 1. It's used at the output of multi-class classifiers and inside attention mechanisms. Numerical stability: subtract the maximum logit before exponentiation to prevent overflow (exp(large_number) = inf). In PyTorch, never apply softmax before CrossEntropyLoss — the loss does log_softmax internally in a numerically stable way using the log-sum-exp trick; applying softmax first causes a double-softmax bug. Temperature scaling divides logits before softmax: T > 1 produces a flatter, less confident distribution (used for calibration and knowledge distillation); T < 1 makes it more peaked. In attention, scores are scaled by √d_k before softmax to prevent saturation — without scaling, dot products grow with dimension and softmax becomes too sharp, producing near one-hot attention weights that can't attend to multiple positions."