Back to blog
Data Engineeringintermediate

pytest Mocking and Patching for Data Pipeline Tests

Master unittest.mock, pytest-mock, monkeypatch, freezegun, and HTTP mocking to isolate external dependencies in Snowflake, S3, REST API, and datetime-sensitive pipeline tests.

LearnixoMay 7, 202615 min read
pytestpythonmockingpytest-mockmonkeypatchdata-engineeringsnowflakes3
Share:š•

pytest Mocking and Patching for Data Pipeline Tests

Data pipelines touch everything: cloud data warehouses, object storage, REST APIs, message queues, and clocks. Running tests against real infrastructure is slow, expensive, and fragile. Mocking replaces those dependencies with controlled stand-ins. This lesson covers every mocking tool you need, with realistic examples of the exact scenarios data engineers face.

The Core Question: Mock or Not?

Before writing a mock, ask:

  • Is the dependency fast and deterministic? Use the real thing. A pandas transformation, a pure Python function, or an in-memory DuckDB connection do not need mocking.
  • Is the dependency slow, external, or stateful? Mock it. Snowflake queries, S3 operations, external HTTP APIs, and datetime.now() all belong in this category.
  • Can I use a fast equivalent? Consider it first. SQLite instead of PostgreSQL for schema tests. DuckDB instead of BigQuery for SQL logic tests. A local file instead of S3 for format tests.

The mocking hierarchy from lightest to heaviest:

  1. Real implementation — fastest, most confidence
  2. In-memory equivalent — DuckDB, SQLite, tmp_path
  3. monkeypatch — replace env vars, builtins, module attributes
  4. unittest.mock / pytest-mock — replace objects and functions with controllable fakes
  5. Docker containers — real services in isolation (testcontainers, next lesson)

unittest.mock Fundamentals

Python's built-in unittest.mock module provides Mock, MagicMock, and patch.

Mock vs MagicMock

Python
from unittest.mock import Mock, MagicMock

# Mock: basic mock object
m = Mock()
m.some_method()          # Returns a new Mock (no AttributeError)
m.some_method.return_value = 42
assert m.some_method() == 42

# MagicMock: Mock + magic method support (__len__, __iter__, __enter__, etc.)
mm = MagicMock()
len(mm)           # Works — returns 0 by default
with mm:          # Works — __enter__/__exit__ are pre-configured
    pass

# In practice: use MagicMock for context managers, iterables
# Use Mock when you want stricter behavior (magic methods raise TypeError)

spec=True: Catch Typos in Your Mocks

Without spec, you can call any attribute on a mock without error. This hides bugs:

Python
from unittest.mock import MagicMock, create_autospec

# Without spec: typo goes unnoticed
conn = MagicMock()
conn.exeucte("SELECT 1")  # Typo: "exeucte" instead of "execute"
# No error! The mock happily creates conn.exeucte as a new attribute.

# With spec: typo raises AttributeError
import psycopg2
conn = create_autospec(psycopg2.extensions.connection)
conn.exeucte("SELECT 1")
# AttributeError: Mock object has no attribute 'exeucte'

Always use spec=True or create_autospec when mocking objects you own or know the interface of.

side_effect: Sequences and Exceptions

side_effect is more powerful than return_value:

Python
from unittest.mock import Mock

# Return different values on successive calls
paginator = Mock()
paginator.next_page.side_effect = [
    {"data": [1, 2, 3], "has_more": True},
    {"data": [4, 5, 6], "has_more": True},
    {"data": [7, 8], "has_more": False},
]

assert paginator.next_page()["data"] == [1, 2, 3]
assert paginator.next_page()["data"] == [4, 5, 6]
assert paginator.next_page()["data"] == [7, 8]


# Raise an exception on the Nth call
api_client = Mock()
api_client.fetch.side_effect = [
    {"result": "ok"},
    {"result": "ok"},
    ConnectionError("Simulated network failure"),
]

api_client.fetch()  # {"result": "ok"}
api_client.fetch()  # {"result": "ok"}
# api_client.fetch()  # Raises ConnectionError


# Side effect as a function — called with the same args as the mock
def validate_query(sql):
    if "DROP" in sql.upper():
        raise PermissionError("DROP statements not allowed")
    return [("row1",), ("row2",)]

cursor = Mock()
cursor.execute.side_effect = validate_query
cursor.execute("SELECT 1")  # Fine
cursor.execute("DROP TABLE users")  # Raises PermissionError

return_value Chaining

Python
# Deep chaining for object hierarchies
from unittest.mock import MagicMock

snowflake_connector = MagicMock()
snowflake_connector.connect.return_value.cursor.return_value.fetchall.return_value = [
    ("2026-01-01", 100.0),
    ("2026-01-02", 200.0),
]

# Access pattern: connector.connect().cursor().fetchall()
conn = snowflake_connector.connect()
cursor = conn.cursor()
rows = cursor.fetchall()
assert rows == [("2026-01-01", 100.0), ("2026-01-02", 200.0)]

unittest.mock.patch

patch replaces an object in the module under test for the duration of the test.

Critical rule: patch where the name is used, not where it is defined.

Python
# src/readers.py
import boto3  # boto3 defined in boto3 package

def list_s3_files(bucket: str, prefix: str) -> list:
    s3 = boto3.client("s3")  # boto3 used here, in src.readers
    response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
    return [obj["Key"] for obj in response.get("Contents", [])]
Python
# tests/unit/test_readers.py
from unittest.mock import patch, MagicMock

# CORRECT: patch boto3 as it is used in src.readers
def test_list_s3_files_returns_keys():
    with patch("src.readers.boto3") as mock_boto3:
        mock_client = MagicMock()
        mock_boto3.client.return_value = mock_client
        mock_client.list_objects_v2.return_value = {
            "Contents": [
                {"Key": "raw/orders/part-001.parquet"},
                {"Key": "raw/orders/part-002.parquet"},
            ]
        }

        from src.readers import list_s3_files
        result = list_s3_files(bucket="my-bucket", prefix="raw/orders/")

    assert len(result) == 2
    assert "raw/orders/part-001.parquet" in result
    mock_client.list_objects_v2.assert_called_once_with(
        Bucket="my-bucket",
        Prefix="raw/orders/"
    )


# WRONG: patching boto3 at its source — has no effect on src.readers
def test_wrong_patch_location():
    with patch("boto3.client") as mock_client:
        # This patch does NOT affect the 'import boto3' in src.readers
        ...

patch as decorator

Python
from unittest.mock import patch, MagicMock
import pytest

@patch("src.readers.boto3")
def test_list_s3_files_empty_prefix(mock_boto3):
    mock_client = MagicMock()
    mock_boto3.client.return_value = mock_client
    mock_client.list_objects_v2.return_value = {}  # No Contents key

    from src.readers import list_s3_files
    result = list_s3_files(bucket="my-bucket", prefix="empty/")

    assert result == []


# Multiple patches — innermost decorator = first argument
@patch("src.pipeline.boto3")
@patch("src.pipeline.snowflake_connector")
def test_pipeline_run(mock_snowflake, mock_boto3):
    # mock_snowflake corresponds to the innermost @patch (closest to function)
    # mock_boto3 corresponds to the outermost @patch
    ...

pytest-mock: The mocker Fixture

pytest-mock wraps unittest.mock in a pytest fixture called mocker. It is cleaner than with patch(...) blocks and automatically unpatches after the test.

Bash
pip install pytest-mock

mocker.patch vs patch

Python
# unittest.mock style — manual context manager
def test_with_unittest_mock():
    with patch("src.readers.boto3") as mock_boto3:
        mock_boto3.client.return_value.list_objects_v2.return_value = {"Contents": []}
        ...  # patch active here
    # patch removed here


# pytest-mock style — cleaner, auto-cleanup
def test_with_pytest_mock(mocker):
    mock_boto3 = mocker.patch("src.readers.boto3")
    mock_boto3.client.return_value.list_objects_v2.return_value = {"Contents": []}
    ...
    # patch automatically removed after test — no context manager needed

Mocking a Snowflake Connection

Python
# src/snowflake_reader.py
import snowflake.connector


def fetch_recent_orders(days: int = 7) -> list:
    """Fetch orders from the last N days from Snowflake."""
    conn = snowflake.connector.connect(
        user=os.environ["SNOWFLAKE_USER"],
        password=os.environ["SNOWFLAKE_PASSWORD"],
        account=os.environ["SNOWFLAKE_ACCOUNT"],
        warehouse="COMPUTE_WH",
        database="ANALYTICS",
        schema="SALES",
    )
    try:
        cursor = conn.cursor()
        cursor.execute(f"""
            SELECT order_id, customer_id, revenue, order_date
            FROM orders
            WHERE order_date >= DATEADD(day, -{days}, CURRENT_DATE())
        """)
        return cursor.fetchall()
    finally:
        conn.close()
Python
# tests/unit/test_snowflake_reader.py
import pytest


@pytest.fixture
def mock_snowflake_rows():
    return [
        ("ORD-001", "C001", 150.0, "2026-04-30"),
        ("ORD-002", "C002", 250.0, "2026-05-01"),
        ("ORD-003", "C001", 75.0, "2026-05-02"),
    ]


def test_fetch_recent_orders_returns_rows(mocker, mock_snowflake_rows):
    # Patch the connector at the point of use
    mock_connect = mocker.patch("src.snowflake_reader.snowflake.connector.connect")

    mock_cursor = mocker.MagicMock()
    mock_cursor.fetchall.return_value = mock_snowflake_rows
    mock_connect.return_value.cursor.return_value = mock_cursor

    from src.snowflake_reader import fetch_recent_orders
    result = fetch_recent_orders(days=7)

    assert len(result) == 3
    assert result[0][0] == "ORD-001"


def test_fetch_recent_orders_uses_correct_days_parameter(mocker):
    mock_connect = mocker.patch("src.snowflake_reader.snowflake.connector.connect")
    mock_cursor = mocker.MagicMock()
    mock_cursor.fetchall.return_value = []
    mock_connect.return_value.cursor.return_value = mock_cursor

    from src.snowflake_reader import fetch_recent_orders
    fetch_recent_orders(days=30)

    executed_sql = mock_cursor.execute.call_args[0][0]
    assert "30" in executed_sql


def test_fetch_recent_orders_closes_connection_on_error(mocker):
    """Connection must be closed even if the query raises an exception."""
    mock_connect = mocker.patch("src.snowflake_reader.snowflake.connector.connect")
    mock_cursor = mocker.MagicMock()
    mock_cursor.execute.side_effect = Exception("Snowflake query failed")
    mock_connect.return_value.cursor.return_value = mock_cursor

    from src.snowflake_reader import fetch_recent_orders
    with pytest.raises(Exception, match="Snowflake query failed"):
        fetch_recent_orders(days=7)

    # Verify connection was still closed
    mock_connect.return_value.close.assert_called_once()

Mocking an S3 Upload

Python
# src/writers.py
import boto3
import logging
from io import BytesIO

logger = logging.getLogger(__name__)


def upload_dataframe_to_s3(df, bucket: str, key: str, format: str = "parquet") -> str:
    """Upload a DataFrame to S3 in the specified format. Returns the S3 URI."""
    import pandas as pd

    buffer = BytesIO()
    if format == "parquet":
        df.to_parquet(buffer, index=False)
    elif format == "csv":
        df.to_csv(buffer, index=False)
    else:
        raise ValueError(f"Unsupported format: {format}")

    buffer.seek(0)
    s3 = boto3.client("s3")
    s3.put_object(Bucket=bucket, Key=key, Body=buffer.getvalue())

    uri = f"s3://{bucket}/{key}"
    logger.info(f"Uploaded {len(df)} rows to {uri}")
    return uri
Python
# tests/unit/test_writers.py
import pytest
import pandas as pd


@pytest.fixture
def sample_df():
    return pd.DataFrame({
        "order_id": ["O1", "O2"],
        "revenue": [100.0, 200.0],
    })


def test_upload_dataframe_returns_s3_uri(mocker, sample_df):
    mock_s3 = mocker.MagicMock()
    mocker.patch("src.writers.boto3").client.return_value = mock_s3

    from src.writers import upload_dataframe_to_s3
    result = upload_dataframe_to_s3(sample_df, bucket="my-bucket", key="output/data.parquet")

    assert result == "s3://my-bucket/output/data.parquet"


def test_upload_calls_put_object_with_correct_bucket_and_key(mocker, sample_df):
    mock_boto3 = mocker.patch("src.writers.boto3")
    mock_s3_client = mocker.MagicMock()
    mock_boto3.client.return_value = mock_s3_client

    from src.writers import upload_dataframe_to_s3
    upload_dataframe_to_s3(sample_df, bucket="analytics-bucket", key="raw/orders.parquet")

    mock_s3_client.put_object.assert_called_once()
    call_kwargs = mock_s3_client.put_object.call_args[1]
    assert call_kwargs["Bucket"] == "analytics-bucket"
    assert call_kwargs["Key"] == "raw/orders.parquet"
    assert len(call_kwargs["Body"]) > 0  # Non-empty bytes


def test_upload_rejects_unsupported_format(mocker, sample_df):
    mocker.patch("src.writers.boto3")  # Prevent real AWS calls

    from src.writers import upload_dataframe_to_s3
    with pytest.raises(ValueError, match="Unsupported format: json"):
        upload_dataframe_to_s3(sample_df, bucket="b", key="k", format="json")


def test_upload_csv_format(mocker, sample_df):
    mock_boto3 = mocker.patch("src.writers.boto3")
    mock_s3 = mocker.MagicMock()
    mock_boto3.client.return_value = mock_s3

    from src.writers import upload_dataframe_to_s3
    upload_dataframe_to_s3(sample_df, bucket="b", key="output.csv", format="csv")

    call_kwargs = mock_s3.put_object.call_args[1]
    # Verify the body is CSV content (starts with column header)
    body_text = call_kwargs["Body"].decode("utf-8")
    assert "order_id,revenue" in body_text

monkeypatch: The Surgical Tool

monkeypatch is pytest's built-in fixture for targeted replacements. It is simpler than patch for straightforward cases and automatically reverts all changes after the test.

Patching Environment Variables

Python
# src/config.py
import os


def get_database_url() -> str:
    url = os.environ.get("DATABASE_URL")
    if not url:
        raise EnvironmentError("DATABASE_URL environment variable is not set")
    return url


def get_batch_size() -> int:
    return int(os.environ.get("PIPELINE_BATCH_SIZE", "1000"))
Python
# tests/unit/test_config.py
def test_get_database_url_reads_env_var(monkeypatch):
    monkeypatch.setenv("DATABASE_URL", "postgresql://localhost/testdb")

    from src.config import get_database_url
    assert get_database_url() == "postgresql://localhost/testdb"


def test_get_database_url_raises_when_unset(monkeypatch):
    monkeypatch.delenv("DATABASE_URL", raising=False)  # Remove if exists

    from src.config import get_database_url
    with pytest.raises(EnvironmentError, match="DATABASE_URL"):
        get_database_url()


def test_get_batch_size_uses_default(monkeypatch):
    monkeypatch.delenv("PIPELINE_BATCH_SIZE", raising=False)

    from src.config import get_batch_size
    assert get_batch_size() == 1000


def test_get_batch_size_reads_custom_value(monkeypatch):
    monkeypatch.setenv("PIPELINE_BATCH_SIZE", "5000")

    from src.config import get_batch_size
    assert get_batch_size() == 5000

Patching Module Attributes and Functions

Python
# src/utils.py
import uuid
import hashlib


def generate_record_id() -> str:
    return str(uuid.uuid4())


def hash_pii(value: str) -> str:
    return hashlib.sha256(value.encode()).hexdigest()[:16]
Python
def test_generate_record_id_uses_uuid(monkeypatch):
    import uuid as uuid_module

    fixed_uuid = uuid_module.UUID("12345678-1234-5678-1234-567812345678")
    monkeypatch.setattr("src.utils.uuid.uuid4", lambda: fixed_uuid)

    from src.utils import generate_record_id
    result = generate_record_id()
    assert result == "12345678-1234-5678-1234-567812345678"


def test_pipeline_uses_deterministic_ids_in_test(monkeypatch):
    """Make record IDs deterministic for assertion-heavy tests."""
    counter = {"n": 0}

    def sequential_id():
        counter["n"] += 1
        return f"TEST-ID-{counter['n']:04d}"

    monkeypatch.setattr("src.pipeline.generate_record_id", sequential_id)

    from src.pipeline import process_records
    result = process_records([{"name": "Alice"}, {"name": "Bob"}])

    assert result[0]["record_id"] == "TEST-ID-0001"
    assert result[1]["record_id"] == "TEST-ID-0002"

Patching Built-in open (File I/O)

Python
# src/readers.py
def read_config_file(path: str) -> dict:
    import json
    with open(path) as f:
        return json.load(f)
Python
from unittest.mock import mock_open
import json


def test_read_config_file(monkeypatch):
    config_data = {"batch_size": 500, "env": "test"}
    mock_file_content = json.dumps(config_data)

    monkeypatch.setattr("builtins.open", mock_open(read_data=mock_file_content))

    from src.readers import read_config_file
    result = read_config_file("/fake/path/config.json")
    assert result == config_data
    assert result["batch_size"] == 500

freezegun: Mocking datetime

Time-sensitive pipeline logic — daily partitions, SLA windows, incremental watermarks — is impossible to test without controlling the clock.

Bash
pip install freezegun
Python
# src/partitioner.py
from datetime import datetime, date


def get_current_partition() -> str:
    """Return the current date partition key: YYYY/MM/DD."""
    today = datetime.now().date()
    return f"{today.year:04d}/{today.month:02d}/{today.day:02d}"


def is_within_sla(created_at: datetime, sla_hours: int = 4) -> bool:
    """Check if a record is within its SLA window."""
    now = datetime.now()
    age = now - created_at
    return age.total_seconds() / 3600 <= sla_hours


def get_incremental_watermark(lookback_days: int = 1) -> date:
    """Return the date N days ago for incremental pipeline runs."""
    from datetime import timedelta
    return datetime.now().date() - timedelta(days=lookback_days)
Python
# tests/unit/test_partitioner.py
import pytest
from freezegun import freeze_time
from datetime import datetime, date


@freeze_time("2026-05-07 14:30:00")
def test_get_current_partition():
    from src.partitioner import get_current_partition
    assert get_current_partition() == "2026/05/07"


@freeze_time("2026-05-07 00:00:00")
def test_get_current_partition_midnight():
    from src.partitioner import get_current_partition
    assert get_current_partition() == "2026/05/07"


@pytest.mark.parametrize("created_at_str,sla_hours,expected", [
    ("2026-05-07 12:00:00", 4, True),   # 2 hours ago, 4h SLA
    ("2026-05-07 09:00:00", 4, False),  # 5 hours ago, 4h SLA
    ("2026-05-07 10:00:01", 4, True),   # Just under 4h, 4h SLA
    ("2026-05-07 10:00:00", 4, True),   # Exactly 4h — within SLA
    ("2026-05-07 14:00:00", 1, False),  # 30 min ago, 0.5h SLA — wait no
])
@freeze_time("2026-05-07 14:00:00")
def test_is_within_sla(created_at_str, sla_hours, expected):
    from src.partitioner import is_within_sla
    created_at = datetime.strptime(created_at_str, "%Y-%m-%d %H:%M:%S")
    assert is_within_sla(created_at, sla_hours) == expected


@freeze_time("2026-05-07")
def test_incremental_watermark_lookback_1_day():
    from src.partitioner import get_incremental_watermark
    assert get_incremental_watermark(lookback_days=1) == date(2026, 5, 6)


@freeze_time("2026-05-07")
def test_incremental_watermark_lookback_7_days():
    from src.partitioner import get_incremental_watermark
    assert get_incremental_watermark(lookback_days=7) == date(2026, 4, 30)


# Use freeze_time as context manager for more control
def test_watermark_crosses_month_boundary():
    with freeze_time("2026-05-03"):
        from src.partitioner import get_incremental_watermark
        result = get_incremental_watermark(lookback_days=5)
        assert result == date(2026, 4, 28)

freezegun with tick=True (Real-Time Advancement)

Python
from freezegun import freeze_time
import time


@freeze_time("2026-05-07", tick=True)
def test_pipeline_respects_timeout():
    """
    With tick=True, time advances normally from the frozen start point.
    Use this to test timeout logic without actually waiting.
    """
    from src.pipeline import run_with_timeout

    start = datetime.now()
    # Simulate a pipeline that checks elapsed time
    result = run_with_timeout(max_seconds=0.1)

    assert result["timed_out"] is True

HTTP Mocking with responses

For pipelines that call external REST APIs (data providers, internal microservices), use the responses library to mock HTTP calls.

Bash
pip install responses
Python
# src/api_client.py
import requests
from typing import List, Dict, Any


class WeatherAPIClient:
    BASE_URL = "https://api.weatherprovider.io/v2"

    def __init__(self, api_key: str):
        self.api_key = api_key
        self.session = requests.Session()
        self.session.headers["Authorization"] = f"Bearer {api_key}"

    def get_daily_observations(self, station_id: str, start_date: str, end_date: str) -> List[Dict]:
        """Fetch daily weather observations — paginated API."""
        results = []
        page = 1

        while True:
            response = self.session.get(
                f"{self.BASE_URL}/observations",
                params={
                    "station_id": station_id,
                    "start_date": start_date,
                    "end_date": end_date,
                    "page": page,
                    "page_size": 100,
                },
            )
            response.raise_for_status()
            data = response.json()
            results.extend(data["observations"])

            if not data.get("has_next_page"):
                break
            page += 1

        return results
Python
# tests/unit/test_api_client.py
import pytest
import responses as responses_lib  # Renamed to avoid shadowing
from responses import matchers


@responses_lib.activate
def test_get_daily_observations_single_page():
    from src.api_client import WeatherAPIClient

    responses_lib.add(
        method=responses_lib.GET,
        url="https://api.weatherprovider.io/v2/observations",
        json={
            "observations": [
                {"date": "2026-05-01", "temp_c": 18.5, "precip_mm": 0.0},
                {"date": "2026-05-02", "temp_c": 21.2, "precip_mm": 2.5},
            ],
            "has_next_page": False,
        },
        status=200,
    )

    client = WeatherAPIClient(api_key="test-key-123")
    result = client.get_daily_observations("STATION-001", "2026-05-01", "2026-05-02")

    assert len(result) == 2
    assert result[0]["date"] == "2026-05-01"
    assert result[1]["temp_c"] == 21.2


@responses_lib.activate
def test_get_daily_observations_paginates_correctly():
    """Verify pagination: client fetches all pages until has_next_page is False."""
    responses_lib.add(
        responses_lib.GET,
        "https://api.weatherprovider.io/v2/observations",
        json={
            "observations": [{"date": "2026-05-01", "temp_c": 18.0}] * 100,
            "has_next_page": True,
        },
        status=200,
    )
    responses_lib.add(
        responses_lib.GET,
        "https://api.weatherprovider.io/v2/observations",
        json={
            "observations": [{"date": "2026-06-10", "temp_c": 25.0}] * 50,
            "has_next_page": False,
        },
        status=200,
    )

    from src.api_client import WeatherAPIClient
    client = WeatherAPIClient(api_key="test-key")
    result = client.get_daily_observations("STATION-001", "2026-05-01", "2026-06-10")

    assert len(result) == 150  # 100 + 50
    assert len(responses_lib.calls) == 2  # Exactly two HTTP calls made


@responses_lib.activate
def test_get_daily_observations_raises_on_401():
    responses_lib.add(
        responses_lib.GET,
        "https://api.weatherprovider.io/v2/observations",
        status=401,
        json={"error": "Invalid API key"},
    )

    from src.api_client import WeatherAPIClient
    import requests

    client = WeatherAPIClient(api_key="bad-key")
    with pytest.raises(requests.HTTPError):
        client.get_daily_observations("STATION-001", "2026-05-01", "2026-05-02")


@responses_lib.activate
def test_get_daily_observations_raises_on_server_error():
    responses_lib.add(
        responses_lib.GET,
        "https://api.weatherprovider.io/v2/observations",
        status=500,
        json={"error": "Internal server error"},
    )

    from src.api_client import WeatherAPIClient
    import requests

    client = WeatherAPIClient(api_key="test-key")
    with pytest.raises(requests.HTTPError) as exc_info:
        client.get_daily_observations("STATION-001", "2026-05-01", "2026-05-02")

    assert exc_info.value.response.status_code == 500

Using respx for httpx-Based Clients

If your pipeline uses httpx (async or sync) instead of requests:

Bash
pip install respx httpx
Python
# tests/unit/test_httpx_client.py
import pytest
import httpx
import respx


@respx.mock
def test_httpx_client_fetches_data():
    respx.get("https://api.provider.io/data").mock(
        return_value=httpx.Response(200, json={"records": [1, 2, 3]})
    )

    from src.httpx_client import DataProviderClient
    client = DataProviderClient()
    result = client.fetch()

    assert result["records"] == [1, 2, 3]


@pytest.mark.asyncio
@respx.mock
async def test_async_client():
    respx.get("https://api.provider.io/data").mock(
        return_value=httpx.Response(200, json={"records": [1, 2, 3]})
    )

    from src.httpx_client import AsyncDataProviderClient
    async with AsyncDataProviderClient() as client:
        result = await client.fetch()

    assert result["records"] == [1, 2, 3]

Combining mocker and monkeypatch

Sometimes you need both: monkeypatch for env vars and mocker for the actual service call.

Python
# src/pipeline.py
import os
import snowflake.connector


class IncrementalPipeline:
    def run(self, date: str) -> dict:
        conn = snowflake.connector.connect(
            user=os.environ["SNOWFLAKE_USER"],
            password=os.environ["SNOWFLAKE_PASSWORD"],
            account=os.environ["SNOWFLAKE_ACCOUNT"],
        )
        cursor = conn.cursor()
        cursor.execute(
            "SELECT COUNT(*) FROM orders WHERE order_date = %s",
            (date,)
        )
        count = cursor.fetchone()[0]
        conn.close()
        return {"date": date, "order_count": count}
Python
def test_incremental_pipeline_run(mocker, monkeypatch):
    # Set required env vars
    monkeypatch.setenv("SNOWFLAKE_USER", "testuser")
    monkeypatch.setenv("SNOWFLAKE_PASSWORD", "testpass")
    monkeypatch.setenv("SNOWFLAKE_ACCOUNT", "testaccount.us-east-1")

    # Mock the Snowflake connection
    mock_cursor = mocker.MagicMock()
    mock_cursor.fetchone.return_value = (42,)

    mock_conn = mocker.MagicMock()
    mock_conn.cursor.return_value = mock_cursor

    mocker.patch(
        "src.pipeline.snowflake.connector.connect",
        return_value=mock_conn,
    )

    from src.pipeline import IncrementalPipeline
    pipeline = IncrementalPipeline()
    result = pipeline.run("2026-05-07")

    assert result == {"date": "2026-05-07", "order_count": 42}
    mock_conn.close.assert_called_once()

Verifying Mock Calls

Always assert that mocks were called correctly — otherwise your test only proves the function didn't crash, not that it behaved correctly.

Python
def test_upload_pipeline_calls_s3_once_per_partition(mocker):
    mock_boto3 = mocker.patch("src.writers.boto3")
    mock_s3 = mocker.MagicMock()
    mock_boto3.client.return_value = mock_s3

    from src.writers import upload_partitioned_output
    df = pd.DataFrame({"date": ["2026-05-01", "2026-05-02"], "value": [1, 2]})
    upload_partitioned_output(df, bucket="my-bucket", partition_col="date")

    # Verify S3 client was created once
    mock_boto3.client.assert_called_once_with("s3")

    # Verify put_object was called twice (once per partition)
    assert mock_s3.put_object.call_count == 2

    # Verify the correct keys were used
    call_args_list = mock_s3.put_object.call_args_list
    keys = [call[1]["Key"] for call in call_args_list]
    assert any("2026-05-01" in k for k in keys)
    assert any("2026-05-02" in k for k in keys)

    # Verify no other unexpected methods were called
    mock_s3.delete_object.assert_not_called()

Mock Assertion Reference

Python
from unittest.mock import MagicMock, call

m = MagicMock()

# Basic call assertions
m.assert_called()                     # Called at least once
m.assert_called_once()               # Called exactly once
m.assert_not_called()                # Never called
m.assert_called_with(arg1, key=val)  # Last call used these args
m.assert_called_once_with(arg1)      # Called once with these args

# Call count
assert m.call_count == 3

# Inspect all calls
assert m.call_args_list == [
    call("first"),
    call("second"),
    call("third"),
]

# Access last call args
last_call_args, last_call_kwargs = m.call_args
assert last_call_kwargs["Bucket"] == "my-bucket"

# Any-order call assertion
assert call("second") in m.call_args_list

When NOT to Mock

Mock too aggressively and your tests pass while the real system breaks. Guidelines for not mocking:

  1. Pure functions: No side effects, no external calls — test with real inputs/outputs
  2. pandas/numpy operations: Use real DataFrames; they are fast and deterministic
  3. Configuration parsing: Use real config files in tests/fixtures/
  4. SQL logic: Use DuckDB or SQLite in-memory — mock-based SQL tests miss query bugs
  5. Business logic: The heart of your pipeline should have zero mocks in unit tests
Python
# DO NOT mock this — test the real logic
def test_revenue_tier_logic():
    from src.enrichment import assign_revenue_tier
    assert assign_revenue_tier(50.0) == "low"
    assert assign_revenue_tier(500.0) == "medium"
    assert assign_revenue_tier(5000.0) == "high"

# MOCK this — the database is external
def test_enrichment_fetches_correct_tier(mocker):
    mocker.patch("src.enrichment.get_revenue_tier_from_db", return_value="medium")
    ...

Summary

  • Patch at the import site: "src.module.boto3", not "boto3.client"
  • Use create_autospec or spec=True to catch interface mismatches at mock time
  • side_effect handles sequences of return values and exception injection
  • pytest-mock's mocker fixture is cleaner than context managers — prefer it
  • monkeypatch is the right tool for env vars, module attributes, and builtins
  • freezegun makes time-sensitive pipeline logic fully deterministic
  • responses or respx mock HTTP calls without requiring real network access
  • Always assert that mocks were called correctly, not just that no exception was raised
  • Do not mock pure functions or pandas operations — use real implementations

The next lesson covers testing complete pipelines: pandas assertions, exception handling, CLI testing, FastAPI endpoints, testcontainers, and CI integration.

Enjoyed this article?

Explore the Data Engineering learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.