ReLU Activation
Why ReLU became the default hidden-layer activation โ its gradient, the dead neuron problem, and variants like ELU and GELU.
The ReLU Function
ReLU(z) = max(0, z)
Range: [0, โ)
Gradient: 1 if z > 0, 0 if z < 0 (undefined at z=0, conventionally 0)
Advantages over sigmoid/tanh:
1. No saturation for positive inputs โ gradient = 1 always
2. Sparse activation โ negative inputs are exactly 0
3. Computationally cheap โ just a max operation
4. Works well with Kaiming initialisation
Disadvantages:
1. Dead neurons: if z < 0 for all training examples, neuron is permanently off
2. Not zero-centred (all outputs โฅ 0)
3. Gradient = 0 for z < 0 โ no signal for that neuronReLU Gradient Flow
import torch
import torch.nn as nn
# ReLU gradient is 0 or 1 โ no attenuation for active neurons
z = torch.tensor([-3.0, -1.0, -0.1, 0.0, 0.1, 1.0, 3.0])
relu_out = torch.relu(z)
relu_grad = (z > 0).float() # 0 for z<=0, 1 for z>0
print(f"{'z':>8} {'ReLU(z)':>10} {'gradient':>10}")
for zi, ri, gi in zip(z, relu_out, relu_grad):
print(f"{zi.item():>8.2f} {ri.item():>10.4f} {gi.item():>10.4f}")
# Contrast with sigmoid max gradient of 0.25
# ReLU in a 10-layer network: gradient product = 1^10 = 1 (for active neurons)
# vs sigmoid: 0.25^10 = 9.5ร10โปโท
# In practice with Kaiming init, gradients stay healthy through many layers
deep_relu = nn.Sequential(
*[nn.Sequential(nn.Linear(64, 64, bias=False), nn.ReLU()) for _ in range(15)],
nn.Linear(64, 1),
)
# Kaiming init
for m in deep_relu.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
X = torch.randn(32, 64)
y = torch.randint(0, 2, (32,)).float()
loss = nn.BCEWithLogitsLoss()(deep_relu(X).squeeze(), y)
loss.backward()
# Check gradient norms through the network
grad_norms = [p.grad.norm().item() for p in deep_relu.parameters() if p.grad is not None]
print(f"\nGrad norms across 15 ReLU layers: min={min(grad_norms):.3f}, max={max(grad_norms):.3f}")Dead Neurons
import torch
import torch.nn as nn
def count_dead_neurons(model: nn.Module, X: torch.Tensor) -> dict:
"""Count neurons that output 0 for ALL samples in X (dead ReLU neurons)."""
dead_counts = {}
def make_hook(name: str):
def hook(module, input, output):
# A neuron is dead if it outputs 0 for every sample
is_dead = (output == 0).all(dim=0) # (hidden_dim,)
n_dead = is_dead.sum().item()
n_total = output.shape[1]
dead_counts[name] = (n_dead, n_total)
return hook
handles = []
for name, module in model.named_modules():
if isinstance(module, nn.ReLU):
h = module.register_forward_hook(make_hook(name))
handles.append(h)
with torch.no_grad():
model(X)
for h in handles:
h.remove()
for name, (n_dead, n_total) in dead_counts.items():
pct = 100 * n_dead / n_total
status = "OK" if pct < 10 else ("WARNING" if pct < 50 else "CRITICAL")
print(f"{name:30s}: {n_dead}/{n_total} dead ({pct:.1f}%) [{status}]")
return dead_counts
# Simulate dead neurons with bad initialisation (large negative bias)
model = nn.Sequential(
nn.Linear(20, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
# Force many dead neurons by setting bias very negative
with torch.no_grad():
model[0].bias.fill_(-5.0) # Most neurons will be negative โ dead
X = torch.randn(200, 20)
print("Dead neurons with large negative bias:")
count_dead_neurons(model, X)ReLU Variants
import torch
import torch.nn as nn
import torch.nn.functional as F
z = torch.linspace(-3, 3, 100)
# โโ Leaky ReLU: gradient = ฮฑ for z < 0 (avoids dead neurons) โโ
leaky = nn.LeakyReLU(negative_slope=0.01)
# Gradient: 0.01 for z < 0, 1.0 for z > 0
# โโ Parametric ReLU (PReLU): learned ฮฑ per channel โโ
prelu = nn.PReLU()
# ฮฑ is a learnable parameter โ can be different for each neuron
# โโ ELU: Exponential Linear Unit โโ
# ELU(z) = z for z > 0, ฮฑ(e^z - 1) for z โค 0
# Smooth, zero-centred outputs, never exactly dead
elu = nn.ELU(alpha=1.0)
# โโ GELU: Gaussian Error Linear Unit โโ
# GELU(z) = z ยท ฮฆ(z) where ฮฆ is the normal CDF
# Smooth approximation of ReLU, used in Transformers (BERT, GPT)
gelu = nn.GELU()
# โโ SiLU (Swish): z ยท ฯ(z) โโ
# Self-gated, smooth, used in EfficientNet, modern CNNs
silu = nn.SiLU()
# Compare outputs at z = [-2, -1, 0, 1, 2]
test_z = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
print(f"{'z':>8}", end="")
for name in ["ReLU", "Leaky", "ELU", "GELU", "SiLU"]:
print(f" {name:>8}", end="")
print()
activations = [nn.ReLU(), leaky, elu, gelu, silu]
vals = [act(test_z).detach() for act in activations]
for i, z_val in enumerate(test_z):
print(f"{z_val.item():>8.1f}", end="")
for v in vals:
print(f" {v[i].item():>8.4f}", end="")
print()When to Use Which Variant
Activation | Use when | Avoid when
-----------|--------------------------------------------|-----------------------
ReLU | Default for MLP, CNN hidden layers | Many dead neurons observed
LeakyReLU | Dead neurons are a problem (small data) | ReLU works fine already
PReLU | Have enough data to learn ฮฑ | Small datasets (overfit)
ELU | Want smooth, zero-centred activations | Speed is critical (exp op)
GELU | Transformer/BERT-style models | Simple MLP (overkill)
SiLU | EfficientNet-style CNNs, modern networks | Legacy code expecting ReLUimport torch.nn as nn
# Standard choices for different architectures:
# MLP for tabular data
mlp = nn.Sequential(
nn.Linear(20, 64), nn.ReLU(),
nn.Linear(64, 32), nn.ReLU(),
nn.Linear(32, 1),
)
# Transformer FFN
class TransformerFFN(nn.Module):
def __init__(self, d_model: int = 512, d_ff: int = 2048):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(), # GELU is standard in transformers
nn.Dropout(0.1),
nn.Linear(d_ff, d_model),
)
def forward(self, x):
return self.net(x)
# EfficientNet-style block
class MBConvBlock(nn.Module):
"""Uses SiLU (Swish) โ smooth and performant for CNNs."""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.conv(x))Interview Answer
"ReLU (Rectified Linear Unit) outputs max(0, z) โ zero for negative inputs, identity for positive. Its gradient is 1 for z > 0 and 0 for z โค 0. The key advantage over sigmoid: gradient doesn't attenuate for active neurons (gradient = 1, not โค 0.25), enabling training of much deeper networks. ReLU also gives sparse activations โ typically 50% of neurons are off โ which has a regularising effect. The downside is dead neurons: if z < 0 for all training inputs, the neuron never activates and never updates. Fixes: Leaky ReLU (gradient = 0.01 for z < 0), ELU (smooth, zero-centred), or better initialisation and learning rate. Modern transformers use GELU (smooth approximation of ReLU); modern CNNs often use SiLU/Swish. For most tabular and CNN tasks, plain ReLU with Kaiming initialisation is the reliable default."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.