Chain Rule of Probability
The chain rule of probability — how joint probabilities factorise into conditionals, its connection to language models, and how to apply it.
The Rule
P(A₁, A₂, ..., Aₙ) = P(A₁) × P(A₂ | A₁) × P(A₃ | A₁, A₂) × ... × P(Aₙ | A₁,...,Aₙ₋₁)
Or in product notation:
P(∩ᵢ Aᵢ) = Π P(Aᵢ | A₁, ..., Aᵢ₋₁)
For two events (simplest case):
P(A ∩ B) = P(A) × P(B | A)
= P(B) × P(A | B) [order doesn't matter — both are correct]Derivation from Definition
P(A | B) = P(A ∩ B) / P(B) [definition of conditional]
→ P(A ∩ B) = P(A | B) × P(B) [multiplication rule: 2 events]
For 3 events:
P(A ∩ B ∩ C) = P(C | A ∩ B) × P(A ∩ B)
= P(C | A, B) × P(B | A) × P(A)
By induction, this extends to any number of events.Connection to Language Models
GPT and all autoregressive models define the probability of a text sequence using the chain rule:
P("The patient has AF")
= P("The")
× P("patient" | "The")
× P("has" | "The patient")
× P("AF" | "The patient has")
Each term is the conditional probability of the next token given all previous tokens.
The model is trained to maximise this joint probability over the training corpus.import torch
import torch.nn.functional as F
def sequence_log_probability(
logits: torch.Tensor, # shape: (seq_len, vocab_size)
token_ids: torch.Tensor, # shape: (seq_len,)
) -> float:
"""
Compute log P(token_1, ..., token_n) using chain rule.
logits[t] = model's predictions for position t given tokens 0..t-1
"""
# log P(each token | previous tokens)
log_probs = F.log_softmax(logits, dim=-1) # (seq_len, vocab_size)
# Select log probability of the actual token at each position
# Skip position 0 (first token has no context to predict it)
target_log_probs = log_probs[:-1, :].gather(
dim=-1,
index=token_ids[1:].unsqueeze(-1)
).squeeze(-1) # (seq_len - 1,)
# Sum of log probs = log of joint probability (product)
total_log_prob = target_log_probs.sum().item()
return total_log_prob
def perplexity_from_chain_rule(log_prob: float, n_tokens: int) -> float:
"""Perplexity = exp(-log P(sequence) / n_tokens)"""
import math
return math.exp(-log_prob / n_tokens)Bayesian Networks and the Chain Rule
Bayesian networks use the chain rule with independence assumptions to
simplify joint probability computation.
Full joint (no assumptions, 3 binary variables):
P(A, B, C) = P(A) × P(B | A) × P(C | A, B) [3 terms]
With independence: if C ⊥ A given B (C is independent of A given B):
P(A, B, C) = P(A) × P(B | A) × P(C | B) [C simplified]
Naive Bayes (all features independent given class Y):
P(X₁, X₂, ..., Xₙ, Y) = P(Y) × Π P(Xᵢ | Y)# Naive Bayes classifier: uses chain rule + conditional independence
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
gnb.fit(X_train, y_train)
# Internally, gnb computes:
# log P(y | x) ∝ log P(y) + Σᵢ log P(xᵢ | y) (chain rule + independence)
proba = gnb.predict_proba(X_test)Worked Example: Patient Diagnosis
Three binary variables:
S = Symptom (fever present)
D = Disease (flu)
T = Test (flu test result)
Assume: T ⊥ S given D (test result depends on disease, not directly on symptom)
P(S=1, D=1, T=1) = P(S=1) × P(D=1 | S=1) × P(T=1 | D=1)
[chain rule + conditional independence of T and S given D]
Given:
P(S=1) = 0.30 (30% of patients have fever)
P(D=1 | S=1) = 0.40 (40% of patients with fever have flu)
P(T=1 | D=1) = 0.90 (test sensitivity 90%)
P(fever, flu, positive test) = 0.30 × 0.40 × 0.90 = 0.108def joint_with_chain_rule(
p_s: float,
p_d_given_s: float,
p_t_given_d: float,
) -> float:
return p_s * p_d_given_s * p_t_given_d
result = joint_with_chain_rule(0.30, 0.40, 0.90)
print(f"P(S=1, D=1, T=1) = {result:.4f}") # 0.108Interview Answer
"The chain rule of probability states that the joint probability of n events equals the product of conditional probabilities: P(A₁,...,Aₙ) = Π P(Aᵢ|A₁,...,Aᵢ₋₁). This is mathematically exact — no assumptions needed. It's the foundation of autoregressive language models: GPT decomposes P(token sequence) into a product of conditional next-token probabilities, training the model to estimate each conditional. The chain rule also appears in Bayesian networks, where conditional independence assumptions simplify the product — Naive Bayes, for instance, assumes features are independent given the class, reducing the chain rule product to P(y) × Π P(xᵢ|y)."
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.