Return null in SUM if some values are null
Asked Answered
C

2

5

I have a case where I may have null values in the column that needs to be summed up in a group.

If I encounter a null in a group, I want the sum of that group to be null. But PySpark by default seems to ignore the null rows and sum-up the rest of the non-null values.

For example:

enter image description here

dataframe = dataframe.groupBy('dataframe.product', 'dataframe.price') \
                     .agg(f.sum('price'))

Expected output is:

enter image description here

But I am getting:

enter image description here

Catgut answered 18/1, 2021 at 6:34 Comment(0)
E
8

sum function returns NULL only if all values are null for that column otherwise nulls are simply ignored.

You can use conditional aggregation, if count(price) == count(*) it means there are no nulls and we return sum(price). Else, null is returned:

from pyspark.sql import functions as F

df.groupby("product").agg(
    F.when(F.count("price") == F.count("*"), F.sum("price")).alias("sum_price")
).show()

#+-------+---------+
#|product|sum_price|
#+-------+---------+
#|      B|      200|
#|      C|     null|
#|      A|      250|
#+-------+---------+

Since Spark 3.0+, one can also use any function:

df.groupby("product").agg(
    F.when(~F.expr("any(price is null)"), F.sum("price")).alias("sum_price")
).show()
Effulgence answered 18/1, 2021 at 8:6 Comment(0)
O
1

You can replace nulls with NaNs using coalesce:

df2 = df.groupBy('product').agg(
    F.sum(
        F.coalesce(F.col('price'), F.lit(float('nan')))
    ).alias('sum(price)')
).orderBy('product')

df2.show()
+-------+----------+
|product|sum(price)|
+-------+----------+
|      A|     250.0|
|      B|     200.0|
|      C|       NaN|
+-------+----------+

If you want to keep integer type, you can convert NaNs back to nulls using nanvl:

df2 = df.groupBy('product').agg(
    F.nanvl(
        F.sum(
            F.coalesce(F.col('price'), F.lit(float('nan')))
        ),
        F.lit(None)
    ).cast('int').alias('sum(price)')
).orderBy('product')

df2.show()
+-------+----------+
|product|sum(price)|
+-------+----------+
|      A|       250|
|      B|       200|
|      C|      null|
+-------+----------+
Oratorical answered 18/1, 2021 at 8:53 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.