Getting min/max column name in Polars
Asked Answered
T

3

6

In polars I can get the horizontal max (maximum value of a set of columns for reach row) like this:

df = pl.DataFrame(
    {
        "a": [1, 8, 3],
        "b": [4, 5, None],
    }
)

df.with_columns(max = pl.max_horizontal("a", "b"))
┌─────┬──────┬─────┐
│ a   ┆ b    ┆ max │
│ --- ┆ ---  ┆ --- │
│ i64 ┆ i64  ┆ i64 │
╞═════╪══════╪═════╡
│ 1   ┆ 4    ┆ 4   │
│ 8   ┆ 5    ┆ 8   │
│ 3   ┆ null ┆ 3   │
└─────┴──────┴─────┘

This corresponds to Pandas df[["a", "b"]].max(axis=1).

Now, how do I get the column names instead of the actual max value? In other words, what is the Polars version of Pandas' df[CHANGE_COLS].idxmax(axis=1)?

The expected output would be:

┌─────┬──────┬─────┐
│ a   ┆ b    ┆ max │
│ --- ┆ ---  ┆ --- │
│ i64 ┆ i64  ┆ str │
╞═════╪══════╪═════╡
│ 1   ┆ 4    ┆ b   │
│ 8   ┆ 5    ┆ a   │
│ 3   ┆ null ┆ a   │
└─────┴──────┴─────┘
Tiffany answered 9/2, 2024 at 10:9 Comment(0)
S
3

You can concatenate the elements into a list using pl.concat_list, get the index of the largest element using pl.Expr.list.arg_max, and replace the index with the column name using pl.Expr.replace.

mapping = {0: "a", 1: "b"}
(
    df
    .with_columns(
        pl.concat_list(["a", "b"]).list.arg_max().replace(mapping).alias("max_col")
    )
)

This can all be wrapped into a function to also handle the creation of the mapping dict.

def max_col(cols) -> str:
    mapping = dict(enumerate(cols))
    return pl.concat_list(cols).list.arg_max().replace(mapping)

df.with_columns(max_col(["a", "b"]).alias("max_col"))

Output.

shape: (3, 3)
┌─────┬──────┬─────────┐
│ a   ┆ b    ┆ max_col │
│ --- ┆ ---  ┆ ---     │
│ i64 ┆ i64  ┆ str     │
╞═════╪══════╪═════════╡
│ 1   ┆ 4    ┆ b       │
│ 8   ┆ 5    ┆ a       │
│ 3   ┆ null ┆ a       │
└─────┴──────┴─────────┘
Suzette answered 9/2, 2024 at 10:33 Comment(4)
maybe you can use {v:i for i,v in enumerate(df.columns)} to cover all the columnsDorkas
@RomanPekar Yes, that's a good idea! I've just added a more general function to handle the creation of the mapping dict.Suzette
dict(enumerate(cols)) should work as well instead of zipDorkas
@RomanPekar Thanks - I've edited the answer to use enumerate instead of zip.Suzette
K
5

You can build a "multiple choice" of when/then and pl.coalesce them to create a single column of results.

df.with_columns(max_col = 
   pl.coalesce(
      pl.when(pl.col(name) == pl.max_horizontal(df.columns))
        .then(pl.lit(name))
      for name in df.columns
   )
)
shape: (3, 3)
┌─────┬──────┬─────────┐
│ a   ┆ b    ┆ max_col │
│ --- ┆ ---  ┆ ---     │
│ i64 ┆ i64  ┆ str     │
╞═════╪══════╪═════════╡
│ 1   ┆ 4    ┆ b       │
│ 8   ┆ 5    ┆ a       │
│ 3   ┆ null ┆ a       │
└─────┴──────┴─────────┘
Kibosh answered 9/2, 2024 at 12:50 Comment(1)
Yours is the fastest with df.shape==(1000000, 6)Solidify
S
3

You can concatenate the elements into a list using pl.concat_list, get the index of the largest element using pl.Expr.list.arg_max, and replace the index with the column name using pl.Expr.replace.

mapping = {0: "a", 1: "b"}
(
    df
    .with_columns(
        pl.concat_list(["a", "b"]).list.arg_max().replace(mapping).alias("max_col")
    )
)

This can all be wrapped into a function to also handle the creation of the mapping dict.

def max_col(cols) -> str:
    mapping = dict(enumerate(cols))
    return pl.concat_list(cols).list.arg_max().replace(mapping)

df.with_columns(max_col(["a", "b"]).alias("max_col"))

Output.

shape: (3, 3)
┌─────┬──────┬─────────┐
│ a   ┆ b    ┆ max_col │
│ --- ┆ ---  ┆ ---     │
│ i64 ┆ i64  ┆ str     │
╞═════╪══════╪═════════╡
│ 1   ┆ 4    ┆ b       │
│ 8   ┆ 5    ┆ a       │
│ 3   ┆ null ┆ a       │
└─────┴──────┴─────────┘
Suzette answered 9/2, 2024 at 10:33 Comment(4)
maybe you can use {v:i for i,v in enumerate(df.columns)} to cover all the columnsDorkas
@RomanPekar Yes, that's a good idea! I've just added a more general function to handle the creation of the mapping dict.Suzette
dict(enumerate(cols)) should work as well instead of zipDorkas
@RomanPekar Thanks - I've edited the answer to use enumerate instead of zip.Suzette
S
3

Here's a way to get both the max and the column name at once in a using a fold (which is how the horizontal functions work behind the scenes anyway).

df.lazy().with_columns(
    max_col = (max_struct:=pl.fold(
        acc=pl.struct(value=-1e20, name=pl.lit("default low value")), 
        function = lambda x,y: (
            pl.when(x.struct.field('value')>y)
            .then(x)
            .otherwise(pl.struct(value=y, name=pl.lit(y.name)))),
        exprs=pl.all()
        )).struct.field('name'),
    max_value=max_struct.struct.field('value')
    ).collect()

A reduce is an operation that will take two columns at a time pass them to a function and take that output. If there are more than two columns it takes that output to the to the next column, and so on. A fold is the same thing except that it starts with an accumulator which is just a default first column. In that way the first pair is the accumulator and the actual first column but after that it's the same.

The fold gets Series which are themselves named so we can simply make it return a struct of the name of the Series that is bigger along with the bigger value so that we get both outputs at once.

In the above I use the walrus operator so that we can take apart the struct in the context that we create it. I make the df lazy so that it will cache the result to a CSE rather than doing it twice.

Performance comparison

Starting with

n=int(1e6)
df=pl.DataFrame({
    'a':np.random.uniform(-5,5,n),
    'b':np.random.uniform(-5,5,n),
    'c':np.random.uniform(-5,5,n),
    'd':np.random.uniform(-5,5,n),
    'e':np.random.uniform(-5,5,n),
    'f':np.random.uniform(-5,5,n)
})

The results with lazy operations

shape: (3, 2)
┌──────────┬───────────┐
│ user     ┆ timeit_ms │
│ ---      ┆ ---       │
│ str      ┆ i64       │
╞══════════╪═══════════╡
│ jqurious ┆ 103       │ # coalesce
│ Dean     ┆ 257       │ # fold
│ Hericks  ┆ 1110      │ # concat_list 
└──────────┴───────────┘
Solidify answered 9/2, 2024 at 12:24 Comment(3)
nice one, I've tried to do this as well with list instead of struct, but stopped when it became too complicatedDorkas
@RomanPekar It doesn't work with list b/c a list has to be one data type. You could do a list of min/max values since those are both floats but value/name pair has to be a struct.Solidify
yep true. I like the idea but at the end solution gets too complicated. It's a pity there's no horizontal_arg_max().Dorkas

© 2022 - 2025 — McMap. All rights reserved.