Spark custom aggregation : collect_list+UDF vs UDAF
Asked Answered
V

1

13

I often have the need to perform custom aggregations on dataframes in spark 2.1, and used these two approaches :

  • Using groupby/collect_list to get all the values in a single row, then apply an UDF to aggregate the values
  • Writing a custom UDAF (User defined aggregate function)

I generally prefer the first option as its easier to implement and more readable than the UDAF implementation. But I would assume that the first option is generally slower, because more data is sent around the network (no partial aggregation), but my experience shows that UDAF are generally slow. Why is that?

Concrete example: Calculating histograms:

Data is in a hive table (1E6 random double values)

val df = spark.table("testtable")

def roundToMultiple(d:Double,multiple:Double) = Math.round(d/multiple)*multiple

UDF approach:

val udf_histo = udf((xs:Seq[Double]) => xs.groupBy(x => roundToMultiple(x,0.25)).mapValues(_.size))

df.groupBy().agg(collect_list($"x").as("xs")).select(udf_histo($"xs")).show(false)

+--------------------------------------------------------------------------------+
|UDF(xs)                                                                         |
+--------------------------------------------------------------------------------+
|Map(0.0 -> 125122, 1.0 -> 124772, 0.75 -> 250819, 0.5 -> 248696, 0.25 -> 250591)|
+--------------------------------------------------------------------------------+

UDAF-Approach

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.collection.mutable

class HistoUDAF(binWidth:Double) extends UserDefinedAggregateFunction {

  override def inputSchema: StructType =
    StructType(
      StructField("value", DoubleType) :: Nil
    )

  override def bufferSchema: StructType =
    new StructType()
      .add("histo", MapType(DoubleType, IntegerType))

  override def deterministic: Boolean = true
  override def dataType: DataType = MapType(DoubleType, IntegerType)
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map[Double, Int]()
  }
  
  private def mergeMaps(a: Map[Double, Int], b: Map[Double, Int]) = {
    a ++ b.map { case (k,v) => k -> (v + a.getOrElse(k, 0)) }
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
     val oldBuffer = buffer.getAs[Map[Double, Int]](0)
     val newInput = Map(roundToMultiple(input.getDouble(0),binWidth) -> 1) 
     buffer(0) = mergeMaps(oldBuffer, newInput)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val a = buffer1.getAs[Map[Double, Int]](0)
    val b = buffer2.getAs[Map[Double, Int]](0)
    buffer1(0) = mergeMaps(a, b)
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Map[Double, Int]](0)
  }
}

val histo = new HistoUDAF(0.25)

df.groupBy().agg(histo($"x")).show(false)

+--------------------------------------------------------------------------------+
|histoudaf(x)                                                                    |
+--------------------------------------------------------------------------------+
|Map(0.0 -> 125122, 1.0 -> 124772, 0.75 -> 250819, 0.5 -> 248696, 0.25 -> 250591)|
+--------------------------------------------------------------------------------+

My tests show that the collect_list/UDF approach is about 2 times faster than the UDAF approach. Is this a general rule, or are there cases where UDAF is really much faster and the rather awkward implemetation is justified?

Vaso answered 15/3, 2018 at 8:8 Comment(5)
Where you able to figure out the reason for this?Kurdish
You may be asking the wrong question. collect_list will pull everything into one executor. So the question should be is there any chance that will blow up on you. If so, you should use a udaf. If there is no chance in collect_list blowing up then use a udf + collect_list.Jackinthepulpit
@RobertBeatty your comment may be misunderstood, with "everything" you mean all records of 1 group. So with many groups and little skew (so no very big group), the collect_list approach is feasibleVaso
@RaphaelRoth, You are correct. It is on a per group basis. But my main point is collect_list+udf pulls the data into memory, PER GROUP, all at once. Which is why its performance is so much better. BUT also why its so dangerous. You have to be absolutely sure that one of those groups will not lead to running out of memory. This is not trivial in many cases. UDAFs are better for most use cases because they are safer. They aren't blackboxes like udfs and they don't pull large amounts of data into memory. UDAFs aren't that hard to write and well worth learning if that is a main part of your job.Jackinthepulpit
UDAFs can also shuffle less data at times, because they create a per-group pre-aggregation on each partition and then shuffle those pre-aggregated results and merges them (using the merge method). Less shuffle data can offset the in-memory performancesHintze
A
2

UDAF is slower because it deserializes/serializes aggregator from/to internal buffer on each update -> on each row which is quite expensive (some more details). Instead you should use Aggregator (in fact, UDAF have been deprecated since Spark 3.0).

Alcantara answered 6/1, 2023 at 0:50 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.