pyspark: aggregate on the most frequent value in a column
Asked Answered
S

4

8
  aggregrated_table = df_input.groupBy('city', 'income_bracket') \
        .agg(
       count('suburb').alias('suburb'),
       sum('population').alias('population'),
       sum('gross_income').alias('gross_income'),
       sum('no_households').alias('no_households'))

Would like to group by city and income bracket but within each city certain suburbs have different income brackets. How do I group by the most frequently occurring income bracket per city?

for example:

city1 suburb1 income_bracket_10 
city1 suburb1 income_bracket_10 
city1 suburb2 income_bracket_10 
city1 suburb3 income_bracket_11 
city1 suburb4 income_bracket_10 

Would be grouped by income_bracket_10

Sludgy answered 11/8, 2017 at 11:59 Comment(1)
Can you show us the desired output?Indigence
S
9

Using a window function before aggregating might do the trick:

from pyspark.sql import Window
import pyspark.sql.functions as psf

w = Window.partitionBy('city')
aggregrated_table = df_input.withColumn(
    "count", 
    psf.count("*").over(w)
).withColumn(
    "rn", 
    psf.row_number().over(w.orderBy(psf.desc("count")))
).filter("rn = 1").groupBy('city', 'income_bracket').agg(
   psf.count('suburb').alias('suburb'),
   psf.sum('population').alias('population'),
   psf.sum('gross_income').alias('gross_income'),
   psf.sum('no_households').alias('no_households'))

you can also use a window function after aggregating since you're keeping a count of (city, income_bracket) occurrences.

Suave answered 16/8, 2017 at 21:46 Comment(1)
Perfect - thanks! I did have some issues with null values that take precedence over actual values, but used your solution in combination with #35142716 and it works!Sludgy
A
5

You don't necessarily need Window functions:

aggregrated_table = (
    df_input.groupby("city", "suburb","income_bracket")
    .count()
    .withColumn("count_income", F.array("count", "income_bracket"))
    .groupby("city", "suburb")
    .agg(F.max("count_income").getItem(1).alias("most_common_income_bracket"))
) 

I think this does what you require. I don't really know if it performs better than the window based solution.

Audette answered 24/9, 2018 at 8:23 Comment(1)
The solution by mfcabrera is better for very large datasets where you wont force the entire dataset into a single node.Frow
E
4

For pyspark version >=3.4 you can use the mode function directly to get the most frequent element per group:

from pyspark.sql import functions as f

df = spark.createDataFrame([
...     ("Java", 2012, 20000), ("dotNET", 2012, 5000),
...     ("Java", 2012, 20000), ("dotNET", 2012, 5000),
...     ("dotNET", 2013, 48000), ("Java", 2013, 30000)],
...     schema=("course", "year", "earnings"))
>>> df.groupby("course").agg(f.mode("year")).show()
+------+----------+
|course|mode(year)|
+------+----------+
|  Java|      2012|
|dotNET|      2012|
+------+----------+

https://github.com/apache/spark/blob/7f1b6fe02bdb2c68d5fb3129684ca0ed2ae5b534/python/pyspark/sql/functions.py#L379

Extremadura answered 14/2, 2023 at 9:3 Comment(0)
D
0

The solution by mfcabrera gave wrong results when F.max was used on F.array column as the values in ArrayType are treated as String and integer max didnt work as expected.

The below solution worked.

w = Window.partitionBy('city', "suburb").orderBy(f.desc("count"))

aggregrated_table = (
    input_df.groupby("city", "suburb","income_bracket")
    .count()
    
    .withColumn("max_income", f.row_number().over(w2))
    .filter(f.col("max_income") == 1).drop("max_income")
) 
aggregrated_table.display()

Deft answered 4/2, 2022 at 9:47 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.