As far as I am aware, to really do this efficiently would require it being implemented natively in Rust.
I'm not sure if this type of method would be added to Polars.
Although it is related to non-equi joins as @RomanPekar has stated.
Timings
Method |
Time |
DuckDB (min) |
9.24s |
Polars (cross-join)* |
8.87s |
Polars (slice/filter)* |
3.61s |
Numba** |
0.2s |
Polars (Rust Plugin) |
0.01s |
(*) method finds rows with matches, need to join result back
(**) Numba removed due to not handling nulls: https://github.com/pola-rs/polars/issues/14811
Dummy data
For comparison, I used a single column with 40k rows.
import polars as pl
import numpy as np
ROWS = 40_000
(pl.from_numpy(np.random.random((ROWS, 1)), schema=["Column B"])
.with_row_index()
.write_parquet("random.parquet")
)
Template
To time each approach, I ran each separately in their own file.
import duckdb
import time
import polars as pl
lf = pl.scan_parquet("random.parquet")
start = time.perf_counter()
# df = ...
end = time.perf_counter() - start
DuckDB (min)
@RomanPekar's approach modified:
df = duckdb.sql("""
from lf as self
select *, (
from lf select min(index)
where "Column B" >= self."Column B" * 1.5
and index > self.index
) as "Column C"
order by index
""").pl()
- (note: The original query took
59.71s
- using min()
was much faster)
Polars (cross-join)
@Hericks's approach modified:
df = (
lf.join(lf, how="cross")
.filter(
pl.col("Column B") * 1.5 <= pl.col("Column B_right"),
pl.col("index") < pl.col("index_right")
)
.group_by("index")
.agg(pl.col("index_right").min())
.sort("index")
.collect(streaming=True)
)
Polars (slice/filter)
Naive approach, requires consuming entire column into memory.
col_b = lf.select("Column B").collect().to_series()
slices = (
lf.slice(start + 1)
.filter(pl.col("Column B") >= value * 1.5)
.limit(1)
.select(
pl.lit(start).alias("index"),
pl.col("index").alias("Column C")
)
for start, value in enumerate(col_b)
)
# pl.collect_all() was slower
df = pl.concat(slices).collect() # streaming=True was slower
Plugins
Writing a plugin is a bit more involved and requires learning a bit about internals.
The same naive slice/filter approach.
#[polars_expr(output_type=UInt32)]
fn find_next_value(inputs: &[Series]) -> PolarsResult<Series> {
let lhs = inputs[0].f64()?;
let rhs = inputs[1].f64()?;
let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let chunks = lhs
.downcast_iter()
.zip(rhs.downcast_iter())
.map(|(lhs_arr, rhs_arr)| {
let mut positions = MutablePrimitiveArray::with_capacity(lhs_arr.len());
lhs_arr.iter().enumerate().for_each(|(idx, left)| {
let pos = rhs_arr
.iter()
.skip(idx + 1)
.position(|right| left.is_some() && right >= left);
match pos {
Some(pos) => positions.push(Some(pos as u32 + 1)),
_ => positions.push(None),
}
});
positions.freeze().boxed()
})
.collect();
let out: UInt32Chunked = unsafe { ChunkedArray::from_chunks(lhs.name(), chunks) };
Ok(out.into_series())
}
I imagine there is an actual name for this task and more efficient algorithms exist?
Column B
is different in the input and output table, I think the one in the output table is correct. You might want to change it to make it easier to work on the answer – Coatee