Efficient rolling trimmed mean with Python
Asked Answered
H

3

7

What's the most efficient way to calculate a rolling (aka moving window) trimmed mean with Python?

For example, for a data set of 50K rows and a window size of 50, for each row I need to take the last 50 rows, remove the top and bottom 3 values (5% of the window size, rounded up), and get the average of the remaining 44 values.

Currently for each row I'm slicing to get the window, sorting the window and then slicing to trim it. It works, slowly, but there has to be a more efficient way.

Example

[10,12,8,13,7,18,19,9,15,14] # data used for example, in real its a 50k lines df

Example data set and results for a window size of 5. For each row we look at the last 5 rows, sort them and discard 1 top and 1 bottom row (5% of 5 = 0.25, rounded up to 1). Then we average the remaining middle rows.

Code to generate this example set as a DataFrame

pd.DataFrame({
    'value': [10, 12, 8, 13, 7, 18, 19, 9, 15, 14],
    'window_of_last_5_values': [
        np.NaN, np.NaN, np.NaN, np.NaN, '10,12,8,13,7', '12,8,13,7,18',
        '8,13,7,18,19', '13,7,18,19,9', '7,18,19,9,15', '18,19,9,15,14'
    ],
    'values that are counting for average': [
        np.NaN, np.NaN, np.NaN, np.NaN, '10,12,8', '12,8,13', '8,13,18',
        '13,18,9', '18,9,15', '18,15,14'
    ],
    'result': [
        np.NaN, np.NaN, np.NaN, np.NaN, 10.0, 11.0, 13.0, 13.333333333333334,
        14.0, 15.666666666666666
    ]
})

Example code for the naive implementation

window_size = 5
outliers_to_remove = 1

for index in range(window_size - 1, len(df)):
    current_window = df.iloc[index - window_size + 1:index + 1]
    trimmed_mean = current_window.sort_values('value')[
        outliers_to_remove:window_size - outliers_to_remove]['value'].mean()
    # save the result and the window content somewhere

A note about DataFrame vs list vs NumPy array

Just by moving the data from a DataFrame to a list, I'm getting a 3.5x speed boost with the same algorithm. Interestingly, using a NumPy array also gives almost the same speed boost. Still, there must be a better way to implement this and achieve an orders-of-magnitude boost.

Herlindaherm answered 2/9, 2018 at 9:26 Comment(2)
@roganjosh how would you include discarding top/bottom 1%(windowsize) rows of values from the rolling window? Is that possible?Cercaria
I doubt that there is much space for optimization because the calculation itself is too complicated (e.g. not a linear transformation). You may try Cython as well.Agadir
R
12

One observation that could come in handy is that you do not need to sort all the values at each step. Rather, if you ensure that the window is always sorted, all you need to do is insert the new value at the relevant spot, and remove the old one from where it was, both of which are operations that can be done in O(log_2(window_size)) using bisect. In practice, this would look something like

def rolling_mean(data):
    x = sorted(data[:49])
    res = np.repeat(np.nan, len(data))
    for i in range(49, len(data)):
        if i != 49:
            del x[bisect.bisect_left(x, data[i - 50])]
        bisect.insort_right(x, data[i])
        res[i] = np.mean(x[3:47])
    return res

Now, the additional benefit in this case turns out to be less than what is gained by the vectorization that scipy.stats.trim_mean relies on, and so in particular, this will still be slower than @ChrisA's solution, but it is a useful starting point for further performance optimization.

> data = pd.Series(np.random.randint(0, 1000, 50000))
> %timeit data.rolling(50).apply(lambda w: trim_mean(w, 0.06))
727 ms ± 34.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> %timeit rolling_mean(data.values)
812 ms ± 42.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Notably, Numba's jitter, which is often useful in situations like these, also provides no benefit:

> from numba import jit
> rolling_mean_jit = jit(rolling_mean)
> %timeit rolling_mean_jit(data.values)
1.05 s ± 183 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

The following, seemingly far-from-optimal, approach outperforms both of the other approaches considered above:

def rolling_mean_np(data):
    res = np.repeat(np.nan, len(data))
    for i in range(len(data)-49):
        x = np.sort(data[i:i+50])
        res[i+49] = x[3:47].mean()
    return res

Timing:

> %timeit rolling_mean_np(data.values)
564 ms ± 4.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

What is more, this time around, JIT compilation does help:

> rolling_mean_np_jit = jit(rolling_mean_np)
> %timeit rolling_mean_np_jit(data.values)
94.9 ms ± 605 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

While we're at it, let's just quickly verify that this actually does what we expect it to:

> np.all(rolling_mean_np_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True

In fact, by helping out the sorter just a little bit, we can squeeze out another factor of 2, taking the total time down to 57 ms:

def rolling_mean_np_manual(data):
    x = np.sort(data[:50])
    res = np.repeat(np.nan, len(data))
    for i in range(50, len(data)+1):
        res[i-1] = x[3:47].mean()
        if i != len(data):
            idx_old = np.searchsorted(x, data[i-50])
            x[idx_old] = data[i]
            x.sort()
    return res

> %timeit rolling_mean_np_manual(data.values)
580 ms ± 23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> rolling_mean_np_manual_jit = jit(rolling_mean_np_manual)
> %timeit rolling_mean_np_manual_jit(data.values)
57 ms ± 5.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> np.all(rolling_mean_np_manual_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True

Now, the "sorting" that is going on in this example of course just boils down to placing the new element in the right place, while shifting everything in between by one. Doing this by hand will make the pure Python code slower, but the jitted version gains another factor of 2, taking us below 30 ms:

def rolling_mean_np_shift(data):
    x = np.sort(data[:50])
    res = np.repeat(np.nan, len(data))
    for i in range(50, len(data)+1):
        res[i-1] = x[3:47].mean()
        if i != len(data):
            idx_old, idx_new = np.searchsorted(x, [data[i-50], data[i]])
            if idx_old < idx_new:
                x[idx_old:idx_new-1] = x[idx_old+1:idx_new]
                x[idx_new-1] = data[i]
            elif idx_new < idx_old:
                x[idx_new+1:idx_old+1] = x[idx_new:idx_old]
                x[idx_new] = data[i]
            else:
                x[idx_new] = data[i]
    return res

> %timeit rolling_mean_np_shift(data.values)
937 ms ± 97.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> rolling_mean_np_shift_jit = jit(rolling_mean_np_shift)
> %timeit rolling_mean_np_shift_jit(data.values)
26.4 ms ± 693 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
> np.all(rolling_mean_np_shift_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True

At this point, most of the time is spent in np.searchsorted, so let us make the search itself JIT-friendly. Adopting the source code for bisect, we let

@jit
def binary_search(a, x):
    lo = 0
    hi = 50
    while lo < hi:
        mid = (lo+hi)//2
        if a[mid] < x: lo = mid+1
        else: hi = mid
    return lo

@jit
def rolling_mean_np_jitted_search(data):
    x = np.sort(data[:50])
    res = np.repeat(np.nan, len(data))
    for i in range(50, len(data)+1):
        res[i-1] = x[3:47].mean()
        if i != len(data):
            idx_old = binary_search(x, data[i-50])
            idx_new = binary_search(x, data[i])
            if idx_old < idx_new:
                x[idx_old:idx_new-1] = x[idx_old+1:idx_new]
                x[idx_new-1] = data[i]
            elif idx_new < idx_old:
                x[idx_new+1:idx_old+1] = x[idx_new:idx_old]
                x[idx_new] = data[i]
            else:
                x[idx_new] = data[i]
    return res

This takes us down to 12 ms, a x60 improvement over the raw pandas+SciPy approach:

> %timeit rolling_mean_np_jitted_search(data.values)
12 ms ± 210 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Ronda answered 2/9, 2018 at 10:32 Comment(12)
So much interesting stuff here! Numba looks amazing. I'm trying to reproduce the JIT improvement of rolling_mean_np, and I do get a comparable improvement with window size of 50 -- but not with a window size of 1000. Why could that be?Herlindaherm
Answering to myself: I'm assuming it's because the sort is become more and more expensive with larger windows. I'll try to dig into integrating the rest of the improvements soon.Herlindaherm
Sounds plausible; the vectorization in the non-Numba approach probably becomes more effective as size increases. If you want to see exactly where the approaches differ, and you're not already using it, I'll recommend line_profiler (and here's a good guide); it's light-weight compared to many other profilers and easy to get started with.Ronda
@Alex Friedman The sorting function within numba is in many cases slower than the numpy version. It may make sense to put the sorting outside the jit-compiled version.Ruler
@Ronda Hmm, it doesn't look like line_profiler can profile jitted functions. I'm adopting it to analyze the non-jitted code though, thanks.Herlindaherm
Ah, yep, sorry, I should have been clearer on that. In the Numba case, you can use its annotate-html to understand what is going on behind the scenes (and in particular where Numba fails to provide improvements); see e.g. this tutorial. Somewhat less straightforward (but on the other hand really cool).Ronda
@Ronda Fascinating! Also, for reference, I'm getting a x15 improvement even for a window size of 1000 with rolling_mean_np_jitted_search over pandas+scipy (67ms vs 1s on my laptop). Very impressive. And it gives me the freedom to do custom stuff there, like save the trimmed windows contents, and sort by two columns instead of one. Though I'll have to test to see if the performance is still good.Herlindaherm
Good to know. And yep, in an intermediate version, I also stored all windows along the way to see if one would gain a speed-up by taking the means only at the end. Wasn't the case.Ronda
@Ronda I realize this is probably pushing it, but is there a way to get similar improvements when using the decimal data type, or when using floats but converting to decimal for precise mean calculations? Looks like it's not benefiting from JIT.Herlindaherm
Good question; as in decimal.Decimal, right? I don't know if there is anything tailor-made for this purpose, but if you know the required number of decimal points, you could always achieve what you want by converting the decimal array to one of integers first (that is, convert [Decimal('1.2'), Decimal('2.5')] to [12, 25], calculate the trimmed means of that, and convert back). That will definitely be slower, but chances are that the binary search is still the bottleneck.Ronda
The answers to this question has some thoughts that could be used to get the minimal exponent. Depending on where you get your data, you may be able to solve this upstream though?Ronda
@Ronda Yup that works well. Another interesting observation is that if we replace mean() with sum() in rolling_mean_np_jitted_search, we can get a further improvement of x2-x3 for large window sizes like 100 and 1000. And then do a vectorized division on the array, which is super fast.Herlindaherm
C
7

You might try using scipy.stats.trim_mean :

from scipy.stats import trim_mean

df['value'].rolling(5).apply(lambda x: trim_mean(x, 0.2))

[output]

0          NaN
1          NaN
2          NaN
3          NaN
4    10.000000
5    11.000000
6    13.000000
7    13.333333
8    14.000000
9    15.666667

Note that I had to use rolling(5) and proportiontocut=0.2 for your toy data set.

For your real data you should use rolling(50) and trim_mean(x, 0.06) to remove the top and bottom 3 values from the rolling window.

Counterweight answered 2/9, 2018 at 9:43 Comment(6)
Is it just me, or does this not actually give the expected result? That is, with rolling(50) and trim_mean(x, 0.05), the first non-nan value isn't actually np.mean(sorted(df.value[:50])[3:47]).Ronda
Interesting! trim_mean slices off conservatively (rounding down the number of elements to slice), but it should be possible to adjust the proportiontocut to a higher value to get the needed number! I'll do some tests.Herlindaherm
@ChrisA: Yep, that looks better!Ronda
@ChrisA rolling+trim_mean works amazingly quickly compared to what I have! Is there a way to use rolling+trimboth somehow to get the trimmed content of each window too? Looks like rolling cannot return an array, but I'm hoping there's a way around it. I'll dive into fuglede's amazingly deep answer too.Herlindaherm
@AlexFriedman Sorry, I don't know of any method off the top of my head to achieve that. You should ask as a separate question on here though. Someone is bound to have a solution :)Counterweight
@ChrisA Figured I could simply rolling.apply a function that runs trimboth and saves the result somewhere, and then returns the mean (or null or whatever). This does actually work, but the performance isn't that great. Continuing with fuglede's manual method :)Herlindaherm
G
0

I bet slicing and sorting with every move of the window is the slow part. Instead of slicing every time, make a separate list of 50 (or 5) values. Sort once at the start, then when adding and removing values (moving the window) add new values in the correct place so as to preserve the sort order (much like in insertion sort algorithm). Then calculate trimmed mean based on the subset of values from that list. You will need a way to keep info on where your list is in relation to the entire set, i think a single int variable will suffice.

Gig answered 2/9, 2018 at 10:39 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.