Reshaping/Pivoting data in Spark RDD and/or Spark DataFrames
Asked Answered
E

6

25

I have some data in the following format (either RDD or Spark DataFrame):

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

 rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

# convert to a Spark DataFrame                    
schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlContext.createDataFrame(rdd, schema)

What I would like to do is to 'reshape' the data, convert certain rows in Country(specifically US, UK and CA) into columns:

ID    Age  US  UK  CA  
'X01'  41  3   1   2  
'X02'  72  4   6   7   

Essentially, I need something along the lines of Python's pivot workflow:

categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', 
                                                  columns = 'Country',
                                                  values = 'Score')

My dataset is rather large so I can't really collect() and ingest the data into memory to do the reshaping in Python itself. Is there a way to convert Python's .pivot() into an invokable function while mapping either an RDD or a Spark DataFrame? Any help would be appreciated!

Explosion answered 15/5, 2015 at 12:51 Comment(0)
D
22

Since Spark 1.6 you can use pivot function on GroupedData and provide aggregate expression.

pivoted = (df
    .groupBy("ID", "Age")
    .pivot(
        "Country",
        ['US', 'UK', 'CA'])  # Optional list of levels
    .sum("Score"))  # alternatively you can use .agg(expr))
pivoted.show()

## +---+---+---+---+---+
## | ID|Age| US| UK| CA|
## +---+---+---+---+---+
## |X01| 41|  3|  1|  2|
## |X02| 72|  4|  6|  7|
## +---+---+---+---+---+

Levels can be omitted but if provided can both boost performance and serve as an internal filter.

This method is still relatively slow but certainly beats manual passing data manually between JVM and Python.

Dupondius answered 1/6, 2016 at 21:16 Comment(0)
T
7

First up, this is probably not a good idea, because you are not getting any extra information, but you are binding yourself with a fixed schema (ie you must need to know how many countries you are expecting, and of course, additional country means change in code)

Having said that, this is a SQL problem, which is shown below. But in case you suppose it is not too "software like" (seriously, I have heard this!!), then you can refer the first solution.

Solution 1:

def reshape(t):
    out = []
    out.append(t[0])
    out.append(t[1])
    for v in brc.value:
        if t[2] == v:
            out.append(t[3])
        else:
            out.append(0)
    return (out[0],out[1]),(out[2],out[3],out[4],out[5])
def cntryFilter(t):
    if t[2] in brc.value:
        return t
    else:
        pass

def addtup(t1,t2):
    j=()
    for k,v in enumerate(t1):
        j=j+(t1[k]+t2[k],)
    return j

def seq(tIntrm,tNext):
    return addtup(tIntrm,tNext)

def comb(tP,tF):
    return addtup(tP,tF)


countries = ['CA', 'UK', 'US', 'XX']
brc = sc.broadcast(countries)
reshaped = calls.filter(cntryFilter).map(reshape)
pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)
for i in pivot.collect():
    print i

Now, Solution 2: Of course better as SQL is right tool for this

callRow = calls.map(lambda t:   

Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))
callsDF = ssc.createDataFrame(callRow)
callsDF.printSchema()
callsDF.registerTempTable("calls")
res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\
                    from (select userid,age,\
                                  case when country='CA' then nbrCalls else 0 end ca,\
                                  case when country='UK' then nbrCalls else 0 end uk,\
                                  case when country='US' then nbrCalls else 0 end us,\
                                  case when country='XX' then nbrCalls else 0 end xx \
                             from calls) x \
                     group by userid,age")
res.show()

data set up:

data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)]
 calls = sc.parallelize(data,1)
countries = ['CA', 'UK', 'US', 'XX']

Result:

From 1st solution

(('X02', 72), (7, 6, 4, 8)) 
(('X01', 41), (2, 1, 3, 0))

From 2nd solution:

root  |-- age: long (nullable = true)  
      |-- country: string (nullable = true)  
      |-- nbrCalls: long (nullable = true)  
      |-- userid: string (nullable = true)

userid age ca uk us xx 
 X02    72  7  6  4  8  
 X01    41  2  1  3  0

Kindly let me know if this works, or not :)

Best Ayan

Tavern answered 16/5, 2015 at 17:19 Comment(5)
thanks..your solutions work and more importantly they are scalable!Explosion
Are you able to expand this to a more generic case? For example, one time in my data I might have 3 countries. Another time I might have 5. What you have above seems to be hard coded to 4 specific countries. I get that I need to know what countries I have ahead of time, but that might change as time goes on. How could I pass a list of countries in as a parameter and still make this work? This is a pretty common thing to do in working with data so I would hope this would be built in functionality pretty soon.Ferment
As I noted, this is a problem with schema design. You "can not" just pass on a list of countries, because your schema will change in downstream. However, you might just get by with returning a generalized tuple from reshape and set up zero values for aggregateByKey. In SQL method, you need to basically programmatecally "generate" a sql following the pattern described here.Tavern
This is a pretty common functionality that exists in most data languages/frameworks: SAS, Scalding, Pandas, etc. Hope this makes it into Spark soon.Ferment
I created a flexible version of this based on your answer above. You can view it here: https://mcmap.net/q/235437/-how-to-pivot-spark-dataframe. I hope Spark implements a solution for this soon as it is pretty basic functionality in most other data manipulation languages/tools (Pandas, Scalding, SAS, Excel, etc.)Ferment
R
5

Here's a native Spark approach that doesn't hardwire the column names. It's based on aggregateByKey, and uses a dictionary to collect the columns that appear for each key. Then we gather all the column names to create the final dataframe. [Prior version used jsonRDD after emitting a dictionary for each record, but this is more efficient.] Restricting to a specific list of columns, or excluding ones like XX would be an easy modification.

The performance seems good even on quite large tables. I'm using a variation which counts the number of times that each of a variable number of events occurs for each ID, generating one column per event type. The code is basically the same except it uses a collections.Counter instead of a dict in the seqFn to count the occurrences.

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    return u

def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    df
    .rdd
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)
columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)
result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c) for c in columns]),
    schema=StructType(
        [StructField('ID', StringType())] + 
        [StructField(c, IntegerType()) for c in columns]
    )
)
result.show()

Produces:

ID  CA UK US XX  
X02 7  6  4  8   
X01 2  1  3  null
Ret answered 23/6, 2015 at 15:53 Comment(3)
Nice writeup - b.t.w spark 1.6 dataframes support easy pivots github.com/apache/spark/pull/7841Beak
Cool - spark is getting better fast.Ret
What if the reshaped output is too big to fit on memory. How can I do it directly on disk?Grapher
T
1

So first off, I had to make this correction to your RDD (which matches your actual output):

rdd = sc.parallelize([('X01',41,'US',3),
                      ('X01',41,'UK',1),
                      ('X01',41,'CA',2),
                      ('X02',72,'US',4),
                      ('X02',72,'UK',6),
                      ('X02',72,'CA',7),
                      ('X02',72,'XX',8)])

Once I made that correction, this did the trick:

df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age")
.join(
    df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"),
    $"ID" === $"usID" and $"C1" === "US"
)
.join(
    df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"),
    $"ID" === $"ukID" and $"C2" === "UK"
)
.join(
    df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"), 
    $"ID" === $"caID" and $"C3" === "CA"
)
.select($"ID",$"Age",$"US",$"UK",$"CA")

Not nearly as elegant as your pivot, for sure.

Thynne answered 15/5, 2015 at 16:48 Comment(2)
David, I couldn't get this to work. First, Spark did not accept $ as a way to reference columns. After removing all the $ signs, I still get a syntax error pointing to the the .select expression in the last line of your code aboveExplosion
Sorry, I am using Scala. It was cut and pasted directly from spark-shell. If you take the last select() out, you should get the correct results just with too many columns. Can you do that and post the results?Thynne
P
1

Just some comments on the very helpful answer of patricksurry:

  • the column Age is missing, so just add u["Age"] = v.Age to the function seqPivot
  • it turned out that both loops over the elements of columns gave the elements in a different order. The values of the columns were correct, but not the names of them. To avoid this behavior just order the column list.

Here is the slightly modified code:

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

# u is a dictionarie
# v is a Row
def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    # In the original posting the Age column was not specified
    u["Age"] = v.Age
    return u

# u1
# u2
def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    rdd
    .map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2],  Score=row[3]))
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)

columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)

columns_ord = sorted(columns)

result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]),
        schema=StructType(
            [StructField('ID', StringType())] + 
            [StructField(c, IntegerType()) for c in columns_ord]
        )
    )

print result.show()

Finally, the output should be

+---+---+---+---+---+----+
| ID|Age| CA| UK| US|  XX|
+---+---+---+---+---+----+
|X02| 72|  7|  6|  4|   8|
|X01| 41|  2|  1|  3|null|
+---+---+---+---+---+----+
Phina answered 24/9, 2015 at 9:31 Comment(0)
P
0

There is a JIRA in Hive for PIVOT to do this natively, without a huge CASE statement for each value:

https://issues.apache.org/jira/browse/HIVE-3776

Please vote that JIRA up so it'll be implemented sooner. Once it in Hive SQL, Spark usually doesn't lack too much behind and eventually it'll be implemented in Spark as well.

Prudie answered 1/9, 2015 at 19:12 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.