Learnixo

Deep Learning for AI Interviews · Lesson 29 of 56

Network Depth vs Width: Tradeoffs

The Core Trade-off

Width (neurons per layer):
  More neurons → more capacity within each layer
  Can represent more features at the same abstraction level
  Parallel computation — wide layers run efficiently on GPUs
  Risk: memorisation of training data if too wide with limited data

Depth (number of layers):
  More layers → hierarchical feature learning
  Each layer builds on abstractions from the previous
  Can represent exponentially more functions with log-linear increase in parameters
  Risk: vanishing gradients (solved by ResNets, BatchNorm, ReLU)

In practice: depth matters more than width for complex tasks.
  A 4-layer narrow network usually outperforms a 2-layer wide network
  with the same parameter count on structured data.

Experimental Comparison

Python
import torch
import torch.nn as nn

def make_mlp_by_budget(
    n_features: int,
    n_params_target: int,
    depth: int,
    n_outputs: int = 1,
) -> nn.Module:
    """Create an MLP with approximately n_params_target parameters."""
    # Estimate hidden size for uniform-width network
    # n_params  n_features * h + (depth-1) * h * h + h * n_outputs
    # Simplified: assume h is roughly sqrt(n_params / depth)
    h = int((n_params_target / depth) ** 0.5)
    h = max(8, min(h, 512))
    
    dims = [n_features] + [h] * depth + [n_outputs]
    layers = []
    for i, (in_d, out_d) in enumerate(zip(dims[:-1], dims[1:])):
        layers.append(nn.Linear(in_d, out_d))
        if i < depth:
            layers.extend([nn.BatchNorm1d(out_d), nn.ReLU(), nn.Dropout(0.2)])
    
    return nn.Sequential(*layers)

# Same parameter budget (~50K params), different depth
for depth in [1, 2, 4, 8]:
    model = make_mlp_by_budget(n_features=20, n_params_target=50_000, depth=depth)
    n_params = sum(p.numel() for p in model.parameters())
    
    # Check output shape
    X = torch.randn(32, 20)
    with torch.no_grad():
        out = model(X)
    
    print(f"depth={depth}: params={n_params:,}, output={tuple(out.shape)}")

Why Depth Enables Hierarchical Learning

Layer 1: Combines raw features
  Input: [age=72, INR=3.2, n_meds=12, systolic_bp=145, ...]
  Output: [frailty_signal, coagulation_signal, polypharmacy_risk, ...]

Layer 2: Combines layer-1 features into higher-level concepts
  Input: frailty_signal + coagulation_signal + polypharmacy_risk
  Output: [overall_risk_cluster, therapeutic_challenge_indicator, ...]

Layer 3: Combines layer-2 concepts into prediction-relevant features
  Input: risk_cluster + therapeutic_challenge
  Output: [readmission_probability_features]

Layer 4 (output): Maps to final prediction

Each layer uses the previous layer's abstractions as building blocks.
This is why CNNs detect [edges] → [textures] → [parts] → [objects].
Python
import torch
import torch.nn as nn

class InspectableMLP(nn.Module):
    """MLP that saves intermediate activations for inspection."""
    
    def __init__(self, n_features: int, hidden_dims: list[int]):
        super().__init__()
        self.layers = nn.ModuleList()
        dims = [n_features] + hidden_dims
        for in_d, out_d in zip(dims[:-1], dims[1:]):
            self.layers.append(nn.Sequential(
                nn.Linear(in_d, out_d),
                nn.BatchNorm1d(out_d),
                nn.ReLU(),
            ))
        self.head = nn.Linear(hidden_dims[-1], 1)
        self.intermediate = {}
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            x = layer(x)
            self.intermediate[f"layer_{i}"] = x.detach()
        return self.head(x)

model = InspectableMLP(20, [64, 32, 16])
X = torch.randn(100, 20)
_ = model(X)

for name, activations in model.intermediate.items():
    dead_neurons_pct = (activations == 0).float().mean().item() * 100
    print(f"{name}: mean={activations.mean():.4f}, std={activations.std():.4f}, dead={dead_neurons_pct:.1f}%")

Width: Effective Representation Size

Python
import torch
import torch.nn as nn

def bottleneck_vs_wide(n_features: int = 20, n_params_each: int = 20_000):
    """Compare bottleneck vs wide architectures at similar parameter count."""
    
    # Wide: one wide hidden layer
    wide = nn.Sequential(
        nn.Linear(n_features, 256),
        nn.ReLU(),
        nn.Linear(256, 1),
    )
    
    # Bottleneck: two narrower layers
    bottleneck = nn.Sequential(
        nn.Linear(n_features, 64),
        nn.ReLU(),
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear(64, 1),
    )
    
    for name, model in [("wide", wide), ("bottleneck", bottleneck)]:
        n_params = sum(p.numel() for p in model.parameters())
        print(f"{name:12s}: {n_params:>8,} params")

bottleneck_vs_wide()

# Bottleneck architecture (ResNet, Transformers):
# Wide  Narrow  Wide
# Projects to low-dimensional space, transforms, projects back
# Computational efficiency: cheap operations in low-dimensional space
class BottleneckBlock(nn.Module):
    def __init__(self, dim: int, bottleneck_ratio: int = 4):
        super().__init__()
        bottleneck_dim = dim // bottleneck_ratio
        self.net = nn.Sequential(
            nn.Linear(dim, bottleneck_dim),  # compress
            nn.ReLU(),
            nn.Linear(bottleneck_dim, dim),  # expand
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.net(x)   # residual connection

Architecture Search Heuristics

Task                           | Recommended architecture
-------------------------------|------------------------------------------
Simple binary classification   | 2 layers, [64, 32], dropout=0.2
Complex tabular (50+ features) | 3–4 layers, [256, 128, 64], dropout=0.3
Time series features           | 3 layers, [128, 64, 32], + temporal features
Image classification           | CNN backbone, not MLP
Text classification            | Transformer, not MLP
Mixed (clinical + imaging)     | Two branches (MLP + CNN), fused at late layer

Depth guidelines:
  1 hidden layer: almost always enough for linear-ish problems
  2–3 layers: standard for tabular data
  4–8 layers: complex tasks with large datasets (use ResNet-style connections)
  > 8 layers: use skip connections to prevent vanishing gradients

Width guidelines:
  Start: 2–4× the input feature dimension
  Decrease by half each layer (pyramid shape)
  Final hidden layer: 16–64 (enough for output head)

When Width Wins

Python
import torch
import torch.nn as nn

# Width matters for: many features that each contribute linearly
# E.g., genomic data with 10,000+ SNPs (single nucleotide polymorphisms)

class WideAndShallowModel(nn.Module):
    """For high-dimensional sparse inputs (genomics, text bag-of-words)."""
    
    def __init__(self, n_features: int = 10000, n_classes: int = 2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_features, 512),    # compress high-dim sparse input
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, n_classes),     # classify
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

# For sparse genomic data: width handles the many features simultaneously
# Depth is less important here  relationships are largely linear/additive
model = WideAndShallowModel(n_features=10000)
n_params = sum(p.numel() for p in model.parameters())
print(f"Wide-shallow params: {n_params:,}")

Interview Answer

"Width (neurons per layer) increases capacity within a layer — useful for representing many features simultaneously, like genomic inputs with thousands of SNPs. Depth (number of layers) enables hierarchical abstraction — each layer builds on the previous one's concepts, enabling CNNs to learn edges → textures → parts → objects. For parameter-budget parity, deeper networks typically outperform wider ones on complex tasks because depth allows exponentially more function classes. Practical trade-off: depth improves expressivity but adds gradient flow challenges (solved by ReLU, BatchNorm, skip connections); width is straightforward but can cause memorisation on small datasets. For clinical tabular data: 3–4 layers, pyramid shape (wider early, narrower late), starting at 2–4× input dimension. Always check for dead ReLU neurons (uniform zero activations in a layer indicate the layer is useless)."