pytest Mocking and Patching for Data Pipeline Tests
Master unittest.mock, pytest-mock, monkeypatch, freezegun, and HTTP mocking to isolate external dependencies in Snowflake, S3, REST API, and datetime-sensitive pipeline tests.
pytest Mocking and Patching for Data Pipeline Tests
Data pipelines touch everything: cloud data warehouses, object storage, REST APIs, message queues, and clocks. Running tests against real infrastructure is slow, expensive, and fragile. Mocking replaces those dependencies with controlled stand-ins. This lesson covers every mocking tool you need, with realistic examples of the exact scenarios data engineers face.
The Core Question: Mock or Not?
Before writing a mock, ask:
- Is the dependency fast and deterministic? Use the real thing. A pandas transformation, a pure Python function, or an in-memory DuckDB connection do not need mocking.
- Is the dependency slow, external, or stateful? Mock it. Snowflake queries, S3 operations, external HTTP APIs, and
datetime.now()all belong in this category. - Can I use a fast equivalent? Consider it first. SQLite instead of PostgreSQL for schema tests. DuckDB instead of BigQuery for SQL logic tests. A local file instead of S3 for format tests.
The mocking hierarchy from lightest to heaviest:
- Real implementation ā fastest, most confidence
- In-memory equivalent ā DuckDB, SQLite, tmp_path
- monkeypatch ā replace env vars, builtins, module attributes
- unittest.mock / pytest-mock ā replace objects and functions with controllable fakes
- Docker containers ā real services in isolation (testcontainers, next lesson)
unittest.mock Fundamentals
Python's built-in unittest.mock module provides Mock, MagicMock, and patch.
Mock vs MagicMock
from unittest.mock import Mock, MagicMock
# Mock: basic mock object
m = Mock()
m.some_method() # Returns a new Mock (no AttributeError)
m.some_method.return_value = 42
assert m.some_method() == 42
# MagicMock: Mock + magic method support (__len__, __iter__, __enter__, etc.)
mm = MagicMock()
len(mm) # Works ā returns 0 by default
with mm: # Works ā __enter__/__exit__ are pre-configured
pass
# In practice: use MagicMock for context managers, iterables
# Use Mock when you want stricter behavior (magic methods raise TypeError)spec=True: Catch Typos in Your Mocks
Without spec, you can call any attribute on a mock without error. This hides bugs:
from unittest.mock import MagicMock, create_autospec
# Without spec: typo goes unnoticed
conn = MagicMock()
conn.exeucte("SELECT 1") # Typo: "exeucte" instead of "execute"
# No error! The mock happily creates conn.exeucte as a new attribute.
# With spec: typo raises AttributeError
import psycopg2
conn = create_autospec(psycopg2.extensions.connection)
conn.exeucte("SELECT 1")
# AttributeError: Mock object has no attribute 'exeucte'Always use spec=True or create_autospec when mocking objects you own or know the interface of.
side_effect: Sequences and Exceptions
side_effect is more powerful than return_value:
from unittest.mock import Mock
# Return different values on successive calls
paginator = Mock()
paginator.next_page.side_effect = [
{"data": [1, 2, 3], "has_more": True},
{"data": [4, 5, 6], "has_more": True},
{"data": [7, 8], "has_more": False},
]
assert paginator.next_page()["data"] == [1, 2, 3]
assert paginator.next_page()["data"] == [4, 5, 6]
assert paginator.next_page()["data"] == [7, 8]
# Raise an exception on the Nth call
api_client = Mock()
api_client.fetch.side_effect = [
{"result": "ok"},
{"result": "ok"},
ConnectionError("Simulated network failure"),
]
api_client.fetch() # {"result": "ok"}
api_client.fetch() # {"result": "ok"}
# api_client.fetch() # Raises ConnectionError
# Side effect as a function ā called with the same args as the mock
def validate_query(sql):
if "DROP" in sql.upper():
raise PermissionError("DROP statements not allowed")
return [("row1",), ("row2",)]
cursor = Mock()
cursor.execute.side_effect = validate_query
cursor.execute("SELECT 1") # Fine
cursor.execute("DROP TABLE users") # Raises PermissionErrorreturn_value Chaining
# Deep chaining for object hierarchies
from unittest.mock import MagicMock
snowflake_connector = MagicMock()
snowflake_connector.connect.return_value.cursor.return_value.fetchall.return_value = [
("2026-01-01", 100.0),
("2026-01-02", 200.0),
]
# Access pattern: connector.connect().cursor().fetchall()
conn = snowflake_connector.connect()
cursor = conn.cursor()
rows = cursor.fetchall()
assert rows == [("2026-01-01", 100.0), ("2026-01-02", 200.0)]unittest.mock.patch
patch replaces an object in the module under test for the duration of the test.
Critical rule: patch where the name is used, not where it is defined.
# src/readers.py
import boto3 # boto3 defined in boto3 package
def list_s3_files(bucket: str, prefix: str) -> list:
s3 = boto3.client("s3") # boto3 used here, in src.readers
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
return [obj["Key"] for obj in response.get("Contents", [])]# tests/unit/test_readers.py
from unittest.mock import patch, MagicMock
# CORRECT: patch boto3 as it is used in src.readers
def test_list_s3_files_returns_keys():
with patch("src.readers.boto3") as mock_boto3:
mock_client = MagicMock()
mock_boto3.client.return_value = mock_client
mock_client.list_objects_v2.return_value = {
"Contents": [
{"Key": "raw/orders/part-001.parquet"},
{"Key": "raw/orders/part-002.parquet"},
]
}
from src.readers import list_s3_files
result = list_s3_files(bucket="my-bucket", prefix="raw/orders/")
assert len(result) == 2
assert "raw/orders/part-001.parquet" in result
mock_client.list_objects_v2.assert_called_once_with(
Bucket="my-bucket",
Prefix="raw/orders/"
)
# WRONG: patching boto3 at its source ā has no effect on src.readers
def test_wrong_patch_location():
with patch("boto3.client") as mock_client:
# This patch does NOT affect the 'import boto3' in src.readers
...patch as decorator
from unittest.mock import patch, MagicMock
import pytest
@patch("src.readers.boto3")
def test_list_s3_files_empty_prefix(mock_boto3):
mock_client = MagicMock()
mock_boto3.client.return_value = mock_client
mock_client.list_objects_v2.return_value = {} # No Contents key
from src.readers import list_s3_files
result = list_s3_files(bucket="my-bucket", prefix="empty/")
assert result == []
# Multiple patches ā innermost decorator = first argument
@patch("src.pipeline.boto3")
@patch("src.pipeline.snowflake_connector")
def test_pipeline_run(mock_snowflake, mock_boto3):
# mock_snowflake corresponds to the innermost @patch (closest to function)
# mock_boto3 corresponds to the outermost @patch
...pytest-mock: The mocker Fixture
pytest-mock wraps unittest.mock in a pytest fixture called mocker. It is cleaner than with patch(...) blocks and automatically unpatches after the test.
pip install pytest-mockmocker.patch vs patch
# unittest.mock style ā manual context manager
def test_with_unittest_mock():
with patch("src.readers.boto3") as mock_boto3:
mock_boto3.client.return_value.list_objects_v2.return_value = {"Contents": []}
... # patch active here
# patch removed here
# pytest-mock style ā cleaner, auto-cleanup
def test_with_pytest_mock(mocker):
mock_boto3 = mocker.patch("src.readers.boto3")
mock_boto3.client.return_value.list_objects_v2.return_value = {"Contents": []}
...
# patch automatically removed after test ā no context manager neededMocking a Snowflake Connection
# src/snowflake_reader.py
import snowflake.connector
def fetch_recent_orders(days: int = 7) -> list:
"""Fetch orders from the last N days from Snowflake."""
conn = snowflake.connector.connect(
user=os.environ["SNOWFLAKE_USER"],
password=os.environ["SNOWFLAKE_PASSWORD"],
account=os.environ["SNOWFLAKE_ACCOUNT"],
warehouse="COMPUTE_WH",
database="ANALYTICS",
schema="SALES",
)
try:
cursor = conn.cursor()
cursor.execute(f"""
SELECT order_id, customer_id, revenue, order_date
FROM orders
WHERE order_date >= DATEADD(day, -{days}, CURRENT_DATE())
""")
return cursor.fetchall()
finally:
conn.close()# tests/unit/test_snowflake_reader.py
import pytest
@pytest.fixture
def mock_snowflake_rows():
return [
("ORD-001", "C001", 150.0, "2026-04-30"),
("ORD-002", "C002", 250.0, "2026-05-01"),
("ORD-003", "C001", 75.0, "2026-05-02"),
]
def test_fetch_recent_orders_returns_rows(mocker, mock_snowflake_rows):
# Patch the connector at the point of use
mock_connect = mocker.patch("src.snowflake_reader.snowflake.connector.connect")
mock_cursor = mocker.MagicMock()
mock_cursor.fetchall.return_value = mock_snowflake_rows
mock_connect.return_value.cursor.return_value = mock_cursor
from src.snowflake_reader import fetch_recent_orders
result = fetch_recent_orders(days=7)
assert len(result) == 3
assert result[0][0] == "ORD-001"
def test_fetch_recent_orders_uses_correct_days_parameter(mocker):
mock_connect = mocker.patch("src.snowflake_reader.snowflake.connector.connect")
mock_cursor = mocker.MagicMock()
mock_cursor.fetchall.return_value = []
mock_connect.return_value.cursor.return_value = mock_cursor
from src.snowflake_reader import fetch_recent_orders
fetch_recent_orders(days=30)
executed_sql = mock_cursor.execute.call_args[0][0]
assert "30" in executed_sql
def test_fetch_recent_orders_closes_connection_on_error(mocker):
"""Connection must be closed even if the query raises an exception."""
mock_connect = mocker.patch("src.snowflake_reader.snowflake.connector.connect")
mock_cursor = mocker.MagicMock()
mock_cursor.execute.side_effect = Exception("Snowflake query failed")
mock_connect.return_value.cursor.return_value = mock_cursor
from src.snowflake_reader import fetch_recent_orders
with pytest.raises(Exception, match="Snowflake query failed"):
fetch_recent_orders(days=7)
# Verify connection was still closed
mock_connect.return_value.close.assert_called_once()Mocking an S3 Upload
# src/writers.py
import boto3
import logging
from io import BytesIO
logger = logging.getLogger(__name__)
def upload_dataframe_to_s3(df, bucket: str, key: str, format: str = "parquet") -> str:
"""Upload a DataFrame to S3 in the specified format. Returns the S3 URI."""
import pandas as pd
buffer = BytesIO()
if format == "parquet":
df.to_parquet(buffer, index=False)
elif format == "csv":
df.to_csv(buffer, index=False)
else:
raise ValueError(f"Unsupported format: {format}")
buffer.seek(0)
s3 = boto3.client("s3")
s3.put_object(Bucket=bucket, Key=key, Body=buffer.getvalue())
uri = f"s3://{bucket}/{key}"
logger.info(f"Uploaded {len(df)} rows to {uri}")
return uri# tests/unit/test_writers.py
import pytest
import pandas as pd
@pytest.fixture
def sample_df():
return pd.DataFrame({
"order_id": ["O1", "O2"],
"revenue": [100.0, 200.0],
})
def test_upload_dataframe_returns_s3_uri(mocker, sample_df):
mock_s3 = mocker.MagicMock()
mocker.patch("src.writers.boto3").client.return_value = mock_s3
from src.writers import upload_dataframe_to_s3
result = upload_dataframe_to_s3(sample_df, bucket="my-bucket", key="output/data.parquet")
assert result == "s3://my-bucket/output/data.parquet"
def test_upload_calls_put_object_with_correct_bucket_and_key(mocker, sample_df):
mock_boto3 = mocker.patch("src.writers.boto3")
mock_s3_client = mocker.MagicMock()
mock_boto3.client.return_value = mock_s3_client
from src.writers import upload_dataframe_to_s3
upload_dataframe_to_s3(sample_df, bucket="analytics-bucket", key="raw/orders.parquet")
mock_s3_client.put_object.assert_called_once()
call_kwargs = mock_s3_client.put_object.call_args[1]
assert call_kwargs["Bucket"] == "analytics-bucket"
assert call_kwargs["Key"] == "raw/orders.parquet"
assert len(call_kwargs["Body"]) > 0 # Non-empty bytes
def test_upload_rejects_unsupported_format(mocker, sample_df):
mocker.patch("src.writers.boto3") # Prevent real AWS calls
from src.writers import upload_dataframe_to_s3
with pytest.raises(ValueError, match="Unsupported format: json"):
upload_dataframe_to_s3(sample_df, bucket="b", key="k", format="json")
def test_upload_csv_format(mocker, sample_df):
mock_boto3 = mocker.patch("src.writers.boto3")
mock_s3 = mocker.MagicMock()
mock_boto3.client.return_value = mock_s3
from src.writers import upload_dataframe_to_s3
upload_dataframe_to_s3(sample_df, bucket="b", key="output.csv", format="csv")
call_kwargs = mock_s3.put_object.call_args[1]
# Verify the body is CSV content (starts with column header)
body_text = call_kwargs["Body"].decode("utf-8")
assert "order_id,revenue" in body_textmonkeypatch: The Surgical Tool
monkeypatch is pytest's built-in fixture for targeted replacements. It is simpler than patch for straightforward cases and automatically reverts all changes after the test.
Patching Environment Variables
# src/config.py
import os
def get_database_url() -> str:
url = os.environ.get("DATABASE_URL")
if not url:
raise EnvironmentError("DATABASE_URL environment variable is not set")
return url
def get_batch_size() -> int:
return int(os.environ.get("PIPELINE_BATCH_SIZE", "1000"))# tests/unit/test_config.py
def test_get_database_url_reads_env_var(monkeypatch):
monkeypatch.setenv("DATABASE_URL", "postgresql://localhost/testdb")
from src.config import get_database_url
assert get_database_url() == "postgresql://localhost/testdb"
def test_get_database_url_raises_when_unset(monkeypatch):
monkeypatch.delenv("DATABASE_URL", raising=False) # Remove if exists
from src.config import get_database_url
with pytest.raises(EnvironmentError, match="DATABASE_URL"):
get_database_url()
def test_get_batch_size_uses_default(monkeypatch):
monkeypatch.delenv("PIPELINE_BATCH_SIZE", raising=False)
from src.config import get_batch_size
assert get_batch_size() == 1000
def test_get_batch_size_reads_custom_value(monkeypatch):
monkeypatch.setenv("PIPELINE_BATCH_SIZE", "5000")
from src.config import get_batch_size
assert get_batch_size() == 5000Patching Module Attributes and Functions
# src/utils.py
import uuid
import hashlib
def generate_record_id() -> str:
return str(uuid.uuid4())
def hash_pii(value: str) -> str:
return hashlib.sha256(value.encode()).hexdigest()[:16]def test_generate_record_id_uses_uuid(monkeypatch):
import uuid as uuid_module
fixed_uuid = uuid_module.UUID("12345678-1234-5678-1234-567812345678")
monkeypatch.setattr("src.utils.uuid.uuid4", lambda: fixed_uuid)
from src.utils import generate_record_id
result = generate_record_id()
assert result == "12345678-1234-5678-1234-567812345678"
def test_pipeline_uses_deterministic_ids_in_test(monkeypatch):
"""Make record IDs deterministic for assertion-heavy tests."""
counter = {"n": 0}
def sequential_id():
counter["n"] += 1
return f"TEST-ID-{counter['n']:04d}"
monkeypatch.setattr("src.pipeline.generate_record_id", sequential_id)
from src.pipeline import process_records
result = process_records([{"name": "Alice"}, {"name": "Bob"}])
assert result[0]["record_id"] == "TEST-ID-0001"
assert result[1]["record_id"] == "TEST-ID-0002"Patching Built-in open (File I/O)
# src/readers.py
def read_config_file(path: str) -> dict:
import json
with open(path) as f:
return json.load(f)from unittest.mock import mock_open
import json
def test_read_config_file(monkeypatch):
config_data = {"batch_size": 500, "env": "test"}
mock_file_content = json.dumps(config_data)
monkeypatch.setattr("builtins.open", mock_open(read_data=mock_file_content))
from src.readers import read_config_file
result = read_config_file("/fake/path/config.json")
assert result == config_data
assert result["batch_size"] == 500freezegun: Mocking datetime
Time-sensitive pipeline logic ā daily partitions, SLA windows, incremental watermarks ā is impossible to test without controlling the clock.
pip install freezegun# src/partitioner.py
from datetime import datetime, date
def get_current_partition() -> str:
"""Return the current date partition key: YYYY/MM/DD."""
today = datetime.now().date()
return f"{today.year:04d}/{today.month:02d}/{today.day:02d}"
def is_within_sla(created_at: datetime, sla_hours: int = 4) -> bool:
"""Check if a record is within its SLA window."""
now = datetime.now()
age = now - created_at
return age.total_seconds() / 3600 <= sla_hours
def get_incremental_watermark(lookback_days: int = 1) -> date:
"""Return the date N days ago for incremental pipeline runs."""
from datetime import timedelta
return datetime.now().date() - timedelta(days=lookback_days)# tests/unit/test_partitioner.py
import pytest
from freezegun import freeze_time
from datetime import datetime, date
@freeze_time("2026-05-07 14:30:00")
def test_get_current_partition():
from src.partitioner import get_current_partition
assert get_current_partition() == "2026/05/07"
@freeze_time("2026-05-07 00:00:00")
def test_get_current_partition_midnight():
from src.partitioner import get_current_partition
assert get_current_partition() == "2026/05/07"
@pytest.mark.parametrize("created_at_str,sla_hours,expected", [
("2026-05-07 12:00:00", 4, True), # 2 hours ago, 4h SLA
("2026-05-07 09:00:00", 4, False), # 5 hours ago, 4h SLA
("2026-05-07 10:00:01", 4, True), # Just under 4h, 4h SLA
("2026-05-07 10:00:00", 4, True), # Exactly 4h ā within SLA
("2026-05-07 14:00:00", 1, False), # 30 min ago, 0.5h SLA ā wait no
])
@freeze_time("2026-05-07 14:00:00")
def test_is_within_sla(created_at_str, sla_hours, expected):
from src.partitioner import is_within_sla
created_at = datetime.strptime(created_at_str, "%Y-%m-%d %H:%M:%S")
assert is_within_sla(created_at, sla_hours) == expected
@freeze_time("2026-05-07")
def test_incremental_watermark_lookback_1_day():
from src.partitioner import get_incremental_watermark
assert get_incremental_watermark(lookback_days=1) == date(2026, 5, 6)
@freeze_time("2026-05-07")
def test_incremental_watermark_lookback_7_days():
from src.partitioner import get_incremental_watermark
assert get_incremental_watermark(lookback_days=7) == date(2026, 4, 30)
# Use freeze_time as context manager for more control
def test_watermark_crosses_month_boundary():
with freeze_time("2026-05-03"):
from src.partitioner import get_incremental_watermark
result = get_incremental_watermark(lookback_days=5)
assert result == date(2026, 4, 28)freezegun with tick=True (Real-Time Advancement)
from freezegun import freeze_time
import time
@freeze_time("2026-05-07", tick=True)
def test_pipeline_respects_timeout():
"""
With tick=True, time advances normally from the frozen start point.
Use this to test timeout logic without actually waiting.
"""
from src.pipeline import run_with_timeout
start = datetime.now()
# Simulate a pipeline that checks elapsed time
result = run_with_timeout(max_seconds=0.1)
assert result["timed_out"] is TrueHTTP Mocking with responses
For pipelines that call external REST APIs (data providers, internal microservices), use the responses library to mock HTTP calls.
pip install responses# src/api_client.py
import requests
from typing import List, Dict, Any
class WeatherAPIClient:
BASE_URL = "https://api.weatherprovider.io/v2"
def __init__(self, api_key: str):
self.api_key = api_key
self.session = requests.Session()
self.session.headers["Authorization"] = f"Bearer {api_key}"
def get_daily_observations(self, station_id: str, start_date: str, end_date: str) -> List[Dict]:
"""Fetch daily weather observations ā paginated API."""
results = []
page = 1
while True:
response = self.session.get(
f"{self.BASE_URL}/observations",
params={
"station_id": station_id,
"start_date": start_date,
"end_date": end_date,
"page": page,
"page_size": 100,
},
)
response.raise_for_status()
data = response.json()
results.extend(data["observations"])
if not data.get("has_next_page"):
break
page += 1
return results# tests/unit/test_api_client.py
import pytest
import responses as responses_lib # Renamed to avoid shadowing
from responses import matchers
@responses_lib.activate
def test_get_daily_observations_single_page():
from src.api_client import WeatherAPIClient
responses_lib.add(
method=responses_lib.GET,
url="https://api.weatherprovider.io/v2/observations",
json={
"observations": [
{"date": "2026-05-01", "temp_c": 18.5, "precip_mm": 0.0},
{"date": "2026-05-02", "temp_c": 21.2, "precip_mm": 2.5},
],
"has_next_page": False,
},
status=200,
)
client = WeatherAPIClient(api_key="test-key-123")
result = client.get_daily_observations("STATION-001", "2026-05-01", "2026-05-02")
assert len(result) == 2
assert result[0]["date"] == "2026-05-01"
assert result[1]["temp_c"] == 21.2
@responses_lib.activate
def test_get_daily_observations_paginates_correctly():
"""Verify pagination: client fetches all pages until has_next_page is False."""
responses_lib.add(
responses_lib.GET,
"https://api.weatherprovider.io/v2/observations",
json={
"observations": [{"date": "2026-05-01", "temp_c": 18.0}] * 100,
"has_next_page": True,
},
status=200,
)
responses_lib.add(
responses_lib.GET,
"https://api.weatherprovider.io/v2/observations",
json={
"observations": [{"date": "2026-06-10", "temp_c": 25.0}] * 50,
"has_next_page": False,
},
status=200,
)
from src.api_client import WeatherAPIClient
client = WeatherAPIClient(api_key="test-key")
result = client.get_daily_observations("STATION-001", "2026-05-01", "2026-06-10")
assert len(result) == 150 # 100 + 50
assert len(responses_lib.calls) == 2 # Exactly two HTTP calls made
@responses_lib.activate
def test_get_daily_observations_raises_on_401():
responses_lib.add(
responses_lib.GET,
"https://api.weatherprovider.io/v2/observations",
status=401,
json={"error": "Invalid API key"},
)
from src.api_client import WeatherAPIClient
import requests
client = WeatherAPIClient(api_key="bad-key")
with pytest.raises(requests.HTTPError):
client.get_daily_observations("STATION-001", "2026-05-01", "2026-05-02")
@responses_lib.activate
def test_get_daily_observations_raises_on_server_error():
responses_lib.add(
responses_lib.GET,
"https://api.weatherprovider.io/v2/observations",
status=500,
json={"error": "Internal server error"},
)
from src.api_client import WeatherAPIClient
import requests
client = WeatherAPIClient(api_key="test-key")
with pytest.raises(requests.HTTPError) as exc_info:
client.get_daily_observations("STATION-001", "2026-05-01", "2026-05-02")
assert exc_info.value.response.status_code == 500Using respx for httpx-Based Clients
If your pipeline uses httpx (async or sync) instead of requests:
pip install respx httpx# tests/unit/test_httpx_client.py
import pytest
import httpx
import respx
@respx.mock
def test_httpx_client_fetches_data():
respx.get("https://api.provider.io/data").mock(
return_value=httpx.Response(200, json={"records": [1, 2, 3]})
)
from src.httpx_client import DataProviderClient
client = DataProviderClient()
result = client.fetch()
assert result["records"] == [1, 2, 3]
@pytest.mark.asyncio
@respx.mock
async def test_async_client():
respx.get("https://api.provider.io/data").mock(
return_value=httpx.Response(200, json={"records": [1, 2, 3]})
)
from src.httpx_client import AsyncDataProviderClient
async with AsyncDataProviderClient() as client:
result = await client.fetch()
assert result["records"] == [1, 2, 3]Combining mocker and monkeypatch
Sometimes you need both: monkeypatch for env vars and mocker for the actual service call.
# src/pipeline.py
import os
import snowflake.connector
class IncrementalPipeline:
def run(self, date: str) -> dict:
conn = snowflake.connector.connect(
user=os.environ["SNOWFLAKE_USER"],
password=os.environ["SNOWFLAKE_PASSWORD"],
account=os.environ["SNOWFLAKE_ACCOUNT"],
)
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM orders WHERE order_date = %s",
(date,)
)
count = cursor.fetchone()[0]
conn.close()
return {"date": date, "order_count": count}def test_incremental_pipeline_run(mocker, monkeypatch):
# Set required env vars
monkeypatch.setenv("SNOWFLAKE_USER", "testuser")
monkeypatch.setenv("SNOWFLAKE_PASSWORD", "testpass")
monkeypatch.setenv("SNOWFLAKE_ACCOUNT", "testaccount.us-east-1")
# Mock the Snowflake connection
mock_cursor = mocker.MagicMock()
mock_cursor.fetchone.return_value = (42,)
mock_conn = mocker.MagicMock()
mock_conn.cursor.return_value = mock_cursor
mocker.patch(
"src.pipeline.snowflake.connector.connect",
return_value=mock_conn,
)
from src.pipeline import IncrementalPipeline
pipeline = IncrementalPipeline()
result = pipeline.run("2026-05-07")
assert result == {"date": "2026-05-07", "order_count": 42}
mock_conn.close.assert_called_once()Verifying Mock Calls
Always assert that mocks were called correctly ā otherwise your test only proves the function didn't crash, not that it behaved correctly.
def test_upload_pipeline_calls_s3_once_per_partition(mocker):
mock_boto3 = mocker.patch("src.writers.boto3")
mock_s3 = mocker.MagicMock()
mock_boto3.client.return_value = mock_s3
from src.writers import upload_partitioned_output
df = pd.DataFrame({"date": ["2026-05-01", "2026-05-02"], "value": [1, 2]})
upload_partitioned_output(df, bucket="my-bucket", partition_col="date")
# Verify S3 client was created once
mock_boto3.client.assert_called_once_with("s3")
# Verify put_object was called twice (once per partition)
assert mock_s3.put_object.call_count == 2
# Verify the correct keys were used
call_args_list = mock_s3.put_object.call_args_list
keys = [call[1]["Key"] for call in call_args_list]
assert any("2026-05-01" in k for k in keys)
assert any("2026-05-02" in k for k in keys)
# Verify no other unexpected methods were called
mock_s3.delete_object.assert_not_called()Mock Assertion Reference
from unittest.mock import MagicMock, call
m = MagicMock()
# Basic call assertions
m.assert_called() # Called at least once
m.assert_called_once() # Called exactly once
m.assert_not_called() # Never called
m.assert_called_with(arg1, key=val) # Last call used these args
m.assert_called_once_with(arg1) # Called once with these args
# Call count
assert m.call_count == 3
# Inspect all calls
assert m.call_args_list == [
call("first"),
call("second"),
call("third"),
]
# Access last call args
last_call_args, last_call_kwargs = m.call_args
assert last_call_kwargs["Bucket"] == "my-bucket"
# Any-order call assertion
assert call("second") in m.call_args_listWhen NOT to Mock
Mock too aggressively and your tests pass while the real system breaks. Guidelines for not mocking:
- Pure functions: No side effects, no external calls ā test with real inputs/outputs
- pandas/numpy operations: Use real DataFrames; they are fast and deterministic
- Configuration parsing: Use real config files in
tests/fixtures/ - SQL logic: Use DuckDB or SQLite in-memory ā mock-based SQL tests miss query bugs
- Business logic: The heart of your pipeline should have zero mocks in unit tests
# DO NOT mock this ā test the real logic
def test_revenue_tier_logic():
from src.enrichment import assign_revenue_tier
assert assign_revenue_tier(50.0) == "low"
assert assign_revenue_tier(500.0) == "medium"
assert assign_revenue_tier(5000.0) == "high"
# MOCK this ā the database is external
def test_enrichment_fetches_correct_tier(mocker):
mocker.patch("src.enrichment.get_revenue_tier_from_db", return_value="medium")
...Summary
- Patch at the import site:
"src.module.boto3", not"boto3.client" - Use
create_autospecorspec=Trueto catch interface mismatches at mock time side_effecthandles sequences of return values and exception injectionpytest-mock'smockerfixture is cleaner than context managers ā prefer itmonkeypatchis the right tool for env vars, module attributes, and builtinsfreezegunmakes time-sensitive pipeline logic fully deterministicresponsesorrespxmock HTTP calls without requiring real network access- Always assert that mocks were called correctly, not just that no exception was raised
- Do not mock pure functions or pandas operations ā use real implementations
The next lesson covers testing complete pipelines: pandas assertions, exception handling, CLI testing, FastAPI endpoints, testcontainers, and CI integration.
Enjoyed this article?
Explore the Data Engineering learning path for more.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.