Learnixo
Back to blog
AI Systemsbeginner

Law of Total Probability

The law of total probability, how to use it to decompose complex probability computations, and its connection to Bayes' theorem and ML.

Asma Hafeez KhanMay 21, 20264 min read
ProbabilityTotal ProbabilityPartitionBayesInterview
Share:𝕏

The Law

If B₁, B₂, ..., Bₙ are mutually exclusive events that partition the sample space (together they cover all possibilities), then:

P(A) = Σᵢ P(A | Bᵢ) × P(Bᵢ)

In words: "total probability of A = sum of (probability of A in each scenario × probability of that scenario)"

Intuition: Weighted Average of Conditional Probabilities

Scenario: probability of a positive COVID test result in a population

Two groups:
  B₁ = actually infected (40% of population)
  B₂ = not infected (60% of population)

Given:
  P(positive | infected)    = 0.95  (sensitivity)
  P(positive | not infected) = 0.02  (1 - specificity)

P(positive) = P(positive | infected) × P(infected)
             + P(positive | not infected) × P(not infected)
            = 0.95 × 0.40 + 0.02 × 0.60
            = 0.380 + 0.012
            = 0.392

Almost 40% of tests are positive — mostly because 40% are infected
and the sensitivity is high.

Implementation

Python
def law_of_total_probability(
    partitions: list[str],
    p_partition: dict[str, float],
    p_a_given_partition: dict[str, float],
) -> float:
    """
    P(A) = Σ P(A|Bᵢ) × P(Bᵢ)
    
    partitions: list of partition labels
    p_partition: P(Bᵢ) for each partition
    p_a_given_partition: P(A|Bᵢ) for each partition
    """
    total = 0.0
    for b in partitions:
        total += p_a_given_partition[b] * p_partition[b]
    return total


# Medical example: P(sepsis readmission) across hospitals with different care quality
partitions = ["high_quality_hospital", "medium_quality_hospital", "low_quality_hospital"]
p_hospital = {
    "high_quality_hospital": 0.30,    # 30% of patients at high-quality hospitals
    "medium_quality_hospital": 0.50,
    "low_quality_hospital": 0.20,
}
p_readmission_given_hospital = {
    "high_quality_hospital": 0.05,
    "medium_quality_hospital": 0.12,
    "low_quality_hospital": 0.25,
}

p_readmission = law_of_total_probability(
    partitions, p_hospital, p_readmission_given_hospital
)
print(f"P(readmission) = {p_readmission:.4f}")  # 0.0150 + 0.0600 + 0.0500 = 0.1250

The Partition Requirement

Critical: the Bᵢ must partition Ω:
  1. Mutually exclusive: Bᵢ ∩ Bⱼ = ∅ for i ≠ j (no overlap)
  2. Collectively exhaustive: B₁ ∪ B₂ ∪ ... ∪ Bₙ = Ω (cover everything)

Common mistakes:
  Partitions that overlap → double-counting
  Partitions that don't cover all cases → probabilities don't sum to P(A)

Simple check: P(B₁) + P(B₂) + ... + P(Bₙ) = 1

Connection to Bayes' Theorem

Total probability is the denominator in Bayes' theorem:

P(B | A) = P(A | B) × P(B) / P(A)

P(A) = Σᵢ P(A | Bᵢ) × P(Bᵢ)   [law of total probability]

Therefore:
P(Bⱼ | A) = P(A | Bⱼ) × P(Bⱼ) / [Σᵢ P(A | Bᵢ) × P(Bᵢ)]

The denominator is what "normalises" the posterior to sum to 1.
Python
def bayes_with_total_probability(
    partitions: list[str],
    p_partition: dict[str, float],          # prior P(Bᵢ)
    p_a_given_partition: dict[str, float],  # likelihood P(A|Bᵢ)
    target: str,                            # which Bⱼ to compute posterior for
) -> float:
    """P(target | A) using Bayes' theorem."""
    p_a = law_of_total_probability(partitions, p_partition, p_a_given_partition)
    
    # Bayes numerator
    numerator = p_a_given_partition[target] * p_partition[target]
    
    return numerator / p_a


# P(infected | positive test) from earlier example
p_infected_given_positive = bayes_with_total_probability(
    partitions=["infected", "not_infected"],
    p_partition={"infected": 0.40, "not_infected": 0.60},
    p_a_given_partition={"infected": 0.95, "not_infected": 0.02},
    target="infected",
)
print(f"P(infected | positive) = {p_infected_given_positive:.4f}")
# = (0.95 × 0.40) / 0.392 = 0.380 / 0.392 = 0.969

Total Probability in ML: Marginalising Over Latent Variables

Python
# Example: Gaussian Mixture Model (GMM)
# P(x) = Σ_k P(x | cluster=k) × P(cluster=k)
# This is the law of total probability with clusters as the partition

import numpy as np
from scipy.stats import norm

def gmm_probability(
    x: float,
    means: list[float],
    stds: list[float],
    weights: list[float],
) -> float:
    """P(x) for a Gaussian Mixture Model."""
    return sum(
        w * norm(mu, sigma).pdf(x)
        for w, mu, sigma in zip(weights, means, stds)
    )


# Two-component GMM: healthy patients and sick patients
p_x = gmm_probability(
    x=5.0,                            # lab value
    means=[3.0, 7.0],                 # healthy mean=3, sick mean=7
    stds=[0.5, 1.0],
    weights=[0.7, 0.3],               # 70% healthy, 30% sick
)

Interview Answer

"The law of total probability states: P(A) = Σᵢ P(A|Bᵢ) × P(Bᵢ), where B₁,...,Bₙ partition the sample space. It decomposes the probability of an event into a weighted average of its conditional probabilities across mutually exclusive scenarios. It's the denominator of Bayes' theorem — the normalisation constant that ensures posteriors sum to 1. In ML, it appears in Gaussian Mixture Models (P(x) = Σ_k P(x|cluster=k)×P(cluster=k)), in language model marginalisation over latent topics, and whenever we need to reason about probabilities across subpopulations."

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.