RNNs and LSTMs
How recurrent networks process sequences, the LSTM gate mechanism that solved vanishing gradients, and when to use RNNs vs Transformers for clinical time series.
Recurrent Networks: Processing Sequences
RNN processes one timestep at a time, maintaining a hidden state:
h_t = tanh(W_h ยท h_{t-1} + W_x ยท x_t + b)
h_{t-1}: previous hidden state (memory from past)
x_t: current input
h_t: new hidden state (updated memory)
Problem: gradient must propagate back through all timesteps.
For a 100-step sequence: gradient involves product of 100 Jacobians.
If |W_h| < 1: gradient vanishes (early timesteps have no influence)
If |W_h| > 1: gradient explodes
This is why vanilla RNNs fail for sequences longer than ~20 timesteps.
LSTM and GRU solve this with gating mechanisms.LSTM: Long Short-Term Memory
import torch
import torch.nn as nn
# LSTM maintains two state vectors:
# h_t: hidden state (short-term memory)
# c_t: cell state (long-term memory โ protected by gates)
class LSTMCell(nn.Module):
"""Manual LSTM cell for understanding the gate mechanism."""
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
# All four gates in one matrix (more efficient)
self.linear = nn.Linear(input_size + hidden_size, 4 * hidden_size)
def forward(
self,
x: torch.Tensor, # (batch, input_size)
h: torch.Tensor, # (batch, hidden_size)
c: torch.Tensor, # (batch, hidden_size)
) -> tuple[torch.Tensor, torch.Tensor]:
# Concatenate input and hidden state
combined = torch.cat([x, h], dim=-1) # (batch, input+hidden)
# Compute all four gates at once
gates = self.linear(combined) # (batch, 4*hidden)
# Split into four gate vectors
i, f, g, o = gates.chunk(4, dim=-1)
# Gate activations
i = torch.sigmoid(i) # input gate: how much new info to add
f = torch.sigmoid(f) # forget gate: how much old info to keep
g = torch.tanh(g) # cell gate: new candidate memory
o = torch.sigmoid(o) # output gate: how much to expose
# Update cell state (long-term memory)
# c_t = f * c_{t-1} + i * g
# If fโ1, iโ0: cell remembers everything, ignores new input
# If fโ0, iโ1: cell forgets everything, uses only new input
c_new = f * c + i * g
# Update hidden state
h_new = o * torch.tanh(c_new)
return h_new, c_new
# PyTorch's built-in LSTM (much faster โ optimised CUDA kernels)
lstm = nn.LSTM(
input_size=10,
hidden_size=64,
num_layers=2, # stacked LSTMs
batch_first=True, # input: (batch, seq, features)
dropout=0.2, # applied between layers (not after last)
bidirectional=False,
)
X_seq = torch.randn(32, 50, 10) # (batch=32, seq_len=50, features=10)
output, (h_n, c_n) = lstm(X_seq)
print(f"Output shape: {output.shape}") # (32, 50, 64) โ output at each timestep
print(f"h_n shape: {h_n.shape}") # (2, 32, 64) โ last hidden state per layer
print(f"c_n shape: {c_n.shape}") # (2, 32, 64) โ last cell state per layerGRU: Gated Recurrent Unit
import torch
import torch.nn as nn
# GRU: simpler than LSTM (2 gates vs 4), often comparable performance
# No separate cell state โ merged into hidden state
class GRUCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.W_r = nn.Linear(input_size + hidden_size, hidden_size) # reset gate
self.W_z = nn.Linear(input_size + hidden_size, hidden_size) # update gate
self.W_n = nn.Linear(input_size + hidden_size, hidden_size) # new gate
def forward(
self, x: torch.Tensor, h: torch.Tensor
) -> torch.Tensor:
combined = torch.cat([x, h], dim=-1)
r = torch.sigmoid(self.W_r(combined)) # reset: how much past to use
z = torch.sigmoid(self.W_z(combined)) # update: blend old vs new
n = torch.tanh(self.W_n(torch.cat([x, r * h], dim=-1))) # candidate
# h_t = (1-z) * h_{t-1} + z * n
# If zโ0: keep old hidden state (don't update much)
# If zโ1: fully replace with new candidate
return (1 - z) * h + z * n
# PyTorch GRU
gru = nn.GRU(
input_size=10,
hidden_size=64,
num_layers=2,
batch_first=True,
bidirectional=True, # bidirectional: reads sequence forward AND backward
)
X = torch.randn(32, 50, 10)
output, h_n = gru(X)
print(f"Bidirectional GRU output: {output.shape}") # (32, 50, 128) โ 2ร64 for bidirClinical Time-Series with LSTM
import torch
import torch.nn as nn
class ICUTimeSeriesModel(nn.Module):
"""
LSTM for predicting ICU mortality from hourly vital signs.
Input: (batch, time_steps, n_vitals) โ e.g., heart rate, SpO2, MAP, RR
Output: (batch, 1) โ mortality probability
"""
def __init__(
self,
n_vitals: int = 6, # HR, SpO2, MAP, RR, temp, lactate
hidden_size: int = 64,
n_layers: int = 2,
dropout: float = 0.3,
):
super().__init__()
self.lstm = nn.LSTM(
input_size=n_vitals,
hidden_size=hidden_size,
num_layers=n_layers,
batch_first=True,
dropout=dropout,
)
self.head = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(hidden_size, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def forward(self, x: torch.Tensor, lengths: torch.Tensor = None) -> torch.Tensor:
if lengths is not None:
# Pack padded sequences for variable-length inputs (more efficient)
x_packed = nn.utils.rnn.pack_padded_sequence(
x, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output, (h_n, _) = self.lstm(x_packed)
else:
output, (h_n, _) = self.lstm(x)
# Use last hidden state (from last layer)
last_hidden = h_n[-1] # (batch, hidden_size)
return self.head(last_hidden)
model = ICUTimeSeriesModel(n_vitals=6, hidden_size=64)
X = torch.randn(32, 48, 6) # 32 patients, 48 hours, 6 vitals
out = model(X)
print(f"ICU model output: {out.shape}") # (32, 1)
probs = torch.sigmoid(out)
print(f"Mortality probs: min={probs.min():.3f}, max={probs.max():.3f}")RNN vs Transformer for Clinical Time Series
When to use LSTM/GRU:
โ Short-to-medium sequences (< 500 timesteps)
โ Real-time inference (processes one step at a time โ streaming)
โ Limited training data (LSTMs work with < 10K sequences)
โ Variable-length sequences with masked/packed input
โ Memory-constrained deployment (LSTMs are smaller than transformers)
โ Causal prediction (online, left-to-right)
When to use Transformers (e.g., clinical BERT for EHR):
โ Very long sequences (1000+ timesteps, full admission notes)
โ Pre-training on large corpus available (MIMIC, EHR)
โ Need bidirectional context (classify entire record)
โ Interpretability via attention weights
โ Multi-modal data (mix text + vitals + labs)
โ Small datasets: transformers overfit without large-scale pre-training
Hybrid approach (common in clinical AI):
- Use LSTM to process time-series vitals
- Use BERT embeddings for clinical notes
- Concatenate features for final predictionimport torch
import torch.nn as nn
class HybridClinicalModel(nn.Module):
"""Combines LSTM for vitals + text embeddings for clinical notes."""
def __init__(
self,
n_vitals: int = 6,
text_embed_dim: int = 768, # BERT embedding size
hidden_size: int = 64,
):
super().__init__()
# Vital sign stream
self.lstm = nn.LSTM(n_vitals, hidden_size, num_layers=2, batch_first=True)
# Text stream projection
self.text_proj = nn.Linear(text_embed_dim, hidden_size)
# Fusion and prediction
self.head = nn.Sequential(
nn.Linear(hidden_size * 2, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
)
def forward(
self,
vitals: torch.Tensor, # (batch, time_steps, n_vitals)
text_embed: torch.Tensor, # (batch, 768) โ CLS token from BERT
) -> torch.Tensor:
_, (h_n, _) = self.lstm(vitals)
vital_features = h_n[-1] # (batch, hidden_size)
text_features = self.text_proj(text_embed) # (batch, hidden_size)
combined = torch.cat([vital_features, text_features], dim=-1)
return self.head(combined)
model = HybridClinicalModel()
vitals = torch.randn(16, 48, 6)
text = torch.randn(16, 768)
out = model(vitals, text)
print(f"Hybrid model output: {out.shape}") # (16, 1)Interview Answer
"RNNs process sequences one timestep at a time, maintaining a hidden state h_t = tanh(W_hยทh_ + W_xยทx_t). The vanishing gradient problem makes training on sequences longer than ~20 timesteps unreliable. LSTM solves this with a cell state (long-term memory) protected by gates: forget gate (how much to keep), input gate (how much new info to add), output gate (how much to expose). The cell state's additive update (c_t = fยทc_ + iยทg) allows gradients to flow back through time without multiplication by potentially-small activations. GRU is a simpler alternative with two gates (reset, update) and no separate cell state โ fewer parameters, often comparable performance. For clinical time series (ICU vitals, ECG): LSTM/GRU work well for short-to-medium sequences (under 500 steps) and streaming inference. For longer sequences or when pre-trained models are available (ClinicalBERT, LLMs fine-tuned on MIMIC): use transformers, which have direct attention paths between all positions."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.