Deep Learning for AI Interviews · Lesson 25 of 56
The Forward Pass: Prediction Step by Step
What the Forward Pass Does
Input → [Layer 1] → [Layer 2] → ... → [Layer N] → Output → Loss
At each layer:
1. Multiply: Z = X @ W.T + b (linear transformation)
2. Activate: A = activation(Z) (non-linearity)
PyTorch records every operation on a computation graph.
This graph is then traversed backwards during backpropagation
to compute gradients automatically (autograd).Tensor Shapes Through a Network
import torch
import torch.nn as nn
class ClinicalMLP(nn.Module):
"""MLP for predicting 30-day readmission from clinical features."""
def __init__(self, n_features: int = 20):
super().__init__()
self.layer1 = nn.Linear(n_features, 64)
self.layer2 = nn.Linear(64, 32)
self.layer3 = nn.Linear(32, 1)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(32)
self.drop = nn.Dropout(p=0.3)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, 20)
z1 = self.layer1(x) # (batch, 64)
z1 = self.bn1(z1) # (batch, 64) — normalise
a1 = self.relu(z1) # (batch, 64) — activate
a1 = self.drop(a1) # (batch, 64) — regularise
z2 = self.layer2(a1) # (batch, 32)
z2 = self.bn2(z2) # (batch, 32)
a2 = self.relu(z2) # (batch, 32)
out = self.layer3(a2) # (batch, 1)
return out # raw logit; apply sigmoid for probability
model = ClinicalMLP(n_features=20)
batch_size = 16
X = torch.randn(batch_size, 20)
out = model(X)
print(f"Input: {X.shape}") # (16, 20)
print(f"Output: {out.shape}") # (16, 1)
prob = torch.sigmoid(out)
print(f"Probs: min={prob.min():.3f}, max={prob.max():.3f}")Tracing the Computation Graph
import torch
def trace_forward_pass(model: nn.Module, X: torch.Tensor) -> None:
"""Register forward hooks to print shape at each layer."""
handles = []
def make_hook(name: str):
def hook(module, input, output):
in_shape = tuple(input[0].shape)
out_shape = tuple(output.shape)
print(f"{name:25s}: {str(in_shape):20s} → {str(out_shape)}")
return hook
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.ReLU, nn.BatchNorm1d, nn.Dropout)):
h = module.register_forward_hook(make_hook(name))
handles.append(h)
with torch.no_grad():
_ = model(X)
for h in handles:
h.remove()
model = ClinicalMLP(n_features=20)
X = torch.randn(16, 20)
print("Forward pass trace:")
trace_forward_pass(model, X)
# Output (example):
# layer1 : (16, 20) → (16, 64)
# bn1 : (16, 64) → (16, 64)
# relu : (16, 64) → (16, 64)
# drop : (16, 64) → (16, 64)
# layer2 : (16, 64) → (16, 32)
# bn2 : (16, 32) → (16, 32)
# relu : (16, 32) → (16, 32)
# layer3 : (16, 32) → (16, 1)Autograd: Building the Computation Graph
import torch
# Autograd tracks operations on tensors with requires_grad=True
x = torch.tensor([2.0, 3.0, 4.0], requires_grad=True)
w = torch.tensor([0.5, -1.0, 0.3], requires_grad=True)
b = torch.tensor([0.1], requires_grad=True)
# Forward pass — graph is built here
z = (x * w).sum() + b # scalar
loss = z ** 2 # scalar
print(f"z = {z.item():.3f}")
print(f"loss = {loss.item():.3f}")
# The graph: loss ← z ← (x*w).sum() + b
print(f"loss.grad_fn: {loss.grad_fn}") # PowBackward
print(f"z.grad_fn: {z.grad_fn}") # AddBackward
# Backward pass — traverse graph to compute gradients
loss.backward()
print(f"dL/dw: {w.grad}") # chain rule: dL/dz * dz/dw = 2z * x
print(f"dL/db: {b.grad}") # chain rule: dL/dz * dz/db = 2z * 1
# Disable graph building for inference
with torch.no_grad():
pred = (x * w).sum() + b # no graph built, faster
print(f"Inference pred: {pred.item():.3f}")Intermediate Activations
import torch
import torch.nn as nn
def get_intermediate_activations(
model: nn.Module,
X: torch.Tensor,
layer_names: list[str],
) -> dict[str, torch.Tensor]:
"""Capture outputs of specific layers during the forward pass."""
activations = {}
handles = []
def make_hook(name: str):
def hook(module, input, output):
activations[name] = output.detach().clone()
return hook
for name, module in model.named_modules():
if name in layer_names:
h = module.register_forward_hook(make_hook(name))
handles.append(h)
with torch.no_grad():
_ = model(X)
for h in handles:
h.remove()
return activations
model = ClinicalMLP(n_features=20)
X = torch.randn(16, 20)
acts = get_intermediate_activations(model, X, layer_names=["layer1", "layer2", "layer3"])
for name, tensor in acts.items():
print(f"{name}: shape={tensor.shape}, mean={tensor.mean():.4f}, std={tensor.std():.4f}")
# Useful for:
# - Debugging dead neurons (relu output all zeros)
# - Checking BatchNorm is normalising (mean≈0, std≈1)
# - Visualising learned representationsForward Pass in Training vs Inference
import torch
import torch.nn as nn
model = ClinicalMLP(n_features=20)
X = torch.randn(16, 20)
# ── Training mode ──
model.train()
# Dropout: randomly zeroes activations
# BatchNorm: uses batch statistics (mean and var from current batch)
# Autograd: graph is built → backward() can be called
out_train = model(X)
# ── Evaluation mode ──
model.eval()
# Dropout: disabled (all neurons active, outputs scaled)
# BatchNorm: uses running statistics (accumulated during training)
# Autograd: still builds graph unless torch.no_grad() is used
out_eval = model(X)
# ── Inference (fastest) ──
model.eval()
with torch.no_grad():
out_infer = model(X)
# No graph → less memory, faster
# Compare: eval and no_grad outputs should be identical
print(torch.allclose(out_eval, out_infer)) # TrueInterview Answer
"The forward pass propagates data through the network layer by layer: at each layer, Z = X @ W.T + b (linear), then A = activation(Z) (non-linearity). PyTorch builds a computation graph during this pass — recording every tensor operation and its inputs. This graph enables automatic differentiation: during the backward pass, autograd traverses the graph in reverse applying the chain rule to compute dL/dW for every parameter. Two key modes: model.train() activates Dropout and uses batch statistics in BatchNorm; model.eval() disables Dropout and uses running statistics. For inference, torch.no_grad() prevents graph construction, saving memory and time. A common bug: forgetting to call model.eval() during validation causes Dropout to randomly drop neurons and BatchNorm to use batch statistics, giving inconsistent and noisy validation metrics."