PySpark DataFrames & Spark SQL: Transformations, Joins, and Window Functions
Deep dive into PySpark DataFrame operations ā UDFs, built-in functions, all join types, broadcast joins, window functions, and a real Silver-layer SCD Type 2 transformation.
Column References: Three Equivalent Styles
PySpark gives you three ways to reference a column. Understand all three because you'll see them all in production code.
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("DFTransformations").getOrCreate()
df = spark.read.parquet("s3://my-bucket/silver/orders/")
# Style 1: String column name ā simplest, but no IDE autocomplete
df.filter("status = 'active'")
df.select("order_id", "customer_id", "total")
# Style 2: df["column"] ā DataFrame-scoped, unambiguous in joins
df.filter(df["status"] == "active")
df.select(df["order_id"], df["customer_id"])
# Style 3: F.col() ā most common in production, chains well
df.filter(F.col("status") == "active")
df.select(F.col("order_id"), F.col("customer_id"))
# In simple cases col() and F.col() are identical
# Use F.col() to be explicit about the functions namespaceRule of thumb: Use F.col() for expressions in withColumn/filter/agg. Use string column names for simple select and drop.
Core DataFrame Transformations
# āāā select āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_slim = df.select(
"order_id",
"customer_id",
F.col("order_date").cast("date").alias("order_date"),
(F.col("subtotal") + F.col("tax")).alias("total"),
)
# āāā filter / where (identical) āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_recent = df.filter(
(F.col("order_date") >= "2026-01-01") &
(F.col("status").isin("completed", "shipped")) &
F.col("customer_id").isNotNull()
)
# āāā withColumn āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_enriched = (
df
.withColumn("year", F.year("order_date"))
.withColumn("month", F.month("order_date"))
.withColumn(
"order_size",
F.when(F.col("total") < 50, "small")
.when(F.col("total") < 200, "medium")
.otherwise("large")
)
)
# āāā drop āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_clean = df.drop("_corrupt_record", "_ingested_at", "internal_notes")
# āāā rename āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_renamed = df.withColumnRenamed("cust_id", "customer_id")
# āāā distinct / dropDuplicates āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_deduped = df.dropDuplicates(["order_id"])
df_deduped_partial = df.dropDuplicates(["customer_id", "order_date"]) # keep one per customer per dayBuilt-in Functions You'll Use Every Day
import pyspark.sql.functions as F
# āāā Conditional Logic āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df = df.withColumn(
"effective_price",
F.when(F.col("discount") > 0, F.col("price") * (1 - F.col("discount")))
.when(F.col("promo_code").isNotNull(), F.col("price") * 0.9)
.otherwise(F.col("price"))
)
# āāā Null Handling āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
# coalesce returns the first non-null value
df = df.withColumn(
"display_name",
F.coalesce(F.col("preferred_name"), F.col("full_name"), F.lit("Anonymous"))
)
# fillna / fill
df = df.fillna({"country_code": "US", "lifetime_value": 0.0})
# āāā String Functions āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df = (
df
.withColumn("email", F.lower(F.trim(F.col("email"))))
.withColumn("phone_digits", F.regexp_replace(F.col("phone"), "[^0-9]", ""))
.withColumn(
"area_code",
F.regexp_extract(F.col("phone"), r"^\(?(\d{3})\)?", 1)
)
.withColumn("name_parts", F.split(F.col("full_name"), " "))
.withColumn("first_name", F.col("name_parts").getItem(0))
.withColumn(
"name_length",
F.length(F.col("full_name"))
)
)
# āāā Date Functions āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df = (
df
.withColumn("order_year", F.year("order_date"))
.withColumn("order_month", F.month("order_date"))
.withColumn("order_quarter", F.quarter("order_date"))
.withColumn(
"order_date_str",
F.date_format(F.col("order_date"), "yyyy-MM-dd")
)
.withColumn(
"days_since_order",
F.datediff(F.current_date(), F.col("order_date"))
)
.withColumn(
"order_ts",
F.to_timestamp(F.col("order_date_str"), "yyyy-MM-dd")
)
)
# āāā Array / Struct Functions āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
# explode turns an array column into multiple rows
df_tags = df.withColumn("tag", F.explode(F.col("tags")))
# explode_outer keeps rows even when the array is null or empty
df_tags_safe = df.withColumn("tag", F.explode_outer(F.col("tags")))
# collect items into an array during aggregation
df_customer_tags = (
df_tags
.groupBy("customer_id")
.agg(F.collect_list("tag").alias("all_tags"))
)
# named_struct creates a struct column
df_address = df.withColumn(
"address",
F.struct(
F.col("street").alias("street"),
F.col("city").alias("city"),
F.col("zip").alias("zip"),
)
)GroupBy and Aggregations
from pyspark.sql import functions as F
# āāā Basic groupBy + agg āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_summary = (
df_orders
.groupBy("customer_id", "year", "month")
.agg(
F.count("*").alias("order_count"),
F.sum("total").alias("total_revenue"),
F.avg("total").alias("avg_order_value"),
F.min("order_date").alias("first_order_date"),
F.max("order_date").alias("last_order_date"),
F.countDistinct("product_id").alias("unique_products"),
F.collect_list("order_id").alias("order_ids"),
)
)
# āāā Pivot āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_pivot = (
df_orders
.groupBy("customer_id")
.pivot("status", ["pending", "completed", "cancelled"])
.agg(F.count("order_id"))
.fillna(0)
)
# Result: customer_id | pending | completed | cancelled
# āāā Multiple aggregations on same column āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_revenue_stats = (
df_orders
.groupBy("country_code")
.agg(
F.expr("percentile_approx(total, 0.5)").alias("median_total"),
F.expr("percentile_approx(total, 0.95)").alias("p95_total"),
F.stddev("total").alias("stddev_total"),
)
)User-Defined Functions (UDFs)
Python UDF ā Convenient but Slow
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import re
# Python UDF: each row crosses the JVM boundary ā serialization overhead
def normalize_phone(phone: str) -> str:
if not phone:
return None
digits = re.sub(r"\D", "", phone)
if len(digits) == 10:
return f"+1{digits}"
elif len(digits) == 11 and digits.startswith("1"):
return f"+{digits}"
return None
normalize_phone_udf = udf(normalize_phone, StringType())
# Register and use
df = df.withColumn("phone_normalized", normalize_phone_udf(F.col("phone")))Performance warning: Python UDFs serialize each row to Python, call your function, then serialize results back to the JVM. For large datasets this is 10ā100x slower than built-in functions.
Pandas UDF (Vectorized) ā Fast
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType, DoubleType
import pandas as pd
import re
# Pandas UDF processes entire columnar batches ā no per-row JVM boundary
@pandas_udf(StringType())
def normalize_phone_fast(phone_series: pd.Series) -> pd.Series:
def _normalize(phone):
if not phone or pd.isna(phone):
return None
digits = re.sub(r"\D", "", str(phone))
if len(digits) == 10:
return f"+1{digits}"
elif len(digits) == 11 and digits.startswith("1"):
return f"+{digits}"
return None
return phone_series.apply(_normalize)
@pandas_udf(DoubleType())
def revenue_score(revenue: pd.Series, order_count: pd.Series) -> pd.Series:
"""Composite score: revenue per order, log-scaled."""
import numpy as np
ratio = revenue / order_count.replace(0, 1)
return np.log1p(ratio)
# Use exactly like built-in functions
df = df.withColumn("phone_normalized", normalize_phone_fast(F.col("phone")))
df = df.withColumn(
"revenue_score",
revenue_score(F.col("lifetime_value"), F.col("order_count"))
)Rule: Always prefer built-in F.* functions. Use pandas UDFs for complex logic that has no built-in equivalent. Avoid Python UDFs in production ā they are always slower than both alternatives.
Join Types and Strategies
# āāā Inner Join (default) āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_joined = df_orders.join(df_customers, on="customer_id", how="inner")
# āāā Left Join āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_left = df_orders.join(df_customers, on="customer_id", how="left")
# āāā Semi Join (filter orders where matching customer exists, no customer cols)
df_semi = df_orders.join(df_customers, on="customer_id", how="left_semi")
# āāā Anti Join (orders with NO matching customer ā orphan detection)
df_anti = df_orders.join(df_customers, on="customer_id", how="left_anti")
# āāā Multi-column join āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_multi = df_orders.join(
df_returns,
on=["order_id", "product_id"],
how="left"
)
# āāā Non-equi join (requires column expression syntax) āāāāāāāāāāāāāāāāāāāāāāāā
df_range = df_transactions.join(
df_date_ranges,
on=(
(df_transactions["txn_date"] >= df_date_ranges["period_start"]) &
(df_transactions["txn_date"] < df_date_ranges["period_end"])
),
how="inner"
)
# āāā Resolving ambiguous columns after join āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
# When both DataFrames have "created_at", use aliases
df_ord = df_orders.alias("ord")
df_cust = df_customers.alias("cust")
df_joined = df_ord.join(df_cust, F.col("ord.customer_id") == F.col("cust.customer_id"))
df_result = df_joined.select(
F.col("ord.order_id"),
F.col("cust.email"),
F.col("ord.created_at").alias("order_created_at"),
F.col("cust.created_at").alias("customer_created_at"),
)Broadcast Joins for Small Tables
# When one side of a join fits in memory (< ~10MB by default),
# Spark can broadcast it to every executor ā eliminates the shuffle
df_orders_enriched = df_orders.join(
F.broadcast(df_country_lookup), # broadcast the small lookup table
on="country_code",
how="left"
)
# Configure the auto-broadcast threshold (bytes):
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 50 * 1024 * 1024) # 50MB
# Disable auto-broadcast (useful for debugging join strategies):
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
# Verify Spark chose a BroadcastHashJoin:
df_orders_enriched.explain()
# Look for: BroadcastHashJoin in the physical planWindow Functions
Window functions operate over a "window" of rows relative to the current row, without collapsing the DataFrame like groupBy does.
from pyspark.sql.window import Window
from pyspark.sql import functions as F
# āāā Define window specs āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
# Partition by customer, order by date ascending
w_customer_time = (
Window
.partitionBy("customer_id")
.orderBy("order_date")
)
# Partition by customer only (for ranking across all time)
w_customer = Window.partitionBy("customer_id")
# Rolling 30-day window (requires timestamp or numeric ordering)
w_rolling_30d = (
Window
.partitionBy("customer_id")
.orderBy(F.col("order_date").cast("long"))
.rangeBetween(-30 * 86400, 0) # 30 days in seconds, current row is 0
)
# āāā Ranking Functions āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_ranked = df_orders.withColumn(
"order_rank", # dense_rank: no gaps in ranking numbers
F.dense_rank().over(w_customer_time)
)
df_top = df_orders.withColumn(
"rank_by_value", # rank: gaps when tied
F.rank().over(Window.partitionBy("customer_id").orderBy(F.col("total").desc()))
)
# Get each customer's single most recent order
df_latest = (
df_orders
.withColumn("row_num", F.row_number().over(
Window.partitionBy("customer_id").orderBy(F.col("order_date").desc())
))
.filter(F.col("row_num") == 1)
.drop("row_num")
)
# āāā Lag and Lead āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_with_prev = (
df_orders
.withColumn(
"prev_order_date",
F.lag("order_date", 1).over(w_customer_time)
)
.withColumn(
"next_order_date",
F.lead("order_date", 1).over(w_customer_time)
)
.withColumn(
"days_since_last_order",
F.datediff(F.col("order_date"), F.col("prev_order_date"))
)
)
# āāā Running Aggregations āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_running = df_orders.withColumn(
"cumulative_revenue",
F.sum("total").over(w_customer_time) # running sum within customer
)
df_running = df_running.withColumn(
"pct_of_customer_total",
F.col("total") / F.sum("total").over(w_customer)
)Spark SQL with Temporary Views
# Register a DataFrame as a SQL view
df_orders.createOrReplaceTempView("orders")
df_customers.createOrReplaceTempView("customers")
# Full SQL ā useful when logic is cleaner in SQL than DataFrame API
df_sql_result = spark.sql("""
SELECT
c.country_code,
DATE_FORMAT(o.order_date, 'yyyy-MM') AS month,
COUNT(DISTINCT o.customer_id) AS active_customers,
SUM(o.total) AS total_revenue,
AVG(o.total) AS avg_order_value,
PERCENTILE_APPROX(o.total, 0.5) AS median_order_value
FROM orders o
INNER JOIN customers c USING (customer_id)
WHERE o.status = 'completed'
AND o.order_date >= '2026-01-01'
GROUP BY 1, 2
ORDER BY 1, 2
""")
df_sql_result.show(20)
# Global temp views persist across SparkSessions (useful in notebooks)
df_orders.createOrReplaceGlobalTempView("orders_global")
spark.sql("SELECT * FROM global_temp.orders_global LIMIT 5").show()Delta Lake MERGE (Upsert)
from delta.tables import DeltaTable
# DeltaTable.forPath gives you the merge builder
delta_table = DeltaTable.forPath(spark, "s3://my-bucket/delta/customers/")
df_updates = spark.read.parquet("s3://my-bucket/staging/customer_updates/")
(
delta_table.alias("target")
.merge(
source=df_updates.alias("source"),
condition="target.customer_id = source.customer_id"
)
.whenMatchedUpdate(set={
"email": "source.email",
"full_name": "source.full_name",
"lifetime_value": "source.lifetime_value",
"updated_at": "source.updated_at",
})
.whenNotMatchedInsert(values={
"customer_id": "source.customer_id",
"email": "source.email",
"full_name": "source.full_name",
"lifetime_value": "source.lifetime_value",
"created_at": "source.created_at",
"updated_at": "source.updated_at",
})
.execute()
)Complete Example: Silver Layer with SCD Type 2
SCD Type 2 (Slowly Changing Dimension Type 2) preserves full history by closing the old row and inserting a new one when a record changes.
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.window import Window
from delta.tables import DeltaTable
from pyspark.sql.types import (
StructType, StructField, IntegerType, StringType,
DoubleType, TimestampType, BooleanType
)
spark = (
SparkSession.builder
.appName("SilverSCD2Pipeline")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config("spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog")
.getOrCreate()
)
SILVER_PATH = "s3://my-bucket/delta/dim_customers_scd2/"
STAGING_PATH = "s3://my-bucket/staging/customer_updates/"
# āāā Read incoming changes āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_incoming = spark.read.parquet(STAGING_PATH)
# Deduplicate incoming: keep most recent change per customer
w_latest = Window.partitionBy("customer_id").orderBy(F.col("updated_at").desc())
df_latest_incoming = (
df_incoming
.withColumn("_rn", F.row_number().over(w_latest))
.filter(F.col("_rn") == 1)
.drop("_rn")
)
# āāā Identify what actually changed (hash-based change detection) āāāāāāāāāāāāā
TRACKED_COLS = ["email", "full_name", "country_code", "tier"]
df_with_hash = df_latest_incoming.withColumn(
"row_hash",
F.md5(F.concat_ws("|", *[F.col(c).cast("string") for c in TRACKED_COLS]))
)
# āāā SCD Type 2 MERGE logic āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
# When a tracked column changes:
# 1. Close the existing active row (set end_date = now, is_current = false)
# 2. Insert a new active row with new values
#
# Delta MERGE cannot do "close old + insert new" in one statement.
# Standard approach: two-pass or use MERGE with whenMatchedUpdateAll + whenNotMatchedInsert
NOW = F.current_timestamp()
# Step 1: Close changed active rows in the target
if DeltaTable.isDeltaTable(spark, SILVER_PATH):
delta_target = DeltaTable.forPath(spark, SILVER_PATH)
# Identify customer_ids whose hash changed
df_existing = delta_target.toDF().filter(F.col("is_current") == True)
df_changed = (
df_existing.alias("existing")
.join(
df_with_hash.alias("incoming"),
on="customer_id",
how="inner"
)
.filter(
F.col("existing.row_hash") != F.col("incoming.row_hash")
)
.select(F.col("existing.customer_id"))
)
changed_ids = [row["customer_id"] for row in df_changed.collect()]
if changed_ids:
# Close the old rows
(
delta_target.alias("target")
.merge(
source=df_with_hash.filter(
F.col("customer_id").isin(changed_ids)
).alias("source"),
condition=(
"target.customer_id = source.customer_id AND target.is_current = true"
)
)
.whenMatchedUpdate(set={
"end_date": "current_timestamp()",
"is_current": "false",
})
.execute()
)
# Step 2: Insert new rows for changed and net-new customers
df_existing_current = (
DeltaTable.forPath(spark, SILVER_PATH).toDF()
.filter(F.col("is_current") == True)
.select("customer_id", "row_hash")
) if DeltaTable.isDeltaTable(spark, SILVER_PATH) else spark.createDataFrame([], schema=StructType([
StructField("customer_id", IntegerType()),
StructField("row_hash", StringType()),
]))
df_to_insert = (
df_with_hash.alias("inc")
.join(
df_existing_current.alias("cur"),
on="customer_id",
how="left"
)
.filter(
F.col("cur.customer_id").isNull() | # net new
(F.col("inc.row_hash") != F.col("cur.row_hash")) # changed
)
.select("inc.*")
.withColumn("start_date", NOW)
.withColumn("end_date", F.lit(None).cast(TimestampType()))
.withColumn("is_current", F.lit(True))
)
(
df_to_insert
.write
.format("delta")
.mode("append")
.save(SILVER_PATH)
)
# āāā Verify: only one active row per customer āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
df_silver = spark.read.format("delta").load(SILVER_PATH)
df_active_counts = (
df_silver
.filter(F.col("is_current") == True)
.groupBy("customer_id")
.agg(F.count("*").alias("active_row_count"))
)
duplicates = df_active_counts.filter(F.col("active_row_count") > 1).count()
assert duplicates == 0, f"SCD2 violation: {duplicates} customers have multiple active rows"
print("SCD2 Silver layer validated. Schema:")
df_silver.printSchema()
df_silver.orderBy("customer_id", "start_date").show(10, truncate=False)Key Takeaways
- Use
F.col()style for complex expressions; string names are fine for simple selects. - Prefer built-in functions over UDFs. When you must use a UDF, use pandas UDF (vectorized) not Python UDF (row-by-row).
- Broadcast small tables explicitly with
F.broadcast()to avoid shuffle joins. - Window functions with
partitionBy + orderBygive you running totals, lag/lead, and deduplication without collapsing rows. - SCD Type 2 on Delta Lake requires a two-pass approach: close old rows, then insert new ones.
- Always validate your SCD logic with an assertion that each entity has exactly one active row.
Enjoyed this article?
Explore the Data Engineering learning path for more.
Found this helpful?
Leave a comment
Have a question, correction, or just found this helpful? Leave a note below.