Back to blog
Data Engineeringadvanced

MLflow, Unity Catalog & Feature Store on Databricks

End-to-end ML on Databricks β€” MLflow experiment tracking with autolog, model registry, REST endpoint serving, Unity Catalog governance with GRANT/REVOKE and row/column security, Delta Sharing, and Feature Store workflows.

LearnixoMay 7, 202614 min read
DatabricksMLflowUnity CatalogFeature StoreMLOpsModel ServingData Governance
Share:𝕏

The ML Platform Problem

Training a model in a notebook is easy. Knowing which model is in production, why it was chosen over the last version, who deployed it, where it gets its features, and who has access to the underlying data β€” that's the hard part. Databricks solves this with three tightly integrated systems:

  • MLflow β€” experiment tracking, model registry, and serving
  • Unity Catalog β€” three-level namespace, access control, and data lineage across all data assets including models
  • Feature Store β€” compute features once, use them in training and real-time inference with automatic lookup

This guide covers each in depth, then wires them together into a complete ML workflow.


MLflow on Databricks: Managed Tracking Server

Databricks includes a fully managed MLflow tracking server. There is no setup β€” experiments, runs, and artifacts are automatically stored in the workspace and linked to the notebooks and jobs that created them.

Experiment Setup

Python
import mlflow
import mlflow.sklearn
from mlflow.models.signature import infer_signature

# Set the experiment (created automatically if it doesn't exist)
# Workspace path format: /Users/<email>/experiment-name
# Shared format: /Shared/team/experiment-name
mlflow.set_experiment("/Shared/churn-prediction/churn-model-v2")

# Get the experiment object (useful for querying past runs)
experiment = mlflow.get_experiment_by_name("/Shared/churn-prediction/churn-model-v2")
print(f"Experiment ID: {experiment.experiment_id}")

mlflow.start_run(): The Context Manager

Python
import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import LabelEncoder

# Load features (from Feature Store β€” covered later)
df = spark.table("catalog.features.customer_features").toPandas()

X = df.drop(columns=["customer_id", "churned"])
y = df["churned"].astype(int)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

with mlflow.start_run(run_name="gbm-baseline") as run:
    run_id = run.info.run_id

    # ── Log hyperparameters ───────────────────────────────────────────────────
    params = {
        "n_estimators":   300,
        "max_depth":      5,
        "learning_rate":  0.05,
        "subsample":      0.8,
        "random_state":   42,
    }
    mlflow.log_params(params)

    # ── Train ─────────────────────────────────────────────────────────────────
    model = GradientBoostingClassifier(**params)
    model.fit(X_train, y_train)

    # ── Log metrics ───────────────────────────────────────────────────────────
    y_prob = model.predict_proba(X_test)[:, 1]
    y_pred = model.predict(X_test)

    metrics = {
        "auc":       roc_auc_score(y_test, y_prob),
        "precision": precision_score(y_test, y_pred),
        "recall":    recall_score(y_test, y_pred),
        "f1":        f1_score(y_test, y_pred),
    }
    mlflow.log_metrics(metrics)
    print(f"AUC: {metrics['auc']:.4f} | F1: {metrics['f1']:.4f}")

    # ── Log artifacts ─────────────────────────────────────────────────────────
    # Feature importance plot
    import matplotlib.pyplot as plt
    import numpy as np

    fig, ax = plt.subplots(figsize=(10, 6))
    importances = pd.Series(model.feature_importances_, index=X.columns)
    importances.nlargest(15).sort_values().plot(kind="barh", ax=ax)
    ax.set_title("Top 15 Feature Importances")
    mlflow.log_figure(fig, "feature_importance.png")
    plt.close()

    # Log training data schema as a text artifact
    mlflow.log_text(str(X.dtypes.to_dict()), "feature_schema.txt")

    # ── Log model with signature ───────────────────────────────────────────────
    signature = infer_signature(X_train, y_pred)
    input_example = X_train.iloc[:5]

    mlflow.sklearn.log_model(
        sk_model=model,
        artifact_path="model",
        signature=signature,
        input_example=input_example,
        registered_model_name=None,   # register separately below
    )

    # ── Log tags for searchability ─────────────────────────────────────────────
    mlflow.set_tags({
        "team":         "data-science",
        "model_type":   "gradient_boosting",
        "data_version": "2026-05",
        "feature_set":  "v3",
    })

print(f"Run logged: {run_id}")

Autolog: Zero-Code Tracking

MLflow autolog captures parameters, metrics, and the model automatically for supported frameworks β€” no manual log_param calls needed.

Python
# sklearn autolog
mlflow.sklearn.autolog(
    log_input_examples=True,
    log_model_signatures=True,
    log_models=True,
    silent=False,
)

with mlflow.start_run(run_name="autolog-gbm"):
    model = GradientBoostingClassifier(n_estimators=300, max_depth=5)
    model.fit(X_train, y_train)
    # MLflow automatically logs: all constructor params, train/test metrics, CV results, model artifact

# XGBoost autolog
import xgboost as xgb
mlflow.xgboost.autolog()

with mlflow.start_run(run_name="autolog-xgb"):
    dtrain = xgb.DMatrix(X_train, label=y_train)
    dtest  = xgb.DMatrix(X_test,  label=y_test)
    params = {"max_depth": 6, "eta": 0.05, "objective": "binary:logistic", "eval_metric": "auc"}
    model_xgb = xgb.train(
        params, dtrain,
        num_boost_round=300,
        evals=[(dtest, "test")],
        verbose_eval=50,
    )
    # MLflow logs: all params, per-round eval metrics (learning curve), model artifact

# PyTorch autolog
import torch
mlflow.pytorch.autolog()
# Captures: optimizer params, per-epoch loss/accuracy, model artifact at best checkpoint

Model Registry: Promotion Workflow

The MLflow Model Registry in Databricks (backed by Unity Catalog) tracks model versions and their lifecycle stage β€” from experiment to production.

Registering a Model to Unity Catalog

Python
# Point MLflow at the Unity Catalog registry (not the legacy workspace registry)
mlflow.set_registry_uri("databricks-uc")

# Register the model from a completed run
model_uri = f"runs:/{run_id}/model"

registered_model = mlflow.register_model(
    model_uri=model_uri,
    name="main.ml_models.churn_predictor",   # catalog.schema.model_name
)
print(f"Registered version: {registered_model.version}")

Adding Descriptions and Aliases

Python
from mlflow import MlflowClient

client = MlflowClient()

# Add a description to the version
client.update_model_version(
    name="main.ml_models.churn_predictor",
    version=registered_model.version,
    description=(
        "GBM churn predictor trained on May 2026 cohort. "
        "AUC=0.847, F1=0.761. Trained on features: v3."
    )
)

# Set aliases (Unity Catalog style replaces stages)
client.set_registered_model_alias(
    name="main.ml_models.churn_predictor",
    alias="champion",
    version=registered_model.version,
)

# Load model by alias
champion_model = mlflow.sklearn.load_model(
    "models:/main.ml_models.churn_predictor@champion"
)

Querying the Registry

Python
# List all versions of a model
versions = client.search_model_versions("name='main.ml_models.churn_predictor'")
for v in versions:
    aliases = ", ".join(v.aliases) if v.aliases else "none"
    print(f"v{v.version} | {v.run_id[:8]} | aliases: {aliases} | {v.description[:50]}")

# Compare the champion against a challenger run
champion_run = client.get_run(
    client.get_model_version_by_alias(
        "main.ml_models.churn_predictor", "champion"
    ).run_id
)
print("Champion metrics:", champion_run.data.metrics)

Model Serving: REST Endpoint Deployment

Databricks Model Serving creates a scalable REST endpoint backed by the model version in the registry. It handles versioning, traffic splitting, and auto-scaling with no infrastructure management.

Creating a Serving Endpoint

Python
import requests
import json

DATABRICKS_HOST  = "https://adb-1234567890.azuredatabricks.net"
DATABRICKS_TOKEN = dbutils.secrets.get("databricks-secrets", "pat-token")

endpoint_config = {
    "name": "churn-predictor-endpoint",
    "config": {
        "served_models": [
            {
                "name":                     "churn-v3",
                "model_name":               "main.ml_models.churn_predictor",
                "model_version":            str(registered_model.version),
                "workload_size":            "Small",    # Small / Medium / Large
                "scale_to_zero_enabled":    True,       # cost: scale down when idle
            }
        ],
        "traffic_config": {
            "routes": [{"served_model_name": "churn-v3", "traffic_percentage": 100}]
        }
    }
}

response = requests.post(
    f"{DATABRICKS_HOST}/api/2.0/serving-endpoints",
    headers={"Authorization": f"Bearer {DATABRICKS_TOKEN}"},
    json=endpoint_config
)
response.raise_for_status()
print(f"Endpoint created: {response.json()['name']}")

Querying the Serving Endpoint

Python
import pandas as pd

# Prepare a batch of customers to score
customers_to_score = pd.DataFrame({
    "tenure_months":       [12, 3, 48, 1],
    "monthly_charges":     [65.0, 120.0, 45.0, 85.0],
    "total_orders":        [8, 2, 35, 0],
    "days_since_last_order": [7, 45, 3, 90],
    "support_tickets":     [1, 4, 0, 6],
})

payload = {"dataframe_records": customers_to_score.to_dict(orient="records")}

response = requests.post(
    f"{DATABRICKS_HOST}/serving-endpoints/churn-predictor-endpoint/invocations",
    headers={
        "Authorization":  f"Bearer {DATABRICKS_TOKEN}",
        "Content-Type":   "application/json",
    },
    json=payload,
)
response.raise_for_status()

predictions = response.json()["predictions"]
print(pd.DataFrame({"churn_probability": predictions}))

A/B Testing with Traffic Splitting

Python
# Update endpoint to split traffic: 80% champion, 20% challenger
traffic_update = {
    "served_models": [
        {
            "name":          "churn-v3-champion",
            "model_name":    "main.ml_models.churn_predictor",
            "model_version": "5",
            "workload_size": "Small",
        },
        {
            "name":          "churn-v4-challenger",
            "model_name":    "main.ml_models.churn_predictor",
            "model_version": "6",
            "workload_size": "Small",
        }
    ],
    "traffic_config": {
        "routes": [
            {"served_model_name": "churn-v3-champion",   "traffic_percentage": 80},
            {"served_model_name": "churn-v4-challenger",  "traffic_percentage": 20},
        ]
    }
}

requests.put(
    f"{DATABRICKS_HOST}/api/2.0/serving-endpoints/churn-predictor-endpoint/config",
    headers={"Authorization": f"Bearer {DATABRICKS_TOKEN}"},
    json=traffic_update,
).raise_for_status()

Unity Catalog: Three-Level Namespace

Unity Catalog replaces the legacy Hive metastore with a centralized governance layer that spans all workspaces in a Databricks account.

catalog           β†’ top-level container (maps to a business unit or environment)
  └── schema      β†’ logical grouping of tables (like a database)
        └── table β†’ Delta table, view, model, volume, or function

Examples:
  prod.silver.orders          β†’ production Silver orders table
  dev.silver.orders           β†’ developer copy of Silver orders
  main.ml_models.churn_predictor  β†’ registered ML model in the model registry
  main.features.customer_features β†’ feature table
SQL
-- Create a catalog
CREATE CATALOG IF NOT EXISTS analytics
  COMMENT 'Analytics team data products';

-- Create a schema
CREATE SCHEMA IF NOT EXISTS analytics.reporting
  COMMENT 'BI-ready Gold tables for dashboard consumption';

-- Create a managed table
CREATE TABLE analytics.reporting.monthly_revenue
USING DELTA
COMMENT 'Monthly revenue summary, refreshed daily'
AS SELECT * FROM prod.gold.monthly_revenue;

Unity Catalog Access Control

GRANT and REVOKE

SQL
-- Grant SELECT on a table to a user and a group
GRANT SELECT ON TABLE prod.silver.customers
    TO `analyst@company.com`;
GRANT SELECT ON TABLE prod.silver.customers
    TO `data-analysts`;

-- Grant access to an entire schema
GRANT USE SCHEMA, SELECT
    ON SCHEMA prod.silver
    TO `data-engineers`;

-- Grant access to all tables in a catalog
GRANT USE CATALOG, USE SCHEMA
    ON CATALOG prod
    TO `data-platform-team`;

-- Revoke access
REVOKE SELECT ON TABLE prod.silver.customers
    FROM `analyst@company.com`;

-- Check current grants on a table
SHOW GRANTS ON TABLE prod.silver.customers;

Column-Level Security: Column Masks

Column masks intercept reads and replace sensitive values based on the current user's group membership β€” without changing the underlying data.

SQL
-- Create a masking function
CREATE OR REPLACE FUNCTION prod.security.mask_email(email STRING)
  RETURNS STRING
  RETURN CASE
    WHEN is_account_group_member('pii-data-access') THEN email
    ELSE CONCAT(LEFT(email, 2), '***@***.***')
  END;

-- Apply the mask to the email column
ALTER TABLE prod.silver.customers
  ALTER COLUMN email
  SET MASK prod.security.mask_email;

-- Now: users in 'pii-data-access' see real emails; everyone else sees masked values
-- The mask is transparent β€” no query changes required for consumers
SQL
-- Remove a column mask
ALTER TABLE prod.silver.customers
  ALTER COLUMN email
  DROP MASK;

Row-Level Security: Row Filters

Row filters restrict which rows each user can see, based on their identity or group membership.

SQL
-- Create a row filter: each user sees only their tenant's data
CREATE OR REPLACE FUNCTION prod.security.tenant_row_filter(tenant_id STRING)
  RETURNS BOOLEAN
  RETURN is_account_group_member('platform-admin')
      OR tenant_id = current_user_tenant();
-- Note: current_user_tenant() is a custom function returning the caller's tenant ID
-- In practice, you'd join against a user→tenant mapping table

-- Apply the filter to the table
ALTER TABLE prod.silver.orders
  SET ROW FILTER prod.security.tenant_row_filter ON (tenant_id);

-- Remove a row filter
ALTER TABLE prod.silver.orders
  DROP ROW FILTER;
SQL
-- Verify what a specific user would see
SET CURRENT USER = 'analyst@company.com';
SELECT COUNT(*) FROM prod.silver.orders;  -- returns only their tenant's rows
RESET CURRENT USER;

Data Lineage and Audit Logs

Unity Catalog automatically captures column-level data lineage β€” which tables and columns flow into each derived table.

Python
# Query lineage via the Unity Catalog API
import requests

DATABRICKS_HOST  = "https://adb-1234567890.azuredatabricks.net"
DATABRICKS_TOKEN = dbutils.secrets.get("databricks-secrets", "pat-token")

# Get upstream lineage for a column
response = requests.get(
    f"{DATABRICKS_HOST}/api/2.0/lineage-tracking/column-lineage",
    headers={"Authorization": f"Bearer {DATABRICKS_TOKEN}"},
    params={
        "table_name":   "prod.gold.monthly_revenue",
        "column_name":  "revenue",
    }
)
lineage = response.json()
print("Upstream columns that feed into revenue:")
for item in lineage.get("upstream_cols", []):
    print(f"  {item['catalog_name']}.{item['schema_name']}.{item['table_name']}.{item['name']}")

Unity Catalog writes audit logs to a configurable destination (Azure Storage, S3). These logs record every data access, schema change, and permission grant β€” enabling compliance reporting and anomaly detection.


Delta Sharing: Cross-Organization Data Sharing

Delta Sharing is an open protocol for sharing live Delta tables with external consumers β€” no data copy, no proprietary format, no Databricks account required on the recipient side.

SQL
-- Provider side: create a share and add tables
CREATE SHARE revenue_share
  COMMENT 'Monthly revenue data shared with external analytics partner';

-- Add tables to the share (optionally with partition filters)
ALTER SHARE revenue_share
  ADD TABLE prod.gold.monthly_revenue
  PARTITION (report_year = 2026);

-- Add a view (share a pre-filtered, anonymised view)
ALTER SHARE revenue_share
  ADD TABLE prod.reporting.public_metrics;

-- Create a recipient (generates an activation link or credential)
CREATE RECIPIENT analytics_partner
  COMMENT 'External analytics partner β€” read-only access';

-- Grant the recipient access to the share
GRANT SELECT ON SHARE revenue_share TO RECIPIENT analytics_partner;

-- Get the activation link for the recipient
DESCRIBE RECIPIENT analytics_partner;
-- Outputs: activation_url for the partner to use with their Delta Sharing client
Python
# Recipient side (any Python environment, no Databricks required):
import delta_sharing

# Recipient downloads their profile file from the activation URL
profile_path = "/path/to/profile.share"
client = delta_sharing.SharingClient(profile_path)

# List available tables
tables = client.list_all_tables()
print(tables)

# Load shared table as a pandas DataFrame
df = delta_sharing.load_as_pandas(f"{profile_path}#revenue_share.prod.monthly_revenue")

Databricks Feature Store

The Feature Store maintains a centralized catalog of feature definitions. Features are computed once in batch (or streaming), stored in a Delta table, and reused across all models β€” ensuring training/serving consistency.

Creating a Feature Table

Python
from databricks.feature_engineering import FeatureEngineeringClient, FeatureLookup

fe = FeatureEngineeringClient()

# Define and compute features
def compute_customer_features(spark) -> "DataFrame":
    return spark.sql("""
        SELECT
            o.customer_id,
            COUNT(*)                                          AS total_orders_90d,
            SUM(o.total)                                      AS spend_90d,
            AVG(o.total)                                      AS avg_order_value_90d,
            MAX(o.created_at)                                 AS last_order_ts,
            DATEDIFF(current_date(), MAX(o.created_at))       AS days_since_last_order,
            COUNT(CASE WHEN o.status = 'cancelled' THEN 1 END) AS cancel_count_90d,
            COUNT(DISTINCT DATE(o.created_at))                AS active_days_90d,
            SUM(t.ticket_count)                               AS support_tickets_90d
        FROM prod.silver.orders o
        LEFT JOIN prod.silver.support_summary t
            ON o.customer_id = t.customer_id
           AND t.period = '90d'
        WHERE o.created_at >= current_date() - INTERVAL 90 DAYS
          AND o.status != 'pending'
        GROUP BY o.customer_id
    """)

features_df = compute_customer_features(spark)

# Create the feature table (first time only)
fe.create_table(
    name="main.features.customer_features_90d",
    primary_keys=["customer_id"],
    schema=features_df.schema,
    description="90-day rolling customer engagement features for churn prediction",
    tags={"team": "data-science", "refresh": "daily"},
)

# Write features (upsert mode)
fe.write_table(
    name="main.features.customer_features_90d",
    df=features_df,
    mode="merge",    # upsert: update existing, insert new
)
print("Feature table refreshed.")

Training with Feature Store Lookups

Python
import mlflow
from databricks.feature_engineering import FeatureLookup

fe = FeatureEngineeringClient()

# Your training DataFrame only needs the primary key and the label
labels_df = spark.sql("""
    SELECT customer_id, churned
    FROM prod.silver.churn_labels
    WHERE label_date = current_date() - INTERVAL 7 DAYS
""")

# Declare which feature tables to join and on which key
feature_lookups = [
    FeatureLookup(
        table_name="main.features.customer_features_90d",
        lookup_key="customer_id",
        # Optionally select a subset of feature columns
        feature_names=[
            "total_orders_90d", "spend_90d", "avg_order_value_90d",
            "days_since_last_order", "cancel_count_90d",
            "active_days_90d", "support_tickets_90d",
        ]
    ),
    FeatureLookup(
        table_name="main.features.customer_demographics",
        lookup_key="customer_id",
        feature_names=["tenure_months", "plan_tier"],
    ),
]

# Feature Store builds the training set β€” joins features to labels automatically
training_set = fe.create_training_set(
    df=labels_df,
    feature_lookups=feature_lookups,
    label="churned",
    exclude_columns=["customer_id"],   # don't pass the key into the model
)

training_df = training_set.load_df().toPandas()
print(f"Training set shape: {training_df.shape}")

# Train and log model β€” the Feature Store records which features were used
with mlflow.start_run(run_name="fs-gbm-v3"):
    X = training_df.drop(columns=["churned"])
    y = training_df["churned"]
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    model = GradientBoostingClassifier(n_estimators=300, max_depth=5, learning_rate=0.05)
    model.fit(X_train, y_train)

    mlflow.log_metric("auc", roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]))

    # Log via Feature Engineering Client β€” records feature lineage in the registry
    fe.log_model(
        model=model,
        artifact_path="model",
        flavor=mlflow.sklearn,
        training_set=training_set,
        registered_model_name="main.ml_models.churn_predictor",
    )

Batch Scoring with Feature Lookup

Python
# Score new customers β€” Feature Store automatically retrieves features at scoring time
customers_to_score = spark.sql("""
    SELECT customer_id FROM prod.silver.active_customers
    WHERE last_seen_at >= current_date() - INTERVAL 1 DAY
""")

# Retrieve features and score in one call
scored_df = fe.score_batch(
    model_uri="models:/main.ml_models.churn_predictor@champion",
    df=customers_to_score,         # only needs customer_id
)

# scored_df contains customer_id + prediction column
scored_df.select("customer_id", "prediction") \
    .write.format("delta") \
    .mode("overwrite") \
    .saveAsTable("prod.ml_outputs.churn_scores")

Complete ML Workflow: Train β†’ Log β†’ Register β†’ Serve

Python
# ── 1. Refresh features ────────────────────────────────────────────────────────

fe = FeatureEngineeringClient()
features_df = compute_customer_features(spark)
fe.write_table(name="main.features.customer_features_90d", df=features_df, mode="merge")
print("Step 1: Features refreshed.")

# ── 2. Build training set ──────────────────────────────────────────────────────

labels_df = spark.table("prod.silver.churn_labels")
feature_lookups = [
    FeatureLookup(table_name="main.features.customer_features_90d", lookup_key="customer_id"),
    FeatureLookup(table_name="main.features.customer_demographics",  lookup_key="customer_id"),
]
training_set = fe.create_training_set(df=labels_df, feature_lookups=feature_lookups, label="churned")
training_df  = training_set.load_df().toPandas()
print(f"Step 2: Training set ready β€” {len(training_df):,} rows, {training_df.shape[1]} features.")

# ── 3. Train, evaluate, log ───────────────────────────────────────────────────

mlflow.set_registry_uri("databricks-uc")
mlflow.set_experiment("/Shared/churn-prediction/churn-model-v3")

with mlflow.start_run(run_name="production-candidate") as run:
    X = training_df.drop(columns=["churned"])
    y = training_df["churned"]
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)

    params = {"n_estimators": 400, "max_depth": 6, "learning_rate": 0.04, "subsample": 0.85}
    mlflow.log_params(params)

    model = GradientBoostingClassifier(**params)
    model.fit(X_train, y_train)

    auc = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1])
    f1  = f1_score(y_test, model.predict(X_test))
    mlflow.log_metrics({"auc": auc, "f1": f1})
    print(f"Step 3: Train complete β€” AUC={auc:.4f}, F1={f1:.4f}")

    fe.log_model(
        model=model,
        artifact_path="model",
        flavor=mlflow.sklearn,
        training_set=training_set,
        registered_model_name="main.ml_models.churn_predictor",
    )
    new_version = mlflow.MlflowClient().get_latest_versions(
        "main.ml_models.churn_predictor"
    )[-1].version

print(f"Step 3: Registered as version {new_version}.")

# ── 4. Promote to champion if AUC improves ───────────────────────────────────

client = mlflow.MlflowClient()

try:
    current_champion = client.get_model_version_by_alias(
        "main.ml_models.churn_predictor", "champion"
    )
    champion_run = client.get_run(current_champion.run_id)
    champion_auc = champion_run.data.metrics.get("auc", 0.0)
except Exception:
    champion_auc = 0.0   # no champion yet

if auc > champion_auc + 0.005:   # require meaningful improvement
    client.set_registered_model_alias(
        "main.ml_models.churn_predictor", "champion", new_version
    )
    print(f"Step 4: Version {new_version} promoted to champion (AUC {auc:.4f} > {champion_auc:.4f}).")
else:
    client.set_registered_model_alias(
        "main.ml_models.churn_predictor", "challenger", new_version
    )
    print(f"Step 4: Version {new_version} registered as challenger (AUC {auc:.4f}, no improvement over {champion_auc:.4f}).")

# ── 5. Update serving endpoint to latest champion ────────────────────────────

champion_version = client.get_model_version_by_alias(
    "main.ml_models.churn_predictor", "champion"
).version

endpoint_update = {
    "served_models": [{
        "name":          f"churn-v{champion_version}",
        "model_name":    "main.ml_models.churn_predictor",
        "model_version": str(champion_version),
        "workload_size": "Small",
        "scale_to_zero_enabled": True,
    }],
    "traffic_config": {
        "routes": [
            {"served_model_name": f"churn-v{champion_version}", "traffic_percentage": 100}
        ]
    }
}

requests.put(
    f"{DATABRICKS_HOST}/api/2.0/serving-endpoints/churn-predictor-endpoint/config",
    headers={"Authorization": f"Bearer {DATABRICKS_TOKEN}"},
    json=endpoint_update,
).raise_for_status()

print(f"Step 5: Serving endpoint updated to version {champion_version}.")

# ── 6. Batch score active customers ──────────────────────────────────────────

customers_to_score = spark.sql(
    "SELECT customer_id FROM prod.silver.active_customers"
)

scored_df = fe.score_batch(
    model_uri=f"models:/main.ml_models.churn_predictor@champion",
    df=customers_to_score,
)

(
    scored_df
    .withColumn("scored_at",     current_timestamp())
    .withColumn("model_version", lit(champion_version))
    .write.format("delta")
    .mode("overwrite")
    .saveAsTable("prod.ml_outputs.churn_scores")
)

print(f"Step 6: {scored_df.count():,} customers scored and written to prod.ml_outputs.churn_scores.")
print("ML workflow complete.")

Key Takeaways

  • MLflow on Databricks requires zero setup β€” experiments are workspace-native. Always use mlflow.set_registry_uri("databricks-uc") to store models in Unity Catalog instead of the legacy workspace registry.
  • fe.log_model (Feature Engineering Client) records feature lineage in the model version β€” enabling automatic feature retrieval at serving and scoring time.
  • Unity Catalog's column masks and row filters enforce security at the storage layer; consumers don't need to change their queries.
  • DESCRIBE HISTORY on a Delta table and the MLflow run history together give you a complete audit trail from raw data to deployed prediction.
  • The A/B traffic split on serving endpoints lets you run champion/challenger tests without any application code changes β€” route 90/10 and compare online metrics before full promotion.
  • Delta Sharing is the right solution when external partners need live data access β€” no data copy, no Databricks account required on their side.

Related: Delta Lake Deep Dive β€” ACID transactions, MERGE, time travel
Related: Advanced PySpark on Databricks β€” DLT, Auto Loader, Kafka streaming

Enjoyed this article?

Explore the Data Engineering learning path for more.

Found this helpful?

Share:𝕏

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.