Here is a little helper function, inspired by the R recode
function, that abstracts the previous answers. As a bonus, it adds the option for a default value.
from itertools import chain
from pyspark.sql.functions import col, create_map, lit, when, isnull
from pyspark.sql.column import Column
df = spark.createDataFrame([
('Tablet', ), ('Phone', ), ('PC', ), ('Other', ), (None, )
], ["device_type"])
deviceDict = {'Tablet':'Mobile','Phone':'Mobile','PC':'Desktop'}
df.show()
+-----------+
|device_type|
+-----------+
| Tablet|
| Phone|
| PC|
| Other|
| null|
+-----------+
Here is the definition of recode
.
def recode(col_name, map_dict, default=None):
if not isinstance(col_name, Column): # Allows either column name string or column instance to be passed
col_name = col(col_name)
mapping_expr = create_map([lit(x) for x in chain(*map_dict.items())])
if default is None:
return mapping_expr.getItem(col_name)
else:
return when(~isnull(mapping_expr.getItem(col_name)), mapping_expr.getItem(col_name)).otherwise(default)
Creating a column without a default gives null
/None
in all unmatched values.
df.withColumn("device_type", recode('device_type', deviceDict)).show()
+-----------+
|device_type|
+-----------+
| Mobile|
| Mobile|
| Desktop|
| null|
| null|
+-----------+
On the other hand, specifying a value for default
replaces all unmatched values with this default.
df.withColumn("device_type", recode('device_type', deviceDict, default='Other')).show()
+-----------+
|device_type|
+-----------+
| Mobile|
| Mobile|
| Desktop|
| Other|
| Other|
+-----------+