Another solution is to number the rows via row_number()
using a window partitioned by A
in the order of B
.
This solution is close to the one by @pault, but when there are several rows with the maximum value, it only keeps one of them, which I find better.
Given the same example:
data = [
('a', 5),
('a', 8),
('a', 7),
('b', 1),
('b', 3)
]
df = spark.createDataFrame(data, ["A", "B"])
df.show()
The row_number
solution is:
w = Window.partitionBy('A').orderBy('B')
df_collect = df.withColumn('row_number', F.row_number().over(w)) \
.filter(F.col('row_number') == 1) \
.drop('row_number') \
.show()
+---+---+
| A| B|
+---+---+
| a| 8|
| b| 3|
+---+---+
I also extended the benchmark from @Fernando Wittmann, both solutions run in about the same time:
The dataframe:
N_SAMPLES = 600000
N_PARTITIONS = 1000
MAX_VALUE = 100
data = zip(
[random.randint(0, N_PARTITIONS-1) for i in range(N_SAMPLES)],
[random.randint(0, MAX_VALUE) for i in range(N_SAMPLES)],
list(range(N_SAMPLES))
)
df = spark.createDataFrame(data, ["A", "B", "C"])
row_number
approach:
%%timeit
w = Window.partitionBy('A').orderBy(F.col('B').desc())
df_collect = df.withColumn('row_number', F.row_number().over(w)) \
.filter(F.col('row_number') == 1) \
.drop('row_number') \
.collect()
313 ms ± 19.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
== max
approach:
%%timeit
w = Window.partitionBy('A')
df_collect = df.withColumn('maxB', F.max('B').over(w))\
.where(F.col('B') == F.col('maxB'))\
.drop('maxB')\
.collect()
328 ms ± 24.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
leftsemi
approach:
%%timeit
df_collect = df.join(df.groupBy('A').agg(F.max('B').alias('B')),on='B',how='leftsemi').collect()
516 ms ± 19.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
java.lang.UnsupportedOperationException: Cannot evaluate expression: max(input[1, bigint, false]) windowspecdefinition(input[0, string, true], specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$()))
– Embrasure