PySpark count rows on condition
Asked Answered
S

5

39

I have a dataframe

test = spark.createDataFrame([('bn', 12452, 221), ('mb', 14521, 330), ('bn', 2, 220), ('mb', 14520, 331)], ['x', 'y', 'z'])
test.show()
# +---+-----+---+
# |  x|    y|  z|
# +---+-----+---+
# | bn|12452|221|
# | mb|14521|330|
# | bn|    2|220|
# | mb|14520|331|
# +---+-----+---+

I need to count the rows based on a condition:

test.groupBy("x").agg(count(col("y") > 12453), count(col("z") > 230)).show()

which gives

+---+------------------+----------------+
|  x|count((y > 12453))|count((z > 230))|
+---+------------------+----------------+
| bn|                 2|               2|
| mb|                 2|               2|
+---+------------------+----------------+

It's just the count of the rows, not the count for certain conditions.

Sachasachem answered 28/2, 2018 at 4:25 Comment(0)
R
76

count doesn't sum Trues, it only counts the number of non null values. To count the True values, you need to convert the conditions to 1 / 0 and then sum:

import pyspark.sql.functions as F

cnt_cond = lambda cond: F.sum(F.when(cond, 1).otherwise(0))
test.groupBy('x').agg(
    cnt_cond(F.col('y') > 12453).alias('y_cnt'), 
    cnt_cond(F.col('z') > 230).alias('z_cnt')
).show()
+---+-----+-----+
|  x|y_cnt|z_cnt|
+---+-----+-----+
| bn|    0|    0|
| mb|    2|    2|
+---+-----+-----+
Rale answered 28/2, 2018 at 4:47 Comment(3)
From the show table, is there a way I could extract the values to Python variable? #53690009Buoyage
Can I just check my pyspark understanding here: the lambda function here is all in spark, so this never has to create a user defined python function, with the associated slow downs. Correct? This looks very handy.Ardell
@Psidom, could you help me with my conditional count problem? #64470531Darladarlan
S
37

Based on @Psidom answer, my answer is as following

from pyspark.sql.functions import col,when,count

test.groupBy("x").agg(
    count(when(col("y") > 12453, True)),
    count(when(col("z") > 230, True))
).show()
Sachasachem answered 28/2, 2018 at 5:37 Comment(1)
Note that the True value here is not necessary - any non null value would achieve the same result, as count() counts non null.Sandstrom
C
3

count function skip null values so you can try this:

import pyspark.sql.functions as F

def count_with_condition(cond):
    return F.count(F.when(cond, True))

and also function in this repo: kolang

Cartel answered 26/10, 2021 at 9:3 Comment(0)
N
3

Since Spark 3.0.0 there is count_if(exp), see Spark function documentation

Nunuance answered 11/6, 2022 at 13:55 Comment(1)
I try count_if(exp) in pyspark 3.1.2 but this is not in pyspark.sql.functions so by this link spark.apache.org/docs/3.1.1/sql-ref.html this is a Built-in Aggregate Functions use for sql query so it is better to explain and make example in answer cause the answers that are barely more than a link to an external site may be deletedSyrup
R
2

Spark 3.5+ has count_if in Python API:

from pyspark.sql import functions as F

test.groupBy('x').agg(
    F.count_if(F.col('y') > 12453).alias('y_cnt'),
    F.count_if(F.col('z') > 230).alias('z_cnt')
).show()
# +---+-----+-----+
# |  x|y_cnt|z_cnt|
# +---+-----+-----+
# | bn|    0|    0|
# | mb|    2|    2|
# +---+-----+-----+

Spark 3.0+ has it too, but expr must be used:

test.groupBy('x').agg(
    F.expr("count_if(y > 12453) y_cnt"),
    F.expr("count_if(z > 230) z_cnt")
).show()
# +---+-----+-----+
# |  x|y_cnt|z_cnt|
# +---+-----+-----+
# | bn|    0|    0|
# | mb|    2|    2|
# +---+-----+-----+
Resentful answered 25/9, 2023 at 23:18 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.