PySpark & Apache Spark · Lesson 2 of 4
DataFrames & SQL: Joins, Window Functions & Delta MERGE
Column References: Three Equivalent Styles
PySpark gives you three ways to reference a column. Understand all three because you'll see them all in production code.
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("DFTransformations").getOrCreate()
df = spark.read.parquet("s3://my-bucket/silver/orders/")
# Style 1: String column name — simplest, but no IDE autocomplete
df.filter("status = 'active'")
df.select("order_id", "customer_id", "total")
# Style 2: df["column"] — DataFrame-scoped, unambiguous in joins
df.filter(df["status"] == "active")
df.select(df["order_id"], df["customer_id"])
# Style 3: F.col() — most common in production, chains well
df.filter(F.col("status") == "active")
df.select(F.col("order_id"), F.col("customer_id"))
# In simple cases col() and F.col() are identical
# Use F.col() to be explicit about the functions namespaceRule of thumb: Use F.col() for expressions in withColumn/filter/agg. Use string column names for simple select and drop.
Core DataFrame Transformations
# ─── select ───────────────────────────────────────────────────────────────────
df_slim = df.select(
"order_id",
"customer_id",
F.col("order_date").cast("date").alias("order_date"),
(F.col("subtotal") + F.col("tax")).alias("total"),
)
# ─── filter / where (identical) ───────────────────────────────────────────────
df_recent = df.filter(
(F.col("order_date") >= "2026-01-01") &
(F.col("status").isin("completed", "shipped")) &
F.col("customer_id").isNotNull()
)
# ─── withColumn ───────────────────────────────────────────────────────────────
df_enriched = (
df
.withColumn("year", F.year("order_date"))
.withColumn("month", F.month("order_date"))
.withColumn(
"order_size",
F.when(F.col("total") < 50, "small")
.when(F.col("total") < 200, "medium")
.otherwise("large")
)
)
# ─── drop ─────────────────────────────────────────────────────────────────────
df_clean = df.drop("_corrupt_record", "_ingested_at", "internal_notes")
# ─── rename ───────────────────────────────────────────────────────────────────
df_renamed = df.withColumnRenamed("cust_id", "customer_id")
# ─── distinct / dropDuplicates ────────────────────────────────────────────────
df_deduped = df.dropDuplicates(["order_id"])
df_deduped_partial = df.dropDuplicates(["customer_id", "order_date"]) # keep one per customer per dayBuilt-in Functions You'll Use Every Day
import pyspark.sql.functions as F
# ─── Conditional Logic ────────────────────────────────────────────────────────
df = df.withColumn(
"effective_price",
F.when(F.col("discount") > 0, F.col("price") * (1 - F.col("discount")))
.when(F.col("promo_code").isNotNull(), F.col("price") * 0.9)
.otherwise(F.col("price"))
)
# ─── Null Handling ────────────────────────────────────────────────────────────
# coalesce returns the first non-null value
df = df.withColumn(
"display_name",
F.coalesce(F.col("preferred_name"), F.col("full_name"), F.lit("Anonymous"))
)
# fillna / fill
df = df.fillna({"country_code": "US", "lifetime_value": 0.0})
# ─── String Functions ─────────────────────────────────────────────────────────
df = (
df
.withColumn("email", F.lower(F.trim(F.col("email"))))
.withColumn("phone_digits", F.regexp_replace(F.col("phone"), "[^0-9]", ""))
.withColumn(
"area_code",
F.regexp_extract(F.col("phone"), r"^\(?(\d{3})\)?", 1)
)
.withColumn("name_parts", F.split(F.col("full_name"), " "))
.withColumn("first_name", F.col("name_parts").getItem(0))
.withColumn(
"name_length",
F.length(F.col("full_name"))
)
)
# ─── Date Functions ───────────────────────────────────────────────────────────
df = (
df
.withColumn("order_year", F.year("order_date"))
.withColumn("order_month", F.month("order_date"))
.withColumn("order_quarter", F.quarter("order_date"))
.withColumn(
"order_date_str",
F.date_format(F.col("order_date"), "yyyy-MM-dd")
)
.withColumn(
"days_since_order",
F.datediff(F.current_date(), F.col("order_date"))
)
.withColumn(
"order_ts",
F.to_timestamp(F.col("order_date_str"), "yyyy-MM-dd")
)
)
# ─── Array / Struct Functions ─────────────────────────────────────────────────
# explode turns an array column into multiple rows
df_tags = df.withColumn("tag", F.explode(F.col("tags")))
# explode_outer keeps rows even when the array is null or empty
df_tags_safe = df.withColumn("tag", F.explode_outer(F.col("tags")))
# collect items into an array during aggregation
df_customer_tags = (
df_tags
.groupBy("customer_id")
.agg(F.collect_list("tag").alias("all_tags"))
)
# named_struct creates a struct column
df_address = df.withColumn(
"address",
F.struct(
F.col("street").alias("street"),
F.col("city").alias("city"),
F.col("zip").alias("zip"),
)
)GroupBy and Aggregations
from pyspark.sql import functions as F
# ─── Basic groupBy + agg ──────────────────────────────────────────────────────
df_summary = (
df_orders
.groupBy("customer_id", "year", "month")
.agg(
F.count("*").alias("order_count"),
F.sum("total").alias("total_revenue"),
F.avg("total").alias("avg_order_value"),
F.min("order_date").alias("first_order_date"),
F.max("order_date").alias("last_order_date"),
F.countDistinct("product_id").alias("unique_products"),
F.collect_list("order_id").alias("order_ids"),
)
)
# ─── Pivot ────────────────────────────────────────────────────────────────────
df_pivot = (
df_orders
.groupBy("customer_id")
.pivot("status", ["pending", "completed", "cancelled"])
.agg(F.count("order_id"))
.fillna(0)
)
# Result: customer_id | pending | completed | cancelled
# ─── Multiple aggregations on same column ─────────────────────────────────────
df_revenue_stats = (
df_orders
.groupBy("country_code")
.agg(
F.expr("percentile_approx(total, 0.5)").alias("median_total"),
F.expr("percentile_approx(total, 0.95)").alias("p95_total"),
F.stddev("total").alias("stddev_total"),
)
)User-Defined Functions (UDFs)
Python UDF — Convenient but Slow
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import re
# Python UDF: each row crosses the JVM boundary — serialization overhead
def normalize_phone(phone: str) -> str:
if not phone:
return None
digits = re.sub(r"\D", "", phone)
if len(digits) == 10:
return f"+1{digits}"
elif len(digits) == 11 and digits.startswith("1"):
return f"+{digits}"
return None
normalize_phone_udf = udf(normalize_phone, StringType())
# Register and use
df = df.withColumn("phone_normalized", normalize_phone_udf(F.col("phone")))Performance warning: Python UDFs serialize each row to Python, call your function, then serialize results back to the JVM. For large datasets this is 10–100x slower than built-in functions.
Pandas UDF (Vectorized) — Fast
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType, DoubleType
import pandas as pd
import re
# Pandas UDF processes entire columnar batches — no per-row JVM boundary
@pandas_udf(StringType())
def normalize_phone_fast(phone_series: pd.Series) -> pd.Series:
def _normalize(phone):
if not phone or pd.isna(phone):
return None
digits = re.sub(r"\D", "", str(phone))
if len(digits) == 10:
return f"+1{digits}"
elif len(digits) == 11 and digits.startswith("1"):
return f"+{digits}"
return None
return phone_series.apply(_normalize)
@pandas_udf(DoubleType())
def revenue_score(revenue: pd.Series, order_count: pd.Series) -> pd.Series:
"""Composite score: revenue per order, log-scaled."""
import numpy as np
ratio = revenue / order_count.replace(0, 1)
return np.log1p(ratio)
# Use exactly like built-in functions
df = df.withColumn("phone_normalized", normalize_phone_fast(F.col("phone")))
df = df.withColumn(
"revenue_score",
revenue_score(F.col("lifetime_value"), F.col("order_count"))
)Rule: Always prefer built-in F.* functions. Use pandas UDFs for complex logic that has no built-in equivalent. Avoid Python UDFs in production — they are always slower than both alternatives.
Join Types and Strategies
# ─── Inner Join (default) ─────────────────────────────────────────────────────
df_joined = df_orders.join(df_customers, on="customer_id", how="inner")
# ─── Left Join ────────────────────────────────────────────────────────────────
df_left = df_orders.join(df_customers, on="customer_id", how="left")
# ─── Semi Join (filter orders where matching customer exists, no customer cols)
df_semi = df_orders.join(df_customers, on="customer_id", how="left_semi")
# ─── Anti Join (orders with NO matching customer — orphan detection)
df_anti = df_orders.join(df_customers, on="customer_id", how="left_anti")
# ─── Multi-column join ────────────────────────────────────────────────────────
df_multi = df_orders.join(
df_returns,
on=["order_id", "product_id"],
how="left"
)
# ─── Non-equi join (requires column expression syntax) ────────────────────────
df_range = df_transactions.join(
df_date_ranges,
on=(
(df_transactions["txn_date"] >= df_date_ranges["period_start"]) &
(df_transactions["txn_date"] < df_date_ranges["period_end"])
),
how="inner"
)
# ─── Resolving ambiguous columns after join ────────────────────────────────────
# When both DataFrames have "created_at", use aliases
df_ord = df_orders.alias("ord")
df_cust = df_customers.alias("cust")
df_joined = df_ord.join(df_cust, F.col("ord.customer_id") == F.col("cust.customer_id"))
df_result = df_joined.select(
F.col("ord.order_id"),
F.col("cust.email"),
F.col("ord.created_at").alias("order_created_at"),
F.col("cust.created_at").alias("customer_created_at"),
)Broadcast Joins for Small Tables
# When one side of a join fits in memory (< ~10MB by default),
# Spark can broadcast it to every executor — eliminates the shuffle
df_orders_enriched = df_orders.join(
F.broadcast(df_country_lookup), # broadcast the small lookup table
on="country_code",
how="left"
)
# Configure the auto-broadcast threshold (bytes):
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 50 * 1024 * 1024) # 50MB
# Disable auto-broadcast (useful for debugging join strategies):
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
# Verify Spark chose a BroadcastHashJoin:
df_orders_enriched.explain()
# Look for: BroadcastHashJoin in the physical planWindow Functions
Window functions operate over a "window" of rows relative to the current row, without collapsing the DataFrame like groupBy does.
from pyspark.sql.window import Window
from pyspark.sql import functions as F
# ─── Define window specs ──────────────────────────────────────────────────────
# Partition by customer, order by date ascending
w_customer_time = (
Window
.partitionBy("customer_id")
.orderBy("order_date")
)
# Partition by customer only (for ranking across all time)
w_customer = Window.partitionBy("customer_id")
# Rolling 30-day window (requires timestamp or numeric ordering)
w_rolling_30d = (
Window
.partitionBy("customer_id")
.orderBy(F.col("order_date").cast("long"))
.rangeBetween(-30 * 86400, 0) # 30 days in seconds, current row is 0
)
# ─── Ranking Functions ────────────────────────────────────────────────────────
df_ranked = df_orders.withColumn(
"order_rank", # dense_rank: no gaps in ranking numbers
F.dense_rank().over(w_customer_time)
)
df_top = df_orders.withColumn(
"rank_by_value", # rank: gaps when tied
F.rank().over(Window.partitionBy("customer_id").orderBy(F.col("total").desc()))
)
# Get each customer's single most recent order
df_latest = (
df_orders
.withColumn("row_num", F.row_number().over(
Window.partitionBy("customer_id").orderBy(F.col("order_date").desc())
))
.filter(F.col("row_num") == 1)
.drop("row_num")
)
# ─── Lag and Lead ─────────────────────────────────────────────────────────────
df_with_prev = (
df_orders
.withColumn(
"prev_order_date",
F.lag("order_date", 1).over(w_customer_time)
)
.withColumn(
"next_order_date",
F.lead("order_date", 1).over(w_customer_time)
)
.withColumn(
"days_since_last_order",
F.datediff(F.col("order_date"), F.col("prev_order_date"))
)
)
# ─── Running Aggregations ─────────────────────────────────────────────────────
df_running = df_orders.withColumn(
"cumulative_revenue",
F.sum("total").over(w_customer_time) # running sum within customer
)
df_running = df_running.withColumn(
"pct_of_customer_total",
F.col("total") / F.sum("total").over(w_customer)
)Spark SQL with Temporary Views
# Register a DataFrame as a SQL view
df_orders.createOrReplaceTempView("orders")
df_customers.createOrReplaceTempView("customers")
# Full SQL — useful when logic is cleaner in SQL than DataFrame API
df_sql_result = spark.sql("""
SELECT
c.country_code,
DATE_FORMAT(o.order_date, 'yyyy-MM') AS month,
COUNT(DISTINCT o.customer_id) AS active_customers,
SUM(o.total) AS total_revenue,
AVG(o.total) AS avg_order_value,
PERCENTILE_APPROX(o.total, 0.5) AS median_order_value
FROM orders o
INNER JOIN customers c USING (customer_id)
WHERE o.status = 'completed'
AND o.order_date >= '2026-01-01'
GROUP BY 1, 2
ORDER BY 1, 2
""")
df_sql_result.show(20)
# Global temp views persist across SparkSessions (useful in notebooks)
df_orders.createOrReplaceGlobalTempView("orders_global")
spark.sql("SELECT * FROM global_temp.orders_global LIMIT 5").show()Delta Lake MERGE (Upsert)
from delta.tables import DeltaTable
# DeltaTable.forPath gives you the merge builder
delta_table = DeltaTable.forPath(spark, "s3://my-bucket/delta/customers/")
df_updates = spark.read.parquet("s3://my-bucket/staging/customer_updates/")
(
delta_table.alias("target")
.merge(
source=df_updates.alias("source"),
condition="target.customer_id = source.customer_id"
)
.whenMatchedUpdate(set={
"email": "source.email",
"full_name": "source.full_name",
"lifetime_value": "source.lifetime_value",
"updated_at": "source.updated_at",
})
.whenNotMatchedInsert(values={
"customer_id": "source.customer_id",
"email": "source.email",
"full_name": "source.full_name",
"lifetime_value": "source.lifetime_value",
"created_at": "source.created_at",
"updated_at": "source.updated_at",
})
.execute()
)Complete Example: Silver Layer with SCD Type 2
SCD Type 2 (Slowly Changing Dimension Type 2) preserves full history by closing the old row and inserting a new one when a record changes.
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.window import Window
from delta.tables import DeltaTable
from pyspark.sql.types import (
StructType, StructField, IntegerType, StringType,
DoubleType, TimestampType, BooleanType
)
spark = (
SparkSession.builder
.appName("SilverSCD2Pipeline")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config("spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog")
.getOrCreate()
)
SILVER_PATH = "s3://my-bucket/delta/dim_customers_scd2/"
STAGING_PATH = "s3://my-bucket/staging/customer_updates/"
# ─── Read incoming changes ────────────────────────────────────────────────────
df_incoming = spark.read.parquet(STAGING_PATH)
# Deduplicate incoming: keep most recent change per customer
w_latest = Window.partitionBy("customer_id").orderBy(F.col("updated_at").desc())
df_latest_incoming = (
df_incoming
.withColumn("_rn", F.row_number().over(w_latest))
.filter(F.col("_rn") == 1)
.drop("_rn")
)
# ─── Identify what actually changed (hash-based change detection) ─────────────
TRACKED_COLS = ["email", "full_name", "country_code", "tier"]
df_with_hash = df_latest_incoming.withColumn(
"row_hash",
F.md5(F.concat_ws("|", *[F.col(c).cast("string") for c in TRACKED_COLS]))
)
# ─── SCD Type 2 MERGE logic ───────────────────────────────────────────────────
# When a tracked column changes:
# 1. Close the existing active row (set end_date = now, is_current = false)
# 2. Insert a new active row with new values
#
# Delta MERGE cannot do "close old + insert new" in one statement.
# Standard approach: two-pass or use MERGE with whenMatchedUpdateAll + whenNotMatchedInsert
NOW = F.current_timestamp()
# Step 1: Close changed active rows in the target
if DeltaTable.isDeltaTable(spark, SILVER_PATH):
delta_target = DeltaTable.forPath(spark, SILVER_PATH)
# Identify customer_ids whose hash changed
df_existing = delta_target.toDF().filter(F.col("is_current") == True)
df_changed = (
df_existing.alias("existing")
.join(
df_with_hash.alias("incoming"),
on="customer_id",
how="inner"
)
.filter(
F.col("existing.row_hash") != F.col("incoming.row_hash")
)
.select(F.col("existing.customer_id"))
)
changed_ids = [row["customer_id"] for row in df_changed.collect()]
if changed_ids:
# Close the old rows
(
delta_target.alias("target")
.merge(
source=df_with_hash.filter(
F.col("customer_id").isin(changed_ids)
).alias("source"),
condition=(
"target.customer_id = source.customer_id AND target.is_current = true"
)
)
.whenMatchedUpdate(set={
"end_date": "current_timestamp()",
"is_current": "false",
})
.execute()
)
# Step 2: Insert new rows for changed and net-new customers
df_existing_current = (
DeltaTable.forPath(spark, SILVER_PATH).toDF()
.filter(F.col("is_current") == True)
.select("customer_id", "row_hash")
) if DeltaTable.isDeltaTable(spark, SILVER_PATH) else spark.createDataFrame([], schema=StructType([
StructField("customer_id", IntegerType()),
StructField("row_hash", StringType()),
]))
df_to_insert = (
df_with_hash.alias("inc")
.join(
df_existing_current.alias("cur"),
on="customer_id",
how="left"
)
.filter(
F.col("cur.customer_id").isNull() | # net new
(F.col("inc.row_hash") != F.col("cur.row_hash")) # changed
)
.select("inc.*")
.withColumn("start_date", NOW)
.withColumn("end_date", F.lit(None).cast(TimestampType()))
.withColumn("is_current", F.lit(True))
)
(
df_to_insert
.write
.format("delta")
.mode("append")
.save(SILVER_PATH)
)
# ─── Verify: only one active row per customer ─────────────────────────────────
df_silver = spark.read.format("delta").load(SILVER_PATH)
df_active_counts = (
df_silver
.filter(F.col("is_current") == True)
.groupBy("customer_id")
.agg(F.count("*").alias("active_row_count"))
)
duplicates = df_active_counts.filter(F.col("active_row_count") > 1).count()
assert duplicates == 0, f"SCD2 violation: {duplicates} customers have multiple active rows"
print("SCD2 Silver layer validated. Schema:")
df_silver.printSchema()
df_silver.orderBy("customer_id", "start_date").show(10, truncate=False)Key Takeaways
- Use
F.col()style for complex expressions; string names are fine for simple selects. - Prefer built-in functions over UDFs. When you must use a UDF, use pandas UDF (vectorized) not Python UDF (row-by-row).
- Broadcast small tables explicitly with
F.broadcast()to avoid shuffle joins. - Window functions with
partitionBy + orderBygive you running totals, lag/lead, and deduplication without collapsing rows. - SCD Type 2 on Delta Lake requires a two-pass approach: close old rows, then insert new ones.
- Always validate your SCD logic with an assertion that each entity has exactly one active row.