Learnixo

Deep Learning for AI Interviews · Lesson 42 of 56

Full CNN Architecture: Conv → Pool → FC

Standard CNN Structure

Input Image (3 × 224 × 224)
    ↓
Stem: Conv 7×7, stride 2 → (64 × 112 × 112)
    ↓
Block 1: Conv 3×3, stride 1 → (64 × 112 × 112)
    ↓
Downsample: stride 2 or MaxPool → (128 × 56 × 56)
    ↓
Block 2: Conv 3×3 → (128 × 56 × 56)
    ↓
Downsample → (256 × 28 × 28)
    ↓
Block 3 → (256 × 28 × 28)
    ↓
Downsample → (512 × 14 × 14)
    ↓
Block 4 → (512 × 14 × 14)
    ↓
Global Average Pool → (512 × 1 × 1) → flatten → (512,)
    ↓
Linear Classification Head → (n_classes,)

Pattern: spatial resolution ↓, channels ↑, each stage

Building a CNN from Scratch

Python
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    """Conv + BN + ReLU — the fundamental building block."""
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
    ):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
                      padding=padding, bias=False),  # bias=False with BN
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)

class SimpleCNNClassifier(nn.Module):
    """
    4-stage CNN for image classification.
    Input: (B, 3, H, W) — e.g., chest X-ray (3, 224, 224)
    Output: (B, n_classes) logits
    """
    
    def __init__(self, n_classes: int = 2, in_channels: int = 3):
        super().__init__()
        
        # Stem: initial feature extraction
        self.stem = nn.Sequential(
            ConvBlock(in_channels, 32, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        # After stem: (B, 32, 56, 56) for 224×224 input
        
        # Feature extraction stages
        self.stage1 = self._make_stage(32,  64,  stride=1, n_blocks=2)
        self.stage2 = self._make_stage(64,  128, stride=2, n_blocks=2)
        self.stage3 = self._make_stage(128, 256, stride=2, n_blocks=2)
        self.stage4 = self._make_stage(256, 512, stride=2, n_blocks=2)
        
        # Classification head
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # (B, 512, 1, 1) regardless of input size
            nn.Flatten(),              # (B, 512)
            nn.Dropout(0.5),
            nn.Linear(512, n_classes),
        )
        
        self._init_weights()
    
    def _make_stage(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        n_blocks: int,
    ) -> nn.Sequential:
        layers = [ConvBlock(in_channels, out_channels, stride=stride)]
        for _ in range(n_blocks - 1):
            layers.append(ConvBlock(out_channels, out_channels))
        return nn.Sequential(*layers)
    
    def _init_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        return self.head(x)

# Test shapes
model = SimpleCNNClassifier(n_classes=2)
X = torch.randn(8, 3, 224, 224)

with torch.no_grad():
    out = model(X)
print(f"Input: {X.shape}")
print(f"Output: {out.shape}")   # (8, 2)
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,}")

Feature Map Shapes Through Stages

Python
import torch
import torch.nn as nn

def trace_cnn_shapes(model: nn.Module, input_shape: tuple) -> None:
    """Print feature map shapes after each major stage."""
    shapes = {}
    handles = []
    
    stage_names = ["stem", "stage1", "stage2", "stage3", "stage4"]
    for name in stage_names:
        module = getattr(model, name, None)
        if module is not None:
            def hook(m, inp, out, n=name):
                shapes[n] = tuple(out.shape)
            handles.append(module.register_forward_hook(hook))
    
    with torch.no_grad():
        _ = model(torch.randn(*input_shape))
    
    for h in handles:
        h.remove()
    
    print(f"Input: {input_shape}")
    for stage, shape in shapes.items():
        b, c, h, w = shape
        print(f"  {stage:8s}: {c:>3d} channels × {h}×{w} = {c*h*w:,} values per sample")

model = SimpleCNNClassifier()
trace_cnn_shapes(model, (4, 3, 224, 224))

Downsampling Strategies

Python
import torch
import torch.nn as nn

# Option 1: MaxPool2d  non-learnable
maxpool_down = nn.MaxPool2d(kernel_size=2, stride=2)

# Option 2: Strided Convolution  learnable (preferred in modern architectures)
strided_conv = nn.Sequential(
    nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(),
)

# Option 3: Average Pooling  smoother than max, preserves more information
avgpool_down = nn.AvgPool2d(kernel_size=2, stride=2)

x = torch.randn(8, 64, 56, 56)

for name, module in [("MaxPool", maxpool_down), ("StridedConv", strided_conv), ("AvgPool", avgpool_down)]:
    if name == "StridedConv":
        out = module(x)
    else:
        out = module(x[:, :64])   # only first 64 channels for pool ops
    spatial_out = out.shape[-1]
    print(f"{name}: (8, ?, {spatial_out}, {spatial_out})")

# Modern recommendation:
#   - Use strided convolution for first downsampling (learns what to keep)
#   - Use AdaptiveAvgPool at the end (eliminates fixed-size constraint)
#   - Avoid MaxPool2d in main stages (use strided conv instead)

Multi-Scale Feature Aggregation

Python
import torch
import torch.nn as nn

class FPN(nn.Module):
    """Feature Pyramid Network — aggregates features at multiple scales."""
    
    def __init__(self, in_channels_list: list[int], out_channels: int = 256):
        super().__init__()
        # Lateral connections (1×1 conv to unify channel count)
        self.lateral = nn.ModuleList([
            nn.Conv2d(c, out_channels, kernel_size=1) for c in in_channels_list
        ])
        # Output convolutions (3×3 to smooth artifacts)
        self.output = nn.ModuleList([
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
            for _ in in_channels_list
        ])
    
    def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
        """features: [P2, P3, P4, P5] from coarsest to finest (reversed here)."""
        laterals = [lat(f) for lat, f in zip(self.lateral, features)]
        
        # Top-down pathway: upsample and add
        for i in range(len(laterals) - 1, 0, -1):
            h, w = laterals[i-1].shape[-2:]
            laterals[i-1] = laterals[i-1] + nn.functional.interpolate(
                laterals[i], size=(h, w), mode="nearest"
            )
        
        return [out(lat) for out, lat in zip(self.output, laterals)]

# FPN is used in object detection (Faster R-CNN, YOLO, RetinaNet)
# for detecting objects at multiple scales

Interview Answer

"A CNN processes images in stages: stem (aggressive downsampling with 7×7 stride-2 conv), followed by 3–4 feature extraction stages, each doubling channels while halving spatial resolution. Key design decisions: (1) Downsampling — strided convolution (learnable) or MaxPool; modern networks prefer strided conv; (2) Global Average Pooling before the classification head eliminates the fixed-size constraint — the network can process any input resolution; (3) BatchNorm + ReLU after every conv layer (no bias needed in conv when using BN); (4) Kaiming init for conv layers. The channel-doubling / spatial-halving pattern maintains roughly constant computation per stage. The classification head is simply GAP → Dropout → Linear. For medical imaging (chest X-ray, pathology slides): use the same CNN backbone but train with appropriate augmentations (random flips, colour jitter, rotation) and class-weighted loss for imbalanced conditions."