Learnixo
Back to blog
AI Systemsintermediate

DataLoader and Data Pipeline

Building efficient PyTorch data pipelines — Dataset, DataLoader, transforms, and handling clinical data with proper train/val/test splits.

Asma Hafeez KhanMay 22, 20264 min read
Deep LearningDataLoaderDatasetPipelinePyTorchInterview
Share:𝕏

The Dataset Abstraction

Python
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

class ClinicalDataset(Dataset):
    """Dataset for tabular clinical data."""
    
    def __init__(
        self,
        df: pd.DataFrame,
        feature_cols: list[str],
        label_col: str,
        transform=None,
    ):
        self.features = torch.tensor(df[feature_cols].values, dtype=torch.float32)
        self.labels   = torch.tensor(df[label_col].values, dtype=torch.float32)
        self.transform = transform
    
    def __len__(self) -> int:
        return len(self.features)
    
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        x = self.features[idx]
        y = self.labels[idx]
        if self.transform:
            x = self.transform(x)
        return x, y


# Usage
df = pd.read_csv("patients.csv")
feature_cols = ["age", "INR", "n_meds", "systolic_bp"]
dataset = ClinicalDataset(df, feature_cols, label_col="readmitted_30d")

print(f"Dataset size: {len(dataset)}")
x, y = dataset[0]
print(f"Feature shape: {x.shape}, Label: {y.item()}")

Train/Val/Test Split

Python
from torch.utils.data import random_split, Subset
from sklearn.model_selection import train_test_split

def create_stratified_splits(
    dataset: Dataset,
    labels: np.ndarray,
    val_size: float = 0.1,
    test_size: float = 0.1,
    seed: int = 42,
) -> tuple[Dataset, Dataset, Dataset]:
    """Stratified train/val/test split preserving class proportions."""
    n = len(dataset)
    all_indices = np.arange(n)
    
    # First split off test set
    train_val_idx, test_idx = train_test_split(
        all_indices, test_size=test_size,
        stratify=labels, random_state=seed
    )
    
    # Then split val from train
    adjusted_val = val_size / (1 - test_size)
    train_idx, val_idx = train_test_split(
        train_val_idx, test_size=adjusted_val,
        stratify=labels[train_val_idx], random_state=seed
    )
    
    return (
        Subset(dataset, train_idx),
        Subset(dataset, val_idx),
        Subset(dataset, test_idx),
    )

labels = df["readmitted_30d"].values
train_ds, val_ds, test_ds = create_stratified_splits(dataset, labels)
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

DataLoader Configuration

Python
# Training DataLoader
train_loader = DataLoader(
    train_ds,
    batch_size=64,
    shuffle=True,           # crucial for training  randomise each epoch
    num_workers=4,          # parallel data loading (set to n_cpu_cores - 1)
    pin_memory=True,        # faster CPU→GPU transfer
    drop_last=True,         # drop last partial batch (consistency for BatchNorm)
    prefetch_factor=2,      # prefetch 2 batches per worker
    persistent_workers=True, # don't recreate workers each epoch
)

# Validation / Test DataLoader (no shuffling needed)
val_loader = DataLoader(
    val_ds,
    batch_size=128,         # can use larger batch for eval (no gradient storage)
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

# Inspect a batch
for X, y in train_loader:
    print(f"X shape: {X.shape}")   # (64, 4)
    print(f"y shape: {y.shape}")   # (64,)
    print(f"Class balance: {y.mean():.3f}")
    break

Custom Transforms

Python
class Normalise(torch.nn.Module):
    """Normalise tabular features using training set statistics."""
    
    def __init__(self, mean: torch.Tensor, std: torch.Tensor):
        super().__init__()
        self.register_buffer("mean", mean)
        self.register_buffer("std", std + 1e-8)  # avoid divide by zero
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return (x - self.mean) / self.std

# Compute statistics from training data only
train_features = torch.stack([train_ds[i][0] for i in range(len(train_ds))])
mean = train_features.mean(dim=0)
std  = train_features.std(dim=0, unbiased=True)

norm_transform = Normalise(mean, std)

# Apply to all datasets
class TransformedSubset(Dataset):
    def __init__(self, subset: Subset, transform):
        self.subset = subset
        self.transform = transform
    
    def __len__(self):
        return len(self.subset)
    
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        return self.transform(x), y

train_ds_norm = TransformedSubset(train_ds, norm_transform)
val_ds_norm   = TransformedSubset(val_ds, norm_transform)
test_ds_norm  = TransformedSubset(test_ds, norm_transform)

Image Dataset with Augmentation

Python
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from pathlib import Path

class ChestXrayDataset(Dataset):
    def __init__(self, image_paths: list, labels: list, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        from PIL import Image
        img = Image.open(self.image_paths[idx]).convert("RGB")
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        if self.transform:
            img = self.transform(img)
        return img, label

train_transform = T.Compose([
    T.Resize(256),
    T.RandomCrop(224),
    T.RandomHorizontalFlip(0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

val_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

Full Training Loop

Python
def train_epoch(model, loader, optimizer, criterion, device, scaler=None):
    model.train()
    total_loss = 0.0
    n_batches = 0
    
    for X, y in loader:
        X = X.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        if scaler:
            from torch.cuda.amp import autocast
            with autocast():
                output = model(X).squeeze()
                loss = criterion(output, y)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(X).squeeze()
            loss = criterion(output, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        total_loss += loss.item()
        n_batches += 1
    
    return total_loss / n_batches

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_targets = [], []
    
    for X, y in loader:
        X = X.to(device)
        y = y.to(device)
        output = model(X).squeeze()
        total_loss += criterion(output, y).item()
        all_preds.extend(torch.sigmoid(output).cpu().tolist())
        all_targets.extend(y.cpu().tolist())
    
    from sklearn.metrics import roc_auc_score
    auc = roc_auc_score(all_targets, all_preds)
    return total_loss / len(loader), auc

Interview Answer

"The PyTorch data pipeline: Dataset (defines len and getitem) → DataLoader (handles batching, shuffling, parallel loading). For clinical data: use stratified splits to preserve class proportions; fit normalisation statistics on training data only (never including val/test). Key DataLoader settings for GPU training: num_workers=4+ (parallel CPU loading), pin_memory=True (faster CPU→GPU transfer), shuffle=True for training, drop_last=True for BatchNorm stability, persistent_workers=True to avoid worker restart overhead. Performance rule: if GPU utilisation is below 80%, the bottleneck is data loading — increase num_workers or prefetch_factor. Transforms should be applied during getitem, not pre-loaded into memory, to enable augmentation variety."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.