The functionality you're looking for is called bucketing and can be used via the bucketBy
method. It's slightly different from regular partitioning, and means that you can load the datasets in & join them without needing to shuffle. The other answer notes that Spark can read partition information for a regular partitioned set of Parquet files. While this is technically true, that info is used primarily for partition pruning. Having to use repartition
is the same thing as a shuffle, so it doesn't achieve what you want.
In contrast, for joins or aggregations, where a hash partitioning into a specific number of partitions needs to happen, bucketBy
is the right choice.
We can demonstrate the difference with example code in the Spark shell.
import org.apache.spark.sql.functions._
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
val data = Seq(
(1, "A"),
(2, "B"),
(3, "C"),
(4, "A"),
(5, "B")
)
val df = spark.createDataFrame(data).toDF("id", "category")
Since we're using a small test dataset, we disable broadcast joins so that Spark's planner will use a standard sort-merge join, which typically requires a shuffle.
Our test query will be to self-join the test dataset to itself on the category
column,
val dfSelfJoin = df.join(df, Seq("category"))
Using dfSelfJoin.explain
shows the plan includes Exchange hashpartitioning
(a hash partitioning + shuffle, using Spark's default of 200 partitions) & a sort within each partition before the two are joined:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [category#5, id#4, id#19]
+- SortMergeJoin [category#5], [category#20], Inner
:- Sort [category#5 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(category#5, 200), ENSURE_REQUIREMENTS, [plan_id=77]
: +- LocalTableScan [id#4, category#5]
+- Sort [category#20 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(category#20, 200), ENSURE_REQUIREMENTS, [plan_id=78]
+- LocalTableScan [id#19, category#20]
Now, if we write the DataFrame to Parquet, read it back in, and try the self-join again,
repartitionedDF.write.format("parquet").mode("overwrite").save("testDF")
val readDF = spark.read.parquet("testDF")
val readDFSelfJoin = readDF.join(readDF, Seq("category"))
we can see the planner still wants to reshuffle the data before joining:
readDFSelfJoin.explain
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [category#25, id#24, id#28]
+- SortMergeJoin [category#25], [category#29], Inner
:- Sort [category#25 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(category#25, 200), ENSURE_REQUIREMENTS, [plan_id=117]
: +- Filter isnotnull(category#25)
: +- FileScan parquet [id#24,category#25] Batched: true, DataFilters: [isnotnull(category#25)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/testDF], PartitionFilters: [], PushedFilters: [IsNotNull(category)], ReadSchema: struct<id:int,category:string>
+- Sort [category#29 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(category#29, 200), ENSURE_REQUIREMENTS, [plan_id=118]
+- Filter isnotnull(category#29)
+- FileScan parquet [id#28,category#29] Batched: true, DataFilters: [isnotnull(category#29)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/testDF], PartitionFilters: [], PushedFilters: [IsNotNull(category)], ReadSchema: struct<id:int,category:string>
This shows that Spark doesn't infer the partitioning correctly when reading regular Parquet files, even if they were partitioned before writing.
However, if we save the table with bucketBy
, read it in, and try the same self-join,
spark.conf.set("spark.sql.sources.bucketing.autoBucketedScan.enabled", "false")
readDF.write.format("parquet").bucketBy(5, "category").sortBy("category").option("path", "/testBucketedTable1").saveAsTable("testBucketedTable1")
val bucketedTable = spark.read.table("testBucketedTable1")
val selfJoinBucketed = bucketedTable.join(bucketedTable, Seq("category"))
we see a different query plan:
selfJoinBucketed.explain
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [category#73, id#72, id#76]
+- SortMergeJoin [category#73], [category#77], Inner
:- Sort [category#73 ASC NULLS FIRST], false, 0
: +- Filter isnotnull(category#73)
: +- FileScan parquet spark_catalog.default.testbucketedtable1[id#72,category#73] Batched: true, Bucketed: true, DataFilters: [isnotnull(category#73)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/testBucketedDF1], PartitionFilters: [], PushedFilters: [IsNotNull(category)], ReadSchema: struct<id:int,category:string>, SelectedBucketsCount: 5 out of 5
+- Sort [category#77 ASC NULLS FIRST], false, 0
+- Filter isnotnull(category#77)
+- FileScan parquet spark_catalog.default.testbucketedtable1[id#76,category#77] Batched: true, Bucketed: true, DataFilters: [isnotnull(category#77)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/testBucketedDF1], PartitionFilters: [], PushedFilters: [IsNotNull(category)], ReadSchema: struct<id:int,category:string>, SelectedBucketsCount: 5 out of 5
Notice how there's no longer a hash-partition & exchange step. Spark will read the file in, filter for no nulls on the join key, sort within the partitions, and join.
To make this work with two different datasets, you'll need to ensure both are bucketed by the same column you want to join them on, and that the number of buckets (partitions) set in bucketBy
is the same.
One more tip is that for the case where you have one dataset bucketed & want to join to another that isn't bucketed/partitioned, you can still avoid shuffling the bucketed table by setting spark.sql.shuffle.partitions
to the bucketed table's number of buckets. This optimization can make sense when you have a large table you know in advance will be joined to many other tables on the same key.
This way, Spark will repartition the unpartitioned table into the same number of partitions as the bucketed table. At that point, both have the same number of partitions & hash key, so the two are joined without ever triggering a shuffle of the bucketed dataset. You may need to turn off Spark's adaptive query planning so that spark.sql.shuffle.partitions
is not overridden when doing this.