Statistics & Math for AI/ML Interviews · Lesson 11 of 30
Law of Total Probability
The Law
If B₁, B₂, ..., Bₙ are mutually exclusive events that partition the sample space (together they cover all possibilities), then:
P(A) = Σᵢ P(A | Bᵢ) × P(Bᵢ)
In words: "total probability of A = sum of (probability of A in each scenario × probability of that scenario)"Intuition: Weighted Average of Conditional Probabilities
Scenario: probability of a positive COVID test result in a population
Two groups:
B₁ = actually infected (40% of population)
B₂ = not infected (60% of population)
Given:
P(positive | infected) = 0.95 (sensitivity)
P(positive | not infected) = 0.02 (1 - specificity)
P(positive) = P(positive | infected) × P(infected)
+ P(positive | not infected) × P(not infected)
= 0.95 × 0.40 + 0.02 × 0.60
= 0.380 + 0.012
= 0.392
Almost 40% of tests are positive — mostly because 40% are infected
and the sensitivity is high.Implementation
def law_of_total_probability(
partitions: list[str],
p_partition: dict[str, float],
p_a_given_partition: dict[str, float],
) -> float:
"""
P(A) = Σ P(A|Bᵢ) × P(Bᵢ)
partitions: list of partition labels
p_partition: P(Bᵢ) for each partition
p_a_given_partition: P(A|Bᵢ) for each partition
"""
total = 0.0
for b in partitions:
total += p_a_given_partition[b] * p_partition[b]
return total
# Medical example: P(sepsis readmission) across hospitals with different care quality
partitions = ["high_quality_hospital", "medium_quality_hospital", "low_quality_hospital"]
p_hospital = {
"high_quality_hospital": 0.30, # 30% of patients at high-quality hospitals
"medium_quality_hospital": 0.50,
"low_quality_hospital": 0.20,
}
p_readmission_given_hospital = {
"high_quality_hospital": 0.05,
"medium_quality_hospital": 0.12,
"low_quality_hospital": 0.25,
}
p_readmission = law_of_total_probability(
partitions, p_hospital, p_readmission_given_hospital
)
print(f"P(readmission) = {p_readmission:.4f}") # 0.0150 + 0.0600 + 0.0500 = 0.1250The Partition Requirement
Critical: the Bᵢ must partition Ω:
1. Mutually exclusive: Bᵢ ∩ Bⱼ = ∅ for i ≠ j (no overlap)
2. Collectively exhaustive: B₁ ∪ B₂ ∪ ... ∪ Bₙ = Ω (cover everything)
Common mistakes:
Partitions that overlap → double-counting
Partitions that don't cover all cases → probabilities don't sum to P(A)
Simple check: P(B₁) + P(B₂) + ... + P(Bₙ) = 1Connection to Bayes' Theorem
Total probability is the denominator in Bayes' theorem:
P(B | A) = P(A | B) × P(B) / P(A)
P(A) = Σᵢ P(A | Bᵢ) × P(Bᵢ) [law of total probability]
Therefore:
P(Bⱼ | A) = P(A | Bⱼ) × P(Bⱼ) / [Σᵢ P(A | Bᵢ) × P(Bᵢ)]
The denominator is what "normalises" the posterior to sum to 1.def bayes_with_total_probability(
partitions: list[str],
p_partition: dict[str, float], # prior P(Bᵢ)
p_a_given_partition: dict[str, float], # likelihood P(A|Bᵢ)
target: str, # which Bⱼ to compute posterior for
) -> float:
"""P(target | A) using Bayes' theorem."""
p_a = law_of_total_probability(partitions, p_partition, p_a_given_partition)
# Bayes numerator
numerator = p_a_given_partition[target] * p_partition[target]
return numerator / p_a
# P(infected | positive test) from earlier example
p_infected_given_positive = bayes_with_total_probability(
partitions=["infected", "not_infected"],
p_partition={"infected": 0.40, "not_infected": 0.60},
p_a_given_partition={"infected": 0.95, "not_infected": 0.02},
target="infected",
)
print(f"P(infected | positive) = {p_infected_given_positive:.4f}")
# = (0.95 × 0.40) / 0.392 = 0.380 / 0.392 = 0.969Total Probability in ML: Marginalising Over Latent Variables
# Example: Gaussian Mixture Model (GMM)
# P(x) = Σ_k P(x | cluster=k) × P(cluster=k)
# This is the law of total probability with clusters as the partition
import numpy as np
from scipy.stats import norm
def gmm_probability(
x: float,
means: list[float],
stds: list[float],
weights: list[float],
) -> float:
"""P(x) for a Gaussian Mixture Model."""
return sum(
w * norm(mu, sigma).pdf(x)
for w, mu, sigma in zip(weights, means, stds)
)
# Two-component GMM: healthy patients and sick patients
p_x = gmm_probability(
x=5.0, # lab value
means=[3.0, 7.0], # healthy mean=3, sick mean=7
stds=[0.5, 1.0],
weights=[0.7, 0.3], # 70% healthy, 30% sick
)Interview Answer
"The law of total probability states: P(A) = Σᵢ P(A|Bᵢ) × P(Bᵢ), where B₁,...,Bₙ partition the sample space. It decomposes the probability of an event into a weighted average of its conditional probabilities across mutually exclusive scenarios. It's the denominator of Bayes' theorem — the normalisation constant that ensures posteriors sum to 1. In ML, it appears in Gaussian Mixture Models (P(x) = Σ_k P(x|cluster=k)×P(cluster=k)), in language model marginalisation over latent topics, and whenever we need to reason about probabilities across subpopulations."