Spark repartitioning by column with dynamic number of partitions per column
Asked Answered
C

2

21

How can a DataFrame be partitioned based on the count of the number of items in a column. Suppose we have a DataFrame with 100 people (columns are first_name and country) and we'd like to create a partition for every 10 people in a country.

If our dataset contains 80 people from China, 15 people from France, and 5 people from Cuba, then we'll want 8 partitions for China, 2 partitions for France, and 1 partition for Cuba.

Here is code that will not work:

  • df.repartition($"country"): This will create 1 partition for China, one partition for France, and one partition for Cuba
  • df.repartition(8, $"country", rand): This will create up to 8 partitions for each country, so it should create 8 partitions for China, but the France & Cuba partitions are unknown. France could be in 8 partitions and Cuba could be in up to 5 partitions. See this answer for more details.

Here's the repartition() documentation:

repartition documentation

When I look at the repartition() method, I don't even see a method that takes three arguments, so looks like some of this behavior isn't documented.

Is there any way to dynamically set the number of partitions for each column? It would make creating partitioned data sets way easier.

Calves answered 8/10, 2019 at 12:36 Comment(1)
Regarding 3 arguments, $"country", rand go together as partitionExprs in second invokationCousteau
S
16

You're not going to be able to exactly accomplish that due to the way spark partitions data. Spark takes the columns you specified in repartition, hashes that value into a 64b long and then modulo the value by the number of partitions. This way the number of partitions is deterministic. The reason why it works this way is that joins need matching number of partitions on the left and right side of a join in addition to assuring that the hashing is the same on both sides.

"we'd like to create a partition for every 10 people in a country."

What exactly are you trying to accomplish here? Having only 10 rows in a partition is likely terrible for performance. Are you trying to create a partitioned table where each of the files in the partition is guarunteed to only have x number of rows?

"df.repartition($"country"): This will create 1 partition for China, one partition for France, and one partition for Cuba"

This will actually create a dataframe with the default number of shuffle partitions hashed by country

  def repartition(partitionExprs: Column*): Dataset[T] = {
    repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
  }

"df.repartition(8, $"country", rand): This will create up to 8 partitions for each country, so it should create 8 partitions for China, but the France & Cuba partitions are unknown. France could be in 8 partitions and Cuba could be in up to 5 partitions. See this answer for more details."

Like wise this is subtly wrong. There's only 8 partitions with the countries essentially randomly shuffled among those 8 partitions.

EDIT: One last point of clarification. data frame repartitionings works differently than then it comes time to write and you do a partitionBy(...) method. The partitionBy operation spark first takes all of the spark partitions than for each spark partition. is slices it into a table partitionBy partition and then each of those gets writen to the folder that corresponds to the partitionBy columns.

Soot answered 8/10, 2019 at 22:51 Comment(5)
Thanks for pointing out my subtle errors. For 10 rows, this code isn't needed, but this is really important when creating partitioned data lakes on large datasets that are skewed.Calves
@Andrew Long there is no "sessionState" in sparkSession , where do we have "sessionState" ??Douro
@Douro are you using an older version of spark 1.x?Soot
@AndrewLong no i am using spark 2.4.5 and now 3.3.1Douro
spark has hidden session state that is private to spark. github.com/apache/spark/blob/master/sql/core/src/main/scala/org/…Soot
C
5

Here's the code that'll create ten rows per data file (sample dataset is here):

val outputPath = new java.io.File("./tmp/partitioned_lake5/").getCanonicalPath
df
  .repartition(col("person_country"))
  .write
  .option("maxRecordsPerFile", 10)
  .partitionBy("person_country")
  .csv(outputPath)

Here's the pre Spark 2.2 code that'll create roughly ten rows per data file:

val desiredRowsPerPartition = 10

val joinedDF = df
  .join(countDF, Seq("person_country"))
  .withColumn(
    "my_secret_partition_key",
    (rand(10) * col("count") / desiredRowsPerPartition).cast(IntegerType)
  )

val outputPath = new java.io.File("./tmp/partitioned_lake6/").getCanonicalPath
joinedDF
  .repartition(col("person_country"), col("my_secret_partition_key"))
  .drop("count", "my_secret_partition_key")
  .write
  .partitionBy("person_country")
  .csv(outputPath)
Calves answered 19/10, 2019 at 9:29 Comment(3)
where are you getting this col("count") column ?Douro
is there anyway we can handle skewed data using hash functionDouro
yes like @Douro said, there is a potential issue of data skewing and therefore cause OOM on the executor side.Grethel

© 2022 - 2024 — McMap. All rights reserved.