Back to blog
Data Engineeringintermediate

PySpark DataFrames & Spark SQL: Transformations, Joins, and Window Functions

Deep dive into PySpark DataFrame operations — UDFs, built-in functions, all join types, broadcast joins, window functions, and a real Silver-layer SCD Type 2 transformation.

LearnixoMay 7, 202610 min read
PySparkApache SparkDataFramesSpark SQLWindow FunctionsDelta LakeSCD Type 2
Share:š•

Column References: Three Equivalent Styles

PySpark gives you three ways to reference a column. Understand all three because you'll see them all in production code.

Python
from pyspark.sql import functions as F
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("DFTransformations").getOrCreate()

df = spark.read.parquet("s3://my-bucket/silver/orders/")

# Style 1: String column name — simplest, but no IDE autocomplete
df.filter("status = 'active'")
df.select("order_id", "customer_id", "total")

# Style 2: df["column"] — DataFrame-scoped, unambiguous in joins
df.filter(df["status"] == "active")
df.select(df["order_id"], df["customer_id"])

# Style 3: F.col() — most common in production, chains well
df.filter(F.col("status") == "active")
df.select(F.col("order_id"), F.col("customer_id"))

# In simple cases col() and F.col() are identical
# Use F.col() to be explicit about the functions namespace

Rule of thumb: Use F.col() for expressions in withColumn/filter/agg. Use string column names for simple select and drop.

Core DataFrame Transformations

Python
# ─── select ───────────────────────────────────────────────────────────────────
df_slim = df.select(
    "order_id",
    "customer_id",
    F.col("order_date").cast("date").alias("order_date"),
    (F.col("subtotal") + F.col("tax")).alias("total"),
)

# ─── filter / where (identical) ───────────────────────────────────────────────
df_recent = df.filter(
    (F.col("order_date") >= "2026-01-01") &
    (F.col("status").isin("completed", "shipped")) &
    F.col("customer_id").isNotNull()
)

# ─── withColumn ───────────────────────────────────────────────────────────────
df_enriched = (
    df
    .withColumn("year",  F.year("order_date"))
    .withColumn("month", F.month("order_date"))
    .withColumn(
        "order_size",
        F.when(F.col("total") < 50,  "small")
         .when(F.col("total") < 200, "medium")
         .otherwise("large")
    )
)

# ─── drop ─────────────────────────────────────────────────────────────────────
df_clean = df.drop("_corrupt_record", "_ingested_at", "internal_notes")

# ─── rename ───────────────────────────────────────────────────────────────────
df_renamed = df.withColumnRenamed("cust_id", "customer_id")

# ─── distinct / dropDuplicates ────────────────────────────────────────────────
df_deduped = df.dropDuplicates(["order_id"])
df_deduped_partial = df.dropDuplicates(["customer_id", "order_date"])  # keep one per customer per day

Built-in Functions You'll Use Every Day

Python
import pyspark.sql.functions as F

# ─── Conditional Logic ────────────────────────────────────────────────────────
df = df.withColumn(
    "effective_price",
    F.when(F.col("discount") > 0, F.col("price") * (1 - F.col("discount")))
     .when(F.col("promo_code").isNotNull(), F.col("price") * 0.9)
     .otherwise(F.col("price"))
)

# ─── Null Handling ────────────────────────────────────────────────────────────
# coalesce returns the first non-null value
df = df.withColumn(
    "display_name",
    F.coalesce(F.col("preferred_name"), F.col("full_name"), F.lit("Anonymous"))
)

# fillna / fill
df = df.fillna({"country_code": "US", "lifetime_value": 0.0})

# ─── String Functions ─────────────────────────────────────────────────────────
df = (
    df
    .withColumn("email", F.lower(F.trim(F.col("email"))))
    .withColumn("phone_digits", F.regexp_replace(F.col("phone"), "[^0-9]", ""))
    .withColumn(
        "area_code",
        F.regexp_extract(F.col("phone"), r"^\(?(\d{3})\)?", 1)
    )
    .withColumn("name_parts", F.split(F.col("full_name"), " "))
    .withColumn("first_name", F.col("name_parts").getItem(0))
    .withColumn(
        "name_length",
        F.length(F.col("full_name"))
    )
)

# ─── Date Functions ───────────────────────────────────────────────────────────
df = (
    df
    .withColumn("order_year",    F.year("order_date"))
    .withColumn("order_month",   F.month("order_date"))
    .withColumn("order_quarter", F.quarter("order_date"))
    .withColumn(
        "order_date_str",
        F.date_format(F.col("order_date"), "yyyy-MM-dd")
    )
    .withColumn(
        "days_since_order",
        F.datediff(F.current_date(), F.col("order_date"))
    )
    .withColumn(
        "order_ts",
        F.to_timestamp(F.col("order_date_str"), "yyyy-MM-dd")
    )
)

# ─── Array / Struct Functions ─────────────────────────────────────────────────
# explode turns an array column into multiple rows
df_tags = df.withColumn("tag", F.explode(F.col("tags")))

# explode_outer keeps rows even when the array is null or empty
df_tags_safe = df.withColumn("tag", F.explode_outer(F.col("tags")))

# collect items into an array during aggregation
df_customer_tags = (
    df_tags
    .groupBy("customer_id")
    .agg(F.collect_list("tag").alias("all_tags"))
)

# named_struct creates a struct column
df_address = df.withColumn(
    "address",
    F.struct(
        F.col("street").alias("street"),
        F.col("city").alias("city"),
        F.col("zip").alias("zip"),
    )
)

GroupBy and Aggregations

Python
from pyspark.sql import functions as F

# ─── Basic groupBy + agg ──────────────────────────────────────────────────────
df_summary = (
    df_orders
    .groupBy("customer_id", "year", "month")
    .agg(
        F.count("*").alias("order_count"),
        F.sum("total").alias("total_revenue"),
        F.avg("total").alias("avg_order_value"),
        F.min("order_date").alias("first_order_date"),
        F.max("order_date").alias("last_order_date"),
        F.countDistinct("product_id").alias("unique_products"),
        F.collect_list("order_id").alias("order_ids"),
    )
)

# ─── Pivot ────────────────────────────────────────────────────────────────────
df_pivot = (
    df_orders
    .groupBy("customer_id")
    .pivot("status", ["pending", "completed", "cancelled"])
    .agg(F.count("order_id"))
    .fillna(0)
)
# Result: customer_id | pending | completed | cancelled

# ─── Multiple aggregations on same column ─────────────────────────────────────
df_revenue_stats = (
    df_orders
    .groupBy("country_code")
    .agg(
        F.expr("percentile_approx(total, 0.5)").alias("median_total"),
        F.expr("percentile_approx(total, 0.95)").alias("p95_total"),
        F.stddev("total").alias("stddev_total"),
    )
)

User-Defined Functions (UDFs)

Python UDF — Convenient but Slow

Python
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import re

# Python UDF: each row crosses the JVM boundary — serialization overhead
def normalize_phone(phone: str) -> str:
    if not phone:
        return None
    digits = re.sub(r"\D", "", phone)
    if len(digits) == 10:
        return f"+1{digits}"
    elif len(digits) == 11 and digits.startswith("1"):
        return f"+{digits}"
    return None

normalize_phone_udf = udf(normalize_phone, StringType())

# Register and use
df = df.withColumn("phone_normalized", normalize_phone_udf(F.col("phone")))

Performance warning: Python UDFs serialize each row to Python, call your function, then serialize results back to the JVM. For large datasets this is 10–100x slower than built-in functions.

Pandas UDF (Vectorized) — Fast

Python
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType, DoubleType
import pandas as pd
import re

# Pandas UDF processes entire columnar batches — no per-row JVM boundary
@pandas_udf(StringType())
def normalize_phone_fast(phone_series: pd.Series) -> pd.Series:
    def _normalize(phone):
        if not phone or pd.isna(phone):
            return None
        digits = re.sub(r"\D", "", str(phone))
        if len(digits) == 10:
            return f"+1{digits}"
        elif len(digits) == 11 and digits.startswith("1"):
            return f"+{digits}"
        return None
    return phone_series.apply(_normalize)

@pandas_udf(DoubleType())
def revenue_score(revenue: pd.Series, order_count: pd.Series) -> pd.Series:
    """Composite score: revenue per order, log-scaled."""
    import numpy as np
    ratio = revenue / order_count.replace(0, 1)
    return np.log1p(ratio)

# Use exactly like built-in functions
df = df.withColumn("phone_normalized", normalize_phone_fast(F.col("phone")))
df = df.withColumn(
    "revenue_score",
    revenue_score(F.col("lifetime_value"), F.col("order_count"))
)

Rule: Always prefer built-in F.* functions. Use pandas UDFs for complex logic that has no built-in equivalent. Avoid Python UDFs in production — they are always slower than both alternatives.

Join Types and Strategies

Python
# ─── Inner Join (default) ─────────────────────────────────────────────────────
df_joined = df_orders.join(df_customers, on="customer_id", how="inner")

# ─── Left Join ────────────────────────────────────────────────────────────────
df_left = df_orders.join(df_customers, on="customer_id", how="left")

# ─── Semi Join (filter orders where matching customer exists, no customer cols)
df_semi = df_orders.join(df_customers, on="customer_id", how="left_semi")

# ─── Anti Join (orders with NO matching customer — orphan detection)
df_anti = df_orders.join(df_customers, on="customer_id", how="left_anti")

# ─── Multi-column join ────────────────────────────────────────────────────────
df_multi = df_orders.join(
    df_returns,
    on=["order_id", "product_id"],
    how="left"
)

# ─── Non-equi join (requires column expression syntax) ────────────────────────
df_range = df_transactions.join(
    df_date_ranges,
    on=(
        (df_transactions["txn_date"] >= df_date_ranges["period_start"]) &
        (df_transactions["txn_date"] <  df_date_ranges["period_end"])
    ),
    how="inner"
)

# ─── Resolving ambiguous columns after join ────────────────────────────────────
# When both DataFrames have "created_at", use aliases
df_ord = df_orders.alias("ord")
df_cust = df_customers.alias("cust")

df_joined = df_ord.join(df_cust, F.col("ord.customer_id") == F.col("cust.customer_id"))
df_result = df_joined.select(
    F.col("ord.order_id"),
    F.col("cust.email"),
    F.col("ord.created_at").alias("order_created_at"),
    F.col("cust.created_at").alias("customer_created_at"),
)

Broadcast Joins for Small Tables

Python
# When one side of a join fits in memory (< ~10MB by default),
# Spark can broadcast it to every executor — eliminates the shuffle

df_orders_enriched = df_orders.join(
    F.broadcast(df_country_lookup),    # broadcast the small lookup table
    on="country_code",
    how="left"
)

# Configure the auto-broadcast threshold (bytes):
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 50 * 1024 * 1024)  # 50MB

# Disable auto-broadcast (useful for debugging join strategies):
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Verify Spark chose a BroadcastHashJoin:
df_orders_enriched.explain()
# Look for: BroadcastHashJoin in the physical plan

Window Functions

Window functions operate over a "window" of rows relative to the current row, without collapsing the DataFrame like groupBy does.

Python
from pyspark.sql.window import Window
from pyspark.sql import functions as F

# ─── Define window specs ──────────────────────────────────────────────────────

# Partition by customer, order by date ascending
w_customer_time = (
    Window
    .partitionBy("customer_id")
    .orderBy("order_date")
)

# Partition by customer only (for ranking across all time)
w_customer = Window.partitionBy("customer_id")

# Rolling 30-day window (requires timestamp or numeric ordering)
w_rolling_30d = (
    Window
    .partitionBy("customer_id")
    .orderBy(F.col("order_date").cast("long"))
    .rangeBetween(-30 * 86400, 0)  # 30 days in seconds, current row is 0
)

# ─── Ranking Functions ────────────────────────────────────────────────────────
df_ranked = df_orders.withColumn(
    "order_rank",          # dense_rank: no gaps in ranking numbers
    F.dense_rank().over(w_customer_time)
)

df_top = df_orders.withColumn(
    "rank_by_value",       # rank: gaps when tied
    F.rank().over(Window.partitionBy("customer_id").orderBy(F.col("total").desc()))
)

# Get each customer's single most recent order
df_latest = (
    df_orders
    .withColumn("row_num", F.row_number().over(
        Window.partitionBy("customer_id").orderBy(F.col("order_date").desc())
    ))
    .filter(F.col("row_num") == 1)
    .drop("row_num")
)

# ─── Lag and Lead ─────────────────────────────────────────────────────────────
df_with_prev = (
    df_orders
    .withColumn(
        "prev_order_date",
        F.lag("order_date", 1).over(w_customer_time)
    )
    .withColumn(
        "next_order_date",
        F.lead("order_date", 1).over(w_customer_time)
    )
    .withColumn(
        "days_since_last_order",
        F.datediff(F.col("order_date"), F.col("prev_order_date"))
    )
)

# ─── Running Aggregations ─────────────────────────────────────────────────────
df_running = df_orders.withColumn(
    "cumulative_revenue",
    F.sum("total").over(w_customer_time)  # running sum within customer
)

df_running = df_running.withColumn(
    "pct_of_customer_total",
    F.col("total") / F.sum("total").over(w_customer)
)

Spark SQL with Temporary Views

Python
# Register a DataFrame as a SQL view
df_orders.createOrReplaceTempView("orders")
df_customers.createOrReplaceTempView("customers")

# Full SQL — useful when logic is cleaner in SQL than DataFrame API
df_sql_result = spark.sql("""
    SELECT
        c.country_code,
        DATE_FORMAT(o.order_date, 'yyyy-MM') AS month,
        COUNT(DISTINCT o.customer_id)        AS active_customers,
        SUM(o.total)                         AS total_revenue,
        AVG(o.total)                         AS avg_order_value,
        PERCENTILE_APPROX(o.total, 0.5)      AS median_order_value
    FROM orders o
    INNER JOIN customers c USING (customer_id)
    WHERE o.status = 'completed'
      AND o.order_date >= '2026-01-01'
    GROUP BY 1, 2
    ORDER BY 1, 2
""")

df_sql_result.show(20)

# Global temp views persist across SparkSessions (useful in notebooks)
df_orders.createOrReplaceGlobalTempView("orders_global")
spark.sql("SELECT * FROM global_temp.orders_global LIMIT 5").show()

Delta Lake MERGE (Upsert)

Python
from delta.tables import DeltaTable

# DeltaTable.forPath gives you the merge builder
delta_table = DeltaTable.forPath(spark, "s3://my-bucket/delta/customers/")

df_updates = spark.read.parquet("s3://my-bucket/staging/customer_updates/")

(
    delta_table.alias("target")
    .merge(
        source=df_updates.alias("source"),
        condition="target.customer_id = source.customer_id"
    )
    .whenMatchedUpdate(set={
        "email":          "source.email",
        "full_name":      "source.full_name",
        "lifetime_value": "source.lifetime_value",
        "updated_at":     "source.updated_at",
    })
    .whenNotMatchedInsert(values={
        "customer_id":    "source.customer_id",
        "email":          "source.email",
        "full_name":      "source.full_name",
        "lifetime_value": "source.lifetime_value",
        "created_at":     "source.created_at",
        "updated_at":     "source.updated_at",
    })
    .execute()
)

Complete Example: Silver Layer with SCD Type 2

SCD Type 2 (Slowly Changing Dimension Type 2) preserves full history by closing the old row and inserting a new one when a record changes.

Python
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.window import Window
from delta.tables import DeltaTable
from pyspark.sql.types import (
    StructType, StructField, IntegerType, StringType,
    DoubleType, TimestampType, BooleanType
)

spark = (
    SparkSession.builder
    .appName("SilverSCD2Pipeline")
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
    .config("spark.sql.catalog.spark_catalog",
            "org.apache.spark.sql.delta.catalog.DeltaCatalog")
    .getOrCreate()
)

SILVER_PATH = "s3://my-bucket/delta/dim_customers_scd2/"
STAGING_PATH = "s3://my-bucket/staging/customer_updates/"


# ─── Read incoming changes ────────────────────────────────────────────────────

df_incoming = spark.read.parquet(STAGING_PATH)

# Deduplicate incoming: keep most recent change per customer
w_latest = Window.partitionBy("customer_id").orderBy(F.col("updated_at").desc())

df_latest_incoming = (
    df_incoming
    .withColumn("_rn", F.row_number().over(w_latest))
    .filter(F.col("_rn") == 1)
    .drop("_rn")
)


# ─── Identify what actually changed (hash-based change detection) ─────────────

TRACKED_COLS = ["email", "full_name", "country_code", "tier"]

df_with_hash = df_latest_incoming.withColumn(
    "row_hash",
    F.md5(F.concat_ws("|", *[F.col(c).cast("string") for c in TRACKED_COLS]))
)


# ─── SCD Type 2 MERGE logic ───────────────────────────────────────────────────
# When a tracked column changes:
#   1. Close the existing active row (set end_date = now, is_current = false)
#   2. Insert a new active row with new values
#
# Delta MERGE cannot do "close old + insert new" in one statement.
# Standard approach: two-pass or use MERGE with whenMatchedUpdateAll + whenNotMatchedInsert

NOW = F.current_timestamp()

# Step 1: Close changed active rows in the target
if DeltaTable.isDeltaTable(spark, SILVER_PATH):
    delta_target = DeltaTable.forPath(spark, SILVER_PATH)

    # Identify customer_ids whose hash changed
    df_existing = delta_target.toDF().filter(F.col("is_current") == True)

    df_changed = (
        df_existing.alias("existing")
        .join(
            df_with_hash.alias("incoming"),
            on="customer_id",
            how="inner"
        )
        .filter(
            F.col("existing.row_hash") != F.col("incoming.row_hash")
        )
        .select(F.col("existing.customer_id"))
    )

    changed_ids = [row["customer_id"] for row in df_changed.collect()]

    if changed_ids:
        # Close the old rows
        (
            delta_target.alias("target")
            .merge(
                source=df_with_hash.filter(
                    F.col("customer_id").isin(changed_ids)
                ).alias("source"),
                condition=(
                    "target.customer_id = source.customer_id AND target.is_current = true"
                )
            )
            .whenMatchedUpdate(set={
                "end_date":   "current_timestamp()",
                "is_current": "false",
            })
            .execute()
        )

# Step 2: Insert new rows for changed and net-new customers
df_existing_current = (
    DeltaTable.forPath(spark, SILVER_PATH).toDF()
    .filter(F.col("is_current") == True)
    .select("customer_id", "row_hash")
) if DeltaTable.isDeltaTable(spark, SILVER_PATH) else spark.createDataFrame([], schema=StructType([
    StructField("customer_id", IntegerType()),
    StructField("row_hash", StringType()),
]))

df_to_insert = (
    df_with_hash.alias("inc")
    .join(
        df_existing_current.alias("cur"),
        on="customer_id",
        how="left"
    )
    .filter(
        F.col("cur.customer_id").isNull() |              # net new
        (F.col("inc.row_hash") != F.col("cur.row_hash")) # changed
    )
    .select("inc.*")
    .withColumn("start_date",  NOW)
    .withColumn("end_date",    F.lit(None).cast(TimestampType()))
    .withColumn("is_current",  F.lit(True))
)

(
    df_to_insert
    .write
    .format("delta")
    .mode("append")
    .save(SILVER_PATH)
)


# ─── Verify: only one active row per customer ─────────────────────────────────

df_silver = spark.read.format("delta").load(SILVER_PATH)

df_active_counts = (
    df_silver
    .filter(F.col("is_current") == True)
    .groupBy("customer_id")
    .agg(F.count("*").alias("active_row_count"))
)

duplicates = df_active_counts.filter(F.col("active_row_count") > 1).count()
assert duplicates == 0, f"SCD2 violation: {duplicates} customers have multiple active rows"

print("SCD2 Silver layer validated. Schema:")
df_silver.printSchema()
df_silver.orderBy("customer_id", "start_date").show(10, truncate=False)

Key Takeaways

  • Use F.col() style for complex expressions; string names are fine for simple selects.
  • Prefer built-in functions over UDFs. When you must use a UDF, use pandas UDF (vectorized) not Python UDF (row-by-row).
  • Broadcast small tables explicitly with F.broadcast() to avoid shuffle joins.
  • Window functions with partitionBy + orderBy give you running totals, lag/lead, and deduplication without collapsing rows.
  • SCD Type 2 on Delta Lake requires a two-pass approach: close old rows, then insert new ones.
  • Always validate your SCD logic with an assertion that each entity has exactly one active row.

Enjoyed this article?

Explore the Data Engineering learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.