Learnixo

PySpark & Apache Spark · Lesson 2 of 4

DataFrames & SQL: Joins, Window Functions & Delta MERGE

Column References: Three Equivalent Styles

PySpark gives you three ways to reference a column. Understand all three because you'll see them all in production code.

Python
from pyspark.sql import functions as F
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("DFTransformations").getOrCreate()

df = spark.read.parquet("s3://my-bucket/silver/orders/")

# Style 1: String column name  simplest, but no IDE autocomplete
df.filter("status = 'active'")
df.select("order_id", "customer_id", "total")

# Style 2: df["column"]  DataFrame-scoped, unambiguous in joins
df.filter(df["status"] == "active")
df.select(df["order_id"], df["customer_id"])

# Style 3: F.col()  most common in production, chains well
df.filter(F.col("status") == "active")
df.select(F.col("order_id"), F.col("customer_id"))

# In simple cases col() and F.col() are identical
# Use F.col() to be explicit about the functions namespace

Rule of thumb: Use F.col() for expressions in withColumn/filter/agg. Use string column names for simple select and drop.

Core DataFrame Transformations

Python
# ─── select ───────────────────────────────────────────────────────────────────
df_slim = df.select(
    "order_id",
    "customer_id",
    F.col("order_date").cast("date").alias("order_date"),
    (F.col("subtotal") + F.col("tax")).alias("total"),
)

# ─── filter / where (identical) ───────────────────────────────────────────────
df_recent = df.filter(
    (F.col("order_date") >= "2026-01-01") &
    (F.col("status").isin("completed", "shipped")) &
    F.col("customer_id").isNotNull()
)

# ─── withColumn ───────────────────────────────────────────────────────────────
df_enriched = (
    df
    .withColumn("year",  F.year("order_date"))
    .withColumn("month", F.month("order_date"))
    .withColumn(
        "order_size",
        F.when(F.col("total") < 50,  "small")
         .when(F.col("total") < 200, "medium")
         .otherwise("large")
    )
)

# ─── drop ─────────────────────────────────────────────────────────────────────
df_clean = df.drop("_corrupt_record", "_ingested_at", "internal_notes")

# ─── rename ───────────────────────────────────────────────────────────────────
df_renamed = df.withColumnRenamed("cust_id", "customer_id")

# ─── distinct / dropDuplicates ────────────────────────────────────────────────
df_deduped = df.dropDuplicates(["order_id"])
df_deduped_partial = df.dropDuplicates(["customer_id", "order_date"])  # keep one per customer per day

Built-in Functions You'll Use Every Day

Python
import pyspark.sql.functions as F

# ─── Conditional Logic ────────────────────────────────────────────────────────
df = df.withColumn(
    "effective_price",
    F.when(F.col("discount") > 0, F.col("price") * (1 - F.col("discount")))
     .when(F.col("promo_code").isNotNull(), F.col("price") * 0.9)
     .otherwise(F.col("price"))
)

# ─── Null Handling ────────────────────────────────────────────────────────────
# coalesce returns the first non-null value
df = df.withColumn(
    "display_name",
    F.coalesce(F.col("preferred_name"), F.col("full_name"), F.lit("Anonymous"))
)

# fillna / fill
df = df.fillna({"country_code": "US", "lifetime_value": 0.0})

# ─── String Functions ─────────────────────────────────────────────────────────
df = (
    df
    .withColumn("email", F.lower(F.trim(F.col("email"))))
    .withColumn("phone_digits", F.regexp_replace(F.col("phone"), "[^0-9]", ""))
    .withColumn(
        "area_code",
        F.regexp_extract(F.col("phone"), r"^\(?(\d{3})\)?", 1)
    )
    .withColumn("name_parts", F.split(F.col("full_name"), " "))
    .withColumn("first_name", F.col("name_parts").getItem(0))
    .withColumn(
        "name_length",
        F.length(F.col("full_name"))
    )
)

# ─── Date Functions ───────────────────────────────────────────────────────────
df = (
    df
    .withColumn("order_year",    F.year("order_date"))
    .withColumn("order_month",   F.month("order_date"))
    .withColumn("order_quarter", F.quarter("order_date"))
    .withColumn(
        "order_date_str",
        F.date_format(F.col("order_date"), "yyyy-MM-dd")
    )
    .withColumn(
        "days_since_order",
        F.datediff(F.current_date(), F.col("order_date"))
    )
    .withColumn(
        "order_ts",
        F.to_timestamp(F.col("order_date_str"), "yyyy-MM-dd")
    )
)

# ─── Array / Struct Functions ─────────────────────────────────────────────────
# explode turns an array column into multiple rows
df_tags = df.withColumn("tag", F.explode(F.col("tags")))

# explode_outer keeps rows even when the array is null or empty
df_tags_safe = df.withColumn("tag", F.explode_outer(F.col("tags")))

# collect items into an array during aggregation
df_customer_tags = (
    df_tags
    .groupBy("customer_id")
    .agg(F.collect_list("tag").alias("all_tags"))
)

# named_struct creates a struct column
df_address = df.withColumn(
    "address",
    F.struct(
        F.col("street").alias("street"),
        F.col("city").alias("city"),
        F.col("zip").alias("zip"),
    )
)

GroupBy and Aggregations

Python
from pyspark.sql import functions as F

# ─── Basic groupBy + agg ──────────────────────────────────────────────────────
df_summary = (
    df_orders
    .groupBy("customer_id", "year", "month")
    .agg(
        F.count("*").alias("order_count"),
        F.sum("total").alias("total_revenue"),
        F.avg("total").alias("avg_order_value"),
        F.min("order_date").alias("first_order_date"),
        F.max("order_date").alias("last_order_date"),
        F.countDistinct("product_id").alias("unique_products"),
        F.collect_list("order_id").alias("order_ids"),
    )
)

# ─── Pivot ────────────────────────────────────────────────────────────────────
df_pivot = (
    df_orders
    .groupBy("customer_id")
    .pivot("status", ["pending", "completed", "cancelled"])
    .agg(F.count("order_id"))
    .fillna(0)
)
# Result: customer_id | pending | completed | cancelled

# ─── Multiple aggregations on same column ─────────────────────────────────────
df_revenue_stats = (
    df_orders
    .groupBy("country_code")
    .agg(
        F.expr("percentile_approx(total, 0.5)").alias("median_total"),
        F.expr("percentile_approx(total, 0.95)").alias("p95_total"),
        F.stddev("total").alias("stddev_total"),
    )
)

User-Defined Functions (UDFs)

Python UDF — Convenient but Slow

Python
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import re

# Python UDF: each row crosses the JVM boundary  serialization overhead
def normalize_phone(phone: str) -> str:
    if not phone:
        return None
    digits = re.sub(r"\D", "", phone)
    if len(digits) == 10:
        return f"+1{digits}"
    elif len(digits) == 11 and digits.startswith("1"):
        return f"+{digits}"
    return None

normalize_phone_udf = udf(normalize_phone, StringType())

# Register and use
df = df.withColumn("phone_normalized", normalize_phone_udf(F.col("phone")))

Performance warning: Python UDFs serialize each row to Python, call your function, then serialize results back to the JVM. For large datasets this is 10–100x slower than built-in functions.

Pandas UDF (Vectorized) — Fast

Python
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType, DoubleType
import pandas as pd
import re

# Pandas UDF processes entire columnar batches  no per-row JVM boundary
@pandas_udf(StringType())
def normalize_phone_fast(phone_series: pd.Series) -> pd.Series:
    def _normalize(phone):
        if not phone or pd.isna(phone):
            return None
        digits = re.sub(r"\D", "", str(phone))
        if len(digits) == 10:
            return f"+1{digits}"
        elif len(digits) == 11 and digits.startswith("1"):
            return f"+{digits}"
        return None
    return phone_series.apply(_normalize)

@pandas_udf(DoubleType())
def revenue_score(revenue: pd.Series, order_count: pd.Series) -> pd.Series:
    """Composite score: revenue per order, log-scaled."""
    import numpy as np
    ratio = revenue / order_count.replace(0, 1)
    return np.log1p(ratio)

# Use exactly like built-in functions
df = df.withColumn("phone_normalized", normalize_phone_fast(F.col("phone")))
df = df.withColumn(
    "revenue_score",
    revenue_score(F.col("lifetime_value"), F.col("order_count"))
)

Rule: Always prefer built-in F.* functions. Use pandas UDFs for complex logic that has no built-in equivalent. Avoid Python UDFs in production — they are always slower than both alternatives.

Join Types and Strategies

Python
# ─── Inner Join (default) ─────────────────────────────────────────────────────
df_joined = df_orders.join(df_customers, on="customer_id", how="inner")

# ─── Left Join ────────────────────────────────────────────────────────────────
df_left = df_orders.join(df_customers, on="customer_id", how="left")

# ─── Semi Join (filter orders where matching customer exists, no customer cols)
df_semi = df_orders.join(df_customers, on="customer_id", how="left_semi")

# ─── Anti Join (orders with NO matching customer  orphan detection)
df_anti = df_orders.join(df_customers, on="customer_id", how="left_anti")

# ─── Multi-column join ────────────────────────────────────────────────────────
df_multi = df_orders.join(
    df_returns,
    on=["order_id", "product_id"],
    how="left"
)

# ─── Non-equi join (requires column expression syntax) ────────────────────────
df_range = df_transactions.join(
    df_date_ranges,
    on=(
        (df_transactions["txn_date"] >= df_date_ranges["period_start"]) &
        (df_transactions["txn_date"] <  df_date_ranges["period_end"])
    ),
    how="inner"
)

# ─── Resolving ambiguous columns after join ────────────────────────────────────
# When both DataFrames have "created_at", use aliases
df_ord = df_orders.alias("ord")
df_cust = df_customers.alias("cust")

df_joined = df_ord.join(df_cust, F.col("ord.customer_id") == F.col("cust.customer_id"))
df_result = df_joined.select(
    F.col("ord.order_id"),
    F.col("cust.email"),
    F.col("ord.created_at").alias("order_created_at"),
    F.col("cust.created_at").alias("customer_created_at"),
)

Broadcast Joins for Small Tables

Python
# When one side of a join fits in memory (< ~10MB by default),
# Spark can broadcast it to every executor  eliminates the shuffle

df_orders_enriched = df_orders.join(
    F.broadcast(df_country_lookup),    # broadcast the small lookup table
    on="country_code",
    how="left"
)

# Configure the auto-broadcast threshold (bytes):
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 50 * 1024 * 1024)  # 50MB

# Disable auto-broadcast (useful for debugging join strategies):
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Verify Spark chose a BroadcastHashJoin:
df_orders_enriched.explain()
# Look for: BroadcastHashJoin in the physical plan

Window Functions

Window functions operate over a "window" of rows relative to the current row, without collapsing the DataFrame like groupBy does.

Python
from pyspark.sql.window import Window
from pyspark.sql import functions as F

# ─── Define window specs ──────────────────────────────────────────────────────

# Partition by customer, order by date ascending
w_customer_time = (
    Window
    .partitionBy("customer_id")
    .orderBy("order_date")
)

# Partition by customer only (for ranking across all time)
w_customer = Window.partitionBy("customer_id")

# Rolling 30-day window (requires timestamp or numeric ordering)
w_rolling_30d = (
    Window
    .partitionBy("customer_id")
    .orderBy(F.col("order_date").cast("long"))
    .rangeBetween(-30 * 86400, 0)  # 30 days in seconds, current row is 0
)

# ─── Ranking Functions ────────────────────────────────────────────────────────
df_ranked = df_orders.withColumn(
    "order_rank",          # dense_rank: no gaps in ranking numbers
    F.dense_rank().over(w_customer_time)
)

df_top = df_orders.withColumn(
    "rank_by_value",       # rank: gaps when tied
    F.rank().over(Window.partitionBy("customer_id").orderBy(F.col("total").desc()))
)

# Get each customer's single most recent order
df_latest = (
    df_orders
    .withColumn("row_num", F.row_number().over(
        Window.partitionBy("customer_id").orderBy(F.col("order_date").desc())
    ))
    .filter(F.col("row_num") == 1)
    .drop("row_num")
)

# ─── Lag and Lead ─────────────────────────────────────────────────────────────
df_with_prev = (
    df_orders
    .withColumn(
        "prev_order_date",
        F.lag("order_date", 1).over(w_customer_time)
    )
    .withColumn(
        "next_order_date",
        F.lead("order_date", 1).over(w_customer_time)
    )
    .withColumn(
        "days_since_last_order",
        F.datediff(F.col("order_date"), F.col("prev_order_date"))
    )
)

# ─── Running Aggregations ─────────────────────────────────────────────────────
df_running = df_orders.withColumn(
    "cumulative_revenue",
    F.sum("total").over(w_customer_time)  # running sum within customer
)

df_running = df_running.withColumn(
    "pct_of_customer_total",
    F.col("total") / F.sum("total").over(w_customer)
)

Spark SQL with Temporary Views

Python
# Register a DataFrame as a SQL view
df_orders.createOrReplaceTempView("orders")
df_customers.createOrReplaceTempView("customers")

# Full SQL  useful when logic is cleaner in SQL than DataFrame API
df_sql_result = spark.sql("""
    SELECT
        c.country_code,
        DATE_FORMAT(o.order_date, 'yyyy-MM') AS month,
        COUNT(DISTINCT o.customer_id)        AS active_customers,
        SUM(o.total)                         AS total_revenue,
        AVG(o.total)                         AS avg_order_value,
        PERCENTILE_APPROX(o.total, 0.5)      AS median_order_value
    FROM orders o
    INNER JOIN customers c USING (customer_id)
    WHERE o.status = 'completed'
      AND o.order_date >= '2026-01-01'
    GROUP BY 1, 2
    ORDER BY 1, 2
""")

df_sql_result.show(20)

# Global temp views persist across SparkSessions (useful in notebooks)
df_orders.createOrReplaceGlobalTempView("orders_global")
spark.sql("SELECT * FROM global_temp.orders_global LIMIT 5").show()

Delta Lake MERGE (Upsert)

Python
from delta.tables import DeltaTable

# DeltaTable.forPath gives you the merge builder
delta_table = DeltaTable.forPath(spark, "s3://my-bucket/delta/customers/")

df_updates = spark.read.parquet("s3://my-bucket/staging/customer_updates/")

(
    delta_table.alias("target")
    .merge(
        source=df_updates.alias("source"),
        condition="target.customer_id = source.customer_id"
    )
    .whenMatchedUpdate(set={
        "email":          "source.email",
        "full_name":      "source.full_name",
        "lifetime_value": "source.lifetime_value",
        "updated_at":     "source.updated_at",
    })
    .whenNotMatchedInsert(values={
        "customer_id":    "source.customer_id",
        "email":          "source.email",
        "full_name":      "source.full_name",
        "lifetime_value": "source.lifetime_value",
        "created_at":     "source.created_at",
        "updated_at":     "source.updated_at",
    })
    .execute()
)

Complete Example: Silver Layer with SCD Type 2

SCD Type 2 (Slowly Changing Dimension Type 2) preserves full history by closing the old row and inserting a new one when a record changes.

Python
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.window import Window
from delta.tables import DeltaTable
from pyspark.sql.types import (
    StructType, StructField, IntegerType, StringType,
    DoubleType, TimestampType, BooleanType
)

spark = (
    SparkSession.builder
    .appName("SilverSCD2Pipeline")
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
    .config("spark.sql.catalog.spark_catalog",
            "org.apache.spark.sql.delta.catalog.DeltaCatalog")
    .getOrCreate()
)

SILVER_PATH = "s3://my-bucket/delta/dim_customers_scd2/"
STAGING_PATH = "s3://my-bucket/staging/customer_updates/"


# ─── Read incoming changes ────────────────────────────────────────────────────

df_incoming = spark.read.parquet(STAGING_PATH)

# Deduplicate incoming: keep most recent change per customer
w_latest = Window.partitionBy("customer_id").orderBy(F.col("updated_at").desc())

df_latest_incoming = (
    df_incoming
    .withColumn("_rn", F.row_number().over(w_latest))
    .filter(F.col("_rn") == 1)
    .drop("_rn")
)


# ─── Identify what actually changed (hash-based change detection) ─────────────

TRACKED_COLS = ["email", "full_name", "country_code", "tier"]

df_with_hash = df_latest_incoming.withColumn(
    "row_hash",
    F.md5(F.concat_ws("|", *[F.col(c).cast("string") for c in TRACKED_COLS]))
)


# ─── SCD Type 2 MERGE logic ───────────────────────────────────────────────────
# When a tracked column changes:
#   1. Close the existing active row (set end_date = now, is_current = false)
#   2. Insert a new active row with new values
#
# Delta MERGE cannot do "close old + insert new" in one statement.
# Standard approach: two-pass or use MERGE with whenMatchedUpdateAll + whenNotMatchedInsert

NOW = F.current_timestamp()

# Step 1: Close changed active rows in the target
if DeltaTable.isDeltaTable(spark, SILVER_PATH):
    delta_target = DeltaTable.forPath(spark, SILVER_PATH)

    # Identify customer_ids whose hash changed
    df_existing = delta_target.toDF().filter(F.col("is_current") == True)

    df_changed = (
        df_existing.alias("existing")
        .join(
            df_with_hash.alias("incoming"),
            on="customer_id",
            how="inner"
        )
        .filter(
            F.col("existing.row_hash") != F.col("incoming.row_hash")
        )
        .select(F.col("existing.customer_id"))
    )

    changed_ids = [row["customer_id"] for row in df_changed.collect()]

    if changed_ids:
        # Close the old rows
        (
            delta_target.alias("target")
            .merge(
                source=df_with_hash.filter(
                    F.col("customer_id").isin(changed_ids)
                ).alias("source"),
                condition=(
                    "target.customer_id = source.customer_id AND target.is_current = true"
                )
            )
            .whenMatchedUpdate(set={
                "end_date":   "current_timestamp()",
                "is_current": "false",
            })
            .execute()
        )

# Step 2: Insert new rows for changed and net-new customers
df_existing_current = (
    DeltaTable.forPath(spark, SILVER_PATH).toDF()
    .filter(F.col("is_current") == True)
    .select("customer_id", "row_hash")
) if DeltaTable.isDeltaTable(spark, SILVER_PATH) else spark.createDataFrame([], schema=StructType([
    StructField("customer_id", IntegerType()),
    StructField("row_hash", StringType()),
]))

df_to_insert = (
    df_with_hash.alias("inc")
    .join(
        df_existing_current.alias("cur"),
        on="customer_id",
        how="left"
    )
    .filter(
        F.col("cur.customer_id").isNull() |              # net new
        (F.col("inc.row_hash") != F.col("cur.row_hash")) # changed
    )
    .select("inc.*")
    .withColumn("start_date",  NOW)
    .withColumn("end_date",    F.lit(None).cast(TimestampType()))
    .withColumn("is_current",  F.lit(True))
)

(
    df_to_insert
    .write
    .format("delta")
    .mode("append")
    .save(SILVER_PATH)
)


# ─── Verify: only one active row per customer ─────────────────────────────────

df_silver = spark.read.format("delta").load(SILVER_PATH)

df_active_counts = (
    df_silver
    .filter(F.col("is_current") == True)
    .groupBy("customer_id")
    .agg(F.count("*").alias("active_row_count"))
)

duplicates = df_active_counts.filter(F.col("active_row_count") > 1).count()
assert duplicates == 0, f"SCD2 violation: {duplicates} customers have multiple active rows"

print("SCD2 Silver layer validated. Schema:")
df_silver.printSchema()
df_silver.orderBy("customer_id", "start_date").show(10, truncate=False)

Key Takeaways

  • Use F.col() style for complex expressions; string names are fine for simple selects.
  • Prefer built-in functions over UDFs. When you must use a UDF, use pandas UDF (vectorized) not Python UDF (row-by-row).
  • Broadcast small tables explicitly with F.broadcast() to avoid shuffle joins.
  • Window functions with partitionBy + orderBy give you running totals, lag/lead, and deduplication without collapsing rows.
  • SCD Type 2 on Delta Lake requires a two-pass approach: close old rows, then insert new ones.
  • Always validate your SCD logic with an assertion that each entity has exactly one active row.