Activation Functions ā Interview Q&A
Five key interview questions on sigmoid, ReLU, softmax, dead neurons, and choosing activations for different architectures.
Q1: Why do we need activation functions at all?
Answer: Without activation functions, any number of linear layers is equivalent to a single linear transformation. If Layer2 = W2 and Layer1 = W1, then Layer2(Layer1(x)) = W2(W1x + b1) + b2 = (W2W1)x + c ā still linear. This is provable by composing the transformations. Non-linear activations (ReLU, sigmoid, tanh) break this collapse, giving the network the ability to approximate non-linear functions. The Universal Approximation Theorem proves that a single hidden layer with a non-polynomial activation can approximate any continuous function. Without activations, depth adds no expressivity ā you'd just have a linear model regardless of architecture.
import torch
import torch.nn as nn
# Proof: stacking linear layers without activation = one linear layer
W1 = torch.randn(5, 10)
W2 = torch.randn(3, 5)
b1 = torch.randn(5)
b2 = torch.randn(3)
X = torch.randn(8, 10)
# Two linear layers
out_two = X @ W1.T + b1
out_two = out_two @ W2.T + b2
# Equivalent single linear layer
W_eq = W2 @ W1
b_eq = W2 @ b1 + b2
out_one = X @ W_eq.T + b_eq
print(f"Max difference: {(out_two - out_one).abs().max().item():.2e}") # ~0Q2: What is the vanishing gradient problem and which activations cause it?
Answer: Vanishing gradients occur when the gradient signal shrinks to near-zero as it propagates backward through many layers. The chain rule multiplies local gradients together ā if each is small (< 1), the product shrinks exponentially. Sigmoid and tanh are the primary culprits: sigmoid's maximum gradient is 0.25 (at z=0) and approaches 0 at ±ā. A 10-layer sigmoid network attenuates gradients by at most 0.25^10 ā 10^-6. Early layers learn essentially nothing. ReLU solves this for positive inputs (gradient = 1), though dead neurons (gradient = 0 for z < 0) remain an issue. Residual connections (ResNet) provide a gradient highway that bypasses non-linearities.
import torch
import torch.nn as nn
# Compare gradient norms at the first layer for sigmoid vs ReLU
def first_layer_grad_norm(activation: nn.Module, n_layers: int = 8) -> float:
layers = []
for _ in range(n_layers):
layers.extend([nn.Linear(32, 32), activation])
layers.append(nn.Linear(32, 1))
model = nn.Sequential(*layers)
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
X = torch.randn(64, 32)
y = torch.randint(0, 2, (64,)).float()
nn.BCEWithLogitsLoss()(model(X).squeeze(), y).backward()
return list(model.parameters())[0].grad.norm().item()
print(f"Sigmoid first-layer grad: {first_layer_grad_norm(nn.Sigmoid()):.2e}")
print(f"ReLU first-layer grad: {first_layer_grad_norm(nn.ReLU()):.2e}")
# Sigmoid should be orders of magnitude smallerQ3: What is a dead ReLU neuron and how do you fix it?
Answer: A dead ReLU neuron always outputs 0 because its pre-activation is negative for every training sample. Since ReLU's gradient is 0 for z ⤠0, the neuron never receives a gradient and cannot learn ā it's permanently disabled. Causes: large negative bias after a bad update, or a too-large learning rate. Detection: check what fraction of neurons output 0 across a validation batch. Fixes: (1) Leaky ReLU ā gradient = 0.01 for z < 0, preventing permanent death; (2) ELU ā smooth exponential for z < 0; (3) Reduce learning rate; (4) Kaiming initialisation prevents them early in training. A dead neuron rate above 10% is a warning sign worth investigating.
import torch
import torch.nn as nn
def dead_neuron_rate(model: nn.Module, X: torch.Tensor) -> float:
"""Fraction of ReLU neurons that output 0 for all samples in X."""
dead = 0
total = 0
def hook(module, input, output):
nonlocal dead, total
is_dead = (output == 0).all(dim=0)
dead += is_dead.sum().item()
total += output.shape[1]
handles = [m.register_forward_hook(hook) for m in model.modules() if isinstance(m, nn.ReLU)]
with torch.no_grad():
model(X)
for h in handles:
h.remove()
return dead / total if total > 0 else 0.0
# Model with bad init
model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 1))
with torch.no_grad():
model[0].bias.fill_(-5.0) # simulate bad update
rate = dead_neuron_rate(model, torch.randn(200, 20))
print(f"Dead neuron rate: {rate:.1%}")Q4: When should you use sigmoid vs softmax at the output layer?
Answer: Use sigmoid for binary classification (1 output neuron) and multi-label classification (N independent binary decisions). Use softmax for multi-class classification (exactly one class is correct, classes are mutually exclusive). The key difference: sigmoid treats each output independently (each can be 0 or 1 independently); softmax forces the outputs to compete and sum to 1. Multi-label bug: using softmax for multi-label makes the model assign high probability to exactly one label, suppressing others. Binary bug: using softmax for binary creates unnecessary competition between the "positive" and "negative" class, though it technically works. In PyTorch: use BCEWithLogitsLoss (implicit sigmoid) for binary/multi-label, CrossEntropyLoss (implicit softmax) for multi-class.
import torch
import torch.nn as nn
# Binary: is_readmitted? ā 1 output neuron, sigmoid
binary_output = torch.tensor([0.7]) # logit
print(f"Binary prob: {torch.sigmoid(binary_output).item():.3f}")
# Multi-class: severity level [mild, moderate, severe, critical] ā 4 outputs, softmax
mc_logits = torch.tensor([1.2, 0.5, -0.3, 2.1])
print(f"Multi-class probs: {torch.softmax(mc_logits, dim=-1).numpy().round(3)}")
print(f"Sum: {torch.softmax(mc_logits, dim=-1).sum():.3f}") # must be 1.0
# Multi-label: [readmitted, ICU_admission, AF_detected] ā 3 outputs, sigmoid each
ml_logits = torch.tensor([0.8, -1.2, 2.0])
ml_probs = torch.sigmoid(ml_logits)
print(f"Multi-label probs: {ml_probs.numpy().round(3)}")
print(f"Sum: {ml_probs.sum():.3f}") # can be any valueQ5: What is temperature in softmax and when is it used?
Answer: Temperature T is a scalar that divides logits before softmax: softmax(z/T). T=1 is standard. T > 1 produces a flatter distribution ā the model is less confident. T < 1 makes it sharper ā more extreme probabilities. Temperature is used in three contexts: (1) Calibration ā if a model is overconfident (ECE is high), find T > 1 that minimises NLL on validation data; (2) Knowledge distillation ā teacher model uses T > 1 to produce soft labels that encode inter-class similarity, helping the student learn more than just the hard label; (3) Language model decoding ā sampling with T > 1 produces more diverse text (T=0.8ā1.2 is typical for creative generation, T < 0.5 for deterministic/factual responses).
import torch
logits = torch.tensor([2.0, 1.0, 0.5, -1.0])
for T in [0.5, 1.0, 2.0, 5.0]:
probs = torch.softmax(logits / T, dim=-1)
conf = probs.max().item()
print(f"T={T}: probs={probs.numpy().round(3)}, max_confidence={conf:.3f}")
# T=0.5: [0.87, 0.10, 0.03, 0.00] ā very confident
# T=2.0: [0.41, 0.28, 0.22, 0.10] ā much less confident
# T=5.0: [0.32, 0.27, 0.23, 0.18] ā nearly uniformInterview Answer
"Activation functions are essential because without them, any depth of linear layers reduces to a single linear transformation ā they provide non-linearity that enables neural networks to approximate complex functions. Sigmoid (max gradient 0.25) causes vanishing gradients in deep networks ā use ReLU for hidden layers instead; ReLU's gradient of 1 preserves signal (though dead neurons are possible). For outputs: sigmoid for binary classification (maps logit to probability), softmax for multi-class (mutual exclusion, sums to 1), sigmoid per-output for multi-label. Never apply sigmoid/softmax before the PyTorch loss functions ā BCEWithLogitsLoss and CrossEntropyLoss handle them internally in a numerically stable way. Temperature scales softmax sharpness: T > 1 for calibration and knowledge distillation, T < 1 for more peaked predictions. Modern architectures: GELU for Transformers, SiLU for modern CNNs, ReLU for most other cases."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.