pytest assert for pyspark dataframe comparison
Asked Answered
J

10

16

I have 2 pyspark dataframe as shown in file attached. expected_df and actual_df

enter image description here

In my unit test I am trying to check if both are equal or not.

for which my code is

expected = map(lambda row: row.asDict(), expected_df.collect()) 
actual = map(lambda row: row.asDict(), actaual_df.collect()) 
assert expected = actual 

Since both dfs are same but row order is different so assert fails here. What is best way to compare such dfs.

Jacaranda answered 3/10, 2018 at 2:59 Comment(2)
have you tried sorting themHassan
I could sort based on 'period_start_time' but is there not any method of comparing without doing same.Jacaranda
L
12

You can try pyspark-test

https://pypi.org/project/pyspark-test/

This is inspired by the panadas testing module build for pyspark.

Usage is simple

from pyspark_test import assert_pyspark_df_equal

assert_pyspark_df_equal(df_1, df_2)

Also apart from just comparing dataframe, just like the pandas testing module it also accepts many optional params that you can check in the documentation.

Note:

  1. The datatypes in pandas and pysaprk are bit different, thats why directly converting to .toPandas and using panadas testing module might not be the right approach.
  2. This package is for unit/integration testing, so meant to be used with small size dfs
Liebfraumilch answered 3/11, 2020 at 21:29 Comment(1)
This lib does not scale. For small datasets we can use pd.testing.assert_frame_equal.Buddybuderus
D
3

This is done in some of the pyspark documentation:

assert sorted(expected_df.collect()) == sorted(actaual_df.collect())

Diluent answered 8/6, 2020 at 19:38 Comment(0)
C
2

We solved this by hashing each row with Spark's hash function and then summing the resultant column.

from pyspark.sql import DataFrame
import pyspark.sql.functions as F

def hash_df(df):
    """Hashes a DataFrame for comparison.

    Arguments:
        df (DataFrame): A dataframe to generate a hash from

    Returns:
        int: Summed value of hashed rows of an input DataFrame
    """
    # Hash every row into a new hash column
    df = df.withColumn('hash_value', F.hash(*sorted(df.columns))).select('hash_value')

    # Sum the hashes, see https://shortest.link/28YE
    value = df.agg(F.sum('hash_value')).collect()[0][0]

    return value

expected_hash = hash_df(expected_df)
actual_hash = hash_df(actual_df)
assert expected_hash == actual_hash

Chatoyant answered 18/10, 2022 at 0:38 Comment(1)
+1 for creativity. and i imagine that while probability of a hash collision increases with the number of columns in the pyspark dataframe, it's essentially 0 for almost all situationsBalder
B
1

If the overhead of an additional library such as pyspark_test is a problem, you could try sorting both dataframes by the same columns, converting them to pandas, and using pd.testing.assert_frame_equal.

I know that the .toPandas method for pyspark dataframes is generally discouraged because the data is loaded into the driver's memory (see the pyspark documentation here), but this solution works for relatively small unit tests.

For example:

sort_cols = actual_df.columns

pd.testing.assert_frame_equal(
    actual_df.sort(sort_cols).toPandas(),
    expected_df.sort(sort_cols).toPandas()
)
Balder answered 15/12, 2022 at 3:2 Comment(0)
D
0

Unfortunately this cannot be done without applying sort on any of the columns(specially on the key column), reason being there isn't any guarantee for ordering of records in a DataFrame . You cannot predict the order in which the records are going to appear in the dataframe. The below approach works fine for me:

expected = expected_df.orderBy('period_start_time').collect()
actual = actaual_df.orderBy('period_start_time').collect() 
assert expected == actual
Delinquency answered 14/3, 2019 at 20:3 Comment(0)
C
0

One way is to use the chispa package.

from chispa.dataframe_comparer import assert_df_equality

assert_df_equality(actual_df, expected_df, ignore_row_order=True)

You can also ignore the column order and set other arguments. Here is a quick look at the function signature.

Signature:
assert_df_equality(
    df1,
    df2,
    ignore_nullable=False,
    transforms=None,
    allow_nan_equality=False,
    ignore_column_order=False,
    ignore_row_order=False,
    underline_cells=False,
    ignore_metadata=False,
)

You can check the documentation here.

Celenacelene answered 30/12, 2023 at 0:21 Comment(0)
A
0

As of 3.5.0, Spark has included assert_pyspark_df_equal in pyspark.testing; you can read more about it here.

Allocation answered 9/3 at 14:22 Comment(0)
T
-1

I have two Dataframes with the same order. Comparing this two I use:

def test_df(df1, df2):
    assert df1.values.tolist() == df2.values.tolist()
Taskmaster answered 19/10, 2022 at 9:46 Comment(0)
P
-2

try to have "==" instead of "=". assert expected == actual

Plateau answered 19/2, 2020 at 16:45 Comment(0)
S
-4

Another way go about that ensuring sort order would be:

from pandas.testing import assert_frame_equal

def assert_frame_with_sort(results, expected, key_columns):
    results_sorted = results.sort_values(by=key_columns).reset_index(drop=True)
    expected_sorted = expected.sort_values(by=key_columns).reset_index(drop=True)
    assert_frame_equal(results_sorted, expected_sorted)
Simplicity answered 31/1, 2019 at 11:8 Comment(1)
That's for pandas DataFrame objects, not Spark DataFrame.Tower

© 2022 - 2024 — McMap. All rights reserved.