Learnixo

Statistics & Math for AI/ML Interviews · Lesson 14 of 30

Chain Rule of Probability

The Rule

P(A₁, A₂, ..., Aₙ) = P(A₁) × P(A₂ | A₁) × P(A₃ | A₁, A₂) × ... × P(Aₙ | A₁,...,Aₙ₋₁)

Or in product notation:
P(∩ᵢ Aᵢ) = Π P(Aᵢ | A₁, ..., Aᵢ₋₁)

For two events (simplest case):
P(A ∩ B) = P(A) × P(B | A)
          = P(B) × P(A | B)    [order doesn't matter — both are correct]

Derivation from Definition

P(A | B) = P(A ∩ B) / P(B)    [definition of conditional]

→ P(A ∩ B) = P(A | B) × P(B)  [multiplication rule: 2 events]

For 3 events:
P(A ∩ B ∩ C) = P(C | A ∩ B) × P(A ∩ B)
             = P(C | A, B) × P(B | A) × P(A)

By induction, this extends to any number of events.

Connection to Language Models

GPT and all autoregressive models define the probability of a text sequence using the chain rule:

P("The patient has AF") 
= P("The") 
  × P("patient" | "The") 
  × P("has" | "The patient") 
  × P("AF" | "The patient has")

Each term is the conditional probability of the next token given all previous tokens.
The model is trained to maximise this joint probability over the training corpus.
Python
import torch
import torch.nn.functional as F

def sequence_log_probability(
    logits: torch.Tensor,     # shape: (seq_len, vocab_size)
    token_ids: torch.Tensor,  # shape: (seq_len,)
) -> float:
    """
    Compute log P(token_1, ..., token_n) using chain rule.
    logits[t] = model's predictions for position t given tokens 0..t-1
    """
    # log P(each token | previous tokens)
    log_probs = F.log_softmax(logits, dim=-1)   # (seq_len, vocab_size)
    
    # Select log probability of the actual token at each position
    # Skip position 0 (first token has no context to predict it)
    target_log_probs = log_probs[:-1, :].gather(
        dim=-1,
        index=token_ids[1:].unsqueeze(-1)
    ).squeeze(-1)  # (seq_len - 1,)
    
    # Sum of log probs = log of joint probability (product)
    total_log_prob = target_log_probs.sum().item()
    return total_log_prob


def perplexity_from_chain_rule(log_prob: float, n_tokens: int) -> float:
    """Perplexity = exp(-log P(sequence) / n_tokens)"""
    import math
    return math.exp(-log_prob / n_tokens)

Bayesian Networks and the Chain Rule

Bayesian networks use the chain rule with independence assumptions to
simplify joint probability computation.

Full joint (no assumptions, 3 binary variables):
P(A, B, C) = P(A) × P(B | A) × P(C | A, B)   [3 terms]

With independence: if C ⊥ A given B (C is independent of A given B):
P(A, B, C) = P(A) × P(B | A) × P(C | B)       [C simplified]

Naive Bayes (all features independent given class Y):
P(X₁, X₂, ..., Xₙ, Y) = P(Y) × Π P(Xᵢ | Y)
Python
# Naive Bayes classifier: uses chain rule + conditional independence
from sklearn.naive_bayes import GaussianNB

gnb = GaussianNB()
gnb.fit(X_train, y_train)

# Internally, gnb computes:
# log P(y | x)  log P(y) + Σᵢ log P(xᵢ | y)   (chain rule + independence)
proba = gnb.predict_proba(X_test)

Worked Example: Patient Diagnosis

Three binary variables:
  S = Symptom (fever present)
  D = Disease (flu)
  T = Test (flu test result)

Assume: T ⊥ S given D (test result depends on disease, not directly on symptom)

P(S=1, D=1, T=1) = P(S=1) × P(D=1 | S=1) × P(T=1 | D=1)
                   [chain rule + conditional independence of T and S given D]

Given:
  P(S=1) = 0.30           (30% of patients have fever)
  P(D=1 | S=1) = 0.40     (40% of patients with fever have flu)
  P(T=1 | D=1) = 0.90     (test sensitivity 90%)

P(fever, flu, positive test) = 0.30 × 0.40 × 0.90 = 0.108
Python
def joint_with_chain_rule(
    p_s: float,
    p_d_given_s: float,
    p_t_given_d: float,
) -> float:
    return p_s * p_d_given_s * p_t_given_d

result = joint_with_chain_rule(0.30, 0.40, 0.90)
print(f"P(S=1, D=1, T=1) = {result:.4f}")  # 0.108

Interview Answer

"The chain rule of probability states that the joint probability of n events equals the product of conditional probabilities: P(A₁,...,Aₙ) = Π P(Aᵢ|A₁,...,Aᵢ₋₁). This is mathematically exact — no assumptions needed. It's the foundation of autoregressive language models: GPT decomposes P(token sequence) into a product of conditional next-token probabilities, training the model to estimate each conditional. The chain rule also appears in Bayesian networks, where conditional independence assumptions simplify the product — Naive Bayes, for instance, assumes features are independent given the class, reducing the chain rule product to P(y) × Π P(xᵢ|y)."