Non-equi join in polars
Asked Answered
C

4

5

If you come from the future, hopefully this PR has already been merged.

If you don't come from the future, hopefully this answer solves your problem.

I want to solve my problem only with polars (which I am no expert, but I can follow what is going on), before just copy-pasting the DuckDB integration suggested above and compare the results in my real data.

I have a list of events (name and timestamp), and a list of time windows. I want to count how many of each event occur in each time window.

I feel like I am close to getting something that works correctly, but I have been stuck for a couple of hours now:

import polars as pl

events = {
    "name": ["a", "b", "a", "b", "a", "c", "b", "a", "b", "a", "b", "a", "b", "a", "b", "a", "b", "a", "b"],
    "time": [0.0, 1.0, 1.5, 2.0, 2.25, 2.26, 2.45, 2.5, 3.0, 3.4, 3.5, 3.6, 3.65, 3.7, 3.8, 4.0, 4.5, 5.0, 6.0],
}

windows = {
    "start_time": [1.0, 2.0, 3.0, 4.0],
    "stop_time": [3.5, 2.5, 3.7, 5.0],
}

events_df = pl.DataFrame(events).sort("time").with_row_index()
windows_df = (
    pl.DataFrame(windows)
    .sort("start_time")
    .join_asof(events_df, left_on="start_time", right_on="time", strategy="forward")
    .drop("name", "time")
    .rename({"index": "first_index"})
    .sort("stop_time")
    .join_asof(events_df, left_on="stop_time", right_on="time", strategy="backward")
    .drop("name", "time")
    .rename({"index": "last_index"})
)

print(windows_df)
"""
shape: (4, 4)
┌────────────┬───────────┬─────────────┬────────────┐
│ start_time ┆ stop_time ┆ first_index ┆ last_index │
│ ---        ┆ ---       ┆ ---         ┆ ---        │
│ f64        ┆ f64       ┆ u32         ┆ u32        │
╞════════════╪═══════════╪═════════════╪════════════╡
│ 2.0        ┆ 2.5       ┆ 3           ┆ 7          │
│ 1.0        ┆ 3.5       ┆ 1           ┆ 10         │
│ 3.0        ┆ 3.7       ┆ 8           ┆ 13         │
│ 4.0        ┆ 5.0       ┆ 15          ┆ 17         │
└────────────┴───────────┴─────────────┴────────────┘
"""

So far, for each time window, I can get the index of the first and last events that I care about. Now I "just" need to count how many of these are of each type. Can I get some help on how to do this?

The output I am looking for should look like:

shape: (4, 5)
┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ i64 ┆ i64 ┆ i64 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 4   ┆ 5   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 2   ┆ 1   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘

I feel like using something like int_ranges(), gather(), and explode() can get me a dataframe with each time window and all it's corresponding events. Finally, something like group_by(), count(), and pivot() can get me to the dataframe I want. But I have been struggling with this for a while.

Confetti answered 28/8 at 7:29 Comment(2)
One of the easy approaches to do it in polars is to just run cross join and filter out events which are outside of the windowCarruth
@RomanPekar Yes, in the DuckDB answer I quoted above (which is yours hehe), note that I was asking because I solved it with cross join, but the size of the dataframe explodes and my process gets killed.Confetti
C
2

update join_where() was released in version 1.7.0:

(
    windows_df
    .join_where(
        events_df,
        pl.col.time >= pl.col.start_time,
        pl.col.time <= pl.col.stop_time,
    )
    .sort("name", "start_time")
    .pivot(on="name", index=["start_time","stop_time"], aggregate_function="len", values="time")
    .fill_null(0)
)

┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 4   ┆ 5   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 2   ┆ 1   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘

previous I was looking at your double join_asof and thought that maybe you could also use another approach, which would not require an explode().

The thing is, you don't really need the data from events_df, only counts. So if we do join_asof for every possible value in name then we can calculate counts by simple arithmetic. First, let's prepare our DataFrames.

events_df = pl.DataFrame(events)

windows_df = (
    pl.DataFrame(windows)
    .join(
        events_df.select(pl.col.name.unique()),
        how="cross"
)

┌────────────┬───────────┬──────┐
│ start_time ┆ stop_time ┆ name │
│ ---        ┆ ---       ┆ ---  │
│ f64        ┆ f64       ┆ str  │
╞════════════╪═══════════╪══════╡
│ 1.0        ┆ 3.5       ┆ c    │
│ 1.0        ┆ 3.5       ┆ a    │
│ 1.0        ┆ 3.5       ┆ b    │
│ 2.0        ┆ 2.5       ┆ c    │
│ 2.0        ┆ 2.5       ┆ a    │
│ …          ┆ …         ┆ …    │
│ 3.0        ┆ 3.7       ┆ a    │
│ 3.0        ┆ 3.7       ┆ b    │
│ 4.0        ┆ 5.0       ┆ c    │
│ 4.0        ┆ 5.0       ┆ a    │
│ 4.0        ┆ 5.0       ┆ b    │
└────────────┴───────────┴──────┘
events_df= (
    events_df
    .with_columns(index = pl.int_range(pl.len()).over("name"))
)

┌──────┬──────┬───────┐
│ name ┆ time ┆ index │
│ ---  ┆ ---  ┆ ---   │
│ str  ┆ f64  ┆ i64   │
╞══════╪══════╪═══════╡
│ a    ┆ 0.0  ┆ 0     │
│ b    ┆ 1.0  ┆ 0     │
│ a    ┆ 1.5  ┆ 1     │
│ b    ┆ 2.0  ┆ 1     │
│ a    ┆ 2.25 ┆ 2     │
│ …    ┆ …    ┆ …     │
│ b    ┆ 3.8  ┆ 6     │
│ a    ┆ 4.0  ┆ 7     │
│ b    ┆ 4.5  ┆ 7     │
│ a    ┆ 5.0  ┆ 8     │
│ b    ┆ 6.0  ┆ 8     │
└──────┴──────┴───────┘

Now we can do the same join you did, but add by parameter, so we do it within name column:

result_df = (
    windows_df
    .sort("name", "start_time")
    .join_asof(events_df, left_on="start_time", right_on="time", strategy="forward", by="name")
    .drop("time")
    .rename({"index": "first_index"})
    .sort("name", "stop_time")
    .join_asof(events_df, left_on="stop_time", right_on="time", strategy="backward", by="name")
    .drop("time")
    .rename({"index": "last_index"})
)

┌────────────┬───────────┬──────┬─────────────┬────────────┐
│ start_time ┆ stop_time ┆ name ┆ first_index ┆ last_index │
│ ---        ┆ ---       ┆ ---  ┆ ---         ┆ ---        │
│ f64        ┆ f64       ┆ str  ┆ i64         ┆ i64        │
╞════════════╪═══════════╪══════╪═════════════╪════════════╡
│ 2.0        ┆ 2.5       ┆ a    ┆ 2           ┆ 3          │
│ 1.0        ┆ 3.5       ┆ a    ┆ 1           ┆ 4          │
│ 3.0        ┆ 3.7       ┆ a    ┆ 4           ┆ 6          │
│ 4.0        ┆ 5.0       ┆ a    ┆ 7           ┆ 8          │
│ 2.0        ┆ 2.5       ┆ b    ┆ 1           ┆ 2          │
│ …          ┆ …         ┆ …    ┆ …           ┆ …          │
│ 4.0        ┆ 5.0       ┆ b    ┆ 7           ┆ 7          │
│ 2.0        ┆ 2.5       ┆ c    ┆ 0           ┆ 0          │
│ 1.0        ┆ 3.5       ┆ c    ┆ 0           ┆ 0          │
│ 3.0        ┆ 3.7       ┆ c    ┆ null        ┆ 0          │
│ 4.0        ┆ 5.0       ┆ c    ┆ null        ┆ 0          │
└────────────┴───────────┴──────┴─────────────┴────────────┘

And now you can calculate the result by simple last_index - first_index + 1:

(
    result_df
    .with_columns(index = pl.col.last_index - pl.col.first_index + 1)
    .pivot(on="name", index=["start_time","stop_time"], values="index")
    .fill_null(0)
    .sort("start_time", "stop_time")
)

┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ i64 ┆ i64 ┆ i64 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 4   ┆ 5   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 2   ┆ 1   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘
Carruth answered 28/8 at 14:4 Comment(5)
This is very interesting, but I am having some issues getting it to work with my real data. Let me try to understand why and I'll get back to you.Confetti
In your second step, when you do the double join_asof, is there a typo somewhere/something missing? I don't see how you get those values, maybe there is some transformation missing to events_df inside each join_asof? I don't see where your index comes from, and it certainly is not the plain events_df = pl.DataFrame(events).sort("time").with_row_index().Confetti
I think the correct thing would be events_df.sort("name", "time").with_row_index() inside each join_asof. This works with the example in my original question, but still complains with real data. I'll keep digging :)Confetti
This also works with my real data. I was worried at intermediate steps when some times first_index is larger than last_index (window is empty), but overflow handles that OK.Confetti
@Confetti sorry I haven't copied one important step - we of course need index within name, we do it with pl.int_range(pl.len()).over(name)Carruth
C
3

update join_where() was released in version 1.7.0:

(
    windows_df
    .join_where(
        events_df,
        pl.col.time >= pl.col.start_time,
        pl.col.time <= pl.col.stop_time,
    )
    .sort("name", "start_time")
    .pivot(on="name", index=["start_time","stop_time"], aggregate_function="len", values="time")
    .fill_null(0)
)

┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 4   ┆ 5   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 2   ┆ 1   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘

previous Not sure if it will be more performant, but you can transform your windows_df into desirable output with:

  • int_ranges() to create list of indexes from first_index to last_index.
  • explode() to explode the rows.
  • join() to join back to events_df.
  • pivot() to transform rows to columns.
(
    windows_df
    .with_columns(index = pl.int_ranges(pl.col.first_index, pl.col.last_index, dtype=pl.UInt32))
    .explode("index")
    .join(events_df, on="index", how="inner")
    .pivot(on="name", index=["start_time","stop_time"], aggregate_function="len", values="index")
    .fill_null(0)
)

┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 2.0        ┆ 2.5       ┆ 1   ┆ 2   ┆ 1   │
│ 1.0        ┆ 3.5       ┆ 4   ┆ 4   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 2   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 1   ┆ 1   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘
Carruth answered 28/8 at 7:51 Comment(8)
You have no idea how long I was struggling with this. Thanks. Note that pl.col.last_index should be pl.col.last_index + 1 instead. But that is just because I was getting the last inclusive index. Thanks again.Confetti
I have another idea on how that might be calculated a bit different (and maybe faster), but i’m on mobile now, will test it laterCarruth
do you mind pointing me to the docs for pl.col.first_index, as I searched but couldnt get anything (i am probably searching wrong)Schoenberg
@Schoenberg It's just the same as pl.col("first_index") (polars added attribute -> column name dispatching to match pandas df.column_name behaviour after several requests)Handy
@Confetti it was a bit too much to put alternative approach to this answer, so I've added another answer. Curious to see if it will work for you.Carruth
This (inherited from OP) approach assumes the events happen without gaps such that the difference between last and first index is the correct count.Pennyweight
@DeanMacGregor Can you expand on what you mean by "events happen without gaps"? How would "events with gaps" look like for the events_df and windows_df from my original question? I just want to make sure that I am not missing some weird edge case.Confetti
I'm not sure, I just wonder if there's an edge case where making that assumption bites you.Pennyweight
S
2

You can avoid the multiple join_asofs and use search_sorted - which is how you would handle a range join anyway if your data fits the scenario. Your data is sorted on the time column, so take advantage of binary search to get your starts and ends. Unfortunately, the explode cannot be avoided ( a more performant route possibly exists that I am not aware of) - it would be much better/faster if the explode part could be avoided, and the aggregation done within the starts and ends - that might require going into rust and doing that.

time = events.get_column('time')
stop_time = windows.get_column('stop_time')
start_time = windows.get_column('start_time')
starts = time.search_sorted(start_time, side='left')
ends = time.search_sorted(stop_time, side='right')
indices = pl.int_ranges(starts, ends, dtype=pl.UInt32)
(windows
.with_columns(indices=indices)
# ideally would love to avoid the explosion here
# iterating in a low level language and aggregating
# should offer more performance
.explode('indices')
.join(events.with_row_index(name='indices'), on='indices')
.pivot(
    index=['start_time', 'stop_time'], 
    on='name', 
    aggregate_function='len',
    values='time')
.fill_null(0)
)

shape: (4, 5)
┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ b   ┆ a   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 5   ┆ 4   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 1   ┆ 2   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘

You can also achieve the results via gather and implode and explode - again explode rears its head - i'd go with the join route, since it is more explicit and clear what we are trying to achieve:

imploded=events.select(pl.struct(pl.all()).implode())
(windows
.with_columns(imploded, indices=indices)
.select('start_time','stop_time',
        pl.col.name.list.gather(pl.col.indices))
.explode('name')
.unnest('name')
.pivot(on='name',
       values='time',
       aggregate_function='len')
.fill_null(0)
)

shape: (4, 5)
┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ b   ┆ a   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 5   ┆ 4   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 1   ┆ 2   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘

At any rate, use a binary search if you can for range joins, it shows in the performance.

update: One way to reduce the number of rows to explode - which should improve performance especially as your dataset grows - is to do the counting within the lists :

imploded=events.select(pl.col.name.implode())
(windows
.with_columns(imploded, indices=indices)
.select('start_time',
        'stop_time',
        pl.col.name  
        .list.gather(pl.col.indices)
        # aggregation happens here
        # maybe turn on parallel here 
        # for more performance
        .list.eval(pl.element()
                     .value_counts()
                  )
      )
.explode('name')
.with_columns(name=pl.col.name
                         .struct
                         .rename_fields(['name','counts'])
              )
.unnest('name')
.pivot(on='name',values='counts')
.fill_null(0)
)

shape: (4, 5)
┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ b   ┆ a   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 5   ┆ 4   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 1   ┆ 2   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘
Schoenberg answered 28/8 at 12:6 Comment(0)
C
2

update join_where() was released in version 1.7.0:

(
    windows_df
    .join_where(
        events_df,
        pl.col.time >= pl.col.start_time,
        pl.col.time <= pl.col.stop_time,
    )
    .sort("name", "start_time")
    .pivot(on="name", index=["start_time","stop_time"], aggregate_function="len", values="time")
    .fill_null(0)
)

┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 4   ┆ 5   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 2   ┆ 1   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘

previous I was looking at your double join_asof and thought that maybe you could also use another approach, which would not require an explode().

The thing is, you don't really need the data from events_df, only counts. So if we do join_asof for every possible value in name then we can calculate counts by simple arithmetic. First, let's prepare our DataFrames.

events_df = pl.DataFrame(events)

windows_df = (
    pl.DataFrame(windows)
    .join(
        events_df.select(pl.col.name.unique()),
        how="cross"
)

┌────────────┬───────────┬──────┐
│ start_time ┆ stop_time ┆ name │
│ ---        ┆ ---       ┆ ---  │
│ f64        ┆ f64       ┆ str  │
╞════════════╪═══════════╪══════╡
│ 1.0        ┆ 3.5       ┆ c    │
│ 1.0        ┆ 3.5       ┆ a    │
│ 1.0        ┆ 3.5       ┆ b    │
│ 2.0        ┆ 2.5       ┆ c    │
│ 2.0        ┆ 2.5       ┆ a    │
│ …          ┆ …         ┆ …    │
│ 3.0        ┆ 3.7       ┆ a    │
│ 3.0        ┆ 3.7       ┆ b    │
│ 4.0        ┆ 5.0       ┆ c    │
│ 4.0        ┆ 5.0       ┆ a    │
│ 4.0        ┆ 5.0       ┆ b    │
└────────────┴───────────┴──────┘
events_df= (
    events_df
    .with_columns(index = pl.int_range(pl.len()).over("name"))
)

┌──────┬──────┬───────┐
│ name ┆ time ┆ index │
│ ---  ┆ ---  ┆ ---   │
│ str  ┆ f64  ┆ i64   │
╞══════╪══════╪═══════╡
│ a    ┆ 0.0  ┆ 0     │
│ b    ┆ 1.0  ┆ 0     │
│ a    ┆ 1.5  ┆ 1     │
│ b    ┆ 2.0  ┆ 1     │
│ a    ┆ 2.25 ┆ 2     │
│ …    ┆ …    ┆ …     │
│ b    ┆ 3.8  ┆ 6     │
│ a    ┆ 4.0  ┆ 7     │
│ b    ┆ 4.5  ┆ 7     │
│ a    ┆ 5.0  ┆ 8     │
│ b    ┆ 6.0  ┆ 8     │
└──────┴──────┴───────┘

Now we can do the same join you did, but add by parameter, so we do it within name column:

result_df = (
    windows_df
    .sort("name", "start_time")
    .join_asof(events_df, left_on="start_time", right_on="time", strategy="forward", by="name")
    .drop("time")
    .rename({"index": "first_index"})
    .sort("name", "stop_time")
    .join_asof(events_df, left_on="stop_time", right_on="time", strategy="backward", by="name")
    .drop("time")
    .rename({"index": "last_index"})
)

┌────────────┬───────────┬──────┬─────────────┬────────────┐
│ start_time ┆ stop_time ┆ name ┆ first_index ┆ last_index │
│ ---        ┆ ---       ┆ ---  ┆ ---         ┆ ---        │
│ f64        ┆ f64       ┆ str  ┆ i64         ┆ i64        │
╞════════════╪═══════════╪══════╪═════════════╪════════════╡
│ 2.0        ┆ 2.5       ┆ a    ┆ 2           ┆ 3          │
│ 1.0        ┆ 3.5       ┆ a    ┆ 1           ┆ 4          │
│ 3.0        ┆ 3.7       ┆ a    ┆ 4           ┆ 6          │
│ 4.0        ┆ 5.0       ┆ a    ┆ 7           ┆ 8          │
│ 2.0        ┆ 2.5       ┆ b    ┆ 1           ┆ 2          │
│ …          ┆ …         ┆ …    ┆ …           ┆ …          │
│ 4.0        ┆ 5.0       ┆ b    ┆ 7           ┆ 7          │
│ 2.0        ┆ 2.5       ┆ c    ┆ 0           ┆ 0          │
│ 1.0        ┆ 3.5       ┆ c    ┆ 0           ┆ 0          │
│ 3.0        ┆ 3.7       ┆ c    ┆ null        ┆ 0          │
│ 4.0        ┆ 5.0       ┆ c    ┆ null        ┆ 0          │
└────────────┴───────────┴──────┴─────────────┴────────────┘

And now you can calculate the result by simple last_index - first_index + 1:

(
    result_df
    .with_columns(index = pl.col.last_index - pl.col.first_index + 1)
    .pivot(on="name", index=["start_time","stop_time"], values="index")
    .fill_null(0)
    .sort("start_time", "stop_time")
)

┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ i64 ┆ i64 ┆ i64 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 4   ┆ 5   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 2   ┆ 1   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘
Carruth answered 28/8 at 14:4 Comment(5)
This is very interesting, but I am having some issues getting it to work with my real data. Let me try to understand why and I'll get back to you.Confetti
In your second step, when you do the double join_asof, is there a typo somewhere/something missing? I don't see how you get those values, maybe there is some transformation missing to events_df inside each join_asof? I don't see where your index comes from, and it certainly is not the plain events_df = pl.DataFrame(events).sort("time").with_row_index().Confetti
I think the correct thing would be events_df.sort("name", "time").with_row_index() inside each join_asof. This works with the example in my original question, but still complains with real data. I'll keep digging :)Confetti
This also works with my real data. I was worried at intermediate steps when some times first_index is larger than last_index (window is empty), but overflow handles that OK.Confetti
@Confetti sorry I haven't copied one important step - we of course need index within name, we do it with pl.int_range(pl.len()).over(name)Carruth
P
1

This answer avoids exploding or cross joining the events with the assumption that it is large while the windows are small enough to cross join once. The idea is to mimic a SQL style range join by converting windows into a new frame which has a begin time for every "mode" rather than a begin/end time. It should also have a column to track which of the begin/end times apply.

Let's make our initial dataframes

events = pl.DataFrame(events).sort("time")
windows = pl.DataFrame(windows)

Now let's make begins. It starts out as just the unique values of start and stop times. You've got to add a small number to stop_time since you want events inclusive of the exact stop_time. We cross join that with the original windows then do a group_by/agg so that there's one row per time with a list of windows that apply until the next time.

begins = (
    windows["start_time"]
    .extend(windows["stop_time"]+.00001) # need to add small number to make end_time inclusive
    .sort()
    .unique()
    .alias("time")
    .to_frame()
    .join(windows, how="cross")
    .group_by("time", maintain_order=True)
    .agg(i=pl.struct('start_time','stop_time').filter(
        (pl.col("start_time") <= pl.col("time"))
        & (pl.col("stop_time") > pl.col("time"))
    ))
)
shape: (8, 2)
┌─────────┬────────────────────────┐
│ time    ┆ i                      │
│ ---     ┆ ---                    │
│ f64     ┆ list[struct[2]]        │
╞═════════╪════════════════════════╡
│ 1.0     ┆ [{1.0,3.5}]            │ # at 1.0 only 1.0-3.5
│ 2.0     ┆ [{1.0,3.5}, {2.0,2.5}] │ # at 2.0 both 1.0-3.5 and 2.0-2.5
│ 2.50001 ┆ [{1.0,3.5}]            │ # at 2.5 only 1.0-3.5
│ 3.0     ┆ [{1.0,3.5}, {3.0,3.7}] │
│ 3.50001 ┆ [{3.0,3.7}]            │
│ 3.70001 ┆ []                     │ # at 3.7 no windows are in force
│ 4.0     ┆ [{4.0,5.0}]            │
│ 5.00001 ┆ []                     │
└─────────┴────────────────────────┘

With that, we can use join_asof after which we can explode the i column, then group_by i and name to agg a count (or rather len). We then need a pivot to make the groups into columns and then to unnest the i to get the original window start/stop times. For column sorting, we do a select at the end.

(
    events
    .join_asof(begins, on="time")
    .explode("i")
    .group_by("i","name",maintain_order=True)
    .agg(pl.len())
    .pivot(on='name',index='i', values='len')
    .unnest('i')
    .filter(pl.col('start_time').is_not_null())
    .select('start_time','stop_time',pl.exclude('start_time','stop_time','i').fill_null(0))
)
shape: (4, 5)
┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 4   ┆ 5   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 2   ┆ 1   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘
Pennyweight answered 28/8 at 19:8 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.