Learnixo
Back to blog
AI Systemsintermediate

Chain Rule of Probability

The chain rule of probability — how joint probabilities factorise into conditionals, its connection to language models, and how to apply it.

Asma Hafeez KhanMay 21, 20264 min read
ProbabilityChain RuleLanguage ModelsAutoregressiveInterview
Share:𝕏

The Rule

P(A₁, A₂, ..., Aₙ) = P(A₁) × P(A₂ | A₁) × P(A₃ | A₁, A₂) × ... × P(Aₙ | A₁,...,Aₙ₋₁)

Or in product notation:
P(∩ᵢ Aᵢ) = Π P(Aᵢ | A₁, ..., Aᵢ₋₁)

For two events (simplest case):
P(A ∩ B) = P(A) × P(B | A)
          = P(B) × P(A | B)    [order doesn't matter — both are correct]

Derivation from Definition

P(A | B) = P(A ∩ B) / P(B)    [definition of conditional]

→ P(A ∩ B) = P(A | B) × P(B)  [multiplication rule: 2 events]

For 3 events:
P(A ∩ B ∩ C) = P(C | A ∩ B) × P(A ∩ B)
             = P(C | A, B) × P(B | A) × P(A)

By induction, this extends to any number of events.

Connection to Language Models

GPT and all autoregressive models define the probability of a text sequence using the chain rule:

P("The patient has AF") 
= P("The") 
  × P("patient" | "The") 
  × P("has" | "The patient") 
  × P("AF" | "The patient has")

Each term is the conditional probability of the next token given all previous tokens.
The model is trained to maximise this joint probability over the training corpus.
Python
import torch
import torch.nn.functional as F

def sequence_log_probability(
    logits: torch.Tensor,     # shape: (seq_len, vocab_size)
    token_ids: torch.Tensor,  # shape: (seq_len,)
) -> float:
    """
    Compute log P(token_1, ..., token_n) using chain rule.
    logits[t] = model's predictions for position t given tokens 0..t-1
    """
    # log P(each token | previous tokens)
    log_probs = F.log_softmax(logits, dim=-1)   # (seq_len, vocab_size)
    
    # Select log probability of the actual token at each position
    # Skip position 0 (first token has no context to predict it)
    target_log_probs = log_probs[:-1, :].gather(
        dim=-1,
        index=token_ids[1:].unsqueeze(-1)
    ).squeeze(-1)  # (seq_len - 1,)
    
    # Sum of log probs = log of joint probability (product)
    total_log_prob = target_log_probs.sum().item()
    return total_log_prob


def perplexity_from_chain_rule(log_prob: float, n_tokens: int) -> float:
    """Perplexity = exp(-log P(sequence) / n_tokens)"""
    import math
    return math.exp(-log_prob / n_tokens)

Bayesian Networks and the Chain Rule

Bayesian networks use the chain rule with independence assumptions to
simplify joint probability computation.

Full joint (no assumptions, 3 binary variables):
P(A, B, C) = P(A) × P(B | A) × P(C | A, B)   [3 terms]

With independence: if C ⊥ A given B (C is independent of A given B):
P(A, B, C) = P(A) × P(B | A) × P(C | B)       [C simplified]

Naive Bayes (all features independent given class Y):
P(X₁, X₂, ..., Xₙ, Y) = P(Y) × Π P(Xᵢ | Y)
Python
# Naive Bayes classifier: uses chain rule + conditional independence
from sklearn.naive_bayes import GaussianNB

gnb = GaussianNB()
gnb.fit(X_train, y_train)

# Internally, gnb computes:
# log P(y | x)  log P(y) + Σᵢ log P(xᵢ | y)   (chain rule + independence)
proba = gnb.predict_proba(X_test)

Worked Example: Patient Diagnosis

Three binary variables:
  S = Symptom (fever present)
  D = Disease (flu)
  T = Test (flu test result)

Assume: T ⊥ S given D (test result depends on disease, not directly on symptom)

P(S=1, D=1, T=1) = P(S=1) × P(D=1 | S=1) × P(T=1 | D=1)
                   [chain rule + conditional independence of T and S given D]

Given:
  P(S=1) = 0.30           (30% of patients have fever)
  P(D=1 | S=1) = 0.40     (40% of patients with fever have flu)
  P(T=1 | D=1) = 0.90     (test sensitivity 90%)

P(fever, flu, positive test) = 0.30 × 0.40 × 0.90 = 0.108
Python
def joint_with_chain_rule(
    p_s: float,
    p_d_given_s: float,
    p_t_given_d: float,
) -> float:
    return p_s * p_d_given_s * p_t_given_d

result = joint_with_chain_rule(0.30, 0.40, 0.90)
print(f"P(S=1, D=1, T=1) = {result:.4f}")  # 0.108

Interview Answer

"The chain rule of probability states that the joint probability of n events equals the product of conditional probabilities: P(A₁,...,Aₙ) = Π P(Aᵢ|A₁,...,Aᵢ₋₁). This is mathematically exact — no assumptions needed. It's the foundation of autoregressive language models: GPT decomposes P(token sequence) into a product of conditional next-token probabilities, training the model to estimate each conditional. The chain rule also appears in Bayesian networks, where conditional independence assumptions simplify the product — Naive Bayes, for instance, assumes features are independent given the class, reducing the chain rule product to P(y) × Π P(xᵢ|y)."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.