Learnixo

Deep Learning for AI Interviews · Lesson 55 of 56

RNN and LSTM: Sequential Data Processing

Recurrent Networks: Processing Sequences

RNN processes one timestep at a time, maintaining a hidden state:

  h_t = tanh(W_h · h_{t-1} + W_x · x_t + b)

  h_{t-1}: previous hidden state (memory from past)
  x_t:     current input
  h_t:     new hidden state (updated memory)

Problem: gradient must propagate back through all timesteps.
  For a 100-step sequence: gradient involves product of 100 Jacobians.
  If |W_h| < 1: gradient vanishes (early timesteps have no influence)
  If |W_h| > 1: gradient explodes

This is why vanilla RNNs fail for sequences longer than ~20 timesteps.
LSTM and GRU solve this with gating mechanisms.

LSTM: Long Short-Term Memory

Python
import torch
import torch.nn as nn

# LSTM maintains two state vectors:
#   h_t: hidden state (short-term memory)
#   c_t: cell state  (long-term memory  protected by gates)

class LSTMCell(nn.Module):
    """Manual LSTM cell for understanding the gate mechanism."""
    
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        # All four gates in one matrix (more efficient)
        self.linear = nn.Linear(input_size + hidden_size, 4 * hidden_size)
    
    def forward(
        self,
        x: torch.Tensor,      # (batch, input_size)
        h: torch.Tensor,      # (batch, hidden_size)
        c: torch.Tensor,      # (batch, hidden_size)
    ) -> tuple[torch.Tensor, torch.Tensor]:
        
        # Concatenate input and hidden state
        combined = torch.cat([x, h], dim=-1)   # (batch, input+hidden)
        
        # Compute all four gates at once
        gates = self.linear(combined)   # (batch, 4*hidden)
        
        # Split into four gate vectors
        i, f, g, o = gates.chunk(4, dim=-1)
        
        # Gate activations
        i = torch.sigmoid(i)   # input gate: how much new info to add
        f = torch.sigmoid(f)   # forget gate: how much old info to keep
        g = torch.tanh(g)      # cell gate: new candidate memory
        o = torch.sigmoid(o)   # output gate: how much to expose
        
        # Update cell state (long-term memory)
        # c_t = f * c_{t-1} + i * g
        # If f≈1, i≈0: cell remembers everything, ignores new input
        # If f≈0, i≈1: cell forgets everything, uses only new input
        c_new = f * c + i * g
        
        # Update hidden state
        h_new = o * torch.tanh(c_new)
        
        return h_new, c_new

# PyTorch's built-in LSTM (much faster — optimised CUDA kernels)
lstm = nn.LSTM(
    input_size=10,
    hidden_size=64,
    num_layers=2,       # stacked LSTMs
    batch_first=True,   # input: (batch, seq, features)
    dropout=0.2,        # applied between layers (not after last)
    bidirectional=False,
)

X_seq = torch.randn(32, 50, 10)   # (batch=32, seq_len=50, features=10)
output, (h_n, c_n) = lstm(X_seq)

print(f"Output shape: {output.shape}")   # (32, 50, 64) — output at each timestep
print(f"h_n shape: {h_n.shape}")         # (2, 32, 64) — last hidden state per layer
print(f"c_n shape: {c_n.shape}")         # (2, 32, 64) — last cell state per layer

GRU: Gated Recurrent Unit

Python
import torch
import torch.nn as nn

# GRU: simpler than LSTM (2 gates vs 4), often comparable performance
# No separate cell state  merged into hidden state

class GRUCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.W_r = nn.Linear(input_size + hidden_size, hidden_size)   # reset gate
        self.W_z = nn.Linear(input_size + hidden_size, hidden_size)   # update gate
        self.W_n = nn.Linear(input_size + hidden_size, hidden_size)   # new gate
    
    def forward(
        self, x: torch.Tensor, h: torch.Tensor
    ) -> torch.Tensor:
        combined = torch.cat([x, h], dim=-1)
        
        r = torch.sigmoid(self.W_r(combined))   # reset: how much past to use
        z = torch.sigmoid(self.W_z(combined))   # update: blend old vs new
        n = torch.tanh(self.W_n(torch.cat([x, r * h], dim=-1)))   # candidate
        
        # h_t = (1-z) * h_{t-1} + z * n
        # If z≈0: keep old hidden state (don't update much)
        # If z≈1: fully replace with new candidate
        return (1 - z) * h + z * n

# PyTorch GRU
gru = nn.GRU(
    input_size=10,
    hidden_size=64,
    num_layers=2,
    batch_first=True,
    bidirectional=True,   # bidirectional: reads sequence forward AND backward
)

X = torch.randn(32, 50, 10)
output, h_n = gru(X)
print(f"Bidirectional GRU output: {output.shape}")  # (32, 50, 128) — 2×64 for bidir

Clinical Time-Series with LSTM

Python
import torch
import torch.nn as nn

class ICUTimeSeriesModel(nn.Module):
    """
    LSTM for predicting ICU mortality from hourly vital signs.
    Input: (batch, time_steps, n_vitals) — e.g., heart rate, SpO2, MAP, RR
    Output: (batch, 1) — mortality probability
    """
    
    def __init__(
        self,
        n_vitals: int = 6,     # HR, SpO2, MAP, RR, temp, lactate
        hidden_size: int = 64,
        n_layers: int = 2,
        dropout: float = 0.3,
    ):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=n_vitals,
            hidden_size=hidden_size,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout,
        )
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(hidden_size, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )
    
    def forward(self, x: torch.Tensor, lengths: torch.Tensor = None) -> torch.Tensor:
        if lengths is not None:
            # Pack padded sequences for variable-length inputs (more efficient)
            x_packed = nn.utils.rnn.pack_padded_sequence(
                x, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            output, (h_n, _) = self.lstm(x_packed)
        else:
            output, (h_n, _) = self.lstm(x)
        
        # Use last hidden state (from last layer)
        last_hidden = h_n[-1]   # (batch, hidden_size)
        return self.head(last_hidden)

model = ICUTimeSeriesModel(n_vitals=6, hidden_size=64)
X = torch.randn(32, 48, 6)   # 32 patients, 48 hours, 6 vitals
out = model(X)
print(f"ICU model output: {out.shape}")   # (32, 1)
probs = torch.sigmoid(out)
print(f"Mortality probs: min={probs.min():.3f}, max={probs.max():.3f}")

RNN vs Transformer for Clinical Time Series

When to use LSTM/GRU:
  ✓ Short-to-medium sequences (< 500 timesteps)
  ✓ Real-time inference (processes one step at a time — streaming)
  ✓ Limited training data (LSTMs work with < 10K sequences)
  ✓ Variable-length sequences with masked/packed input
  ✓ Memory-constrained deployment (LSTMs are smaller than transformers)
  ✓ Causal prediction (online, left-to-right)

When to use Transformers (e.g., clinical BERT for EHR):
  ✓ Very long sequences (1000+ timesteps, full admission notes)
  ✓ Pre-training on large corpus available (MIMIC, EHR)
  ✓ Need bidirectional context (classify entire record)
  ✓ Interpretability via attention weights
  ✓ Multi-modal data (mix text + vitals + labs)
  ✗ Small datasets: transformers overfit without large-scale pre-training

Hybrid approach (common in clinical AI):
  - Use LSTM to process time-series vitals
  - Use BERT embeddings for clinical notes
  - Concatenate features for final prediction
Python
import torch
import torch.nn as nn

class HybridClinicalModel(nn.Module):
    """Combines LSTM for vitals + text embeddings for clinical notes."""
    
    def __init__(
        self,
        n_vitals: int = 6,
        text_embed_dim: int = 768,  # BERT embedding size
        hidden_size: int = 64,
    ):
        super().__init__()
        # Vital sign stream
        self.lstm = nn.LSTM(n_vitals, hidden_size, num_layers=2, batch_first=True)
        
        # Text stream projection
        self.text_proj = nn.Linear(text_embed_dim, hidden_size)
        
        # Fusion and prediction
        self.head = nn.Sequential(
            nn.Linear(hidden_size * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
        )
    
    def forward(
        self,
        vitals: torch.Tensor,       # (batch, time_steps, n_vitals)
        text_embed: torch.Tensor,   # (batch, 768)  CLS token from BERT
    ) -> torch.Tensor:
        _, (h_n, _) = self.lstm(vitals)
        vital_features = h_n[-1]                  # (batch, hidden_size)
        text_features  = self.text_proj(text_embed)  # (batch, hidden_size)
        
        combined = torch.cat([vital_features, text_features], dim=-1)
        return self.head(combined)

model = HybridClinicalModel()
vitals = torch.randn(16, 48, 6)
text   = torch.randn(16, 768)
out = model(vitals, text)
print(f"Hybrid model output: {out.shape}")   # (16, 1)

Interview Answer

"RNNs process sequences one timestep at a time, maintaining a hidden state h_t = tanh(W_h·h_ + W_x·x_t). The vanishing gradient problem makes training on sequences longer than ~20 timesteps unreliable. LSTM solves this with a cell state (long-term memory) protected by gates: forget gate (how much to keep), input gate (how much new info to add), output gate (how much to expose). The cell state's additive update (c_t = f·c_ + i·g) allows gradients to flow back through time without multiplication by potentially-small activations. GRU is a simpler alternative with two gates (reset, update) and no separate cell state — fewer parameters, often comparable performance. For clinical time series (ICU vitals, ECG): LSTM/GRU work well for short-to-medium sequences (under 500 steps) and streaming inference. For longer sequences or when pre-trained models are available (ClinicalBERT, LLMs fine-tuned on MIMIC): use transformers, which have direct attention paths between all positions."