Learnixo
Back to blog
AI Systemsintermediate

RNNs and LSTMs

How recurrent networks process sequences, the LSTM gate mechanism that solved vanishing gradients, and when to use RNNs vs Transformers for clinical time series.

Asma Hafeez KhanMay 22, 20267 min read
Deep LearningRNNLSTMGRUTime SeriesSequence ModellingInterview
Share:๐•

Recurrent Networks: Processing Sequences

RNN processes one timestep at a time, maintaining a hidden state:

  h_t = tanh(W_h ยท h_{t-1} + W_x ยท x_t + b)

  h_{t-1}: previous hidden state (memory from past)
  x_t:     current input
  h_t:     new hidden state (updated memory)

Problem: gradient must propagate back through all timesteps.
  For a 100-step sequence: gradient involves product of 100 Jacobians.
  If |W_h| < 1: gradient vanishes (early timesteps have no influence)
  If |W_h| > 1: gradient explodes

This is why vanilla RNNs fail for sequences longer than ~20 timesteps.
LSTM and GRU solve this with gating mechanisms.

LSTM: Long Short-Term Memory

Python
import torch
import torch.nn as nn

# LSTM maintains two state vectors:
#   h_t: hidden state (short-term memory)
#   c_t: cell state  (long-term memory โ€” protected by gates)

class LSTMCell(nn.Module):
    """Manual LSTM cell for understanding the gate mechanism."""
    
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        # All four gates in one matrix (more efficient)
        self.linear = nn.Linear(input_size + hidden_size, 4 * hidden_size)
    
    def forward(
        self,
        x: torch.Tensor,      # (batch, input_size)
        h: torch.Tensor,      # (batch, hidden_size)
        c: torch.Tensor,      # (batch, hidden_size)
    ) -> tuple[torch.Tensor, torch.Tensor]:
        
        # Concatenate input and hidden state
        combined = torch.cat([x, h], dim=-1)   # (batch, input+hidden)
        
        # Compute all four gates at once
        gates = self.linear(combined)   # (batch, 4*hidden)
        
        # Split into four gate vectors
        i, f, g, o = gates.chunk(4, dim=-1)
        
        # Gate activations
        i = torch.sigmoid(i)   # input gate: how much new info to add
        f = torch.sigmoid(f)   # forget gate: how much old info to keep
        g = torch.tanh(g)      # cell gate: new candidate memory
        o = torch.sigmoid(o)   # output gate: how much to expose
        
        # Update cell state (long-term memory)
        # c_t = f * c_{t-1} + i * g
        # If fโ‰ˆ1, iโ‰ˆ0: cell remembers everything, ignores new input
        # If fโ‰ˆ0, iโ‰ˆ1: cell forgets everything, uses only new input
        c_new = f * c + i * g
        
        # Update hidden state
        h_new = o * torch.tanh(c_new)
        
        return h_new, c_new

# PyTorch's built-in LSTM (much faster โ€” optimised CUDA kernels)
lstm = nn.LSTM(
    input_size=10,
    hidden_size=64,
    num_layers=2,       # stacked LSTMs
    batch_first=True,   # input: (batch, seq, features)
    dropout=0.2,        # applied between layers (not after last)
    bidirectional=False,
)

X_seq = torch.randn(32, 50, 10)   # (batch=32, seq_len=50, features=10)
output, (h_n, c_n) = lstm(X_seq)

print(f"Output shape: {output.shape}")   # (32, 50, 64) โ€” output at each timestep
print(f"h_n shape: {h_n.shape}")         # (2, 32, 64) โ€” last hidden state per layer
print(f"c_n shape: {c_n.shape}")         # (2, 32, 64) โ€” last cell state per layer

GRU: Gated Recurrent Unit

Python
import torch
import torch.nn as nn

# GRU: simpler than LSTM (2 gates vs 4), often comparable performance
# No separate cell state โ€” merged into hidden state

class GRUCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.W_r = nn.Linear(input_size + hidden_size, hidden_size)   # reset gate
        self.W_z = nn.Linear(input_size + hidden_size, hidden_size)   # update gate
        self.W_n = nn.Linear(input_size + hidden_size, hidden_size)   # new gate
    
    def forward(
        self, x: torch.Tensor, h: torch.Tensor
    ) -> torch.Tensor:
        combined = torch.cat([x, h], dim=-1)
        
        r = torch.sigmoid(self.W_r(combined))   # reset: how much past to use
        z = torch.sigmoid(self.W_z(combined))   # update: blend old vs new
        n = torch.tanh(self.W_n(torch.cat([x, r * h], dim=-1)))   # candidate
        
        # h_t = (1-z) * h_{t-1} + z * n
        # If zโ‰ˆ0: keep old hidden state (don't update much)
        # If zโ‰ˆ1: fully replace with new candidate
        return (1 - z) * h + z * n

# PyTorch GRU
gru = nn.GRU(
    input_size=10,
    hidden_size=64,
    num_layers=2,
    batch_first=True,
    bidirectional=True,   # bidirectional: reads sequence forward AND backward
)

X = torch.randn(32, 50, 10)
output, h_n = gru(X)
print(f"Bidirectional GRU output: {output.shape}")  # (32, 50, 128) โ€” 2ร—64 for bidir

Clinical Time-Series with LSTM

Python
import torch
import torch.nn as nn

class ICUTimeSeriesModel(nn.Module):
    """
    LSTM for predicting ICU mortality from hourly vital signs.
    Input: (batch, time_steps, n_vitals) โ€” e.g., heart rate, SpO2, MAP, RR
    Output: (batch, 1) โ€” mortality probability
    """
    
    def __init__(
        self,
        n_vitals: int = 6,     # HR, SpO2, MAP, RR, temp, lactate
        hidden_size: int = 64,
        n_layers: int = 2,
        dropout: float = 0.3,
    ):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=n_vitals,
            hidden_size=hidden_size,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout,
        )
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(hidden_size, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )
    
    def forward(self, x: torch.Tensor, lengths: torch.Tensor = None) -> torch.Tensor:
        if lengths is not None:
            # Pack padded sequences for variable-length inputs (more efficient)
            x_packed = nn.utils.rnn.pack_padded_sequence(
                x, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            output, (h_n, _) = self.lstm(x_packed)
        else:
            output, (h_n, _) = self.lstm(x)
        
        # Use last hidden state (from last layer)
        last_hidden = h_n[-1]   # (batch, hidden_size)
        return self.head(last_hidden)

model = ICUTimeSeriesModel(n_vitals=6, hidden_size=64)
X = torch.randn(32, 48, 6)   # 32 patients, 48 hours, 6 vitals
out = model(X)
print(f"ICU model output: {out.shape}")   # (32, 1)
probs = torch.sigmoid(out)
print(f"Mortality probs: min={probs.min():.3f}, max={probs.max():.3f}")

RNN vs Transformer for Clinical Time Series

When to use LSTM/GRU:
  โœ“ Short-to-medium sequences (< 500 timesteps)
  โœ“ Real-time inference (processes one step at a time โ€” streaming)
  โœ“ Limited training data (LSTMs work with < 10K sequences)
  โœ“ Variable-length sequences with masked/packed input
  โœ“ Memory-constrained deployment (LSTMs are smaller than transformers)
  โœ“ Causal prediction (online, left-to-right)

When to use Transformers (e.g., clinical BERT for EHR):
  โœ“ Very long sequences (1000+ timesteps, full admission notes)
  โœ“ Pre-training on large corpus available (MIMIC, EHR)
  โœ“ Need bidirectional context (classify entire record)
  โœ“ Interpretability via attention weights
  โœ“ Multi-modal data (mix text + vitals + labs)
  โœ— Small datasets: transformers overfit without large-scale pre-training

Hybrid approach (common in clinical AI):
  - Use LSTM to process time-series vitals
  - Use BERT embeddings for clinical notes
  - Concatenate features for final prediction
Python
import torch
import torch.nn as nn

class HybridClinicalModel(nn.Module):
    """Combines LSTM for vitals + text embeddings for clinical notes."""
    
    def __init__(
        self,
        n_vitals: int = 6,
        text_embed_dim: int = 768,  # BERT embedding size
        hidden_size: int = 64,
    ):
        super().__init__()
        # Vital sign stream
        self.lstm = nn.LSTM(n_vitals, hidden_size, num_layers=2, batch_first=True)
        
        # Text stream projection
        self.text_proj = nn.Linear(text_embed_dim, hidden_size)
        
        # Fusion and prediction
        self.head = nn.Sequential(
            nn.Linear(hidden_size * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
        )
    
    def forward(
        self,
        vitals: torch.Tensor,       # (batch, time_steps, n_vitals)
        text_embed: torch.Tensor,   # (batch, 768) โ€” CLS token from BERT
    ) -> torch.Tensor:
        _, (h_n, _) = self.lstm(vitals)
        vital_features = h_n[-1]                  # (batch, hidden_size)
        text_features  = self.text_proj(text_embed)  # (batch, hidden_size)
        
        combined = torch.cat([vital_features, text_features], dim=-1)
        return self.head(combined)

model = HybridClinicalModel()
vitals = torch.randn(16, 48, 6)
text   = torch.randn(16, 768)
out = model(vitals, text)
print(f"Hybrid model output: {out.shape}")   # (16, 1)

Interview Answer

"RNNs process sequences one timestep at a time, maintaining a hidden state h_t = tanh(W_hยทh_ + W_xยทx_t). The vanishing gradient problem makes training on sequences longer than ~20 timesteps unreliable. LSTM solves this with a cell state (long-term memory) protected by gates: forget gate (how much to keep), input gate (how much new info to add), output gate (how much to expose). The cell state's additive update (c_t = fยทc_ + iยทg) allows gradients to flow back through time without multiplication by potentially-small activations. GRU is a simpler alternative with two gates (reset, update) and no separate cell state โ€” fewer parameters, often comparable performance. For clinical time series (ICU vitals, ECG): LSTM/GRU work well for short-to-medium sequences (under 500 steps) and streaming inference. For longer sequences or when pre-trained models are available (ClinicalBERT, LLMs fine-tuned on MIMIC): use transformers, which have direct attention paths between all positions."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:๐•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.