Learnixo
Back to blog
AI Systemsintermediate

Depth vs Width in Neural Networks

Why deeper networks learn hierarchical features, why wider networks have more capacity, and how to choose architecture dimensions for your task.

Asma Hafeez KhanMay 22, 20265 min read
Deep LearningArchitectureDepthWidthCapacityInterview
Share:š•

The Core Trade-off

Width (neurons per layer):
  More neurons → more capacity within each layer
  Can represent more features at the same abstraction level
  Parallel computation — wide layers run efficiently on GPUs
  Risk: memorisation of training data if too wide with limited data

Depth (number of layers):
  More layers → hierarchical feature learning
  Each layer builds on abstractions from the previous
  Can represent exponentially more functions with log-linear increase in parameters
  Risk: vanishing gradients (solved by ResNets, BatchNorm, ReLU)

In practice: depth matters more than width for complex tasks.
  A 4-layer narrow network usually outperforms a 2-layer wide network
  with the same parameter count on structured data.

Experimental Comparison

Python
import torch
import torch.nn as nn

def make_mlp_by_budget(
    n_features: int,
    n_params_target: int,
    depth: int,
    n_outputs: int = 1,
) -> nn.Module:
    """Create an MLP with approximately n_params_target parameters."""
    # Estimate hidden size for uniform-width network
    # n_params ā‰ˆ n_features * h + (depth-1) * h * h + h * n_outputs
    # Simplified: assume h is roughly sqrt(n_params / depth)
    h = int((n_params_target / depth) ** 0.5)
    h = max(8, min(h, 512))
    
    dims = [n_features] + [h] * depth + [n_outputs]
    layers = []
    for i, (in_d, out_d) in enumerate(zip(dims[:-1], dims[1:])):
        layers.append(nn.Linear(in_d, out_d))
        if i < depth:
            layers.extend([nn.BatchNorm1d(out_d), nn.ReLU(), nn.Dropout(0.2)])
    
    return nn.Sequential(*layers)

# Same parameter budget (~50K params), different depth
for depth in [1, 2, 4, 8]:
    model = make_mlp_by_budget(n_features=20, n_params_target=50_000, depth=depth)
    n_params = sum(p.numel() for p in model.parameters())
    
    # Check output shape
    X = torch.randn(32, 20)
    with torch.no_grad():
        out = model(X)
    
    print(f"depth={depth}: params={n_params:,}, output={tuple(out.shape)}")

Why Depth Enables Hierarchical Learning

Layer 1: Combines raw features
  Input: [age=72, INR=3.2, n_meds=12, systolic_bp=145, ...]
  Output: [frailty_signal, coagulation_signal, polypharmacy_risk, ...]

Layer 2: Combines layer-1 features into higher-level concepts
  Input: frailty_signal + coagulation_signal + polypharmacy_risk
  Output: [overall_risk_cluster, therapeutic_challenge_indicator, ...]

Layer 3: Combines layer-2 concepts into prediction-relevant features
  Input: risk_cluster + therapeutic_challenge
  Output: [readmission_probability_features]

Layer 4 (output): Maps to final prediction

Each layer uses the previous layer's abstractions as building blocks.
This is why CNNs detect [edges] → [textures] → [parts] → [objects].
Python
import torch
import torch.nn as nn

class InspectableMLP(nn.Module):
    """MLP that saves intermediate activations for inspection."""
    
    def __init__(self, n_features: int, hidden_dims: list[int]):
        super().__init__()
        self.layers = nn.ModuleList()
        dims = [n_features] + hidden_dims
        for in_d, out_d in zip(dims[:-1], dims[1:]):
            self.layers.append(nn.Sequential(
                nn.Linear(in_d, out_d),
                nn.BatchNorm1d(out_d),
                nn.ReLU(),
            ))
        self.head = nn.Linear(hidden_dims[-1], 1)
        self.intermediate = {}
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            x = layer(x)
            self.intermediate[f"layer_{i}"] = x.detach()
        return self.head(x)

model = InspectableMLP(20, [64, 32, 16])
X = torch.randn(100, 20)
_ = model(X)

for name, activations in model.intermediate.items():
    dead_neurons_pct = (activations == 0).float().mean().item() * 100
    print(f"{name}: mean={activations.mean():.4f}, std={activations.std():.4f}, dead={dead_neurons_pct:.1f}%")

Width: Effective Representation Size

Python
import torch
import torch.nn as nn

def bottleneck_vs_wide(n_features: int = 20, n_params_each: int = 20_000):
    """Compare bottleneck vs wide architectures at similar parameter count."""
    
    # Wide: one wide hidden layer
    wide = nn.Sequential(
        nn.Linear(n_features, 256),
        nn.ReLU(),
        nn.Linear(256, 1),
    )
    
    # Bottleneck: two narrower layers
    bottleneck = nn.Sequential(
        nn.Linear(n_features, 64),
        nn.ReLU(),
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear(64, 1),
    )
    
    for name, model in [("wide", wide), ("bottleneck", bottleneck)]:
        n_params = sum(p.numel() for p in model.parameters())
        print(f"{name:12s}: {n_params:>8,} params")

bottleneck_vs_wide()

# Bottleneck architecture (ResNet, Transformers):
# Wide → Narrow → Wide
# Projects to low-dimensional space, transforms, projects back
# Computational efficiency: cheap operations in low-dimensional space
class BottleneckBlock(nn.Module):
    def __init__(self, dim: int, bottleneck_ratio: int = 4):
        super().__init__()
        bottleneck_dim = dim // bottleneck_ratio
        self.net = nn.Sequential(
            nn.Linear(dim, bottleneck_dim),  # compress
            nn.ReLU(),
            nn.Linear(bottleneck_dim, dim),  # expand
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.net(x)   # residual connection

Architecture Search Heuristics

Task                           | Recommended architecture
-------------------------------|------------------------------------------
Simple binary classification   | 2 layers, [64, 32], dropout=0.2
Complex tabular (50+ features) | 3–4 layers, [256, 128, 64], dropout=0.3
Time series features           | 3 layers, [128, 64, 32], + temporal features
Image classification           | CNN backbone, not MLP
Text classification            | Transformer, not MLP
Mixed (clinical + imaging)     | Two branches (MLP + CNN), fused at late layer

Depth guidelines:
  1 hidden layer: almost always enough for linear-ish problems
  2–3 layers: standard for tabular data
  4–8 layers: complex tasks with large datasets (use ResNet-style connections)
  > 8 layers: use skip connections to prevent vanishing gradients

Width guidelines:
  Start: 2–4Ɨ the input feature dimension
  Decrease by half each layer (pyramid shape)
  Final hidden layer: 16–64 (enough for output head)

When Width Wins

Python
import torch
import torch.nn as nn

# Width matters for: many features that each contribute linearly
# E.g., genomic data with 10,000+ SNPs (single nucleotide polymorphisms)

class WideAndShallowModel(nn.Module):
    """For high-dimensional sparse inputs (genomics, text bag-of-words)."""
    
    def __init__(self, n_features: int = 10000, n_classes: int = 2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_features, 512),    # compress high-dim sparse input
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, n_classes),     # classify
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

# For sparse genomic data: width handles the many features simultaneously
# Depth is less important here — relationships are largely linear/additive
model = WideAndShallowModel(n_features=10000)
n_params = sum(p.numel() for p in model.parameters())
print(f"Wide-shallow params: {n_params:,}")

Interview Answer

"Width (neurons per layer) increases capacity within a layer — useful for representing many features simultaneously, like genomic inputs with thousands of SNPs. Depth (number of layers) enables hierarchical abstraction — each layer builds on the previous one's concepts, enabling CNNs to learn edges → textures → parts → objects. For parameter-budget parity, deeper networks typically outperform wider ones on complex tasks because depth allows exponentially more function classes. Practical trade-off: depth improves expressivity but adds gradient flow challenges (solved by ReLU, BatchNorm, skip connections); width is straightforward but can cause memorisation on small datasets. For clinical tabular data: 3–4 layers, pyramid shape (wider early, narrower late), starting at 2–4Ɨ input dimension. Always check for dead ReLU neurons (uniform zero activations in a layer indicate the layer is useless)."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.