How to sort array of struct type in Spark DataFrame by particular field?
Asked Answered
D

4

7

Given following code:

import java.sql.Date
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

object SortQuestion extends App{

  val spark = SparkSession.builder().appName("local").master("local[*]").getOrCreate()
  import spark.implicits._
  case class ABC(a: Int, b: Int, c: Int)

  val first = Seq(
    ABC(1, 2, 3),
    ABC(1, 3, 4),
    ABC(2, 4, 5),
    ABC(2, 5, 6)
  ).toDF("a", "b", "c")

  val second = Seq(
    (1, 2, (Date.valueOf("2018-01-02"), 30)),
    (1, 3, (Date.valueOf("2018-01-01"), 20)),
    (2, 4, (Date.valueOf("2018-01-02"), 50)),
    (2, 5, (Date.valueOf("2018-01-01"), 60))
  ).toDF("a", "b", "c")

  first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b")).groupBy("a").agg(sort_array(collect_list("c2")))
    .show(false)

}

Spark produces following result:

+---+----------------------------------+
|a  |sort_array(collect_list(c2), true)|
+---+----------------------------------+
|1  |[[2018-01-01,20], [2018-01-02,30]]|
|2  |[[2018-01-01,60], [2018-01-02,50]]|
+---+----------------------------------+

This implies that Spark is sorting an array by date (since it is the first field), but I want to instruct Spark to sort by specific field from that nested struct.

I know I can reshape array to (value, date) but it seems inconvenient, I want a general solution (imagine I have a big nested struct, 5 layers deep, and I want to sort that structure by particular column).

Is there a way to do that? Am I missing something?

Dace answered 5/4, 2018 at 11:34 Comment(4)
i can suggest is to just collect list and then use udf function for sorting where you can give the index or column for sortingCanoewood
@RameshMaharjan hi, I want to avoid collecting list since this sorting is the intermediate operation of DataFrame transformation, And I want to minimize serialization-deserialization operations. Also, If I collect I will receive Array[Seq[(Date, Int)]] but it may not fit on one machine (due to large DataFrame)Dace
I was suggesting you to use collect_list function with sort_array function ;) is it clear nowCanoewood
@RameshMaharjan now I get It, thx, if you provide answer, I ll accept itDace
S
3

If you have complex object it is much better to use statically typed Dataset.

case class Result(a: Int, b: Int, c: Int, c2: (java.sql.Date, Int))

val joined = first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
joined.as[Result]
  .groupByKey(_.a)
  .mapGroups((key, xs) => (key, xs.map(_.c2).toSeq.sortBy(_._2)))
  .show(false)

// +---+----------------------------------+            
// |_1 |_2                                |
// +---+----------------------------------+
// |1  |[[2018-01-01,20], [2018-01-02,30]]|
// |2  |[[2018-01-02,50], [2018-01-01,60]]|
// +---+----------------------------------+

In simple cases it is also possible to udf, but leads to inefficient and fragile code in general and quickly goes out of control, when complexity of objects grows.

Sylphid answered 5/4, 2018 at 12:42 Comment(0)
G
6

According to the Hive Wiki:

sort_array(Array<T>) : Sorts the input array in ascending order according to the natural ordering of the array elements and returns it (as of version 0.9.0).

This means that the array will be sorted lexicographically which holds true even with complex data types.

Alternatively, you can create a UDF to sort it (and witness performance degradation) based on the second element:

val sortUdf = udf { (xs: Seq[Row]) => xs.sortBy(_.getAs[Int](1) )
                                        .map{ case Row(x:java.sql.Date, y: Int) => (x,y) }}

first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
     .groupBy("a")
     .agg(sortUdf(collect_list("c2")))
     .show(false)

//+---+----------------------------------+
//|a  |UDF(collect_list(c2, 0, 0))       |
//+---+----------------------------------+
//|1  |[[2018-01-01,20], [2018-01-02,30]]|
//|2  |[[2018-01-02,50], [2018-01-01,60]]|
//+---+----------------------------------+
Gladiolus answered 5/4, 2018 at 12:43 Comment(1)
The question is clear about the array type. sort_array does not work with struct. You will get the error: "sort_array does not support sorting array of type struct" So your comment about sort_array is not valid.Avestan
H
6

For Spark 3+, you can pass a custom comparator function to array_sort:

The comparator will take two arguments representing two elements of the array. It returns -1, 0, or 1 as the first element is less than, equal to, or greater than the second element. If the comparator function returns other values (including null), the function will fail and raise an error.

val df = first
  .join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
  .groupBy("a")
  .agg(collect_list("c2").alias("list"))

val df2 = df.withColumn(
  "list",
  expr(
    "array_sort(list, (left, right) -> case when left._2 < right._2 then -1 when left._2 > right._2 then 1 else 0 end)"
  )
)

df2.show(false)
//+---+------------------------------------+
//|a  |list                                |
//+---+------------------------------------+
//|1  |[[2018-01-01, 20], [2018-01-02, 30]]|
//|2  |[[2018-01-02, 50], [2018-01-01, 60]]|
//+---+------------------------------------+

Where _2 is the name of the struct field you wan to use for sorting

Hintze answered 14/2, 2022 at 9:7 Comment(0)
S
3

If you have complex object it is much better to use statically typed Dataset.

case class Result(a: Int, b: Int, c: Int, c2: (java.sql.Date, Int))

val joined = first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
joined.as[Result]
  .groupByKey(_.a)
  .mapGroups((key, xs) => (key, xs.map(_.c2).toSeq.sortBy(_._2)))
  .show(false)

// +---+----------------------------------+            
// |_1 |_2                                |
// +---+----------------------------------+
// |1  |[[2018-01-01,20], [2018-01-02,30]]|
// |2  |[[2018-01-02,50], [2018-01-01,60]]|
// +---+----------------------------------+

In simple cases it is also possible to udf, but leads to inefficient and fragile code in general and quickly goes out of control, when complexity of objects grows.

Sylphid answered 5/4, 2018 at 12:42 Comment(0)
S
0

This is similar to what blackbishop mentioned in previous answer. However, this provides little more type safety and I feel it's easier to read than long sql expression inside expr.

val first = Seq(
    ABC(1, 2, 3),
    ABC(1, 3, 4),
    ABC(2, 4, 5),
    ABC(2, 5, 6)
).toDF("a", "b", "c")

val second = Seq(
    (1, 2, (Date.valueOf("2018-01-02"), 30)),
    (1, 3, (Date.valueOf("2018-01-01"), 20)),
    (2, 4, (Date.valueOf("2018-01-02"), 50)),
    (2, 5, (Date.valueOf("2018-01-01"), 60))
).toDF("a", "b", "c")

def sortElems(p1: Column, p2: Column): Column = {
    when(p1.getField("_2") < p2.getField("_2"), -1)
        .when(p1.getField("_2") > p2.getField("_2"), 1)
        .otherwise(0)
}

first.join(
    second.withColumnRenamed("c", "c2"), Seq("a", "b")
).groupBy("a")
    .agg(collect_list("c2").as("grouped_c2")
).withColumn("sorted_c2", array_sort($"grouped_c2", sortElems))
    .show(false)

This produces below output where sorted_c2 is ordered by second item in the group.

+---+------------------------------------+------------------------------------+
|a  |grouped_c2                          |sorted_c2                           |
+---+------------------------------------+------------------------------------+
|1  |[{2018-01-02, 30}, {2018-01-01, 20}]|[{2018-01-01, 20}, {2018-01-02, 30}]|
|2  |[{2018-01-02, 50}, {2018-01-01, 60}]|[{2018-01-02, 50}, {2018-01-01, 60}]|
+---+------------------------------------+------------------------------------+
Stirrup answered 6/6 at 10:18 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.