Find the next value >= the actual value plus 50% using polars
Asked Answered
D

4

5

I have the following dataframe:

 df = pl.DataFrame({
        "Column A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        "Column B": [2, 3, 1, 4, 1, 7, 3, 2, 12, 0]
    })

I want to create a new column C that holds the distance, in rows, between the B value of the current row and the next value in column B that is greater than or equal to B + 50%.

The end result should look like this:

 df = pl.DataFrame({
        "Column A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        "Column B": [2, 3, 1, 4, 1, 7, 3, 2, 12, 0],
        "Column C": [1, 4, 1, 2, 1, 3, 2, 1, None, None]
    })

How can I efficiently achieve this using Polars, especially since I'm working with a large DataFrame?

Distrust answered 30/4, 2024 at 0:18 Comment(2)
just a note, your Column B is different in the input and output table, I think the one in the output table is correct. You might want to change it to make it easier to work on the answerCoatee
Sorry, changed. ThanksDistrust
C
2

Ok, so first I should say - this one looks like it requires join on inequality on multiple columns and from what I've found pure polars is not great with it. It's probably possible to do it with join_asof but I couldn't make it pretty.

I'd probably use duckdb integration with polars to achieve the results:

import duckdb

duckdb.sql("""
    select
        d."Column A",
        d."Column B",
        (
            select tt."Column A"
            from df as tt
            where tt."Column A" > d."Column A" and tt."Column B" >= d."Column B" * 1.5
            order by tt."Column A" asc
            limit 1
        ) - d."Column A" as "Column C"
    from df as d
""").pl()

┌──────────┬──────────┬──────────┐
│ Column A ┆ Column B ┆ Column C │
│ ---      ┆ ---      ┆ ---      │
│ i64      ┆ i64      ┆ i64      │
╞══════════╪══════════╪══════════╡
│ 1        ┆ 2        ┆ 1        │
│ 2        ┆ 3        ┆ 4        │
│ 3        ┆ 1        ┆ 1        │
│ 4        ┆ 4        ┆ 2        │
│ 5        ┆ 1        ┆ 1        │
│ 6        ┆ 7        ┆ 3        │
│ 7        ┆ 3        ┆ 2        │
│ 8        ┆ 2        ┆ 1        │
│ 9        ┆ 12       ┆ null     │
│ 10       ┆ 0        ┆ null     │
└──────────┴──────────┴──────────┘
Coatee answered 30/4, 2024 at 7:30 Comment(0)
B
2

An alternative that relies purely on polars' native expression API could be to

  1. use a cross join and a suitable filter to create output rows of interest,
  2. join these rows back to the original (lazy) dataframe to obtain the final result.

First, we convert the pl.DataFrame to pl.LazyFrame to enable efficient outer-joining by query plan optimisation.

lf = df.lazy()

(
    lf
    # join back to original dataframe
    .join(
        (
            lf
            # perform cross join and filter for rows that satisfy conditions
            .join(lf, how="cross")
            .filter(
                pl.col("Column A_right") > pl.col("Column A"),
                pl.col("Column B_right") >= 1.5 * pl.col("Column B"),
            )
            # select first row within each group defined by Column A
            .group_by("Column A").agg(pl.all().first())
            # create output column and select only columns of interest
            .select(
                "Column A",
                (pl.col("Column A_right") - pl.col("Column A")).alias("Column C")
            )
        ),
        on="Column A",
        how="left",
    )
    .collect(streaming=True)
)
shape: (10, 3)
┌──────────┬──────────┬──────────┐
│ Column A ┆ Column B ┆ Column C │
│ ---      ┆ ---      ┆ ---      │
│ i64      ┆ i64      ┆ i64      │
╞══════════╪══════════╪══════════╡
│ 1        ┆ 2        ┆ 1        │
│ 2        ┆ 3        ┆ 4        │
│ 3        ┆ 1        ┆ 1        │
│ 4        ┆ 4        ┆ 2        │
│ 5        ┆ 1        ┆ 1        │
│ 6        ┆ 7        ┆ 3        │
│ 7        ┆ 3        ┆ 2        │
│ 8        ┆ 2        ┆ 1        │
│ 9        ┆ 12       ┆ null     │
│ 10       ┆ 0        ┆ null     │
└──────────┴──────────┴──────────┘
Becker answered 30/4, 2024 at 14:10 Comment(0)
H
1

As far as I am aware, to really do this efficiently would require it being implemented natively in Rust.

I'm not sure if this type of method would be added to Polars.

Although it is related to non-equi joins as @RomanPekar has stated.

Timings

Method Time
DuckDB (min) 9.24s
Polars (cross-join)* 8.87s
Polars (slice/filter)* 3.61s
Numba** 0.2s
Polars (Rust Plugin) 0.01s

(*) method finds rows with matches, need to join result back

(**) Numba removed due to not handling nulls: https://github.com/pola-rs/polars/issues/14811


Dummy data

For comparison, I used a single column with 40k rows.

import polars as pl
import numpy as np

ROWS = 40_000

(pl.from_numpy(np.random.random((ROWS, 1)), schema=["Column B"])
   .with_row_index()
   .write_parquet("random.parquet")
)

Template

To time each approach, I ran each separately in their own file.

import duckdb
import time
import polars as pl

lf = pl.scan_parquet("random.parquet")

start = time.perf_counter()

# df = ...

end = time.perf_counter() - start

DuckDB (min)

@RomanPekar's approach modified:

df = duckdb.sql("""
from lf as self
select *, (
   from lf select min(index)
   where "Column B" >= self."Column B" * 1.5 
   and index > self.index
) as "Column C"
order by index
""").pl()
  • (note: The original query took 59.71s - using min() was much faster)

Polars (cross-join)

@Hericks's approach modified:

df = (
   lf.join(lf, how="cross")
     .filter(
        pl.col("Column B") * 1.5 <= pl.col("Column B_right"),
        pl.col("index") < pl.col("index_right")
     )
    .group_by("index")
    .agg(pl.col("index_right").min())
    .sort("index")
    .collect(streaming=True)
)

Polars (slice/filter)

Naive approach, requires consuming entire column into memory.

col_b = lf.select("Column B").collect().to_series()

slices = (
    lf.slice(start + 1)
      .filter(pl.col("Column B") >= value * 1.5)
      .limit(1)
      .select(
         pl.lit(start).alias("index"), 
         pl.col("index").alias("Column C")
      )
      for start, value in enumerate(col_b)
)

# pl.collect_all() was slower
df = pl.concat(slices).collect() # streaming=True was slower

Plugins

Writing a plugin is a bit more involved and requires learning a bit about internals.

The same naive slice/filter approach.

#[polars_expr(output_type=UInt32)]
fn find_next_value(inputs: &[Series]) -> PolarsResult<Series> {
    let lhs = inputs[0].f64()?;
    let rhs = inputs[1].f64()?;

    let (lhs, rhs) = align_chunks_binary(lhs, rhs);

    let chunks = lhs
        .downcast_iter()
        .zip(rhs.downcast_iter())
        .map(|(lhs_arr, rhs_arr)| {
            let mut positions = MutablePrimitiveArray::with_capacity(lhs_arr.len());
            lhs_arr.iter().enumerate().for_each(|(idx, left)| {
                let pos = rhs_arr
                    .iter()
                    .skip(idx + 1)
                    .position(|right| left.is_some() && right >= left);
                match pos {
                    Some(pos) => positions.push(Some(pos as u32 + 1)),
                    _ => positions.push(None),
                }
            });
            positions.freeze().boxed()
        })
        .collect();

    let out: UInt32Chunked = unsafe { ChunkedArray::from_chunks(lhs.name(), chunks) };

    Ok(out.into_series())
}

I imagine there is an actual name for this task and more efficient algorithms exist?

Holmes answered 5/5, 2024 at 12:58 Comment(3)
Thanks for the detailed analysis! Minor: I think something went wrong with the table formatting.Becker
@Becker D'oh, my bad - thanks. It renders fine in the preview. Edit: Seems it was the * characters, and they need to be escaped.Holmes
@Becker It also seems like maybe there could be other way to do this. If we sort ColumnB and search_sorted for ColumnB * 1.5 - we find next_value_index. But we need to somehow include the >= current_index constraint. (and then search again to the find the nearest)Holmes
B
1

Since the previous answers to this question, polars added native support for non-equi joins.

Especially, this PR was merged introducing pl.DataFrame.join_where.

(
    df
    .join_where(
        df,
        pl.col("Column B") * 1.5 <= pl.col("Column B_right"),
        pl.col("index") < pl.col("index_right")
    )
    .group_by("index")
    .agg(pl.col("index_right").min())
    .sort("index")
)
Becker answered 13/9, 2024 at 9:30 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.