Back to blog
Data Engineeringadvanced

PySpark Performance Optimization: Partitions, Skew, AQE, and Delta Tuning

Diagnose and fix slow Spark jobs — understand the Spark UI, tune partitions, eliminate skew with salting, use AQE, leverage Delta Lake optimizations, and read explain() plans like a pro.

LearnixoMay 7, 202614 min read
PySparkApache SparkPerformanceOptimizationDelta LakeAQEData Skew
Share:š•

Why Your Spark Job Is Slow

There are exactly five root causes for slow Spark jobs:

  1. Too many or too few partitions — Too many creates task scheduling overhead and small file problems. Too few means underutilized parallelism.
  2. Shuffle — Moving data across the network between stages is the single most expensive operation.
  3. Data skew — When 5% of partitions hold 80% of the data, 95% of your executors sit idle while 5% work.
  4. Unnecessary recomputation — Not caching DataFrames that are used multiple times.
  5. Reading too much data — Not using partition pruning or column pruning.

Everything below is a fix for one of these five problems.

Understanding the Spark UI

The Spark UI (port 4040 by default) is your primary diagnostic tool. Never guess at performance problems — read the UI first.

Jobs → Stages → Tasks
  │        │        └── Individual task metrics: duration, GC time,
  │        │                                     shuffle read/write bytes
  │        └── Stage DAG: which stages depend on which,
  │                       where shuffles happen (ShuffleMapStage)
  └── Job timeline: which jobs ran, how long they took

What to Look For

In the Stages tab:

  • Shuffle Read/Write bytes: Large shuffle (>1GB) = a rewrite opportunity
  • Spill (disk): Data that didn't fit in memory. Usually means too-large partitions or too little executor memory
  • GC Time: If GC time is >10% of task time, you have memory pressure

In the Tasks tab (for a specific stage):

  • Task duration variance: If some tasks take 10x longer than others → data skew
  • Shuffle Read Size: Wildly unequal sizes across tasks → partition skew after a shuffle

In the SQL tab (for DataFrame queries):

  • Click on any DataFrame action to see the physical plan with actual row counts and timing at each operator
  • Look for: large row count estimates, skipped partitions

Reading explain() Output

Python
from pyspark.sql import SparkSession, functions as F

spark = SparkSession.builder.appName("Optimization").getOrCreate()

df_orders    = spark.read.parquet("s3://my-bucket/silver/orders/")
df_customers = spark.read.parquet("s3://my-bucket/silver/customers/")

df_result = (
    df_orders
    .filter(F.col("status") == "completed")
    .join(df_customers, on="customer_id", how="inner")
    .groupBy("country_code")
    .agg(F.sum("total").alias("total_revenue"))
)

# Physical plan — what Spark will actually execute
df_result.explain(mode="formatted")

Sample output and what it means:

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[country_code], functions=[sum(total)])        ← Gold aggregation
   +- Exchange hashpartitioning(country_code, 200)                   ← SHUFFLE here
      +- HashAggregate(keys=[country_code], functions=[partial_sum]) ← Partial agg
         +- Project [country_code, total]                            ← Column pruning
            +- BroadcastHashJoin [customer_id]                       ← No shuffle join!
               :- Filter (status = completed)
               :  +- FileScan parquet [customer_id, total, status]   ← Predicate pushdown
               :     PartitionFilters: []
               :     PushedFilters: [IsNotNull(customer_id), EqualTo(status,completed)]
               +- BroadcastExchange HashedRelationBroadcastMode
                  +- FileScan parquet [customer_id, country_code]    ← Only 2 cols read

Key observations from this plan:

  • BroadcastHashJoin — good, no shuffle for the join
  • PushedFilters — the status = 'completed' filter is pushed to the file reader
  • FileScan parquet [customer_id, country_code] — only 2 columns read from customers file (column pruning)
  • Exchange hashpartitioning — one shuffle for the final aggregation (unavoidable)

Partitions and Parallelism

Python
# ─── Check current partition count ───────────────────────────────────────────
print(df_orders.rdd.getNumPartitions())  # e.g., 400

# ─── Rule of thumb: 128MB – 1GB per partition ─────────────────────────────────
# If your data is 20GB and you want ~200MB partitions: 20GB / 200MB = 100 partitions

# ─── Configure default shuffle partitions ─────────────────────────────────────
# Default is 200 — fine for medium data, too high for small data
spark.conf.set("spark.sql.shuffle.partitions", "100")

# ─── repartition vs coalesce ──────────────────────────────────────────────────

# repartition(N): full shuffle — creates exactly N equal-ish partitions
# Use when: increasing partition count, or redistributing after a skewed join
df_balanced = df_orders.repartition(200)

# repartition(N, col): hash-partitions by column — co-locates same key in same partition
# Use when: preparing for a join or groupBy on the same key
df_by_customer = df_orders.repartition(200, "customer_id")

# coalesce(N): no shuffle — merges partitions on the same executor
# Use when: reducing partition count before writing (avoids small files)
# WARNING: can create unequal partitions if data isn't evenly distributed
df_small = df_orders.coalesce(10)

# ─── Setting parallelism ──────────────────────────────────────────────────────
# For RDD operations (less common):
spark.conf.set("spark.default.parallelism", "200")

# For DataFrame shuffle operations:
spark.conf.set("spark.sql.shuffle.partitions", "200")

# With AQE enabled, Spark can auto-tune this at runtime (see AQE section)

The Small Files Problem

Writing 1,000 partitions creates 1,000 files. Reading those files later requires 1,000 file-open operations — expensive in object storage.

Python
# Before writing, consolidate to a reasonable file count
(
    df_large
    .coalesce(50)           # reduce from 1,000 partitions to 50 files
    .write
    .mode("overwrite")
    .parquet("s3://my-bucket/silver/output/")
)

# Or use repartition if you need balanced sizes (shuffle is OK on write path)
(
    df_large
    .repartition(50)
    .write
    .mode("overwrite")
    .parquet("s3://my-bucket/silver/output/")
)

Minimizing Shuffle

Shuffle is the most expensive operation in Spark — data crosses the network. Every groupBy, join (non-broadcast), distinct, repartition, and orderBy causes a shuffle.

Python
# ─── Avoid multiple shuffles on the same key ──────────────────────────────────

# BAD: two separate shuffles (two groupBys on same key)
df_count   = df_orders.groupBy("customer_id").agg(F.count("*").alias("order_count"))
df_revenue = df_orders.groupBy("customer_id").agg(F.sum("total").alias("total_revenue"))
df_result  = df_count.join(df_revenue, on="customer_id")   # third shuffle!

# GOOD: one shuffle, compute both aggregations together
df_result = (
    df_orders
    .groupBy("customer_id")
    .agg(
        F.count("*").alias("order_count"),
        F.sum("total").alias("total_revenue"),
    )
)

# ─── Avoid orderBy on large DataFrames ────────────────────────────────────────
# orderBy does a full global sort (2 shuffles internally)
# If you only need top-N, use:
df_top10 = df_orders.orderBy(F.col("total").desc()).limit(10)   # still full sort!

# Better: approximate top-N without full sort (use SQL)
spark.sql("""
    SELECT * FROM orders ORDER BY total DESC LIMIT 10
""")  # Spark optimizes LIMIT + ORDER BY into a single top-N aggregation

# ─── Filter before joining ────────────────────────────────────────────────────
# Push filters as early as possible to reduce data moved in shuffle

# BAD: join everything, then filter
df_bad = df_orders.join(df_customers, "customer_id").filter(F.col("status") == "completed")

# GOOD: filter before join
df_good = (
    df_orders.filter(F.col("status") == "completed")
    .join(df_customers, "customer_id")
)

Broadcast Joins

Python
# ─── Manual broadcast ─────────────────────────────────────────────────────────
df_enriched = df_orders.join(
    F.broadcast(df_country_lookup),
    on="country_code",
    how="left"
)

# ─── Auto-broadcast threshold ─────────────────────────────────────────────────
# Spark auto-broadcasts tables smaller than this threshold (bytes)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 100 * 1024 * 1024)  # 100MB

# Disable (useful for diagnosing which joins Spark is broadcasting)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# ─── Check if broadcast is happening ──────────────────────────────────────────
df_enriched.explain()
# Look for: BroadcastHashJoin or BroadcastNestedLoopJoin in the physical plan
# Absence and SortMergeJoin = shuffle is happening

Data Skew and Salting

Data skew is when one key appears disproportionately often. After a shuffle, one task processes 80% of the data while others process 1%.

Python
# ─── Detect skew ──────────────────────────────────────────────────────────────
# Check value distribution on the join/groupBy key
df_orders.groupBy("customer_id").count().orderBy(F.col("count").desc()).show(20)
# If one customer_id has 10M rows and the average is 100 → severe skew

# ─── Salting technique ────────────────────────────────────────────────────────
# Artificially split the heavy key into N sub-keys (salt), then unsalt after aggregation

import random

SALT_FACTOR = 10  # split heavy keys into 10 buckets

# Step 1: Add a random salt to the left side
df_orders_salted = df_orders.withColumn(
    "salted_customer_id",
    F.concat(
        F.col("customer_id").cast("string"),
        F.lit("_"),
        (F.rand() * SALT_FACTOR).cast("int").cast("string")
    )
)

# Step 2: Explode the right side to match all salt values
salt_values = list(range(SALT_FACTOR))

df_customers_exploded = (
    df_customers
    .crossJoin(
        spark.createDataFrame([(i,) for i in salt_values], ["salt"])
    )
    .withColumn(
        "salted_customer_id",
        F.concat(
            F.col("customer_id").cast("string"),
            F.lit("_"),
            F.col("salt").cast("string")
        )
    )
    .drop("salt")
)

# Step 3: Join on the salted key — skew is now spread across salt_factor tasks
df_joined_salted = df_orders_salted.join(
    df_customers_exploded,
    on="salted_customer_id",
    how="left"
)

# Step 4: Drop the salt column, aggregate as normal
df_result = (
    df_joined_salted
    .drop("salted_customer_id")
    .groupBy("customer_id", "country_code")
    .agg(F.sum("total").alias("total_revenue"))
)

When to use salting: Only when the skew is so severe that some tasks take 10x longer than others. Salting adds complexity (explode + crossJoin) — use it as a last resort after trying broadcast joins.

Caching and Persistence

Python
from pyspark import StorageLevel

# ─── cache() — alias for persist(MEMORY_AND_DISK) ────────────────────────────
df_expensive = (
    spark.read.parquet("s3://my-bucket/silver/orders/")
    .join(spark.read.parquet("s3://my-bucket/silver/customers/"), "customer_id")
    .filter(F.col("status") == "completed")
)

df_expensive.cache()       # lazy — cache is populated on first action
df_expensive.count()       # triggers the cache materialization

# Now multiple downstream actions reuse the cached data
df_by_country  = df_expensive.groupBy("country_code").agg(F.sum("total"))
df_by_month    = df_expensive.groupBy(F.month("order_date")).agg(F.count("*"))
df_top_customers = df_expensive.groupBy("customer_id").agg(F.sum("total")).limit(100)

# ─── persist() with explicit storage level ────────────────────────────────────
# MEMORY_ONLY: fastest, fails if data doesn't fit in memory
df_expensive.persist(StorageLevel.MEMORY_ONLY)

# MEMORY_AND_DISK: spills to disk if memory full (default for cache())
df_expensive.persist(StorageLevel.MEMORY_AND_DISK)

# DISK_ONLY: slow, but doesn't use heap memory
df_expensive.persist(StorageLevel.DISK_ONLY)

# MEMORY_AND_DISK_SER: serialized (smaller memory footprint, more CPU)
df_expensive.persist(StorageLevel.MEMORY_AND_DISK_SER)

# OFF_HEAP: stores in Tungsten off-heap memory (no GC pressure)
df_expensive.persist(StorageLevel.OFF_HEAP)

# ─── Always unpersist when done ───────────────────────────────────────────────
df_expensive.unpersist()

# ─── When to cache ────────────────────────────────────────────────────────────
# Cache when: same DataFrame is used 2+ times in the same job
# Don't cache when: DataFrame is used only once, or it's too large to fit in memory

Predicate Pushdown and Column Pruning

These happen automatically with Parquet and Delta — but only if your schema and filter expressions allow it.

Python
# ─── Column pruning: read only the columns you need ───────────────────────────
# BAD: read all columns from a wide table, then select
df_bad = spark.read.parquet("s3://wide-table/").select("id", "name", "amount")

# GOOD: explicit select in the read path triggers column pruning in the file reader
df_good = (
    spark.read.parquet("s3://wide-table/")
    .select("id", "name", "amount")
)
# With Parquet, these are identical — Catalyst pushes the projection to the scan
# But being explicit makes the intent clear

# ─── Predicate pushdown ───────────────────────────────────────────────────────
# Filters on partition columns → partition pruning (skips entire directories)
df_jan = spark.read.parquet("s3://my-bucket/orders/").filter("year = 2026 AND month = 1")
# Spark will only open the year=2026/month=01/ subdirectory

# Filters on non-partition columns → file-level statistics pruning
# Parquet stores min/max stats for each row group — Spark skips row groups that can't match
df_large = (
    spark.read.parquet("s3://my-bucket/orders/")
    .filter(F.col("total") > 10000)
)
# verify pushdown happened:
df_large.explain()
# Look for: PushedFilters: [IsNotNull(total), GreaterThan(total,10000.0)]

# ─── Why explicit schemas matter for pushdown ─────────────────────────────────
# If you read CSV without a schema, all columns are strings
# filter(F.col("total") > 10000) on a string column does STRING comparison
# "9999" > "10000" is True in string ordering!
# Always cast or use explicit schema

Adaptive Query Execution (AQE)

AQE was introduced in Spark 3.0 and enabled by default in Spark 3.2+. It re-optimizes the physical plan at runtime using actual statistics from completed stages.

Python
# ─── Enable AQE (on by default in Spark 3.2+) ────────────────────────────────
spark.conf.set("spark.sql.adaptive.enabled", "true")

# ─── Feature 1: Dynamic partition coalescing ──────────────────────────────────
# After a shuffle, if partitions are too small, AQE merges them
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.minPartitionNum", "1")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "256mb")
# AQE will merge shuffle partitions to achieve ~256MB per partition

# ─── Feature 2: Dynamic join strategy switching ───────────────────────────────
# At runtime, if a join input is smaller than expected, AQE switches to broadcast
spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", "true")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10mb")
# AQE uses actual stats, not estimates — catches cases where estimates were wrong

# ─── Feature 3: Skew join optimization ───────────────────────────────────────
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256mb")
# AQE splits skewed partitions and replicates the matching partition from the other side
# This is automatic salting — less code, same effect

# ─── Verify AQE is working ────────────────────────────────────────────────────
df_result.explain(mode="formatted")
# Look for: AdaptiveSparkPlan isFinalPlan=true in the completed plan
# The "final" plan shows what AQE actually chose at runtime

Delta Lake Optimizations

Python
# ─── OPTIMIZE: compact small files into larger ones ──────────────────────────
spark.sql("OPTIMIZE delta.`s3://my-bucket/delta/orders/`")

# With partitioning: optimize only recent partitions (much faster)
spark.sql("""
    OPTIMIZE delta.`s3://my-bucket/delta/orders/`
    WHERE event_date >= '2026-04-01'
""")

# ─── Z-ORDER: co-locate related data in the same files ───────────────────────
# Z-ORDER dramatically improves query performance for high-cardinality filter columns
# (columns that appear in WHERE clauses but are not partition columns)
spark.sql("""
    OPTIMIZE delta.`s3://my-bucket/delta/orders/`
    ZORDER BY (customer_id, store_id)
""")
# After Z-ORDER, queries like WHERE customer_id = 'X' will skip >90% of files

# ─── VACUUM: remove old files no longer in the current snapshot ───────────────
# Delta retains old files for 7 days by default (time travel)
spark.sql("VACUUM delta.`s3://my-bucket/delta/orders/` RETAIN 168 HOURS")  # 7 days

# Force vacuum (bypass retention check — destroys time travel history)
spark.sql("""
    VACUUM delta.`s3://my-bucket/delta/orders/`
    RETAIN 0 HOURS
""")  # only do this when storage costs are the concern

# ─── Data skipping statistics ─────────────────────────────────────────────────
# Delta automatically collects min/max stats for the first 32 columns
# You can control which columns get stats:
spark.sql("""
    ALTER TABLE delta.`s3://my-bucket/delta/orders/`
    SET TBLPROPERTIES ('delta.dataSkippingNumIndexedCols' = '10')
""")

# ─── Bloom filters (Databricks / Delta 2.0+) ──────────────────────────────────
spark.sql("""
    ALTER TABLE delta.`s3://my-bucket/delta/orders/`
    SET TBLPROPERTIES (
        'delta.bloomFilter.order_id.enabled'           = 'true',
        'delta.bloomFilter.order_id.numItems'          = '10000000',
        'delta.bloomFilter.order_id.fpp'               = '0.1'
    )
""")

Cluster Configuration for Production

Python
# These configs belong in your SparkSession or cluster config:
spark = (
    SparkSession.builder
    .appName("ProductionJob")
    # ── Executor sizing ───────────────────────────────────────────────────────
    # Rule: 5 cores per executor, ~20GB RAM per executor
    .config("spark.executor.cores", "5")
    .config("spark.executor.memory", "20g")
    .config("spark.executor.memoryOverhead", "4g")   # native memory (Python, JVM overhead)
    # ── Driver ────────────────────────────────────────────────────────────────
    .config("spark.driver.memory", "8g")
    .config("spark.driver.maxResultSize", "2g")      # max size of .collect() result
    # ── Shuffle ───────────────────────────────────────────────────────────────
    .config("spark.sql.shuffle.partitions", "200")   # override at runtime with AQE
    .config("spark.shuffle.compress", "true")
    .config("spark.shuffle.spill.compress", "true")
    # ── AQE ───────────────────────────────────────────────────────────────────
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .config("spark.sql.adaptive.skewJoin.enabled", "true")
    # ── Serialization (Kryo is faster than Java for RDD ops) ──────────────────
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    # ── Dynamic allocation (scale executors based on workload) ────────────────
    .config("spark.dynamicAllocation.enabled", "true")
    .config("spark.dynamicAllocation.minExecutors", "2")
    .config("spark.dynamicAllocation.maxExecutors", "50")
    .config("spark.dynamicAllocation.initialExecutors", "10")
    .getOrCreate()
)

Before / After Optimization Example

Before: Unoptimized Join

Python
# BAD: large shuffle join, reading all columns, no filter pushdown
df_result_slow = (
    spark.read.option("inferSchema", "true").csv("s3://raw/orders/")   # inferSchema!
    .join(
        spark.read.option("inferSchema", "true").csv("s3://raw/customers/"),
        on="customer_id"     # SortMergeJoin — both sides shuffle
    )
    .filter(F.col("status") == "completed")  # filter AFTER join
    .groupBy("country")
    .agg(F.sum("total"))
)
df_result_slow.explain()

Explain output will show:

SortMergeJoin [customer_id]          ← full shuffle on both sides
:- Sort [customer_id]
:  +- Exchange hashpartitioning       ← shuffle 1
:     +- Filter (status = completed)  ← filter applied after shuffle
:        +- Scan CSV [*]              ← ALL columns read (no pruning)
+- Sort [customer_id]
   +- Exchange hashpartitioning       ← shuffle 2
      +- Scan CSV [*]                 ← ALL columns read

Total shuffles: 2 (join) + 1 (groupBy) = 3

After: Optimized

Python
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, DateType

ORDERS_SCHEMA = StructType([
    StructField("order_id",    StringType(), True),
    StructField("customer_id", StringType(), True),
    StructField("status",      StringType(), True),
    StructField("total",       DoubleType(), True),
    StructField("order_date",  DateType(),   True),
])

CUSTOMERS_SCHEMA = StructType([
    StructField("customer_id", StringType(), True),
    StructField("country",     StringType(), True),
])

# Read only needed columns, with explicit schema, filter before join
df_orders_filtered = (
    spark.read
    .schema(ORDERS_SCHEMA)
    .parquet("s3://silver/orders/")             # Parquet (not CSV!)
    .filter(F.col("status") == "completed")     # filter BEFORE join → less data in shuffle
    .select("order_id", "customer_id", "total") # only 3 columns
)

df_customers_slim = (
    spark.read
    .schema(CUSTOMERS_SCHEMA)
    .parquet("s3://silver/customers/")          # small table → broadcast
    .select("customer_id", "country")
)

# Broadcast the small table, AQE handles shuffle partitions
df_result_fast = (
    df_orders_filtered
    .join(F.broadcast(df_customers_slim), on="customer_id")   # BroadcastHashJoin
    .groupBy("country")
    .agg(F.sum("total").alias("total_revenue"))
)

df_result_fast.explain(mode="formatted")

Explain output will show:

HashAggregate [country]
+- Exchange hashpartitioning(country, 200)      ← 1 shuffle (groupBy only)
   +- HashAggregate [country] partial_sum
      +- Project [country, total]
         +- BroadcastHashJoin [customer_id]     ← NO shuffle for join!
            :- Filter (status = completed)       ← predicate pushdown
            :  +- FileScan parquet [customer_id,total,status]  ← 3 cols only
            +- BroadcastExchange
               +- FileScan parquet [customer_id,country]       ← 2 cols only

Total shuffles: 1 (groupBy only) — 3x fewer shuffles.

Optimization Checklist

Use this before every production job:

| Check | What to Look For | Fix | |---|---|---| | explain() | SortMergeJoin on small table | Add F.broadcast() | | explain() | Missing PushedFilters | Filter before join, use Parquet | | Spark UI → Tasks | High task duration variance | Investigate skew, apply salting | | Spark UI → Stages | Large shuffle read/write | Combine groupBys, filter earlier | | Spark UI → Tasks | High GC time | Increase executor memory, use persist() | | Output | Thousands of small files | Coalesce before write | | Delta table | Slow point queries | Run OPTIMIZE + ZORDER | | Cache usage | Same DF used 3+ times | .cache() + .unpersist() | | Partition count | Too few (< CPU cores) or too many (>10,000) | repartition(N) |

Key Takeaways

  • Read the Spark UI first — every optimization decision should be data-driven.
  • explain(mode="formatted") is your fastest diagnostic tool before running at scale.
  • AQE handles many problems automatically in Spark 3.2+: enable it and let it work.
  • Broadcast small tables explicitly — don't rely on auto-broadcast for critical jobs.
  • Filter and project early — every row and column eliminated before a shuffle saves network I/O.
  • Salt skewed keys only as a last resort; try spark.sql.adaptive.skewJoin.enabled=true first.
  • For Delta Lake: run OPTIMIZE + ZORDER on your most-queried tables weekly; VACUUM monthly.
  • Cache thoughtfully: only cache when a DataFrame is reused, and always unpersist() when done.

Enjoyed this article?

Explore the Data Engineering learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.