Learnixo

Statistics & Math for AI/ML Interviews · Lesson 11 of 30

Law of Total Probability

The Law

If B₁, B₂, ..., Bₙ are mutually exclusive events that partition the sample space (together they cover all possibilities), then:

P(A) = Σᵢ P(A | Bᵢ) × P(Bᵢ)

In words: "total probability of A = sum of (probability of A in each scenario × probability of that scenario)"

Intuition: Weighted Average of Conditional Probabilities

Scenario: probability of a positive COVID test result in a population

Two groups:
  B₁ = actually infected (40% of population)
  B₂ = not infected (60% of population)

Given:
  P(positive | infected)    = 0.95  (sensitivity)
  P(positive | not infected) = 0.02  (1 - specificity)

P(positive) = P(positive | infected) × P(infected)
             + P(positive | not infected) × P(not infected)
            = 0.95 × 0.40 + 0.02 × 0.60
            = 0.380 + 0.012
            = 0.392

Almost 40% of tests are positive — mostly because 40% are infected
and the sensitivity is high.

Implementation

Python
def law_of_total_probability(
    partitions: list[str],
    p_partition: dict[str, float],
    p_a_given_partition: dict[str, float],
) -> float:
    """
    P(A) = Σ P(A|Bᵢ) × P(Bᵢ)
    
    partitions: list of partition labels
    p_partition: P(Bᵢ) for each partition
    p_a_given_partition: P(A|Bᵢ) for each partition
    """
    total = 0.0
    for b in partitions:
        total += p_a_given_partition[b] * p_partition[b]
    return total


# Medical example: P(sepsis readmission) across hospitals with different care quality
partitions = ["high_quality_hospital", "medium_quality_hospital", "low_quality_hospital"]
p_hospital = {
    "high_quality_hospital": 0.30,    # 30% of patients at high-quality hospitals
    "medium_quality_hospital": 0.50,
    "low_quality_hospital": 0.20,
}
p_readmission_given_hospital = {
    "high_quality_hospital": 0.05,
    "medium_quality_hospital": 0.12,
    "low_quality_hospital": 0.25,
}

p_readmission = law_of_total_probability(
    partitions, p_hospital, p_readmission_given_hospital
)
print(f"P(readmission) = {p_readmission:.4f}")  # 0.0150 + 0.0600 + 0.0500 = 0.1250

The Partition Requirement

Critical: the Bᵢ must partition Ω:
  1. Mutually exclusive: Bᵢ ∩ Bⱼ = ∅ for i ≠ j (no overlap)
  2. Collectively exhaustive: B₁ ∪ B₂ ∪ ... ∪ Bₙ = Ω (cover everything)

Common mistakes:
  Partitions that overlap → double-counting
  Partitions that don't cover all cases → probabilities don't sum to P(A)

Simple check: P(B₁) + P(B₂) + ... + P(Bₙ) = 1

Connection to Bayes' Theorem

Total probability is the denominator in Bayes' theorem:

P(B | A) = P(A | B) × P(B) / P(A)

P(A) = Σᵢ P(A | Bᵢ) × P(Bᵢ)   [law of total probability]

Therefore:
P(Bⱼ | A) = P(A | Bⱼ) × P(Bⱼ) / [Σᵢ P(A | Bᵢ) × P(Bᵢ)]

The denominator is what "normalises" the posterior to sum to 1.
Python
def bayes_with_total_probability(
    partitions: list[str],
    p_partition: dict[str, float],          # prior P(Bᵢ)
    p_a_given_partition: dict[str, float],  # likelihood P(A|Bᵢ)
    target: str,                            # which Bⱼ to compute posterior for
) -> float:
    """P(target | A) using Bayes' theorem."""
    p_a = law_of_total_probability(partitions, p_partition, p_a_given_partition)
    
    # Bayes numerator
    numerator = p_a_given_partition[target] * p_partition[target]
    
    return numerator / p_a


# P(infected | positive test) from earlier example
p_infected_given_positive = bayes_with_total_probability(
    partitions=["infected", "not_infected"],
    p_partition={"infected": 0.40, "not_infected": 0.60},
    p_a_given_partition={"infected": 0.95, "not_infected": 0.02},
    target="infected",
)
print(f"P(infected | positive) = {p_infected_given_positive:.4f}")
# = (0.95 × 0.40) / 0.392 = 0.380 / 0.392 = 0.969

Total Probability in ML: Marginalising Over Latent Variables

Python
# Example: Gaussian Mixture Model (GMM)
# P(x) = Σ_k P(x | cluster=k) × P(cluster=k)
# This is the law of total probability with clusters as the partition

import numpy as np
from scipy.stats import norm

def gmm_probability(
    x: float,
    means: list[float],
    stds: list[float],
    weights: list[float],
) -> float:
    """P(x) for a Gaussian Mixture Model."""
    return sum(
        w * norm(mu, sigma).pdf(x)
        for w, mu, sigma in zip(weights, means, stds)
    )


# Two-component GMM: healthy patients and sick patients
p_x = gmm_probability(
    x=5.0,                            # lab value
    means=[3.0, 7.0],                 # healthy mean=3, sick mean=7
    stds=[0.5, 1.0],
    weights=[0.7, 0.3],               # 70% healthy, 30% sick
)

Interview Answer

"The law of total probability states: P(A) = Σᵢ P(A|Bᵢ) × P(Bᵢ), where B₁,...,Bₙ partition the sample space. It decomposes the probability of an event into a weighted average of its conditional probabilities across mutually exclusive scenarios. It's the denominator of Bayes' theorem — the normalisation constant that ensures posteriors sum to 1. In ML, it appears in Gaussian Mixture Models (P(x) = Σ_k P(x|cluster=k)×P(cluster=k)), in language model marginalisation over latent topics, and whenever we need to reason about probabilities across subpopulations."