How do I collect a single column in Spark?
Asked Answered
Q

2

16

I would like to perform an action on a single column. Unfortunately, after I transform that column, it is now no longer a part of the dataframe it came from but a Column object. As such, it cannot be collected.

Here is an example:

df = sqlContext.createDataFrame([Row(array=[1,2,3])])
df['array'].collect()

This produces the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'Column' object is not callable

How can I use the collect() function on a single column?

Quidnunc answered 19/2, 2016 at 0:32 Comment(0)
R
22

Spark >= 2.0

Starting from Spark 2.0.0 you need to explicitly specify .rdd in order to use flatMap

df.select("array").rdd.flatMap(lambda x: x).collect()

Spark < 2.0

Just select and flatMap:

df.select("array").flatMap(lambda x: x).collect()
## [[1, 2, 3]] 
Rapparee answered 19/2, 2016 at 0:44 Comment(5)
so using select instead of subsetting essentially turns this in to a one column dataframe instead of a ColumnQuidnunc
Thats right. Column is just a SQL DSL expression not a standalone data structure.Rapparee
What is the equivalent in spark 2.0? I can't see flatMap as a method on DataFrameDaciadacie
@Daciadacie you need to explicitly pass .rdd now. Once it was wrapped in. e.g. df.select("array").rdd.flatMap(lambda x: x).collect()Henghold
Converting a dataframe to rdd creates an overhead. Try avoiding it with something like data = list(map(lambda x: x[0], df.select("array").collect())) Flatten the list using normal python codeHardly
C
0

In 2024, being mindful on python serialization overhead for RDD level actions, the accepted answer with

.rdd.flatmap(lambda x:x)

is not the ideal for high-performance spark.

Whenever using only dataframe API, you are actually working with bytecode constructs. The python code is only declarative, with exception of UDFs.

Whenever you leave the dataframe API context and go to RDD world, you will actually need to (in the default scenario where Apache Arrow is not enabled) serialize the whole data out of the JVM to a python script and then back again. Here, this performance overhead that can be avoided.

To avoid this you can perform the following:

array_var = [ row['array'] for row in df.select("array").collect() ]

In it, you will unwrap the row in a pythonic list-comprehension, instead of Spark RDD invocations.

Canice answered 15/8 at 7:37 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.